In [1]:
import torch
import torch.nn as nn
import pandas as pd

from peft import PeftModel
from transformers import (AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments, DataCollatorForLanguageModeling, 
                          EarlyStoppingCallback, get_linear_schedule_with_warmup, AdamW, DataCollatorForSeq2Seq)
from transformers import GPT2Config, GPT2LMHeadModel, LlamaConfig, LlamaForCausalLM
from datasets import load_dataset, DatasetDict, Dataset
from peft import get_peft_model, LoraConfig, TaskType
from accelerate import Accelerator
import evaluate
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader
from tqdm import tqdm

import os
# os.environ['NCCL_DEBUG'] = 'INFO'
# import os
# os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
os.environ["TOKENIZERS_PARALLELISM"] = "true"
# os.environ['NCCL_P2P_DISABLE'] = '1'
# os.environ['ddp_backend'] = 'gloo'
# os.environ['NCCL_SOCKET_IFNAME'] = '^lo,docker,virbr,vmnet,vboxnet,wl,ww,ppp'
# os.environ["CUDA_VISIBLE_DEVICES"]="0"

  from .autonotebook import tqdm as notebook_tqdm


In [6]:
a = 'b/c'
d = a.replace('/', '_')
print(d)
print(a)

b_c
b/c


In [2]:
import yaml
with open("./configs/plugin_config.yaml", "r") as file:
    config = yaml.safe_load(file)

In [4]:
config['plugin_model']

{'llama': {'name': None,
  'num_hidden_layers': 1,
  'vocab_size': 128256,
  'hidden_size': 512,
  'num_attention_heads': 4,
  'intermediate_size': 2048},
 'gpt2': {'name': None, 'n_layer': 1}}

In [2]:
model_type = 'llama' # gpt2
if(model_type == 'llama'):
    access_token = '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'
    base_model_name = "meta-llama/Llama-3.1-8B"
    # base_model_name = 'meta-llama/Llama-2-7b'

else:
    access_token = None
    base_model_name = 'gpt2-medium'

In [3]:
PROC_NAME = f'plugin_over_base_{model_type}'
FT_MODEL_NAME = 'ft_model_no_token_without_lora_train_hyperval_concatenated'
DATASET = 'e2e_nlg_cleaned'
DATASET_VERSION = None # 'webnlg_challenge_2017' # None

In [4]:
INPUT_SIZE = 64
TARGET_SIZE = 128

In [5]:
# Load the E2E NLG dataset from Hugging Face datasets library
if(DATASET_VERSION):
    dataset = load_dataset(DATASET, DATASET_VERSION)
else:
    dataset = load_dataset(DATASET)

In [6]:
def process_web_nlg(dataset, train_categories = ['Airport', 'Building', 'University', 'Monument', 'MeanOfTransportation'], 
                    test_categories = ['Artist', 'Politician', 'Athlete', 'ComicsCharacter', 'Astronaut', 'SportsTeam' ]):
        # Step 1: Filter the 'train', 'dev', and 'test' datasets by category
    def filter_by_category(example, categories):
        return example['category'] in categories

    # Filter the 'train' set to only include the "Airport" category
    train_dataset = dataset['train'].filter(lambda x: filter_by_category(x, train_categories))

    # Filter 'dev' and 'test' sets to only include the "Food" category
    dev_dataset = dataset['dev'].filter(lambda x: filter_by_category(x, test_categories))
    test_dataset = dataset['test'].filter(lambda x: filter_by_category(x, test_categories))

    combined_dev_test = concatenate_datasets([dev_dataset, test_dataset])
    combined_dev_test = combined_dev_test.train_test_split(test_size=0.5, seed=42)
    # Rename the splits for clarity
    combined_dev_test = DatasetDict({
        'validation': combined_dev_test['train'],  # Rename 'train' split as 'validation'
        'test': combined_dev_test['test']  # Keep 'test' split as 'test'
    })

    # Step 3: Select only one reference sentence from 'text' field based on 'comment'
    def select_good_comment(example):
        for i, comment in enumerate(example['lex']['comment']):
            if comment == 'good':
                return {'human_reference': example['lex']['text'][i]}  # Pick the sentence marked 'good'
        return {'human_reference': ''}  # Default to the first sentence if none are marked 'good'

    # Apply the function to each dataset
    train_dataset = train_dataset.map(select_good_comment)
    combined_dev_test['validation'] = combined_dev_test['validation'].map(select_good_comment)
    combined_dev_test['test'] = combined_dev_test['test'].map(select_good_comment)

    # Step 4: Join 'mtriple_set' list into a string separated by ';'
    def join_mtriple_set(example):
        return {'meaning_representation': ' ; '.join(example['modified_triple_sets']['mtriple_set'][0])}

    # Apply the function to each dataset
    train_dataset = train_dataset.map(join_mtriple_set)
    combined_dev_test['validation'] = combined_dev_test['validation'].map(join_mtriple_set)
    combined_dev_test['test'] = combined_dev_test['test'].map(join_mtriple_set)

    # Step 5: Retain only 'meaning_representation' and 'human_reference' fields
    train_dataset = train_dataset.remove_columns(
        [col for col in train_dataset.column_names if col not in ['meaning_representation', 'human_reference']])
    combined_dev_test['validation'] = combined_dev_test['validation'].remove_columns(
        [col for col in combined_dev_test['validation'].column_names if col not in ['meaning_representation', 'human_reference']])
    combined_dev_test['test'] = combined_dev_test['test'].remove_columns(
        [col for col in combined_dev_test['test'].column_names if col not in ['meaning_representation', 'human_reference']])

    # Step 6 removing empty human reference
    # Define a function to filter out rows where 'human_reference' is empty or None
    def filter_empty_human_reference(example):
        return example['human_reference'] is not None and example['human_reference'].strip() != ''

    def filter_empty_meaning_representation(example):
        return example['meaning_representation'] is not None and example['meaning_representation'].strip() != ''

    train_dataset = train_dataset.filter(filter_empty_human_reference)
    combined_dev_test['validation'] = combined_dev_test['validation'].filter(filter_empty_human_reference)
    combined_dev_test['test'] = combined_dev_test['test'].filter(filter_empty_human_reference)

    train_dataset = train_dataset.filter(filter_empty_meaning_representation)
    combined_dev_test['validation'] = combined_dev_test['validation'].filter(filter_empty_meaning_representation)
    combined_dev_test['test'] = combined_dev_test['test'].filter(filter_empty_meaning_representation)
    
    dataset = DatasetDict({
    'train': train_dataset,
    'validation': combined_dev_test['validation'],
    'test': combined_dev_test['test']
    })
    return dataset

def process_e2e_nlg_cleaned(dataset):
    return dataset

def process_common_gen(dataset):
    
    # Define a function to filter rows
    def filter_target_and_concepts(example):
        # Check if 'man' exists as a whole word in 'target'
        target_contains_man = bool(re.search(r'\bman\b', example['target']))

        # Ensure 'man' and 'woman' are not in 'concepts'
        concepts_does_not_contain_man_or_woman = all(c not in ['man', 'woman'] for c in example['concepts'])

        # Return True if the conditions are met
        return target_contains_man and concepts_does_not_contain_man_or_woman
    
    dataset['train'] = dataset['train'].filter(filter_target_and_concepts)

    def join_concepts(example):
        return {'meaning_representation': ' ; '.join(example['concepts'])}

    # Apply the function to each dataset
    dataset['train'] = dataset['train'].map(join_concepts)
    dataset['validation'] = dataset['validation'].map(join_concepts)
    dataset['test'] = dataset['test'].map(join_concepts)
    
    def ret_human_reference(example):
        return {'human_reference': example['target']}
    
    dataset['train'] = dataset['train'].map(ret_human_reference)
    dataset['validation'] = dataset['validation'].map(ret_human_reference)
    dataset['test'] = dataset['test'].map(ret_human_reference)
    
    dataset['train'] = dataset['train'].remove_columns(
        [col for col in dataset['train'].column_names if col not in ['meaning_representation', 'human_reference']])
    dataset['validation'] = dataset['validation'].remove_columns(
        [col for col in dataset['validation'].column_names if col not in ['meaning_representation', 'human_reference']])
    dataset['test'] = dataset['test'].remove_columns(
        [col for col in dataset['test'].column_names if col not in ['meaning_representation', 'human_reference']])
    
    return dataset

In [7]:
if(DATASET == 'e2e_nlg_cleaned'):
    dataset = process_e2e_nlg_cleaned(dataset)
elif(DATASET == 'web_nlg'):
    dataset = process_web_nlg(dataset)
elif(DATASET == 'common_gen'):
    dataset = process_common_gen(dataset)

In [8]:
def add_input_prompt(examples):
    inp = examples['meaning_representation']
    # for gpt2
    if(model_type == 'gpt2'):
        prefix_str = 'Given the following aspects of a restaurant, "'
        suffix_str = '", a natural language sentence describing the restuarant is: '
    # for llama
    elif(model_type == 'llama'):
        prefix_str = 'Question: Given the following attributes of a restaurant, "'
        suffix_str = '", how would you describe the restaurant based on the attributes? Just provide the description with no explanation.\nAnswer: '
    new_input = prefix_str + inp + suffix_str
    return {'meaning_representation': new_input}

In [9]:
if(DATASET == 'e2e_nlg_cleaned'):
    dataset['train'] = dataset['train'].map(add_input_prompt)
    dataset['validation'] = dataset['validation'].map(add_input_prompt)
    dataset['test'] = dataset['test'].map(add_input_prompt)

In [10]:
dataset['train']

Dataset({
    features: ['meaning_representation', 'human_reference'],
    num_rows: 33525
})

In [11]:
base_model_name

'meta-llama/Llama-3.1-8B'

In [12]:
# tokenizer = AutoTokenizer.from_pretrained(f'./models/{FT_MODEL_NAME}')
tokenizer = AutoTokenizer.from_pretrained(base_model_name, token = access_token)

In [13]:
tokenizer.pad_token = tokenizer.eos_token
# tokenizer.padding_side = 'right'
# CHANGE THE PADDIGN TO LEFT FOR PREVIOUS BATCH FORM OF RESULTS
tokenizer.padding_side = 'left'


In [14]:
# tokenizer.add_special_tokens(tokenizer.special_tokens_map)

In [15]:
len(tokenizer.vocab), tokenizer.vocab_size

(128256, 128000)

In [16]:
# Load the fine-tuned model and tokenizer
# base_model_name = 'gpt2-medium'
# base_model_name = llama_model_id
if(model_type == 'gpt2'):
    config = GPT2Config(
            n_layer=1,         # Fewer transformer layers (default 12)
        )
    base_model = GPT2LMHeadModel(config=config)
else:
    # config = LlamaConfig(
    #         num_hidden_layers=1,         # Fewer transformer layers (default 12)
    #         vocab_size=len(tokenizer.vocab),
    #     )
    # base_model = LlamaForCausalLM(config=config)
    base_model = AutoModelForCausalLM.from_pretrained(base_model_name, token = access_token, torch_dtype=torch.float16, 
                                                    #   device_map="auto", 
                                                      )
    base_model.to('cuda:0')
# base_model.resize_token_embeddings(len(tokenizer))
# tokenizer = AutoTokenizer.from_pretrained(base_model_name)
# tokenizer.pad_token = tokenizer.eos_token
# tokenizer.padding_side = 'left'

Loading checkpoint shards: 100%|██████████| 4/4 [00:01<00:00,  2.57it/s]


In [17]:
# base_model.resize_token_embeddings(len(tokenizer))


In [18]:
# base_model.config.pad_token_id = base_model.config.eos_token_id

In [19]:
print(tokenizer.eos_token, tokenizer.pad_token, tokenizer.bos_token)
print(tokenizer.eos_token_id, tokenizer.pad_token_id, tokenizer.bos_token_id)
print(base_model.config.eos_token_id, base_model.config.pad_token_id, base_model.config.bos_token_id)

<|end_of_text|> <|end_of_text|> <|begin_of_text|>
128001 128001 128000
128001 None 128000


In [20]:
# model_ft = AutoModelForCausalLM.from_pretrained("../../models/peft_gpt2-medium_e2e")  # Path to your fine-tuned model
# Create the GPT-2 model with the smaller configuration
if('plugin_over_ft' in PROC_NAME):
# if(PROC_NAME == 'ft_model_pad_token'):
    # model_ft = PeftModel.from_pretrained(base_model, f'./models/{FT_MODEL_NAME}')
    model_ft = AutoModelForCausalLM.from_pretrained(f'./models/{FT_MODEL_NAME}')
elif('plugin_over_base' in PROC_NAME):
    model_ft = base_model
for param in model_ft.parameters():
    param.requires_grad = False
# model_ft = AutoModelForCausalLM.from_pretrained('./models/ft_model')

In [21]:
for name, param in model_ft.named_parameters():
    if('wte' in name or 'lm_head' in name):
        print(name, param.requires_grad)

lm_head.weight False


In [22]:
model_ft

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
      )
    )
    (n

In [23]:
# # Preprocess the dataset to include the meaning representation (MR) as input and human reference as target
# def preprocess_function(examples):
#     # Concatenate MR and human reference with a separator
#     inputs = [f"{mr}" for mr in examples["meaning_representation"]]
#     targets = [f"{ref}" for ref in examples["human_reference"]]
#     model_inputs = tokenizer(inputs, max_length=INPUT_SIZE, truncation=True, padding="max_length", return_tensors="pt")
#     labels = tokenizer(targets, max_length=TARGET_SIZE, truncation=True, padding="max_length", return_tensors="pt")
    
#     # Replace padding token id's of the labels by -100 so that it's ignored by the loss
#     labels["input_ids"] = [
#         [(label if label != tokenizer.pad_token_id else -100) for label in labels_seq] 
#         for labels_seq in labels["input_ids"]
#     ]
#     model_inputs["labels"] = labels["input_ids"]
#     return model_inputs

In [24]:
# Preprocess the dataset to include the meaning representation (MR) as input and human reference as target
def preprocess_function(examples):
    # Left-pad the meaning_representation and right-pad the human reference
    inputs = [f"{mr}" for mr in examples["meaning_representation"]]
    targets = [f"{ref}" for ref in examples["human_reference"]]
    
    # Right-pad the human references (default padding="max_length")
    labels = tokenizer(targets, max_length=TARGET_SIZE, truncation=True, padding="max_length", return_tensors="pt")
    
    # Left-pad the meaning_representation manually
    # We manually pad to the left by prepending padding tokens
    max_input_length = INPUT_SIZE
    tokenized_inputs = [tokenizer(mr, truncation=True)["input_ids"] for mr in inputs]
    left_padded_inputs = [
        [tokenizer.pad_token_id] * (max_input_length - len(input_seq)) + input_seq if len(input_seq) < max_input_length else input_seq[:max_input_length] 
        for input_seq in tokenized_inputs
    ]
    
    # Convert left-padded inputs to tensors and include attention mask
    model_inputs = {
        "input_ids": torch.tensor(left_padded_inputs),
        "attention_mask": torch.tensor([[0] * (max_input_length - len(input_seq)) + [1] * len(input_seq) for input_seq in tokenized_inputs])
    }

    # Replace padding token id's of the labels by -100 so that it's ignored by the loss
    labels["input_ids"] = [
        [(label if label != tokenizer.pad_token_id else -100) for label in labels_seq] 
        for labels_seq in labels["input_ids"]
    ]
    
    # Include the right-padded labels in the model inputs
    model_inputs["labels"] = labels["input_ids"]
    
    return model_inputs


In [25]:
def preprocess_function_batch(examples):
    # Extract the meaning representations (MR) and human references (target text) from the examples
    inputs = examples["meaning_representation"]
    targets = examples["human_reference"]
    
    # Tokenize the inputs (meaning representations)
    tokenized_inputs = tokenizer(
        inputs, 
        max_length=INPUT_SIZE, 
        truncation=True, 
        padding="max_length", 
        # return_tensors="pt"  # Use numpy for batch processing
    )
    
    # Tokenize the targets (human references)
    tokenized_targets = tokenizer(
        targets, 
        max_length=TARGET_SIZE, 
        truncation=True, 
        padding="max_length", 
        # return_tensors="pt"  # Use numpy for batch processing
    )
    
    # Concatenate input_ids (MR) and input_ids from the targets (human reference) into one sequence
    # This creates the full sequence: [MR, target] (all tokenized)
    concatenated_input_ids = [
        list(input_seq) + list(target_seq) for input_seq, target_seq in zip(tokenized_inputs["input_ids"], tokenized_targets["input_ids"])
    ]
    
    # Concatenate attention masks for both MR and target
    concatenated_attention_mask = [
        list(input_mask) + list(target_mask) for input_mask, target_mask in zip(tokenized_inputs["attention_mask"], tokenized_targets["attention_mask"])
    ]
    
    # Prepare the labels for loss computation:
    # We need to ignore the loss for the part corresponding to MR and only compute it for the target (human reference).
    
    labels = []
    for input_len, target_seq in zip([INPUT_SIZE] * len(inputs), tokenized_targets["input_ids"]):
        # Ignore loss for MR part by setting it to -100
        labels_seq = [-100] * input_len
        
        # For the target sequence, we keep the tokens, but set padding tokens to -100
        labels_seq += [token if token != tokenizer.pad_token_id else -100 for token in target_seq]
        
        labels.append(labels_seq)
    
    # Return the final dictionary containing input_ids, attention_mask, and labels
    return {
        "input_ids": torch.tensor(concatenated_input_ids),
        "attention_mask": torch.tensor(concatenated_attention_mask),
        "labels": torch.tensor(labels)
    }


In [26]:
def preprocess_function_preconcat(examples):
    # Extract the meaning representations (MR) and human references (target text) from the examples
    inputs = examples["meaning_representation"]
    targets = examples["human_reference"]
    
    sentences = []
    for inp, tar in zip(inputs, targets):
        sentences.append(inp + tar)
    
    # Tokenize the targets (human references)
    tokenized_targets = tokenizer(
        targets, 
        # add_special_tokens=False,
    )
    
    # Tokenize the inputs (meaning representations)
    tokenized_sentences = tokenizer(
        sentences, 
        max_length=INPUT_SIZE + TARGET_SIZE, 
        truncation=True, 
        padding="max_length", 
        # add_special_tokens=False,
        # return_tensors="pt"  # Use numpy for batch processing
    )
        
    labels = []
    for comb_seq, target_seq in zip(tokenized_sentences['input_ids'], tokenized_targets['input_ids']):
        label_seq = [-100]*len(comb_seq)
        label_seq[-len(target_seq):] = target_seq
        labels.append(label_seq)
        
    
    # Return the final dictionary containing input_ids, attention_mask, and labels
    return {
        "input_ids": torch.tensor(tokenized_sentences['input_ids']),
        "attention_mask": torch.tensor(tokenized_sentences['attention_mask']),
        "labels": torch.tensor(labels)
    }


In [27]:
# # Tokenize the dataset
# def tokenize_function(examples):
#     return tokenizer(examples["meaning_representation"], text_target=examples["human_reference"], padding="max_length", truncation=True, max_length=TARGET_SIZE)

In [28]:
# Apply tokenization to the dataset
# tokenized_dataset = dataset.map(tokenize_function, batched=True)
tokenized_dataset = dataset.map(preprocess_function_preconcat, batched=True, 
                                remove_columns=["meaning_representation", "human_reference"]
                               )
# tokenized_dataset = dataset.map(preprocess_function, batched=True, 
#                                 # remove_columns=["meaning_representation", "human_reference"]
#                                )

In [29]:
# tokenized_dataset['validation'], tokenized_dataset['hypervalidation'] = tokenized_dataset['validation'].train_test_split(
#     test_size=0.2, seed=42).values()

In [30]:
# # Convert the tokenized dataset to a pandas DataFrame for easier manipulation
# df = pd.DataFrame(tokenized_dataset)

# # Step 1: Identify unique MRs and group by MR
# grouped_by_mr = df.groupby('meaning_representation')

# # Step 2: Extract all unique MRs
# unique_mrs = df['meaning_representation'].unique()

# # Step 3: Perform train-test split on the unique MRs
# train_mrs, hypervalidation_mrs = train_test_split(unique_mrs, test_size=0.5, random_state=42)

# # Step 4: Create new DataFrames for train and hypervalidation based on the split MRs
# train_df = df[df['meaning_representation'].isin(train_mrs)]
# hypervalidation_df = df[df['meaning_representation'].isin(hypervalidation_mrs)]

# # Step 5: Convert back to the Dataset format for Hugging Face
# train_dataset = DatasetDict({"train": Dataset.from_pandas(train_df)})
# hypervalidation_dataset = DatasetDict({"hypervalidation": Dataset.from_pandas(hypervalidation_df)})

# tokenized_dataset = {}
# # Update tokenized_dataset with the new split
# tokenized_dataset['train'] = train_dataset['train'].remove_columns(["meaning_representation", "human_reference", "__index_level_0__"])

# tokenized_dataset['hypervalidation'] = hypervalidation_dataset['hypervalidation'].remove_columns(["meaning_representation", "human_reference", "__index_level_0__"])

In [31]:
# data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

In [32]:
def get_position_ids(input_ids, attention_mask):
    """
    Generate position IDs for left-padded sequences.
    
    Args:
        input_ids (torch.Tensor): Tensor of input token IDs (shape: [batch_size, seq_len]).
        attention_mask (torch.Tensor): Tensor of attention mask (shape: [batch_size, seq_len]).
        
    Returns:
        torch.Tensor: Position IDs (shape: [batch_size, seq_len]).
    """
    # Get the lengths of the non-padded tokens (i.e., count of '1's in the attention mask)
    seq_lengths = attention_mask.sum(dim=-1)

    # Create a tensor with position IDs starting from 0 for each non-padded token
    position_ids = torch.arange(input_ids.size(1), dtype=torch.long).unsqueeze(0).repeat(input_ids.size(0), 1).to(input_ids.device)

    # Adjust position IDs for each sequence to start from 0 after padding
    position_ids = position_ids - (input_ids.size(1) - seq_lengths).unsqueeze(-1)

    # Set position IDs for padding tokens to 0 (optional: you can use another value if needed)
    position_ids = torch.where(attention_mask == 1, position_ids, torch.zeros_like(position_ids))
    return position_ids.long()

# def get_position_ids(input_ids, attention_mask):
#     position_ids = torch.arange(input_ids.size(1)).expand_as(input_ids).to(input_ids.device)
#     position_ids = position_ids * attention_mask
#     return position_ids.long()

In [33]:
def get_position_ids_batch(input_ids, attention_mask, split_point=64):
    """
    Generate position IDs for left-padded sequences, where input_ids contains both the 
    meaning representation (MR) and human reference (target), concatenated together.
    
    Args:
        input_ids (torch.Tensor): Tensor of input token IDs (shape: [batch_size, seq_len]).
        attention_mask (torch.Tensor): Tensor of attention mask (shape: [batch_size, seq_len]).
        split_point (int): The index at which the meaning representation ends and the human reference begins.
        
    Returns:
        torch.Tensor: Position IDs (shape: [batch_size, seq_len]).
    """
    # Step 1: Process the meaning representation (MR) part of the input_ids
    # Create a tensor with position IDs starting from 0 for each non-padded token in the MR
    mr_attention_mask = attention_mask[:, :split_point]
    mr_position_ids = torch.arange(split_point, dtype=torch.long).unsqueeze(0).repeat(input_ids.size(0), 1).to(input_ids.device)
    
    # Get the lengths of the non-padded tokens in the MR
    mr_seq_lengths = mr_attention_mask.sum(dim=-1)

    # Adjust position IDs for MR to start from 0 after padding
    mr_position_ids = mr_position_ids - (split_point - mr_seq_lengths).unsqueeze(-1)
    
    # Set position IDs for padding tokens in MR to 0
    mr_position_ids = torch.where(mr_attention_mask == 1, mr_position_ids, torch.zeros_like(mr_position_ids))

    # Step 2: Process the human reference (target) part of the input_ids
    target_attention_mask = attention_mask[:, split_point:]
    target_seq_len = input_ids.size(1) - split_point
    target_position_ids = torch.arange(target_seq_len, dtype=torch.long).unsqueeze(0).repeat(input_ids.size(0), 1).to(input_ids.device)
    
    # Get the lengths of the non-padded tokens in the human reference
    target_seq_lengths = target_attention_mask.sum(dim=-1)

    # Adjust position IDs for human reference to start from 0 after padding
    target_position_ids = target_position_ids - (target_seq_len - target_seq_lengths).unsqueeze(-1)
    
    # print(target_position_ids.size())
    # print(mr_seq_lengths.size())
    # print(target_position_ids)
    # print(mr_seq_lengths)
    target_position_ids += mr_seq_lengths.unsqueeze(1).repeat(1, target_position_ids.size()[1])
    
    # Set position IDs for padding tokens in human reference to 0
    target_position_ids = torch.where(target_attention_mask == 1, target_position_ids, torch.zeros_like(target_position_ids))

    # Step 3: Concatenate the position IDs for MR and human reference
    position_ids = torch.cat([mr_position_ids, target_position_ids], dim=-1)
    
    return position_ids.long()


In [34]:
input_ids = torch.tensor(tokenized_dataset['train'][0]['input_ids']).reshape(1, -1)
attention_mask = torch.tensor(tokenized_dataset['train'][0]['attention_mask']).reshape(1, -1)
position_ids = get_position_ids(input_ids, attention_mask)
# position_ids = position_ids * attention_mask
print(input_ids)
print(attention_mask)
print(sum(attention_mask[0]))
print(position_ids)


tensor([[128001, 128001, 128001, 128001, 128001, 128001, 128001, 128001, 128001,
         128001, 128001, 128001, 128001, 128001, 128001, 128001, 128001, 128001,
         128001, 128001, 128001, 128001, 128001, 128001, 128001, 128001, 128001,
         128001, 128001, 128001, 128001, 128001, 128001, 128001, 128001, 128001,
         128001, 128001, 128001, 128001, 128001, 128001, 128001, 128001, 128001,
         128001, 128001, 128001, 128001, 128001, 128001, 128001, 128001, 128001,
         128001, 128001, 128001, 128001, 128001, 128001, 128001, 128001, 128001,
         128001, 128001, 128001, 128001, 128001, 128001, 128001, 128001, 128001,
         128001, 128001, 128001, 128001, 128001, 128001, 128001, 128001, 128001,
         128001, 128001, 128001, 128001, 128001, 128001, 128001, 128000,  14924,
             25,  16644,    279,   2768,   8365,    315,    264,  10960,     11,
            330,    609,     58,    791,  36895,   1145,   8343,    941,     58,
          79217,   8221,   1

In [35]:
class CustomGPT2ModelBatch(GPT2LMHeadModel):
    def __init__(self, config, ft_model):
        super().__init__(config)
        self.config = config
        self.ft_model = ft_model

        # for param in self.base_model.parameters():
        #     param.requires_grad = False
    
    def forward(self, input_ids, attention_mask=None, labels=None):
        # Get the logits from the base GPT-2 model
        position_ids = get_position_ids(input_ids=input_ids, attention_mask=attention_mask)
        # print(input_ids)
        # print(attention_mask)
        # print(position_ids)
        # print(labels)
        outputs = super().forward(input_ids, attention_mask=attention_mask, position_ids=position_ids)
        # outputs = super().forward(input_ids, attention_mask=attention_mask)
        logits = outputs.logits
        # print('logits', logits.size())
        # Convert logits to probabilities (apply softmax)
        probabilities = nn.functional.softmax(logits, dim=-1)
        # print('probabilities', probabilities.size())
        
        # print('no error here')

        with torch.no_grad():
            outputs_base = self.ft_model.forward(input_ids, attention_mask=attention_mask, position_ids=position_ids)
            # outputs_base = self.ft_model.forward(input_ids, attention_mask=attention_mask)
            # print(type(outputs_base))
            logits_base = outputs_base.logits
            # Convert logits to probabilities (apply softmax)
            probabilities_base = nn.functional.softmax(logits_base, dim=-1)
        
        # print('no error here as well')
        
        # Multiply the probabilities by the custom vector
        # Ensure that the custom vector shape matches the logits shape [batch_size, seq_len, vocab_size]
        # For this example, assume the vector applies element-wise across the vocabulary dimension
        weighted_probabilities = probabilities * probabilities_base
        # weighted_probabilities = probabilities
        
        # Normalize the probabilities again to ensure they sum to 1
        normalized_probabilities = weighted_probabilities / weighted_probabilities.sum(dim=-1, keepdim=True)
        
        # Convert back to logits (unnormalized scores) by applying log after multiplying probabilities
        modified_logits = torch.log(normalized_probabilities + 1e-8)  # Add small constant to avoid log(0)
        
        # Use modified logits for loss calculation or generation
        if labels is not None:
            # print('reached here')
            # print('labels size', labels.size())
            # print('labels', labels)
            # Shift the logits and labels for computing the loss
            shift_logits = modified_logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            # print('shift_logits', shift_logits.size())
            # print('shift_labels', shift_labels.size())            
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
            return loss, modified_logits
        
        return modified_logits

In [36]:
class GPT2SmallBatch(GPT2LMHeadModel):
    def __init__(self, config, ft_model):
        super().__init__(config)
        self.ft_model = ft_model
        for param in self.ft_model.parameters():
            param.requires_grad = False
        
        # for name, param in self.base_model.named_parameters():
        #     if 'wte' not in name:
        #         param.requires_grad = False  # Freeze all layers except lm_head
    
    def forward(self, input_ids, attention_mask=None, labels=None):
        # Get the logits from the base GPT-2 model
        position_ids = get_position_ids(input_ids=input_ids, attention_mask=attention_mask)
        # print(input_ids)
        # print(attention_mask)
        # print(position_ids)
        # print(labels)
        outputs = super().forward(input_ids, attention_mask=attention_mask, position_ids=position_ids)
        # outputs = super().forward(input_ids, attention_mask=attention_mask)
        logits = outputs.logits
        # print('logits', logits.size())
        # Convert logits to probabilities (apply softmax)
        probabilities = nn.functional.softmax(logits, dim=-1)
        # print('probabilities', probabilities.size())
        
        # print('no error here')

        with torch.no_grad():
            outputs_base = self.ft_model.forward(input_ids, attention_mask=attention_mask, position_ids=position_ids)
            # outputs_base = self.ft_model.forward(input_ids, attention_mask=attention_mask)
            # print(type(outputs_base))
            logits_base = outputs_base.logits
            # Convert logits to probabilities (apply softmax)
            probabilities_base = nn.functional.softmax(logits_base, dim=-1)
        
        # print('no error here as well')
        
        # Multiply the probabilities by the custom vector
        # Ensure that the custom vector shape matches the logits shape [batch_size, seq_len, vocab_size]
        # For this example, assume the vector applies element-wise across the vocabulary dimension
        weighted_probabilities = probabilities * probabilities_base
        # weighted_probabilities = probabilities
        
        # Normalize the probabilities again to ensure they sum to 1
        normalized_probabilities = weighted_probabilities / weighted_probabilities.sum(dim=-1, keepdim=True)
        
        # Convert back to logits (unnormalized scores) by applying log after multiplying probabilities
        modified_logits = torch.log(normalized_probabilities + 1e-8)  # Add small constant to avoid log(0)
        
        # Use modified logits for loss calculation or generation
        if labels is not None:
            # print('reached here')
            # print('labels size', labels.size())
            # print('labels', labels)
            # Shift the logits and labels for computing the loss
            shift_logits = modified_logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            # print('shift_logits', shift_logits.size())
            # print('shift_labels', shift_labels.size())            
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
            return loss, modified_logits
        
        return modified_logits

In [37]:
class CustomGPT2Model(GPT2LMHeadModel):
    def __init__(self, config, ft_model):
        super().__init__(config)
        self.config = config
        self.ft_model = ft_model

        # for param in self.base_model.parameters():
        #     param.requires_grad = False
    
    def forward(self, input_ids, attention_mask=None, labels=None):
        
        
        generated_ids = input_ids.clone()  # Start with the input prompt
        # finished_sequences = torch.zeros(input_ids.size(0), dtype=torch.bool).to(input_ids.device)
        
#         with torch.no_grad():
#             position_ids = get_position_ids(input_ids=generated_ids, attention_mask=attention_mask)
#         # print('position_ids', position_ids.size())
#         # print(position_ids)
#         outputs = super().forward(input_ids=generated_ids, attention_mask=attention_mask, position_ids=position_ids)
#         logits = outputs.logits[:, -1, :]  # Get logits of the last token
#         probs = torch.nn.functional.softmax(logits, dim=-1)

#         with torch.no_grad():
#             outputs_base = self.ft_model.forward(input_ids=generated_ids, attention_mask=attention_mask, position_ids=position_ids)
#             logits_base = outputs_base.logits[:, -1, :]  # Get logits of the last token
#             probs_base = torch.nn.functional.softmax(logits_base, dim=-1)

#         # print('middle', torch.max(probs_base, dim=-1))
#         probs = probs*probs_base
#         sum_probs = probs.sum(dim=-1, keepdim=True)

#         # Avoid division by zero by adding a small value (epsilon)
#         sum_probs = torch.clamp(sum_probs, min=1e-9)

#         # Re-normalize by dividing each probability by the sum of probabilities
#         probs = probs / sum_probs

#         next_token = torch.argmax(probs, dim=-1).unsqueeze(-1)

#         generated_ids = torch.cat((generated_ids, next_token), dim=-1)
#         # Extend the attention mask to include the newly generated token
#         new_attention_mask = torch.ones((attention_mask.shape[0], 1)).to(input_ids.device)
#         attention_mask = torch.cat((attention_mask, new_attention_mask), dim=-1)
        
#         modified_logits = torch.log(probs + 1e-8)
        modified_logits = torch.empty(input_ids.size(0), TARGET_SIZE, 50257).to(input_ids.device)  # Empty tensor of the desired final size
        with torch.no_grad():
            position_ids = get_position_ids(input_ids=generated_ids, attention_mask=attention_mask)
        # past_key_values = None
        # past_key_values_base = None

        
        # print('modified_logits', modified_logits.size())
 
        for step in range(TARGET_SIZE):
            # Get the model outputs (logits) for the current step
            # with torch.no_grad():
            #     position_ids = get_position_ids(input_ids=generated_ids, attention_mask=attention_mask)
            # print('generated_ids', generated_ids.size())
            # print('attention_mask', attention_mask.size())
            # print('position_ids', position_ids.size())
            # if(past_key_values):
            #     outputs = super().forward(input_ids=generated_ids[:, -1:], attention_mask=attention_mask[:, -1:], 
            #                               position_ids=position_ids[:, -1:], use_cache=True, past_key_values=past_key_values)
            # else:
            outputs = super().forward(input_ids=generated_ids, attention_mask=attention_mask, position_ids=position_ids)
            logits = outputs.logits[:, -1, :]  # Get logits of the last token
            probs = torch.nn.functional.softmax(logits, dim=-1)
            # past_key_values = outputs.past_key_values

            with torch.no_grad():
                # if(past_key_values_base):
                #     outputs_base = self.ft_model.forward(input_ids=generated_ids[:, -1:], attention_mask=attention_mask[:, -1:], position_ids=position_ids[:, -1:], 
                #                                          use_cache=True, past_key_values=past_key_values_base)
                # else:
                outputs_base = self.ft_model.forward(input_ids=generated_ids, attention_mask=attention_mask, position_ids=position_ids)
                logits_base = outputs_base.logits[:, -1, :]  # Get logits of the last token
                probs_base = torch.nn.functional.softmax(logits_base, dim=-1)
                # past_key_values_base = outputs_base.past_key_values
                    
            # print('middle', torch.max(probs_base, dim=-1))
            probs = probs*probs_base
            sum_probs = probs.sum(dim=-1, keepdim=True)

            # Avoid division by zero by adding a small value (epsilon)
            sum_probs = torch.clamp(sum_probs, min=1e-9)

            # Re-normalize by dividing each probability by the sum of probabilities
            probs = probs / sum_probs
            
            next_token = torch.argmax(probs, dim=-1).unsqueeze(-1)
            
            generated_ids = torch.cat((generated_ids, next_token), dim=-1)
            # Extend the attention mask to include the newly generated token
            new_attention_mask = torch.ones((attention_mask.shape[0], 1)).to(input_ids.device)
            attention_mask = torch.cat((attention_mask, new_attention_mask), dim=-1)
            
            temp_logits = torch.log(probs + 1e-8) 
            
            # modified_logits = torch.cat((modified_logits, temp_logits), dim=1)
            modified_logits[:, step, :] = temp_logits
            
            last_values = position_ids[:, -1]  # This gets the last value of each row (shape: m)
            new_values = last_values + 1  # Increment each last value by 1
            new_values = new_values.unsqueeze(1)  # Reshape to (m, 1) to concatenate with the tensor
            position_ids = torch.cat([position_ids, new_values], dim=1) 
            
        # print('modified_logits', modified_logits.size())

            
#         # Get the logits from the base GPT-2 model
#         position_ids = get_position_ids(input_ids=input_ids, attention_mask=attention_mask)
#         outputs = super().forward(input_ids, attention_mask=attention_mask, position_ids=position_ids)
#         logits = outputs.logits
#         probabilities = nn.functional.softmax(logits, dim=-1)

#         with torch.no_grad():
#             outputs_base = self.ft_model.forward(input_ids, attention_mask=attention_mask, position_ids=position_ids)
#             logits_base = outputs_base.logits
#             probabilities_base = nn.functional.softmax(logits_base, dim=-1)

#         weighted_probabilities = probabilities * probabilities_base
#         # weighted_probabilities = probabilities
        
#         # Normalize the probabilities again to ensure they sum to 1
#         normalized_probabilities = weighted_probabilities / weighted_probabilities.sum(dim=-1, keepdim=True)
        
#         # Convert back to logits (unnormalized scores) by applying log after multiplying probabilities
#         modified_logits = torch.log(normalized_probabilities + 1e-8)  # Add small constant to avoid log(0)
        
        # Use modified logits for loss calculation or generation
        if labels is not None:
#             shift_logits = modified_logits[..., :-1, :].contiguous()
#             shift_labels = labels[..., 1:].contiguous()
            
            ### IF BATCH THEN MOVE IT, OTHERWISE NOT
            
            shift_logits = modified_logits
            shift_labels = labels
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
            return loss, modified_logits
        
        return modified_logits

In [38]:
def get_base_probs(md, input_ids, attention_mask, position_ids):
    with torch.no_grad():
        position_ids_base = position_ids.to(md.device)
        input_ids_base = input_ids.to(md.device)
        attention_mask_base = attention_mask.to(md.device)
        # print('reached  here as well')
        # print('input_ids_base device', input_ids_base.device)
        # print('self.ft_model device', model_ft.device)
        outputs_base = md.forward(input_ids_base, attention_mask=attention_mask_base, position_ids=position_ids_base)
        # outputs_base = self.ft_model.forward(input_ids, attention_mask=attention_mask)
        # print(type(outputs_base))
        logits_base = outputs_base.logits
        # Convert logits to probabilities (apply softmax)
        probabilities_base = nn.functional.softmax(logits_base, dim=-1)
    return probabilities_base


In [39]:
class CustomLlamaModelBatchSeparate(LlamaForCausalLM):
    def __init__(self, config, tokenizer):
        super().__init__(config)
        self.config = config
        super().resize_token_embeddings(len(tokenizer))
    
    def forward(self, input_ids, attention_mask=None, labels=None):
        # Get the logits from the base GPT-2 model
        position_ids = get_position_ids(input_ids=input_ids, attention_mask=attention_mask)
        # print('got position ids')
        # print('$$$', input_ids.device)
        # print(attention_mask)
        # print(position_ids)
        # print(labels)
        outputs = super().forward(input_ids, attention_mask=attention_mask, position_ids=position_ids)
        # outputs = super().forward(input_ids, attention_mask=attention_mask)
        logits = outputs.logits
        # print('logits', logits.size())
        # Convert logits to probabilities (apply softmax)
        probabilities = nn.functional.softmax(logits, dim=-1)
        # print('probabilities', probabilities.size())
        
        # print('no error here')
        probabilities_base = get_base_probs(input_ids, attention_mask, position_ids)
        
        probabilities_base = probabilities_base.to(probabilities.device)
        # print('no error here as well')
        
        # Multiply the probabilities by the custom vector
        # Ensure that the custom vector shape matches the logits shape [batch_size, seq_len, vocab_size]
        # For this example, assume the vector applies element-wise across the vocabulary dimension
        weighted_probabilities = probabilities * probabilities_base
        # weighted_probabilities = probabilities
        
        # Normalize the probabilities again to ensure they sum to 1
        normalized_probabilities = weighted_probabilities / weighted_probabilities.sum(dim=-1, keepdim=True)
        
        # Convert back to logits (unnormalized scores) by applying log after multiplying probabilities
        modified_logits = torch.log(normalized_probabilities + 1e-8)  # Add small constant to avoid log(0)
        
        # Use modified logits for loss calculation or generation
        if labels is not None:
            # print('reached here')
            # print('labels size', labels.size())
            # print('labels', labels)
            # Shift the logits and labels for computing the loss
            shift_logits = modified_logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            # print('shift_logits', shift_logits.size())
            # print('shift_labels', shift_labels.size())            
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
            return loss, modified_logits
        
        return modified_logits

In [40]:
class CustomLlamaModelBatch(LlamaForCausalLM):
    def __init__(self, config, ft_model, tokenizer):
        super().__init__(config)
        self.config = config
        self.ft_model = ft_model
        super().resize_token_embeddings(len(tokenizer))

        torch.nn.DataParallel(super(), device_ids=[0, 1, 2, 3, 4, 5, 6, 7])
        self.ft_model.to('cpu')

        # for param in self.base_model.parameters():
        #     param.requires_grad = False

    # def to(self, device):
    #     super().to(device)  # Move only small_model to GPU
    #     self.ft_model.to('cpu')
    #     return self
    
    def forward(self, input_ids, attention_mask=None, labels=None):
        # Get the logits from the base GPT-2 model
        position_ids = get_position_ids(input_ids=input_ids, attention_mask=attention_mask)
        print('got position ids')
        print('$$$', input_ids.device)
        # print(attention_mask)
        # print(position_ids)
        # print(labels)
        outputs = super().forward(input_ids, attention_mask=attention_mask, position_ids=position_ids)
        # outputs = super().forward(input_ids, attention_mask=attention_mask)
        logits = outputs.logits
        # print('logits', logits.size())
        # Convert logits to probabilities (apply softmax)
        probabilities = nn.functional.softmax(logits, dim=-1)
        print('probabilities', probabilities.size())
        
        # print('no error here')

        with torch.no_grad():
            position_ids_base = position_ids.to('cpu')
            input_ids_base = input_ids.to('cpu')
            attention_mask_base = attention_mask.to('cpu')
            print('reached  here as well')
            print('input_ids_base device', input_ids_base.device)
            print('self.ft_model device', self.ft_model.device)
            outputs_base = self.ft_model.forward(input_ids_base, attention_mask=attention_mask_base, position_ids=position_ids_base)
            # outputs_base = self.ft_model.forward(input_ids, attention_mask=attention_mask)
            # print(type(outputs_base))
            logits_base = outputs_base.logits
            # Convert logits to probabilities (apply softmax)
            probabilities_base = nn.functional.softmax(logits_base, dim=-1)
        
        probabilities_base = probabilities_base.to(probabilities.device)
        print('no error here as well')
        
        # Multiply the probabilities by the custom vector
        # Ensure that the custom vector shape matches the logits shape [batch_size, seq_len, vocab_size]
        # For this example, assume the vector applies element-wise across the vocabulary dimension
        weighted_probabilities = probabilities * probabilities_base
        # weighted_probabilities = probabilities
        
        # Normalize the probabilities again to ensure they sum to 1
        normalized_probabilities = weighted_probabilities / weighted_probabilities.sum(dim=-1, keepdim=True)
        
        # Convert back to logits (unnormalized scores) by applying log after multiplying probabilities
        modified_logits = torch.log(normalized_probabilities + 1e-8)  # Add small constant to avoid log(0)
        
        # Use modified logits for loss calculation or generation
        if labels is not None:
            # print('reached here')
            # print('labels size', labels.size())
            # print('labels', labels)
            # Shift the logits and labels for computing the loss
            shift_logits = modified_logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            # print('shift_logits', shift_logits.size())
            # print('shift_labels', shift_labels.size())            
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
            return loss, modified_logits
        
        return modified_logits

In [41]:
# # Define a smaller configuration
# # config = GPT2Config(
# #     n_layer=1,         # Fewer transformer layers
# #     n_embd=64,        # Smaller hidden size (default is 768)
# #     n_head=1,          # Fewer attention heads (default is 12)
# #     resid_pdrop=0.2,   # Increase dropout to reduce overfitting
# #     attn_pdrop=0.2,    # Increase dropout in the attention layer
# #     embd_pdrop=0.2,    # Increase dropout in the embedding layer
# #     initializer_range=0.02,  # Smaller initialization range to stabilize training
# # )
if(model_type == 'gpt2'):
    config = GPT2Config(
        n_layer=1,         # Fewer transformer layers (default 12)
    )
elif(model_type == 'llama'):
    config = LlamaConfig(
        num_hidden_layers=1,         # Fewer transformer layers (default 12)
        vocab_size=len(tokenizer.vocab),
        hidden_size=512,
        num_attention_heads=4,
        intermediate_size=2048,

    )

In [42]:
if(model_type=='gpt2'):
    model = CustomGPT2ModelBatch(config, model_ft)
# model = GPT2SmallBatch.from_pretrained('gpt2', model_ft)
else:
    # model = CustomLlamaModelBatch(config, model_ft, tokenizer)
    model = CustomLlamaModelBatchSeparate(config, tokenizer)


In [43]:
for name, param in model.named_parameters():
    if(param.requires_grad):
        print(name)

model.embed_tokens.weight
model.layers.0.self_attn.q_proj.weight
model.layers.0.self_attn.k_proj.weight
model.layers.0.self_attn.v_proj.weight
model.layers.0.self_attn.o_proj.weight
model.layers.0.mlp.gate_proj.weight
model.layers.0.mlp.up_proj.weight
model.layers.0.mlp.down_proj.weight
model.layers.0.input_layernorm.weight
model.layers.0.post_attention_layernorm.weight
model.norm.weight
lm_head.weight


In [44]:
os.environ["CUDA_VISIBLE_DEVICES"] = "1,2,3,4,5,6,7"

In [45]:
# # Define training arguments
# training_args = TrainingArguments(
#     output_dir="../../results/plugin",
#     evaluation_strategy="epoch",
#     learning_rate=5e-6,
#     per_device_train_batch_size=8,  # This batch size is per GPU
#     per_device_eval_batch_size=8,
#     num_train_epochs=10,
#     weight_decay=0.01,
#     logging_dir='../../logs/plugin',
#     logging_steps=1,
#     push_to_hub=False
# )

# Define training arguments
training_args = TrainingArguments(
    output_dir=f'./results/{PROC_NAME}',
    evaluation_strategy='epoch',
    # eval_steps=2,
    save_strategy="epoch",
    learning_rate=5e-6,
    per_device_train_batch_size=8,  # This batch size is per GPU
    per_device_eval_batch_size=8,
    num_train_epochs=2,
    weight_decay=1,
    logging_dir=f'./logs/{PROC_NAME}',
    logging_steps=10,
    push_to_hub=False,
    fp16=True,  # Enable mixed-precision training for faster training
    report_to="none",
    # load_best_model_at_end=True,  # Required for early stopping
    # metric_for_best_model="eval_loss",  # Metric to determine the best model (optional)
    # greater_is_better=False,  # Set to False if lower metric is better (e.g., loss)
    save_total_limit=1,
    # no_cuda = True,
    # dataloader_num_workers=2,
)



In [46]:
# Accelerator for multi-GPU support
# train_dataset = tokenized_dataset['train']
# eval_dataset = tokenized_dataset['hypervalidation']

train_dataset = tokenized_dataset['validation']
eval_dataset = tokenized_dataset['test']

# model, train_dataset, eval_dataset = accelerator.prepare(model, train_dataset, eval_dataset)

In [47]:
# Define a collate function to convert lists to tensors on-the-fly
def collate_fn(batch):
    input_ids = [torch.tensor(item['input_ids'], dtype=torch.long) for item in batch]
    attention_mask = [torch.tensor(item['attention_mask'], dtype=torch.long) for item in batch]
    labels = [torch.tensor(item['labels'], dtype=torch.long) for item in batch]

    return {
        'input_ids': torch.stack(input_ids),
        'attention_mask': torch.stack(attention_mask),
        'labels': torch.stack(labels),
    }


In [48]:
optimizer = AdamW(model.parameters(), lr=training_args.learning_rate)

total_steps = len(train_dataset) * training_args.num_train_epochs
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=int(0.1 * total_steps),  # Warm-up for the first 10% of steps
    num_training_steps=total_steps
)



In [49]:
# train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, collate_fn=collate_fn)
# val_loader = DataLoader(eval_dataset, batch_size=8, shuffle=False, collate_fn=collate_fn)



# # Set the model to training mode
# model.train()

# # Tokenize some example data
# # inputs = tokenizer("Example input text", return_tensors="pt", add_special_tokens=True).to("cuda")
# # labels = tokenizer("Example output text", return_tensors="pt", add_special_tokens=True).to("cuda")

# # Initialize optimizer
# # optimizer = AdamW(model.parameters(), lr=5e-5)

# # Basic training loop
# for epoch in range(1):
#     total_loss = 0
#     for batch in tqdm(train_loader):
#         input_ids = batch['input_ids'].to(model.device)
#         attention_mask = batch['attention_mask'].to(model.device)
#         labels = batch['labels'].to(model.device)
        
#         optimizer.zero_grad()
#         outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
#         loss, logits = outputs
#         total_loss+=loss.item()
#         # print(f"Batch {epoch} | Loss: {loss.item()}")
#         loss.backward()
#         optimizer.step()
#         scheduler.step()
#     print(f'loss after epoch {epoch} is: {total_loss/len(train_loader)}')

In [50]:

# Initialize the Trainer
trainer = Trainer(
    model=model,                       # The model with PEFT applied
    args=training_args,                     # Training arguments
    train_dataset=train_dataset, # Training data
    eval_dataset=eval_dataset, # Validation data
    # data_collator=data_collator,
    tokenizer = tokenizer,
    optimizers=(
        optimizer, 
        scheduler
        ),  # Pass optimizer and scheduler
    # callbacks=[EarlyStoppingCallback(early_stopping_patience=3, early_stopping_threshold=0.001)]  # Add early stopping
)

In [51]:
# for name, param in check.named_parameters():
#     if(param.requires_grad):
#         print(name, param.requires_grad)

# # Compare the weights of the lm_head and the embedding layer
# print("Are lm_head and wte sharing the same weights?")
# print(model.lm_head.weight is model.transformer.wte.weight)

# # Compare the weights of the lm_head and the embedding layer
# print("Are lm_head and wte sharing the same weights?")
# print(model.ft_model.lm_head.weight is model.ft_model.embed_tokens.weight)

# # Compare the weights of the lm_head and the embedding layer
# print("Are lm_head and wte sharing the same weights?")
# print(base_model.lm_head.weight is base_model.transformer.wte.weight)

for name, param in model.named_parameters():
    if('wte' in name or 'lm_head' in name):
        print(name, param.requires_grad)
    # if(param.requires_grad):
    #     print(name, param.requires_grad)

lm_head.weight True


In [52]:
# Compare the weights of the lm_head and the embedding layer
# print("Are lm_head and wte sharing the same weights?")
# print(model.ft_model.lm_head.weight is model_ft.transformer.wte.weight)

In [53]:
# model_ft.transformer.wte.weight, model_ft.transformer.h[1].attn.c_attn.weight

In [54]:
# model.ft_model.transformer.wte.weight, model.ft_model.transformer.h[1].attn.c_attn.weight

In [55]:
# from torch.nn import DataParallel
# model = DataParallel(model)

In [56]:
# Fine-tune the model
trainer.train()

Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)


Epoch,Training Loss,Validation Loss
1,2.4467,2.510737
2,2.3577,2.441068




TrainOutput(global_step=154, training_loss=2.4397337096078053, metrics={'train_runtime': 359.8245, 'train_samples_per_second': 23.895, 'train_steps_per_second': 0.428, 'total_flos': 691984877617152.0, 'train_loss': 2.4397337096078053, 'epoch': 2.0})

In [48]:
# model.tie_weights()

In [49]:
# accelerator.wait_for_everyone()  # Synchronize GPUs

In [50]:
# for name, param in check.named_parameters():
#     if(param.requires_grad):
#         print(name, param.requires_grad)

# Compare the weights of the lm_head and the embedding layer
print("Are lm_head and wte sharing the same weights?")
print(model.lm_head.weight is model.transformer.wte.weight)

# Compare the weights of the lm_head and the embedding layer
print("Are lm_head and wte sharing the same weights?")
print(model.ft_model.lm_head.weight is model.ft_model.transformer.wte.weight)

# Compare the weights of the lm_head and the embedding layer
print("Are lm_head and wte sharing the same weights?")
print(base_model.lm_head.weight is base_model.transformer.wte.weight)

for name, param in model.named_parameters():
    if('wte' in name or 'lm_head' in name):
        print(name, param.requires_grad)
    # if(param.requires_grad):
    #     print(name, param.requires_grad)

Are lm_head and wte sharing the same weights?
True
Are lm_head and wte sharing the same weights?
True
Are lm_head and wte sharing the same weights?
True
transformer.wte.weight True
ft_model.transformer.wte.weight False


In [51]:
# Save the model
model.save_pretrained(f'./models/{PROC_NAME}_{FT_MODEL_NAME}')
tokenizer.save_pretrained(f'./models/{PROC_NAME}_{FT_MODEL_NAME}')

('./models/plugin_over_base_gpt2_small_full_ft_ft_model_no_token_without_lora_train_hyperval_concatenated/tokenizer_config.json',
 './models/plugin_over_base_gpt2_small_full_ft_ft_model_no_token_without_lora_train_hyperval_concatenated/special_tokens_map.json',
 './models/plugin_over_base_gpt2_small_full_ft_ft_model_no_token_without_lora_train_hyperval_concatenated/vocab.json',
 './models/plugin_over_base_gpt2_small_full_ft_ft_model_no_token_without_lora_train_hyperval_concatenated/merges.txt',
 './models/plugin_over_base_gpt2_small_full_ft_ft_model_no_token_without_lora_train_hyperval_concatenated/added_tokens.json',
 './models/plugin_over_base_gpt2_small_full_ft_ft_model_no_token_without_lora_train_hyperval_concatenated/tokenizer.json')

In [None]:
# for name, param in check.named_parameters():
#     if(param.requires_grad):
#         print(name, param.requires_grad)

# Compare the weights of the lm_head and the embedding layer
print("Are lm_head and wte sharing the same weights?")
print(model.lm_head.weight is model.transformer.wte.weight)

# Compare the weights of the lm_head and the embedding layer
print("Are lm_head and wte sharing the same weights?")
print(model.ft_model.lm_head.weight is model.ft_model.transformer.wte.weight)

# Compare the weights of the lm_head and the embedding layer
print("Are lm_head and wte sharing the same weights?")
print(base_model.lm_head.weight is base_model.transformer.wte.weight)

for name, param in model.named_parameters():
    if('wte' in name or 'lm_head' in name):
        print(name, param.requires_grad)
    # if(param.requires_grad):
    #     print(name, param.requires_grad)

In [53]:
check = AutoModelForCausalLM.from_pretrained(f'./models/test{PROC_NAME}_{FT_MODEL_NAME}')

In [54]:
check

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2SdpaAttention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
)

In [55]:
# for name, param in check.named_parameters():
#     if(param.requires_grad):
#         print(name, param.requires_grad)

# Compare the weights of the lm_head and the embedding layer
print("Are lm_head and wte sharing the same weights?")
print(model.lm_head.weight is model.transformer.wte.weight)

# Compare the weights of the lm_head and the embedding layer
print("Are lm_head and wte sharing the same weights?")
print(model.ft_model.lm_head.weight is model.ft_model.transformer.wte.weight)

# Compare the weights of the lm_head and the embedding layer
print("Are lm_head and wte sharing the same weights?")
print(base_model.lm_head.weight is base_model.transformer.wte.weight)

for name, param in model.named_parameters():
    if('wte' in name or 'lm_head' in name):
        print(name, param.requires_grad)
    # if(param.requires_grad):
    #     print(name, param.requires_grad)

Are lm_head and wte sharing the same weights?
True
Are lm_head and wte sharing the same weights?
True
Are lm_head and wte sharing the same weights?
True
transformer.wte.weight True
ft_model.transformer.wte.weight False


In [56]:
# for name, param in check.named_parameters():
#     if(param.requires_grad):
#         print(name, param.requires_grad)

# Compare the weights of the lm_head and the embedding layer
print("Are lm_head and wte sharing the same weights?")
print(check.lm_head.weight is check.transformer.wte.weight)

# Compare the weights of the lm_head and the embedding layer
print("Are lm_head and wte sharing the same weights?")
print(check.ft_model.lm_head.weight is check.ft_model.transformer.wte.weight)

# Compare the weights of the lm_head and the embedding layer
print("Are lm_head and wte sharing the same weights?")
print(base_model.lm_head.weight is base_model.transformer.wte.weight)

for name, param in check.named_parameters():
    if('wte' in name or 'lm_head' in name):
        print(name, param.requires_grad)
    # if(param.requires_grad):
    #     print(name, param.requires_grad)

Are lm_head and wte sharing the same weights?
True
Are lm_head and wte sharing the same weights?


AttributeError: 'GPT2LMHeadModel' object has no attribute 'ft_model'

In [48]:
# Compare the weights of the lm_head and the embedding layer
print("Are lm_head and wte sharing the same weights?")
print(base_model.lm_head.weight is base_model.transformer.wte.weight)


Are lm_head and wte sharing the same weights?
True


In [49]:
for name, param in model.named_parameters():
    if('wte' in name or 'lm_head' in name):
        print(name, param.requires_grad)
    # if(param.requires_grad):
    #     print(name, param.requires_grad)

transformer.wte.weight True
ft_model.transformer.wte.weight False
