In [22]:
import transformers
import textwrap
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import os
import sys
from typing import List
import json
from peft import (
    LoraConfig,
    get_peft_model,
    get_peft_model_state_dict,
    prepare_model_for_kbit_training
)
 
import fire
import torch
from datasets import load_from_disk, load_dataset, Dataset
import pandas as pd
import numpy as np
 
import matplotlib.pyplot as plt
import matplotlib as mpl
import seaborn as sns
from sklearn.model_selection import train_test_split
from pylab import rcParams

In [23]:
base_model_id = "mistralai/Mistral-7B-v0.1"
cache_dir = "/data/sambhav" 


tokenizer = AutoTokenizer.from_pretrained(
    base_model_id,
    padding_side="left"
    )

tokenizer.pad_token_id = (
    0  # unk. we want this to be different from the eos token
)


In [39]:
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)
model = AutoModelForCausalLM.from_pretrained(base_model_id,  quantization_config=bnb_config,torch_dtype = torch.bfloat16, cache_dir=cache_dir,device_map={'':0})

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [40]:
model.device

device(type='cuda', index=0)

In [41]:
CUTOFF_LEN = 4096
def generate_prompt(data_point):
    return f"""Below is an instruction that describes a task, paired with an input that provides further context. Write an output that appropriately completes the request. 
### Instruction:
Find the node degree of given node in the following input graph.
### Input:
{data_point["Input"]}{data_point['Instruction']} 
### Output:
{data_point["Output"]}"""

def tokenize(prompt, add_eos_token=True):
    result = tokenizer(
        prompt,
        truncation=True,
        max_length=CUTOFF_LEN,
        padding=False,
        return_tensors=None,
    )
    if (
        result["input_ids"][-1] != tokenizer.eos_token_id
        and len(result["input_ids"]) < CUTOFF_LEN
        and add_eos_token
    ):
        result["input_ids"].append(tokenizer.eos_token_id)
        result["attention_mask"].append(1)
 
    result["labels"] = result["input_ids"].copy()
 
    return result
 
def generate_and_tokenize_prompt(data_point):
    full_prompt = generate_prompt(data_point)
    tokenized_full_prompt = tokenize(full_prompt)
    return tokenized_full_prompt

In [42]:
def generate_test_prompt(data_point):
    return f"""Below is an instruction that describes a task, paired with an input that provides further context. Write an output that appropriately completes the request. 
### Instruction:
Find the node degree of given node in the following input graph.
### Input:
{data_point["Input"]}{data_point['Instruction']} 
### Output:
"""

In [43]:
config = LoraConfig(
    r=8,
    lora_alpha=16,         ####Try 8
    target_modules=[
        "q_proj",
        "k_proj",
        "v_proj",
        "o_proj",
        "gate_proj",
        "up_proj",
        "down_proj",
        "lm_head",
    ],
    bias="none",
    lora_dropout=0.05,  # Conventional
    task_type="CAUSAL_LM",
)

In [44]:
def print_trainable_parameters(model):
    """
    Prints the number of trainable parameters in the model.
    """
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(
        f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
        
    )


print_trainable_parameters(model)

trainable params: 262410240 || all params: 3752071168 || trainable%: 6.993743675173274


In [45]:
model.gradient_checkpointing_enable()
model = prepare_model_for_kbit_training(model)

# print_trainable_parameters(model)
model = get_peft_model(model, config)
print_trainable_parameters(model)

trainable params: 21260288 || all params: 3773331456 || trainable%: 0.5634354746703705


In [48]:
BATCH_SIZE = 1
MICRO_BATCH_SIZE = 1
GRADIENT_ACCUMULATION_STEPS = BATCH_SIZE // MICRO_BATCH_SIZE
LEARNING_RATE = 3e-4
# TRAIN_STEPS = len(train_data)//BATCH_SIZE   ## 1 epoch  = 28K
OUTPUT_DIR = "/data/sambhav/LLM4Graph/experiments/Mistral/"
model.config.use_cache = False
columns_to_remove = ["Output", "Instruction", "Input"]
# data=load_dataset('json', data_files=file_path)['train']
# train_data = (
#         data.map(generate_and_tokenize_prompt)
# )
# train_data = train_data.remove_columns(columns_to_remove)

In [49]:
for epoch in range(1):
    dataset_index = epoch%78
    print(f"Epoch number: {(epoch//78) +1},   Dataset number: {dataset_index}")
    file_path = f"/mnt/data/shared/sambhav/node_degree_partial_dataset_{dataset_index}.json"
    data=load_dataset('json', data_files=file_path)['train']

    
    
    train_data = (
        data.map(generate_and_tokenize_prompt)
    )
    train_data = train_data.remove_columns(columns_to_remove)
    
    training_arguments = transformers.TrainingArguments(
        per_device_train_batch_size=MICRO_BATCH_SIZE,
        # per_device_eval_batch_size=1,
        gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
        warmup_steps=100,
        # max_steps=TRAIN_STEPS,
        learning_rate=LEARNING_RATE,
        fp16=True,
        # logging_steps=10000,
        optim="adamw_torch",
        num_train_epochs=1,
        # evaluation_strategy="epoch",
        save_strategy="epoch",
        # eval_steps=20000,
        # save_steps=20000,
        output_dir=OUTPUT_DIR+f"step_{epoch}",
        # load_best_model_at_end=True,
        report_to="none"
    )

    data_collator = transformers.DataCollatorForSeq2Seq(
        tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
    )
    trainer = transformers.Trainer(
        model=model,
        train_dataset=train_data,
        args=training_arguments,
        data_collator=data_collator
    )
    trainer.train()

Epoch number: 1,   Dataset number: 0


OutOfMemoryError: CUDA out of memory. Tried to allocate 28.00 MiB (GPU 2; 79.17 GiB total capacity; 624.84 MiB already allocated; 20.88 MiB free; 626.00 MiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [50]:
torch.cuda.current_device()

0

In [21]:
# trainer = transformers.Trainer(
#     model=model,
#     train_dataset=train_data,
#     args=training_arguments,
#     data_collator=data_collator
# )

# # old_state_dict = model.state_dict
# # model.state_dict = (
# #     lambda self, *_, **__: get_peft_model_state_dict(
# #         self, old_state_dict()
# #     )
# # ).__get__(model, type(model))
 
# # model = torch.compile(model)
 
# trainer.train()
# # model.save_pretrained(OUTPUT_DIR)

1

In [18]:
# # prompt=generate_test_prompt(train_data[0])
# prompt="""Below is an instruction that describes a task, paired with an input that provides further context. Write an output that appropriately completes the request. 
# ### Instruction:
# Find the node degree of given node in the following input graph.
# ### Input:
# An undirected graph has 45 nodes - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43 and 44.
# Node 0 is connected to nodes 2, 5, 6, 7, 16, 19, 21, 23, 27, 28, 29, 34, 35, 40 and 42.
# Node 1 is connected to nodes 3, 5, 6, 7, 9, 10, 12, 13, 14, 15, 16, 18, 19, 20, 21, 22, 24, 25, 26, 27, 28, 29, 31, 33, 34, 35, 36, 37, 38, 40 and 44.
# Node 2 is connected to nodes 0, 4, 6, 7, 8, 9, 12, 14, 15, 20, 22, 23, 24, 25, 26, 28, 29, 30, 40 and 41.
# Node 3 is connected to nodes 1, 5, 6, 9, 10, 14, 15, 16, 18, 19, 21, 22, 23, 24, 25, 28, 30, 34, 38 and 43.
# Node 4 is connected to nodes 2, 6, 7, 8, 9, 10, 12, 13, 17, 18, 19, 21, 22, 23, 24, 25, 29, 31, 32, 33, 36, 37 and 42.
# Node 5 is connected to nodes 0, 1, 3, 6, 7, 9, 10, 11, 12, 15, 17, 19, 20, 21, 22, 25, 26, 27, 28, 30, 31, 33, 35, 36, 38, 39, 40, 41, 42, 43 and 44.
# Node 6 is connected to nodes 0, 1, 2, 3, 4, 5, 8, 9, 10, 13, 14, 15, 16, 17, 19, 20, 21, 22, 23, 24, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 42, 43 and 44.
# Node 7 is connected to nodes 0, 1, 2, 4, 5, 8, 10, 15, 16, 20, 21, 23, 25, 26, 28, 29, 30, 31, 32, 33, 35, 36, 37, 38, 39, 41 and 43.
# Node 8 is connected to nodes 2, 4, 6, 7, 14, 15, 16, 20, 22, 24, 32, 33, 35, 38 and 39.
# Node 9 is connected to nodes 1, 2, 3, 4, 5, 6, 14, 15, 18, 22, 26, 35, 38, 40 and 44.
# Node 10 is connected to nodes 1, 3, 4, 5, 6, 7, 14, 15, 16, 19, 20, 21, 22, 23, 24, 30, 31, 33, 36, 37, 38, 42 and 43.
# Node 11 is connected to nodes 5, 15, 16, 19, 21, 22, 23, 25, 28, 31, 32, 33, 37, 38 and 44.
# Node 12 is connected to nodes 1, 2, 4, 5, 16, 17, 18, 20, 21, 25, 31, 33, 34, 38, 41 and 42.
# Node 13 is connected to nodes 1, 4, 6, 15, 18, 21, 23, 24, 25, 29, 38, 41, 42 and 44.
# Node 14 is connected to nodes 1, 2, 3, 6, 8, 9, 10, 15, 16, 21, 22, 23, 25, 28, 29, 31, 32, 33, 35, 36, 37, 38, 39 and 44.
# Node 15 is connected to nodes 1, 2, 3, 5, 6, 7, 8, 9, 10, 11, 13, 14, 16, 19, 21, 24, 28, 29, 31, 32, 33, 35, 37, 38, 39, 40 and 43.
# Node 16 is connected to nodes 0, 1, 3, 6, 7, 8, 10, 11, 12, 14, 15, 17, 18, 19, 20, 21, 22, 23, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 38, 40 and 44.
# Node 17 is connected to nodes 4, 5, 6, 12, 16, 18, 22, 23, 24, 25, 26, 27, 28, 29, 31, 36 and 42.
# Node 18 is connected to nodes 1, 3, 4, 9, 12, 13, 16, 17, 22, 23, 24, 25, 28, 31, 33, 41 and 43.
# Node 19 is connected to nodes 0, 1, 3, 4, 5, 6, 10, 11, 15, 16, 20, 21, 22, 24, 27, 30, 34, 35, 36, 38 and 42.
# Node 20 is connected to nodes 1, 2, 5, 6, 7, 8, 10, 12, 16, 19, 21, 22, 23, 27, 28, 30, 31, 34, 40, 42 and 43.
# Node 21 is connected to nodes 0, 1, 3, 4, 5, 6, 7, 10, 11, 12, 13, 14, 15, 16, 19, 20, 24, 25, 26, 28, 29, 31, 34, 36, 38 and 44.
# Node 22 is connected to nodes 1, 2, 3, 4, 5, 6, 8, 9, 10, 11, 14, 16, 17, 18, 19, 20, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 35, 36, 37, 38, 39, 40, 41, 42, 43 and 44.
# Node 23 is connected to nodes 0, 2, 3, 4, 6, 7, 10, 11, 13, 14, 16, 17, 18, 20, 22, 24, 25, 26, 27, 29, 31, 32, 33, 34, 35, 36, 37, 38, 39 and 43.
# Node 24 is connected to nodes 1, 2, 3, 4, 6, 8, 10, 13, 15, 17, 18, 19, 21, 22, 23, 25, 26, 28, 30, 31, 32, 33, 35, 36, 38, 40 and 42.
# Node 25 is connected to nodes 1, 2, 3, 4, 5, 7, 11, 12, 13, 14, 17, 18, 21, 22, 23, 24, 28, 29, 30, 33, 35, 36, 38, 39, 41, 42, 43 and 44.
# Node 26 is connected to nodes 1, 2, 5, 7, 9, 16, 17, 21, 22, 23, 24, 27, 28, 29, 31, 32, 33, 35, 36, 41, 42 and 44.
# Node 27 is connected to nodes 0, 1, 5, 6, 16, 17, 19, 20, 22, 23, 26, 28, 29, 31, 35, 38 and 42.
# Node 28 is connected to nodes 0, 1, 2, 3, 5, 6, 7, 11, 14, 15, 16, 17, 18, 20, 21, 22, 24, 25, 26, 27, 29, 31, 33, 34, 35, 36, 37, 38, 39, 40, 41 and 44.
# Node 29 is connected to nodes 0, 1, 2, 4, 6, 7, 13, 14, 15, 16, 17, 21, 22, 23, 25, 26, 27, 28, 30, 31, 33, 36, 38, 39 and 41.
# Node 30 is connected to nodes 2, 3, 5, 6, 7, 10, 16, 19, 20, 22, 24, 25, 29, 33, 36, 38, 42 and 44.
# Node 31 is connected to nodes 1, 4, 5, 6, 7, 10, 11, 12, 14, 15, 16, 17, 18, 20, 21, 22, 23, 24, 26, 27, 28, 29, 32, 33, 37, 38, 40 and 44.
# Node 32 is connected to nodes 4, 6, 7, 8, 11, 14, 15, 16, 22, 23, 24, 26, 31, 36, 39 and 41.
# Node 33 is connected to nodes 1, 4, 5, 6, 7, 8, 10, 11, 12, 14, 15, 16, 18, 22, 23, 24, 25, 26, 28, 29, 30, 31, 34, 36, 37, 38, 40 and 43.
# Node 34 is connected to nodes 0, 1, 3, 6, 12, 16, 19, 20, 21, 23, 28, 33, 36 and 38.
# Node 35 is connected to nodes 0, 1, 5, 6, 7, 8, 9, 14, 15, 16, 19, 22, 23, 24, 25, 26, 27, 28, 36, 38, 41, 42, 43 and 44.
# Node 36 is connected to nodes 1, 4, 5, 6, 7, 10, 14, 16, 17, 19, 21, 22, 23, 24, 25, 26, 28, 29, 30, 32, 33, 34, 35, 38, 39, 40, 41 and 44.
# Node 37 is connected to nodes 1, 4, 6, 7, 10, 11, 14, 15, 22, 23, 28, 31, 33 and 44.
# Node 38 is connected to nodes 1, 3, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 19, 21, 22, 23, 24, 25, 27, 28, 29, 30, 31, 33, 34, 35, 36, 41, 42, 43 and 44.
# Node 39 is connected to nodes 5, 6, 7, 8, 14, 15, 22, 23, 25, 28, 29, 32, 36 and 44.
# Node 40 is connected to nodes 0, 1, 2, 5, 9, 15, 16, 20, 22, 24, 28, 31, 33 and 36.
# Node 41 is connected to nodes 2, 5, 7, 12, 13, 18, 22, 25, 26, 28, 29, 32, 35, 36, 38, 42 and 44.
# Node 42 is connected to nodes 0, 4, 5, 6, 10, 12, 13, 17, 19, 20, 22, 24, 25, 26, 27, 30, 35, 38 and 41.
# Node 43 is connected to nodes 3, 5, 6, 7, 10, 15, 18, 20, 22, 23, 25, 33, 35, 38 and 44.
# Node 44 is connected to nodes 1, 5, 6, 9, 11, 13, 14, 16, 21, 22, 25, 26, 28, 30, 31, 35, 36, 37, 38, 39, 41 and 43.
# What is the degree of Node 20? 
# ### Output:"""
# print(prompt)

In [19]:
# val_input=tokenizer(prompt, return_tensors = "pt")

# device="cuda:0"
# val_input.to(device)
# ft_model.eval()
# with torch.no_grad():
#     generated_ids = ft_model.generate(**val_input, max_new_tokens=10, pad_token_id=tokenizer.eos_token_id)
# decoded = tokenizer.batch_decode(generated_ids)
# print(decoded[0])

In [20]:
# checkpoint_dir = "/data/sambhav/LLM4Graph/experiments/Mistral/checkpoint-100/"
# from peft import PeftModel
# bnb_config = BitsAndBytesConfig(
#     load_in_4bit=True,
#     bnb_4bit_use_double_quant=True,
#     bnb_4bit_quant_type="nf4",
#     bnb_4bit_compute_dtype=torch.bfloat16
# )
# base_model = AutoModelForCausalLM.from_pretrained(base_model_id,  quantization_config=bnb_config,torch_dtype = torch.bfloat16, cache_dir=cache_dir,device_map={'':0})
# print_trainable_parameters(base_model)
# ft_model = PeftModel.from_pretrained(base_model,checkpoint_dir)
# print_trainable_parameters(ft_model)
