In [11]:
from transformers import GPT2LMHeadModel, GPT2Tokenizer
from torch.optim import Adam
from torch.utils.data import DataLoader
import tqdm
import torch
from torch.utils.data import Dataset
import json

In [12]:
class ChatData(Dataset):
    def __init__(self, path:str, tokenizer):
        self.data = json.load(open(path, "r"))

        self.X = []
        for i in self.data:
            for j in i['dialog']:
                self.X.append(j['text'])

        for idx, i in enumerate(self.X):
            try:
                self.X[idx] = "<startofstring> "+i+" <bot>: "+self.X[idx+1]+" <endofstring>"
            except:
                break

        self.X = self.X[:5000]
        
        print(self.X[0])

        self.X_encoded = tokenizer(self.X,max_length=40, truncation=True, padding="max_length", return_tensors="pt")
        self.input_ids = self.X_encoded['input_ids']
        self.attention_mask = self.X_encoded['attention_mask']

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return (self.input_ids[idx], self.attention_mask[idx])

In [33]:
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"

tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.pad_token = "<pad>"
tokenizer.add_special_tokens({"bos_token": "<startofstring>",
                                "eos_token": "<endofstring>"})
tokenizer.add_tokens(["<bot>:"])

model = GPT2LMHeadModel.from_pretrained("gpt2")
model.resize_token_embeddings(len(tokenizer))

model = model.to(device)


In [34]:
chatData = ChatData("./chat_data.json", tokenizer)
chatData =  DataLoader(chatData, batch_size=64)

model.train()

optim = Adam(model.parameters(), lr=1e-3)


<startofstring> I love iphone! i just bought new iphone! <bot>: Thats good for you, i'm not very into new tech <endofstring>


In [36]:
def infer(inp):
    inp = "<startofstring> "+inp+" <bot>: "
    inp = tokenizer(inp, return_tensors="pt")
    X = inp["input_ids"].to(device)
    a = inp["attention_mask"].to(device)
    output = model.generate(X, attention_mask=a )
    output = tokenizer.decode(output[0])
    return output

In [16]:
def train(chatData, model, optim):

    epochs = 12

    for i in tqdm.tqdm(range(epochs)):
        c = 1
        for X, a in chatData:
            X = X.to(device)
            a = a.to(device)
            optim.zero_grad()
            loss = model(X, attention_mask=a, labels=X).loss
            loss.backward()
            optim.step()
            print(f"epoch {i+1} batch {c} loss : {loss.item()}")
            c+=1
        torch.save(model.state_dict(), "model_state.pt")
        print("--------------------------------------------------")
        print(infer("hey"))
        print("--------------------------------------------------")

In [17]:
print("training .... ")
train(chatData, model, optim)

training .... 


  0%|          | 0/12 [00:00<?, ?it/s]

epoch 1 batch 1 loss : 93.96897888183594
epoch 1 batch 2 loss : 97.72197723388672
epoch 1 batch 3 loss : 7.878519535064697
epoch 1 batch 4 loss : 6.835137844085693
epoch 1 batch 5 loss : 6.4496331214904785
epoch 1 batch 6 loss : 6.296763896942139
epoch 1 batch 7 loss : 5.942861080169678
epoch 1 batch 8 loss : 5.4540324211120605
epoch 1 batch 9 loss : 5.476049423217773
epoch 1 batch 10 loss : 5.293665409088135
epoch 1 batch 11 loss : 4.425160884857178
epoch 1 batch 12 loss : 5.463339805603027
epoch 1 batch 13 loss : 4.793827056884766
epoch 1 batch 14 loss : 4.37387752532959
epoch 1 batch 15 loss : 3.198378562927246
epoch 1 batch 16 loss : 3.140977144241333
epoch 1 batch 17 loss : 4.76219367980957
epoch 1 batch 18 loss : 4.2240891456604
epoch 1 batch 19 loss : 4.558444499969482
epoch 1 batch 20 loss : 4.790134906768799
epoch 1 batch 21 loss : 4.475421905517578
epoch 1 batch 22 loss : 4.625162601470947
epoch 1 batch 23 loss : 3.4931232929229736
epoch 1 batch 24 loss : 3.7245404720306396
e

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
  8%|▊         | 1/12 [14:43<2:41:58, 883.54s/it]

--------------------------------------------------
<startofstring> hey  <bot>:  <endofstring><|endoftext|>
--------------------------------------------------
epoch 2 batch 1 loss : 2.5716030597686768
epoch 2 batch 2 loss : 2.831345558166504
epoch 2 batch 3 loss : 2.5584511756896973
epoch 2 batch 4 loss : 2.516566514968872
epoch 2 batch 5 loss : 2.5286357402801514
epoch 2 batch 6 loss : 2.772916555404663
epoch 2 batch 7 loss : 2.4544241428375244
epoch 2 batch 8 loss : 2.253345012664795
epoch 2 batch 9 loss : 2.530212879180908
epoch 2 batch 10 loss : 2.76501727104187
epoch 2 batch 11 loss : 2.1326985359191895
epoch 2 batch 12 loss : 3.179919481277466
epoch 2 batch 13 loss : 2.4909563064575195
epoch 2 batch 14 loss : 2.133256673812866
epoch 2 batch 15 loss : 1.5712839365005493
epoch 2 batch 16 loss : 1.5368993282318115
epoch 2 batch 17 loss : 3.087360143661499
epoch 2 batch 18 loss : 2.49908447265625
epoch 2 batch 19 loss : 2.976935863494873
epoch 2 batch 20 loss : 3.557049036026001
epoch

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
 17%|█▋        | 2/12 [28:17<2:20:24, 842.47s/it]

--------------------------------------------------
<startofstring> hey  <bot>:  <endofstring><|endoftext|>
--------------------------------------------------
epoch 3 batch 1 loss : 2.167609691619873
epoch 3 batch 2 loss : 2.477191209793091
epoch 3 batch 3 loss : 2.260333299636841
epoch 3 batch 4 loss : 2.1466963291168213
epoch 3 batch 5 loss : 2.1125950813293457
epoch 3 batch 6 loss : 2.1777217388153076
epoch 3 batch 7 loss : 2.039309024810791
epoch 3 batch 8 loss : 1.769187331199646
epoch 3 batch 9 loss : 2.0679585933685303
epoch 3 batch 10 loss : 2.3294169902801514
epoch 3 batch 11 loss : 1.768523931503296
epoch 3 batch 12 loss : 2.6785929203033447
epoch 3 batch 13 loss : 2.069905996322632
epoch 3 batch 14 loss : 1.6633081436157227
epoch 3 batch 15 loss : 1.351298451423645
epoch 3 batch 16 loss : 1.1858493089675903
epoch 3 batch 17 loss : 2.536935567855835
epoch 3 batch 18 loss : 1.9535653591156006
epoch 3 batch 19 loss : 2.466430902481079
epoch 3 batch 20 loss : 3.0759754180908203
e

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


--------------------------------------------------


 25%|██▌       | 3/12 [41:44<2:03:57, 826.34s/it]

<startofstring> hey  <bot>:   <bot>:  Hi, i am a huge gamer <endofstring><|endoftext|>
--------------------------------------------------
epoch 4 batch 1 loss : 1.781527042388916
epoch 4 batch 2 loss : 2.164747953414917
epoch 4 batch 3 loss : 1.9329111576080322
epoch 4 batch 4 loss : 1.81088125705719
epoch 4 batch 5 loss : 1.7368439435958862
epoch 4 batch 6 loss : 1.6519947052001953
epoch 4 batch 7 loss : 1.706878662109375
epoch 4 batch 8 loss : 1.495866298675537
epoch 4 batch 9 loss : 1.7077736854553223
epoch 4 batch 10 loss : 1.9112392663955688
epoch 4 batch 11 loss : 1.495490550994873
epoch 4 batch 12 loss : 2.2232449054718018
epoch 4 batch 13 loss : 1.7365000247955322
epoch 4 batch 14 loss : 1.3113024234771729
epoch 4 batch 15 loss : 1.1036690473556519
epoch 4 batch 16 loss : 0.9549606442451477
epoch 4 batch 17 loss : 2.069035768508911
epoch 4 batch 18 loss : 1.5575538873672485
epoch 4 batch 19 loss : 2.0563242435455322
epoch 4 batch 20 loss : 2.5653860569000244
epoch 4 batch 21 lo

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


--------------------------------------------------


 33%|███▎      | 4/12 [55:27<1:49:59, 824.91s/it]

<startofstring> hey  <bot>:   <bot>: <bot>:  Hi, how are doing? <endofstring><|endoftext|>
--------------------------------------------------
epoch 5 batch 1 loss : 1.5162222385406494
epoch 5 batch 2 loss : 1.8197280168533325
epoch 5 batch 3 loss : 1.678934097290039
epoch 5 batch 4 loss : 1.560239315032959
epoch 5 batch 5 loss : 1.4872900247573853
epoch 5 batch 6 loss : 1.366298794746399
epoch 5 batch 7 loss : 1.4529272317886353
epoch 5 batch 8 loss : 1.2559969425201416
epoch 5 batch 9 loss : 1.4105192422866821
epoch 5 batch 10 loss : 1.6551755666732788
epoch 5 batch 11 loss : 1.2214432954788208
epoch 5 batch 12 loss : 1.9017025232315063
epoch 5 batch 13 loss : 1.4584952592849731
epoch 5 batch 14 loss : 1.0310217142105103
epoch 5 batch 15 loss : 0.9085278511047363
epoch 5 batch 16 loss : 0.7953083515167236
epoch 5 batch 17 loss : 1.7287845611572266
epoch 5 batch 18 loss : 1.2755613327026367
epoch 5 batch 19 loss : 1.6845321655273438
epoch 5 batch 20 loss : 2.0785109996795654
epoch 5 ba

 33%|███▎      | 4/12 [1:01:43<2:03:26, 925.81s/it]


KeyboardInterrupt: 

In [37]:
print("infer from model : ")
while True:
  inp = input()
  if inp.lower() == 'exit':
    print("Terminating the program...")
    break

  print(infer(inp))

infer from model : 


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


<startofstring> hey  <bot>:  <startofstring><startofstring><startofstring><startofstring><startofstring><startofstring><startofstring><startofstring><startofstring><startofstring><startofstring><startofstring><startofstring><startofstring><startofstring>
Terminating the program...
