In [75]:
from transformers import GPT2Tokenizer, GPTNeoForCausalLM, Trainer, TrainingArguments, AutoConfig
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import TextGenerationPipeline

In [76]:
tokenizer = GPT2Tokenizer.from_pretrained("EleutherAI/gpt-neo-125M")
tokenizer.pad_token = '[pad]'

In [77]:
df = pd.read_csv('training_data.csv')
df

Unnamed: 0.1,Unnamed: 0,input_text,output_label
0,0,janu: brow I’m donna do it janu: I j got home ...,K cool
1,1,janu: I’m donna email Or Nguyen janu: W about ...,To or fine
2,2,janu: W about it YALL submit Pranav Kannepalli...,It’s all good
3,3,"janu: I seem so reckless about this, sorry. I’...",Oh ok
4,4,Pranav Kannepalli: To or fine Pranav Kannepall...,I’m on a plane in
...,...,...,...
11874,224,Krishna Chintalapudi: Tomorrow evening is poss...,Went the link to your email
11875,225,"Pranav Kannepalli: sounds good, what time work...",Krishna uncle?
11876,226,Krishna Chintalapudi: 530? Pranav Kannepalli: ...,sent a new link
11877,227,Pranav Kannepalli: sure that works for me Pran...,"Krishna uncle, could we meet today or tomorrow..."


In [78]:
def tokenize_text():
    input_ids = tokenizer(df['input_text'].astype(str).to_list(), padding='max_length', truncation=True, max_length=5, return_tensors='pt').input_ids
    output_ids = tokenizer(df['output_label'].astype(str).to_list(), padding='max_length', truncation=True, max_length=5, return_tensors='pt').input_ids
    return {"input_ids": input_ids, "output_ids": output_ids}

# Apply tokenization
tokenized_data = tokenize_text()

# Tokenized dataset
tokenized_data

{'input_ids': tensor([[13881,    84,    25,  4772,   314],
         [13881,    84,    25,   314,   447],
         [13881,    84,    25,   370,   546],
         ...,
         [   42, 37518,  2616,   609,   600],
         [   47,  2596,   615,   509,  1236],
         [   47,  2596,   615,   509,  1236]]),
 'output_ids': tensor([[   42,  3608, 50256, 50256, 50256],
         [ 2514,   393,  3734, 50256, 50256],
         [ 1026,   447,   247,    82,   477],
         ...,
         [34086,   257,   649,  2792, 50256],
         [   42, 37518,  2616,  7711,    11],
         [   82,  3733,   922, 50256, 50256]])}

In [79]:
class KittsLLMDataset(Dataset):
    def __init__(self, tokenized_data):
        self.data = tokenized_data
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return {
            "input_ids": self.data['input_ids'][idx].squeeze(0),
            "labels": self.data['output_ids'][idx].squeeze(0)
        }
    
train_data = KittsLLMDataset(tokenized_data)
train_data.__getitem__(0)

{'input_ids': tensor([13881,    84,    25,  4772,   314]),
 'labels': tensor([   42,  3608, 50256, 50256, 50256])}

In [80]:
model = GPTNeoForCausalLM.from_pretrained('EleutherAI/gpt-neo-125M')

In [81]:
training_args = TrainingArguments(
    output_dir='./pretrained_results', 
    num_train_epochs=3, 
    per_device_train_batch_size=8, 
    logging_dir='./pretrained_logs',
    save_steps=10_000,
)

In [82]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_data
)

In [83]:
trainer.train()

100%|██████████| 3/3 [00:04<00:00,  1.66s/it]

{'train_runtime': 4.9038, 'train_samples_per_second': 1.224, 'train_steps_per_second': 0.612, 'train_loss': 6.415890375773112, 'epoch': 3.0}





TrainOutput(global_step=3, training_loss=6.415890375773112, metrics={'train_runtime': 4.9038, 'train_samples_per_second': 1.224, 'train_steps_per_second': 0.612, 'total_flos': 15305103360.0, 'train_loss': 6.415890375773112, 'epoch': 3.0})

In [91]:
input_text = "janu: Hi, What's Up? Pranav Kannepalli: "
inputs = tokenizer(input_text, return_tensors="pt")
outputs = model.generate(**inputs, max_length=36)
print(outputs)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


tensor([[13881,    84,    25, 15902,    11,  1867,   338,  3205,    30,  1736,
           272,   615,   509,  1236,   538, 36546,    25,   220,   198,   198,
            40,  1101,   257,  1310, 10416,    13,   314,  1101,   407,  1654,
           644,   262,  1917,   318,    13,   198]])
janu: Hi, What's Up? Pranav Kannepalli: 

I'm a little confused. I'm not sure what the problem is.

