In [2]:
import json
import os

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.utils.data import Dataset, DataLoader
from torch.nn import functional as F

from transformers import GPT2Tokenizer, GPT2LMHeadModel
from transformers import AdamW,  get_linear_schedule_with_warmup

import numpy as np
from tqdm.notebook import tqdm

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

# Dataset

In [4]:
with open("dataset/reddit_fixed.json", "r") as read_file:
    jokes = json.load(read_file)

In [5]:
for joke in jokes:
    joke['body'] = "<SOS>{0}<EOS>".format(joke['body'])

In [6]:
print(jokes[0])

{'body': '<SOS>I hate how you cant even say black paint anymore\nNow I have to say "Leroy can you please paint the fence?"<EOS>', 'id': '5tz52q', 'score': 1}


In [7]:
class JokesDataset(Dataset):
    def __init__(self, jokes):

        self.jokes = jokes
        
    def __len__(self):
        return len(self.jokes)

    def __getitem__(self, idx):
        return self.jokes[idx]['body'], self.jokes[idx]['score']

In [8]:
dataset = JokesDataset(jokes)
data_loader = DataLoader(dataset, batch_size=1, shuffle=True)

#Training

In [9]:
tokenizer = GPT2Tokenizer.from_pretrained('gpt2-medium')
tokenizer.pad_token = tokenizer.eos_token

model = GPT2LMHeadModel.from_pretrained('gpt2-medium')
model = model.to(device)
model.train()

optimizer = AdamW(model.parameters(), lr=1e-4)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=10000, num_training_steps = -1)


In [None]:
for epoch in range(100):
    epoch_loss = 0
    for jokes, scores in tqdm(data_loader):
        encoding = tokenizer(jokes, return_tensors='pt', padding=True, truncation=True)
        input_ids = encoding['input_ids'].to(device)
        attention_mask = encoding['attention_mask'].to(device)
        labels = scores.unsqueeze(0).to(device)
        outputs = model(input_ids, attention_mask=attention_mask, labels=input_ids)

        loss = outputs.loss
        loss.backward()
        optimizer.step()
        scheduler.step() 
        optimizer.zero_grad()
        model.zero_grad()
        epoch_loss += loss.detach().data
    print(f'Average loss:{epoch_loss/len(data_loader)}')

HBox(children=(FloatProgress(value=0.0, max=194553.0), HTML(value='')))