In [1]:
from transformers import GPT2LMHeadModel, GPT2Tokenizer, TrainingArguments, Trainer
from datasets import Dataset
import pandas as pd
from sklearn.model_selection import train_test_split
from torch.nn.utils.rnn import pad_sequence
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Load the tokenizer
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token

In [3]:
def load_and_process_data(file_path):
    with open(file_path, 'r', encoding='utf-8') as file:
        lines = file.readlines()

    dialogues = [line.strip().split('|||') for line in lines if line.strip()]
    # Process each dialogue pair
    processed_examples = []
    for dialogue in dialogues:
        customer, agent = dialogue
        combined_text = customer + tokenizer.eos_token + agent + tokenizer.eos_token
        tokenized = tokenizer(combined_text, truncation=True, max_length=512, padding='max_length')
        # Ensure the 'labels' are correctly set for loss computation
        tokenized['labels'] = tokenized['input_ids'].copy()
        processed_examples.append(tokenized)

    # Creating a DataFrame
    df = pd.DataFrame(processed_examples)

    train_df, eval_df = train_test_split(df, test_size=0.1)
    train_dataset = Dataset.from_pandas(train_df)
    eval_dataset = Dataset.from_pandas(eval_df)

    return train_dataset, eval_dataset

In [4]:
# Custom collate function to ensure dynamic padding
def custom_collate_fn(batch):
    input_ids = [item['input_ids'] for item in batch]
    labels = [item['labels'] for item in batch]

    input_ids_padded = pad_sequence([torch.tensor(seq) for seq in input_ids], 
                                    batch_first=True, padding_value=tokenizer.pad_token_id)
    labels_padded = pad_sequence([torch.tensor(seq) for seq in labels], 
                                 batch_first=True, padding_value=tokenizer.pad_token_id)

    return {'input_ids': input_ids_padded, 'labels': labels_padded}


In [9]:
# Load and process the dataset
train_dataset, eval_dataset = load_and_process_data('../../resources/next_sentence_prediction_gpt2/nsp_fine_tuning_dataset.txt')


In [11]:
model = GPT2LMHeadModel.from_pretrained('gpt2')

training_args = TrainingArguments(
    output_dir='./results',
    num_train_epochs=3,
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    warmup_steps=500,
    weight_decay=0.01,
    logging_dir='./logs',
    logging_steps=10,
    evaluation_strategy="epoch",
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    data_collator=custom_collate_fn,  # Use the custom collate function
)


In [12]:
import os
os.environ["WANDB_NOTEBOOK_NAME"] = "gpt2_nsp_trainer.ipynb"


In [13]:
import wandb
wandb.login(key="a9433891c8b370b61ad36f3b5c379d41d79ff4dd")


[34m[1mwandb[0m: Currently logged in as: [33msaket-singh1[0m ([33mmphasis[0m). Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: C:\Users\saket.singh1\.netrc


True

In [14]:
# Train the model
trainer.train()

  2%|▏         | 10/540 [01:58<1:45:41, 11.96s/it]

{'loss': 12.0575, 'grad_norm': 237.56495666503906, 'learning_rate': 1.0000000000000002e-06, 'epoch': 0.06}


  4%|▎         | 20/540 [04:35<2:16:47, 15.78s/it]

{'loss': 8.9151, 'grad_norm': 238.3504638671875, 'learning_rate': 2.0000000000000003e-06, 'epoch': 0.11}


  6%|▌         | 30/540 [07:09<2:08:32, 15.12s/it]

{'loss': 3.1803, 'grad_norm': 65.6487045288086, 'learning_rate': 3e-06, 'epoch': 0.17}


  7%|▋         | 40/540 [09:28<1:39:43, 11.97s/it]

{'loss': 0.7488, 'grad_norm': 11.113787651062012, 'learning_rate': 4.000000000000001e-06, 'epoch': 0.22}


  9%|▉         | 50/540 [11:16<1:28:22, 10.82s/it]

{'loss': 0.3448, 'grad_norm': 3.6761839389801025, 'learning_rate': 5e-06, 'epoch': 0.28}


 11%|█         | 60/540 [13:03<1:22:46, 10.35s/it]

{'loss': 0.2315, 'grad_norm': 2.043060064315796, 'learning_rate': 6e-06, 'epoch': 0.33}


 13%|█▎        | 70/540 [14:50<1:24:14, 10.75s/it]

{'loss': 0.2009, 'grad_norm': 1.8377227783203125, 'learning_rate': 7.000000000000001e-06, 'epoch': 0.39}


 15%|█▍        | 80/540 [16:37<1:22:34, 10.77s/it]

{'loss': 0.1736, 'grad_norm': 1.7176202535629272, 'learning_rate': 8.000000000000001e-06, 'epoch': 0.44}


 17%|█▋        | 90/540 [18:42<1:48:20, 14.45s/it]

{'loss': 0.153, 'grad_norm': 1.579350233078003, 'learning_rate': 9e-06, 'epoch': 0.5}


 19%|█▊        | 100/540 [20:49<1:22:16, 11.22s/it]

{'loss': 0.1336, 'grad_norm': 1.5734918117523193, 'learning_rate': 1e-05, 'epoch': 0.56}


 20%|██        | 110/540 [22:35<1:16:12, 10.63s/it]

{'loss': 0.1276, 'grad_norm': 1.5299124717712402, 'learning_rate': 1.1000000000000001e-05, 'epoch': 0.61}


 22%|██▏       | 120/540 [24:23<1:15:17, 10.76s/it]

{'loss': 0.1206, 'grad_norm': 1.6226211786270142, 'learning_rate': 1.2e-05, 'epoch': 0.67}


 24%|██▍       | 130/540 [26:09<1:12:35, 10.62s/it]

{'loss': 0.1092, 'grad_norm': 1.6439725160598755, 'learning_rate': 1.3000000000000001e-05, 'epoch': 0.72}


 26%|██▌       | 140/540 [27:55<1:11:17, 10.69s/it]

{'loss': 0.1011, 'grad_norm': 1.667094111442566, 'learning_rate': 1.4000000000000001e-05, 'epoch': 0.78}


 28%|██▊       | 150/540 [29:41<1:09:15, 10.65s/it]

{'loss': 0.1048, 'grad_norm': 1.3565764427185059, 'learning_rate': 1.5e-05, 'epoch': 0.83}


 30%|██▉       | 160/540 [31:27<1:06:57, 10.57s/it]

{'loss': 0.1039, 'grad_norm': 2.023925542831421, 'learning_rate': 1.6000000000000003e-05, 'epoch': 0.89}


 31%|███▏      | 170/540 [33:07<1:00:43,  9.85s/it]

{'loss': 0.0946, 'grad_norm': 1.4914634227752686, 'learning_rate': 1.7000000000000003e-05, 'epoch': 0.94}


 33%|███▎      | 180/540 [34:42<55:57,  9.33s/it]  

{'loss': 0.1179, 'grad_norm': 1.8545736074447632, 'learning_rate': 1.8e-05, 'epoch': 1.0}


                                                 
 33%|███▎      | 180/540 [35:51<55:57,  9.33s/it]

{'eval_loss': 0.08096234500408173, 'eval_runtime': 69.6161, 'eval_samples_per_second': 0.575, 'eval_steps_per_second': 0.287, 'epoch': 1.0}


 35%|███▌      | 190/540 [38:31<1:36:54, 16.61s/it]

{'loss': 0.0982, 'grad_norm': 1.7289223670959473, 'learning_rate': 1.9e-05, 'epoch': 1.06}


 37%|███▋      | 200/540 [41:11<1:30:48, 16.02s/it]

{'loss': 0.0752, 'grad_norm': 1.6542977094650269, 'learning_rate': 2e-05, 'epoch': 1.11}


 39%|███▉      | 210/540 [43:48<1:23:04, 15.10s/it]

{'loss': 0.0865, 'grad_norm': 1.53231942653656, 'learning_rate': 2.1e-05, 'epoch': 1.17}


 41%|████      | 220/540 [46:22<1:24:04, 15.76s/it]

{'loss': 0.0845, 'grad_norm': 1.2956771850585938, 'learning_rate': 2.2000000000000003e-05, 'epoch': 1.22}


 43%|████▎     | 230/540 [49:08<1:29:11, 17.26s/it]

{'loss': 0.0839, 'grad_norm': 1.6261184215545654, 'learning_rate': 2.3000000000000003e-05, 'epoch': 1.28}


 44%|████▍     | 240/540 [51:53<1:21:42, 16.34s/it]

{'loss': 0.0781, 'grad_norm': 1.7511905431747437, 'learning_rate': 2.4e-05, 'epoch': 1.33}


 46%|████▋     | 250/540 [54:31<1:18:49, 16.31s/it]

{'loss': 0.0778, 'grad_norm': 1.5387507677078247, 'learning_rate': 2.5e-05, 'epoch': 1.39}


 48%|████▊     | 260/540 [57:15<1:16:07, 16.31s/it]

{'loss': 0.087, 'grad_norm': 2.093186378479004, 'learning_rate': 2.6000000000000002e-05, 'epoch': 1.44}


 50%|█████     | 270/540 [59:58<1:13:28, 16.33s/it]

{'loss': 0.0767, 'grad_norm': 1.6479960680007935, 'learning_rate': 2.7000000000000002e-05, 'epoch': 1.5}


 52%|█████▏    | 280/540 [1:02:41<1:10:50, 16.35s/it]

{'loss': 0.0838, 'grad_norm': 1.4221104383468628, 'learning_rate': 2.8000000000000003e-05, 'epoch': 1.56}


 54%|█████▎    | 290/540 [1:05:05<51:03, 12.25s/it]  

{'loss': 0.0858, 'grad_norm': 1.5121551752090454, 'learning_rate': 2.9e-05, 'epoch': 1.61}


 56%|█████▌    | 300/540 [1:06:53<43:24, 10.85s/it]

{'loss': 0.0753, 'grad_norm': 1.7284377813339233, 'learning_rate': 3e-05, 'epoch': 1.67}


 57%|█████▋    | 310/540 [1:08:41<41:40, 10.87s/it]

{'loss': 0.0803, 'grad_norm': 1.8121699094772339, 'learning_rate': 3.1e-05, 'epoch': 1.72}


 59%|█████▉    | 320/540 [1:10:36<45:56, 12.53s/it]

{'loss': 0.0719, 'grad_norm': 1.4553824663162231, 'learning_rate': 3.2000000000000005e-05, 'epoch': 1.78}


 61%|██████    | 330/540 [1:13:01<43:32, 12.44s/it]

{'loss': 0.0777, 'grad_norm': 1.8122836351394653, 'learning_rate': 3.3e-05, 'epoch': 1.83}


 63%|██████▎   | 340/540 [1:14:49<36:06, 10.83s/it]

{'loss': 0.0768, 'grad_norm': 1.6358875036239624, 'learning_rate': 3.4000000000000007e-05, 'epoch': 1.89}


 65%|██████▍   | 350/540 [1:16:19<27:42,  8.75s/it]

{'loss': 0.0732, 'grad_norm': 1.6336723566055298, 'learning_rate': 3.5e-05, 'epoch': 1.94}


 67%|██████▋   | 360/540 [1:17:46<25:44,  8.58s/it]

{'loss': 0.0787, 'grad_norm': 1.7487150430679321, 'learning_rate': 3.6e-05, 'epoch': 2.0}


                                                   
 67%|██████▋   | 360/540 [1:18:49<25:44,  8.58s/it]

{'eval_loss': 0.06654295325279236, 'eval_runtime': 62.7231, 'eval_samples_per_second': 0.638, 'eval_steps_per_second': 0.319, 'epoch': 2.0}


 69%|██████▊   | 370/540 [1:20:22<29:42, 10.49s/it]  

{'loss': 0.0682, 'grad_norm': 1.6938284635543823, 'learning_rate': 3.7e-05, 'epoch': 2.06}


 70%|███████   | 380/540 [1:21:58<24:50,  9.31s/it]

{'loss': 0.068, 'grad_norm': 1.4387884140014648, 'learning_rate': 3.8e-05, 'epoch': 2.11}


 72%|███████▏  | 390/540 [1:23:32<23:48,  9.53s/it]

{'loss': 0.068, 'grad_norm': 1.4701826572418213, 'learning_rate': 3.9000000000000006e-05, 'epoch': 2.17}


 74%|███████▍  | 400/540 [1:25:08<22:33,  9.67s/it]

{'loss': 0.0647, 'grad_norm': 1.3522177934646606, 'learning_rate': 4e-05, 'epoch': 2.22}


 76%|███████▌  | 410/540 [1:26:44<20:43,  9.56s/it]

{'loss': 0.0679, 'grad_norm': 1.149408221244812, 'learning_rate': 4.1e-05, 'epoch': 2.28}


 78%|███████▊  | 420/540 [1:28:21<19:17,  9.65s/it]

{'loss': 0.0681, 'grad_norm': 1.8512886762619019, 'learning_rate': 4.2e-05, 'epoch': 2.33}


 80%|███████▉  | 430/540 [1:29:56<16:36,  9.06s/it]

{'loss': 0.06, 'grad_norm': 1.32083261013031, 'learning_rate': 4.3e-05, 'epoch': 2.39}


 81%|████████▏ | 440/540 [1:31:28<15:24,  9.24s/it]

{'loss': 0.0643, 'grad_norm': 1.3609890937805176, 'learning_rate': 4.4000000000000006e-05, 'epoch': 2.44}


 83%|████████▎ | 450/540 [1:33:00<13:56,  9.30s/it]

{'loss': 0.0732, 'grad_norm': 1.6107959747314453, 'learning_rate': 4.5e-05, 'epoch': 2.5}


 85%|████████▌ | 460/540 [1:34:33<12:20,  9.25s/it]

{'loss': 0.052, 'grad_norm': 1.3616284132003784, 'learning_rate': 4.600000000000001e-05, 'epoch': 2.56}


 87%|████████▋ | 470/540 [1:36:11<11:33,  9.91s/it]

{'loss': 0.0616, 'grad_norm': 1.3776196241378784, 'learning_rate': 4.7e-05, 'epoch': 2.61}


 89%|████████▉ | 480/540 [1:37:51<09:53,  9.89s/it]

{'loss': 0.0619, 'grad_norm': 1.5759397745132446, 'learning_rate': 4.8e-05, 'epoch': 2.67}


 91%|█████████ | 490/540 [1:39:28<08:06,  9.73s/it]

{'loss': 0.057, 'grad_norm': 1.2912721633911133, 'learning_rate': 4.9e-05, 'epoch': 2.72}


 93%|█████████▎| 500/540 [1:41:14<07:13, 10.84s/it]

{'loss': 0.0719, 'grad_norm': 1.6529687643051147, 'learning_rate': 5e-05, 'epoch': 2.78}


 94%|█████████▍| 510/540 [1:43:07<05:29, 10.97s/it]

{'loss': 0.0694, 'grad_norm': 1.7314153909683228, 'learning_rate': 3.7500000000000003e-05, 'epoch': 2.83}


 96%|█████████▋| 520/540 [1:44:56<03:36, 10.84s/it]

{'loss': 0.0662, 'grad_norm': 1.8247193098068237, 'learning_rate': 2.5e-05, 'epoch': 2.89}


 98%|█████████▊| 530/540 [1:46:46<01:49, 10.99s/it]

{'loss': 0.0691, 'grad_norm': 1.5836316347122192, 'learning_rate': 1.25e-05, 'epoch': 2.94}


100%|██████████| 540/540 [1:48:35<00:00, 10.92s/it]

{'loss': 0.0636, 'grad_norm': 1.2164207696914673, 'learning_rate': 0.0, 'epoch': 3.0}


                                                   
100%|██████████| 540/540 [1:49:57<00:00, 12.22s/it]

{'eval_loss': 0.06193813681602478, 'eval_runtime': 81.9981, 'eval_samples_per_second': 0.488, 'eval_steps_per_second': 0.244, 'epoch': 3.0}
{'train_runtime': 6600.2927, 'train_samples_per_second': 0.164, 'train_steps_per_second': 0.082, 'train_loss': 0.5489876355285997, 'epoch': 3.0}





TrainOutput(global_step=540, training_loss=0.5489876355285997, metrics={'train_runtime': 6600.2927, 'train_samples_per_second': 0.164, 'train_steps_per_second': 0.082, 'train_loss': 0.5489876355285997, 'epoch': 3.0})

In [16]:
# Save the fine-tuned model
model_path = '../../resources/next_sentence_prediction_gpt2/gpt2_finetuned_nsp'
model.save_pretrained(model_path)
tokenizer.save_pretrained(model_path)


('../../resources/next_sentence_prediction_gpt2/gpt2_finetuned\\tokenizer_config.json',
 '../../resources/next_sentence_prediction_gpt2/gpt2_finetuned\\special_tokens_map.json',
 '../../resources/next_sentence_prediction_gpt2/gpt2_finetuned\\vocab.json',
 '../../resources/next_sentence_prediction_gpt2/gpt2_finetuned\\merges.txt',
 '../../resources/next_sentence_prediction_gpt2/gpt2_finetuned\\added_tokens.json')