In [11]:
import json
import os
import re

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.utils.data import Dataset, DataLoader, RandomSampler
from torch.nn import functional as F

from transformers import GPT2Tokenizer, GPT2LMHeadModel, GPT2Config
from transformers import AdamW,  get_linear_schedule_with_warmup

from sklearn.model_selection import train_test_split
import numpy as np
from tqdm.notebook import tqdm

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

# Dataset

In [5]:
with open("dataset/reddit_fixed.json", "r") as read_file:
    jokes = json.load(read_file)

In [6]:
jokes_5_score = []
for joke in jokes:
    if joke['score'] == 5:
        jokes_5_score.append(joke)

In [32]:
class JokesDataset(Dataset):
    def __init__(self, jokes, tokenizer, max_length):

        self.jokes = jokes
        self.input_ids = []
        self.attn_masks = []

        for joke in jokes:

            encodings_dict = tokenizer('<SOS> '+ joke['body'] + ' <EOS>', truncation=True, max_length=max_length, padding='max_length')

            self.input_ids.append(torch.tensor(encodings_dict['input_ids']))
            self.attn_masks.append(torch.tensor(encodings_dict['attention_mask']))
        
    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, idx):
        return self.input_ids[idx], self.attn_masks[idx] 

In [33]:
tokenizer = GPT2Tokenizer.from_pretrained('gpt2', bos_token='<SOS>', eos_token='<EOS>')
tokenizer.pad_token = tokenizer.eos_token

dataset = JokesDataset(jokes_5_score, tokenizer, max_length=768)

train_idx, valid_idx= train_test_split(
np.arange(len(dataset)),
test_size=0.1,
shuffle=True)

train_sampler = torch.utils.data.SubsetRandomSampler(train_idx)
valid_sampler = torch.utils.data.SubsetRandomSampler(valid_idx)

dataloaders = {'train': torch.utils.data.DataLoader(dataset, batch_size=4, sampler=train_sampler),
'val': torch.utils.data.DataLoader(dataset, batch_size=2,sampler=valid_sampler)}

dataset_sizes = {'train': len(train_idx), 'val': len(valid_idx)}

Special tokens have been added in the vocabulary, make sure the associated word embedding are fine-tuned or trained.


# Training

In [66]:
configuration = GPT2Config.from_pretrained('gpt2', output_hidden_states=False)
model = GPT2LMHeadModel.from_pretrained('gpt2', config=configuration)
model.resize_token_embeddings(len(tokenizer))
model = model.to(device)
model.train()

optimizer = AdamW(model.parameters(), lr=5e-4, eps=1e-8)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=10000, num_training_steps = -1)


In [67]:
for epoch in range(3):
    epoch_loss_train = 0
    model.train()
    for i, batch in enumerate(tqdm(dataloaders['train'])):  
        input_ids = batch[0].to(device)
        labels = batch[0].to(device)
        masks = batch[1].to(device)

        outputs = model(input_ids, labels=labels, attention_mask=masks)
        
        loss = outputs[0]  

        batch_loss = loss.item()
        epoch_loss_train += batch_loss

        loss.backward()

        optimizer.step()

        scheduler.step() 
    
    model.eval()
    epoch_loss_val = 0

    for batch in tqdm(dataloaders['val']):
        input_ids = batch[0].to(device)
        labels = batch[0].to(device)
        masks = batch[1].to(device)

        with torch.no_grad():
            outputs = model(input_ids, labels=labels, attention_mask=masks)
        
            loss = outputs[0]

        batch_loss = loss.item()
        epoch_loss_val += batch_loss 

    print('Average train loss: {}'.format(epoch_loss_train/len(dataloaders['train'])))
    print('Average val loss: {}'.format(epoch_loss_val/len(dataloaders['val'])))
    torch.save(model.state_dict(), '/content/drive/My Drive/SavedModels/GPT-2_jokes3.h5')

HBox(children=(FloatProgress(value=0.0, max=1334.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=297.0), HTML(value='')))


Average train loss: 0.8900212166325442
Average val loss: 0.5970960970378484


HBox(children=(FloatProgress(value=0.0, max=1334.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=297.0), HTML(value='')))


Average train loss: 0.5768181447783063
Average val loss: 0.7118179524185682


HBox(children=(FloatProgress(value=0.0, max=1334.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=297.0), HTML(value='')))


Average train loss: 0.7828975106621611
Average val loss: 0.9642804497821564


In [None]:
state_dict = torch.load('/content/drive/My Drive/SavedModels/GPT-2_jokes3.h5')
model.load_state_dict(state_dict)

In [65]:
model.eval()
generated = torch.tensor(tokenizer.encode('<SOS>')).unsqueeze(0)
generated = generated.to(device)

print(generated)

sample_outputs = model.generate(
                                generated, 
                                do_sample=True,   
                                top_k=50, 
                                max_length = 768,
                                top_p=0.95, 
                                num_return_sequences=3
                                )

for i, sample_output in enumerate(sample_outputs):
    print("{}: {}\n\n".format(i, tokenizer.decode(sample_output, skip_special_tokens=True).replace('\n',' ')))

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


tensor([[50257]], device='cuda:0')
0: What did the blind man say to the deaf man? You're not going anywhere.


1: If you're not a fan of a certain type of music, you're probably not a fan of me.


2: Did you hear about the new iPhone? It's called the "iPhone 7."


