In [1]:
from transformers import GPT2LMHeadModel, GPT2Tokenizer,GPT2Config, GPT2Model
from torch.optim import Adam
from torch.utils.data import DataLoader
import tqdm
import torch
import gc 
import torch
from torch.utils.data import Dataset
import json

In [2]:
!nvidia-smi

Wed Apr 10 16:45:02 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.161.07             Driver Version: 535.161.07   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA GeForce RTX 3060 ...    Off | 00000000:01:00.0  On |                  N/A |
| N/A   60C    P8              15W / 115W |     59MiB /  6144MiB |     40%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    

In [3]:
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(device)

cuda


### Tokenizer/Model

In [4]:
tokenizer = GPT2Tokenizer.from_pretrained("gpt2-medium")

In [5]:
tokenizer

GPT2Tokenizer(name_or_path='gpt2-medium', vocab_size=50257, model_max_length=1024, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<|endoftext|>', 'eos_token': '<|endoftext|>', 'unk_token': '<|endoftext|>'}, clean_up_tokenization_spaces=True),  added_tokens_decoder={
	50256: AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),
}

In [6]:
tokenizer.add_special_tokens({"pad_token": "<pad>",
                                "bos_token": "<sos>",
                                "eos_token": "<eos>"})
tokenizer.add_tokens(["<bot>",'<per>'])

2

In [7]:
model = GPT2LMHeadModel.from_pretrained("gpt2-medium")
model.resize_token_embeddings(len(tokenizer))
model = model.to(device)
model.to(torch.bfloat16)

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50262, 1024)
    (wpe): Embedding(1024, 1024)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-23): 24 x GPT2Block(
        (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=1024, out_features=50262, bias=False)
)

In [8]:
model.num_parameters()

354828288

### Dataset

In [9]:
from torch.utils.data import Dataset
import json

In [10]:
class ChatData(Dataset):
    def __init__(self, path:str, tokenizer):
        self.file = json.load(open(path, "r"))

        
        self.data = [] # To store the dialogue 

        # Reading the dialogues 
        for idx, dialog in enumerate(self.file):
            self.data.append([]) # new dialogue
            for txt in dialog['dialog']:
                self.data[idx].append(txt['text'])
        # Formating the conversations 
        self.X = []
        int1 = ''
        int2 = ''
        for idx, dialog in enumerate(self.data):
            if len(dialog)==0:
                continue
            if len(dialog) % 2 == 0:
                int1 = '<per>'
                int2 = '<bot>'
            else:
                int1 = '<bot>'
                int2 = '<per>'
            tmp_str = '<sos>'
            for k,utt in enumerate(dialog):
                if k%2 == 0:
                    tmp_str += int1 + utt
                else:
                    tmp_str += int2 + utt
            tmp_str += '<eos>'
            # print(len(tmp_str))
            self.X.append(tmp_str)

        # max = 0
        # for dialog in self.X:
        #     if len(dialog)>max:
        #         max= len(dialog)
        # print(max)
    
        
        # Tokenization of the dataset 
        self.X_encoded = tokenizer(self.X,max_length=512, truncation=True, padding="max_length", return_tensors="pt")
        
        self.input_ids = self.X_encoded['input_ids']
        self.attention_mask = self.X_encoded['attention_mask']
        
    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return (self.input_ids[idx], self.attention_mask[idx])

In [11]:
chatData = ChatData("./chat_data.json", tokenizer)
chatData =  DataLoader(chatData, batch_size=2,shuffle=True)

In [12]:
ex = next(iter(chatData))

In [13]:
tokenizer.batch_decode(ex[0])

["<sos> <bot> i am not sure what that is. <per> What are you not sure of? <bot> i am not sure. i am a very experienced person. <per> I bet you are. Tell me more about yourself <bot> i am a very experienced, but i would rather be a librarian. <per> Why librarian? <bot> i am not sure. i am sure they are. <per> You're sure they are what? <bot> i think they are fda regulated. <per> Librarians are fda regulated? <bot> no i have not. but i have heard that before. <per> Do you know what is fda? <bot> i think i am fda regulated. <per> What is fda? <bot> i think so. i think i will add some of my food to eat. <per> What do you like to eat? <bot> i do not think i have voluntarily tried it <per> You don't eat at all? <bot> i think i am available for that. <per> Do you drink! <bot> i am not a very good person. i am not very good at it. <per> You're not good at drinking? Why? <bot> i think i am not sure. i am sure it will. <per> It will what? <bot> i am not sure. i am a very experienced person. <per

### Training 

In [14]:
def train(chatData, model, optim,epochs=5):
    model.train()
    for i in range(epochs):
        for X, a in tqdm.tqdm(chatData):
            garbage_collect()
            X = X.to(device)
            a = a.to(device)
            # Ignoring the padding token by setting its label to-100
            labels = X.clone()
            labels[labels == tokenizer.pad_token_id] = -100

            optim.zero_grad()
            loss = model(X, attention_mask=a, labels=labels).loss
            loss.backward()
            optim.step()
            loss.detach()
        print(f'Epoch {i}, Loss = {loss.item()}')
        torch.save(model.state_dict(), "model_state.pt")
        # print(infer("<sos> <per>hello how are you<bot>"))

In [15]:
def infer(inp):
    inp = tokenizer(inp, return_tensors="pt")
    X = inp["input_ids"].to(device)
    a = inp["attention_mask"].to(device)
    output = model.generate(X, attention_mask=a )
    output = tokenizer.decode(output[0])
    return output


In [16]:
optim = Adam(model.parameters(), lr=1e-4)

In [19]:
print("training .... ")
train(chatData, model, optim,epochs=5)

training .... 


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 214/214 [01:19<00:00,  2.68it/s]


Epoch 0, Loss = 4.96875


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 214/214 [01:19<00:00,  2.70it/s]


Epoch 1, Loss = 5.0


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 214/214 [01:17<00:00,  2.76it/s]


Epoch 2, Loss = 3.390625


100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 214/214 [01:18<00:00,  2.73it/s]


Epoch 3, Loss = 1.453125


100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 214/214 [01:17<00:00,  2.76it/s]


Epoch 4, Loss = 2.203125


In [18]:
def garbage_collect():
    """
        Clear the memory of the GPU and CPU 
    """
    torch.cuda.empty_cache()
    gc.collect()

In [39]:
garbage_collect()

In [None]:
while True:
  inp = input()
  print(infer(inp))

In [21]:
print(infer("<sos> <per>hello how are you<bot>"))

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


<sos>  <per> hello how are you <bot> I am a new person. <bot> hello <bot>  how are you
