In [None]:
import glob
import math
import pandas as pd
import torch
from torch.utils.data import DataLoader
from transformers.optimization import AdamW, get_linear_schedule_with_warmup
from transformers import GPT2Tokenizer, GPT2LMHeadModel
from sklearn.model_selection import train_test_split
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
tqdm.pandas()

In [None]:
class BrainyQuoteDataset(torch.utils.data.Dataset):
    def __init__(self,text_list,tokenizer,max_len):
        self.text_list = text_list
        self.tokenizer = tokenizer
        self.tokenizer.pad_token = self.tokenizer.eos_token
        self.max_len = max_len
    def __getitem__(self,index):
        text = self.text_list[index]
        return tokenizer.encode_plus(text,tokenizer.eos_token,padding="max_length",max_length=self.max_len,return_tensors='pt')
    def __len__(self):
        return len(self.text_list)

In [None]:
df = pd.concat([pd.read_csv(f) for f in glob.glob("../input/brainyquote-topics/"+'*.csv')])
train_df,test_df = train_test_split(df,test_size=0.2)

In [None]:
tokenizer = GPT2Tokenizer.from_pretrained('distilgpt2')
model = GPT2LMHeadModel.from_pretrained('distilgpt2')

In [None]:
# max_len = 0
# def find_max_len(text):
#     global max_len
#     curr_len = len(tokenizer.encode(text))
#     if curr_len > max_len:
#         max_len = curr_len
# df['title'].progress_apply(find_max_len)

In [None]:
max_len = 128
batch_size = 8
train_dataset = BrainyQuoteDataset(train_df['title'].tolist(),tokenizer,max_len)
test_dataset = BrainyQuoteDataset(test_df['title'].tolist(),tokenizer,max_len)

train_dataloader = DataLoader(train_dataset,batch_size=batch_size,shuffle=True,num_workers=4)
test_dataloader = DataLoader(test_dataset,batch_size=batch_size,shuffle=True,num_workers=4)

no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
# Define two sets of parameters: those with weight decay, and those without
optimizer_parameters = [
    {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 0.001},
    {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0},
]
epochs = 5
optimizer = AdamW(optimizer_parameters, lr=5e-5)
scheduler = get_linear_schedule_with_warmup(
    optimizer, num_warmup_steps=math.floor(len(train_dataloader)*epochs/2), num_training_steps=len(train_dataloader)*epochs
)

In [None]:
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [None]:
best_loss = 9999
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
train_loss_ls = []
avg_meter = AverageMeter()
model.to(device)
for epoch in range(epochs):
    model.train()
    train_loss = 0
    avg_meter.reset()
    tk =  tqdm(train_dataloader)
    for data in tk:
        optimizer.zero_grad()
        input_ids = data['input_ids'].to(device)
        attention_mask = data['attention_mask'].to(device)
        out = model(input_ids,labels=input_ids,attention_mask=attention_mask)
        loss = out[0]
        loss.backward()
        optimizer.step()
        scheduler.step()
        train_loss += loss.item()
        train_loss_ls.append(loss.item())
        avg_meter.update(loss.item(),input_ids.shape[0])
        tk.set_postfix({'loss':avg_meter.avg})
    test_loss = 0
    model.eval()
    with torch.no_grad():
        for data in tqdm(test_dataloader):
            input_ids = data['input_ids'].to(device)
            attention_mask = data['attention_mask'].to(device)
            out = model(input_ids,labels=input_ids,attention_mask=attention_mask)
            loss = out[0]
            test_loss += loss.item()
    if test_loss<best_loss:
        best_loss = test_loss
        torch.save(model.state_dict(), 'best_brainyquotegpt2.pth')
    print(f"epoch: {epoch} train loss: {train_loss/len(train_dataloader)} test loss: {test_loss/len(test_dataloader)}")

In [None]:
plt.plot(train_loss_ls)

In [None]:
model.eval()
prompt = "Inspiration is"
with torch.no_grad():
    input_ids = tokenizer(prompt, return_tensors="pt").input_ids
    input_ids = input_ids.to(device)
    outputs = model.generate(
        input_ids,
        max_length=1024, 
        do_sample=True, 
        top_k=50, 
        top_p=0.95, 
    )
    print(tokenizer.decode(outputs[0], skip_special_tokens=True))

In [None]:
torch.save(model.state_dict(), 'brainyquotegpt2.pth')