<a href="https://colab.research.google.com/github/vivek09thakur/ELSA/blob/main/Colab%20Notebook/Enhance_Lang_Seq_Agent.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from torch.utils.data import Dataset
import json

class TextData(Dataset):

    def __init__(self,path:str,tokenizer,print_data=False):
        self.data = json.load(open(path,'r'))
        self.X = []

        for i in self.data:
            for j in i['dailog']:
                self.X.append(j['text'])

        for idx,i in enumerate(self.X):
            try:
                self.X[idx] = "<sos>" + i + "<model>" + self.X[idx+1] + "<eos>"
            except:
                break

        self.X = self.X[:5000]
        if print_data==True:
            print(self.X[0])

        self.X_encoded = tokenizer(self.X,max_length=50,padding=True,truncation=True,return_tensors='pt')
        self.input_ids = self.X_encoded['input_ids']
        self.attention_mask = self.X_encoded['attention_mask']

    def __len__(self):
        return len(self.X)

    def __getitem__(self,idx):
        return self.input_ids[idx],self.attention_mask[idx]

In [2]:
import tqdm
import torch

class GPT2FineTuner:

    def __init__(self):
        pass

    def train(self,textData,model,epoch,save_model=False,save_path=None):
        for i in tqdm.tqdm(range(epoch)):
            for X,a in textData:
                X = X.to(self.device)
                a = a.to(self.device)
                optim = optim.zero_grad()
                loss = model(X,attention_mask=a,labels=X).loss
                loss.backward()
                optim.step()

                if save_model==True:
                    if save_path==None:
                        raise Exception("Please provide a save path")
                    else:
                        torch.save(model.state_dict(),save_path)


    def predict(self,text):
        user_input = "<sos>" + text + "<model>"
        user_input_encoded = self.tokenizer(user_input,return_tensors='pt')
        X,a = user_input_encoded['input_ids'].to(self.device),user_input_encoded['attention_mask'].to(self.device)
        completion = self.model.generate(X,attention_mask=a,max_length=50)
        return self.tokenizer.decode(completion[0])

In [3]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


**main run**

In [12]:
from transformers import GPT2Tokenizer,GPT2LMHeadModel
from torch.utils.data import DataLoader
from torch.optim import Adam

device = "cuda" if torch.cuda.is_available() else "cpu"

tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.add_special_tokens({'pad_token':'<pad>','bos_token':'<sos>','eos_token':'<eos>'})
tokenizer.add_tokens(['<model>'])

model = GPT2LMHeadModel.from_pretrained('gpt2')
model.resize_token_embeddings(len(tokenizer))
model.to(device)

textdata = TextData('/content/data.json',tokenizer,print_data=False)
textDataLoader = DataLoader(textdata, batch_size=64, ignore_index=True)

tuned_model = GPT2FineTuner()
optim = Adam(model.parameters(), lr=1e-3)
tuned_model.train(textData=textdata, model=model, epoch=100, save_model=True, save_path='saved_model.pt')

while True:
    text = input("<user> ")
    completion = tuned_model.predict(text)
    print("<model> " + completion)

KeyError: ignored