<a href="https://colab.research.google.com/github/pranavkarnani/StoryGenerator/blob/pranav/GPT2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
! pip install transformers



In [1]:
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import numpy as np
import random
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import GPT2Tokenizer, GPT2LMHeadModel, AdamW, get_linear_schedule_with_warmup
import torch.nn.functional as F
import torch.nn as nn
import csv

In [2]:
import pandas as pd
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel, GPT2Config, AdamW, get_linear_schedule_with_warmup
from torch.utils.data import Dataset, random_split, DataLoader, RandomSampler, SequentialSampler

In [3]:
from tqdm.auto import tqdm

In [4]:
device = 'cpu'
if torch.cuda.is_available():
    device = 'cuda'

tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

In [5]:
RANDOM_SEED = 73
BATCH_SIZE = 1

EPOCHS = 4
SAMPLE_EVERY = 10000

MAX_INPUT_SEQUENCE_LENGTH = 600

In [6]:
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

special_tokens_dict = {'bos_token': '<BOS>', 'eos_token': '<EOS>', 'pad_token': '<PAD>', 'sep_token': '<SEP>'}
num_added_tokens = tokenizer.add_special_tokens(special_tokens_dict)

In [7]:
data = pd.read_csv("/content/drive/MyDrive/refined.csv")

In [8]:
# data = data.dropna()
# data.to_csv('refined.csv')

In [9]:
len(tokenizer)

50261

In [55]:
data.loc[0, 'storyline']

'The pigs elevate themselves to positions of leadership and set aside special food items, ostensibly for their personal health. <SEP> However, the ideals which Snowball discussed, including stalls with electric lighting, heating and running water are forgotten, with Napoleon advocating that the happiest animals live simple lives. <SEP> Mr Frederick, one of the neighbouring farmers, attacks the farm, using blasting powder to blow up the restored windmill.'

In [56]:
class StoryOutlineDataset(Dataset):

    def __init__(self, data, tokenizer, max_input_length):

        self.tokenizer = tokenizer
        self.input_ids = []
        self.attn_masks = []
        self.labels = []
        self.data = data
        self.labels_attn = []

        for i in tqdm(range(len(self.data))):
            text = self.data.loc[i, 'text']
            outline = self.data.loc[i, 'storyline'].split(' ')
            outline = " ".join(outline[:100]).replace("<SEP>", "")

            input = outline + "<SEP>" + text

            encodings_dict_story = tokenizer('<BOS> ' + input + ' <EOS>',
                                     truncation=True,
                                     max_length=max_input_length,
                                     padding=True
                                    )
            
            # encodings_dict_outline = tokenizer(outline,
            #                          truncation=True,
            #                          max_length=max_input_length,
            #                          padding='max_length'
            #                         )

            # self.input_ids.append(torch.tensor(encodings_dict_outline['input_ids']))
            # self.attn_masks.append(torch.tensor(encodings_dict_outline['attention_mask']))
            self.input_ids.append(torch.tensor(encodings_dict_story['input_ids']))
            self.attn_masks.append(torch.tensor(encodings_dict_story['attention_mask']))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, ind):
        return self.input_ids[ind], self.attn_masks[ind]

In [57]:
story_dataset = StoryOutlineDataset(data.loc[0:50], tokenizer, MAX_INPUT_SEQUENCE_LENGTH)

  0%|          | 0/51 [00:00<?, ?it/s]

In [58]:
from torch.utils.data import random_split

In [59]:
def train_val_split(split, dataset):
    train_size = int(split * len(dataset))
    val_size = len(dataset) - train_size
    return train_size, val_size

In [60]:
train_size, val_size = train_val_split(0.8, story_dataset)
train_dataset, val_dataset = random_split(story_dataset, [train_size, val_size])

In [61]:
torch.cuda.manual_seed_all(RANDOM_SEED)
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)

<torch._C.Generator at 0x7fbf903bec50>

In [62]:
train_loader = DataLoader(train_dataset, batch_size = BATCH_SIZE, shuffle = True)
val_loader = DataLoader(val_dataset, batch_size = BATCH_SIZE, shuffle = True)

In [63]:
learning_rate = 5e-4
eps = 1e-8
warmup_steps = 100

In [66]:
tokenizer.encode("<SEP>")

[50260]

In [67]:
configuration = GPT2Config(vocab_size=len(tokenizer), n_positions = MAX_INPUT_SEQUENCE_LENGTH, 
                           activation_function = "gelu_new", resid_pdrop = 0.1, embd_pdrop = 0.2,
                           attn_pdrop = 0.2, output_attentions = True, output_hidden_states = True)

model_config = configuration.from_pretrained('gpt2', output_hidden_states=True)

In [68]:
model = GPT2LMHeadModel.from_pretrained('gpt2')
model.config = model_config
model.resize_token_embeddings(len(tokenizer))

Embedding(50261, 768)

In [69]:
import time
import datetime
scaler = torch.cuda.amp.GradScaler()

In [70]:
def format_time(elapsed):
    return str(datetime.timedelta(seconds=int(round((elapsed)))))

In [71]:
# model.resize_token_embeddings(len(tokenizer))

model.cuda()
optimizer = AdamW(model.parameters(), lr=learning_rate, eps=eps)
total_steps = len(train_loader) * EPOCHS
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0 = 100,
                                                                 T_mult = 3,
                                                                 eta_min = 1e-7)



In [72]:
mse_loss = nn.MSELoss()

In [163]:
def format_out_texts(text):
    t_map = tokenizer.special_tokens_map
    for key in t_map:
        text = text.replace(t_map[key], '')
    return text

def inference(val_loader):
    model.eval()

    for i, batch in enumerate(val_loader):

        if i % 100 == 0:
            lens = np.array([])
            input_ids = batch[0].numpy()
            attn_masks = batch[1].numpy()

            truncated_input = []
            truncated_attention_mask = []
            for i, input_id in enumerate(input_ids):
                context_index = np.where(input_id == 50260)[0][0]
                truncated_input.append(input_id[:context_index+1])
                truncated_attention_mask.append(attn_masks[i][:context_index+1])
                lens = np.append(lens, context_index+1)
    
            max_len = int(np.amax(lens))

            padded_tokens = []
            for tok_ids in truncated_input:
                
                padded_tokens.append(list(tok_ids) + [0] * (max_len - len(tok_ids)))
                
            padded_tokens = torch.LongTensor(padded_tokens).to(device)
            attn_mask = np.zeros(padded_tokens.shape)
            
            for ix, lengths in enumerate(lens):
                print(ix)
                print(lengths)
                attn_mask[ix][:int(lengths)] = 1

            attn_mask = torch.tensor(attn_mask).long().to(device)

    story_ids = model.generate(padded_tokens, attention_mask=attn_mask,
                            num_beams=5,
                            max_length=800,
                            temperature=0.9,
                            remove_invalid_values = True,
                            top_k=50,
                            do_sample=True)

    raw_stories = [tokenizer.decode(story) for story in story_ids]
    output_texts = list(map(format_out_texts, raw_stories))
    print(output_texts)
    return output_texts

In [164]:
# import ERLoss
# from ERLoss import get_er

In [165]:
def train(ep, train_loader):

    total_train_loss = 0

    for step, batch in enumerate(tqdm(train_loader)):

        model.train() 

        b_input_ids = batch[0]
        b_masks = batch[1].to(device)

        labels = b_input_ids.clone().numpy()

        for i, text in enumerate(b_input_ids.numpy()):
            context_index = np.where(text == 50260)[0][0]
            labels[i][:context_index+1] = -100

        model.zero_grad()

        b_input_ids = b_input_ids.to(device)
        labels = torch.tensor(labels).to(device)

        with torch.cuda.amp.autocast():
        
            outputs = model(b_input_ids,
                        attention_mask=b_masks,
                        labels = labels,
                        token_type_ids=None)
            
            

            # outputs_label = model(b_labels,
            #                 labels = b_labels, 
            #                 attention_mask = b_labels_mask,
            #                 token_type_ids = None)

            
            loss = outputs[0]

            # attention = torch.stack(outputs[3])[:,:,-1,:]
            # last_layer_attns_avg_over_heads = attention.mean(dim=0)

            # attention_target = torch.stack(outputs_label[3])[:,:,-1,:]
            # last_layer_attns_avg_over_heads_target = attention_target.mean(dim=0)

            # loss1 = mse_loss(last_layer_attns_avg_over_heads, last_layer_attns_avg_over_heads_target)

                # logits = outputs[1][:,0,:]

                # target_logit = torch.zeros(size=(BATCH_SIZE, 1, len(tokenizer)))

                # for story in range(BATCH_SIZE):
                #     target_logit[story, 0, b_labels[BATCH_SIZE, time]] = 1

                # nn.CrossEntropyLoss(logits, target_logit)

            # story_logits = torch.argmax(logits, dim = 2)
            
            # actual_stories = [tokenizer.decode(story) for story in b_labels]
            # raw_stories = [tokenizer.decode(story) for story in story_logits]

        batch_loss = loss

        # for i in range(len(raw_stories)):

        #     er_target = get_er(actual_stories[i])
        #     er_generate = get_er(raw_stories[i])

        #     target = torch.FloatTensor().cuda()
        #     inp = torch.FloatTensor().cuda()

        #     for token in tokenizer.encode(er_target):
        #         target = torch.cat((target, model.transformer.wte.weight[token].unsqueeze(0)), dim = 0)

        #     for token in tokenizer.encode(er_generate):
        #         inp = torch.cat((inp, model.transformer.wte.weight[token].unsqueeze(0)), dim = 0)

        #     if inp.shape[0] < target.shape[0]:
        #         for i in range(target.shape[0] - inp.shape[0]):
        #             inp = torch.cat((inp, model.transformer.wte.weight[50259].unsqueeze(0)), dim = 0)

        #     else:
        #         for i in range(inp.shape[0] - target.shape[0]):
        #             target = torch.cat((target, model.transformer.wte.weight[50259].unsqueeze(0)), dim = 0)

        #     loss1 += mse_loss(torch.flatten(inp), torch.flatten(target))

        # if ep == 1:
        #     batch_loss = 0.7*loss + 0.3*loss1
        # elif ep >= 2:
        #     batch_loss = 0.5*loss + 0.5*loss1
        # else:

        total_train_loss += batch_loss
        scaler.scale(batch_loss).backward() 
        scaler.step(optimizer) 
        scaler.update()
        scheduler.step()

    avg_train_loss = total_train_loss / len(train_loader)       

    print(f'Average Training Loss: {avg_train_loss}.')


def validate(val_dataloader, file_name):

    model.eval()
    total_eval_loss = 0

    for idx, batch in enumerate(val_dataloader):
        b_input_ids = batch[0]
        b_masks = batch[1].to(device)

        labels = b_input_ids.clone().numpy()

        for i, text in enumerate(b_input_ids.numpy()):
            context_index = np.where(text == 50260)[0][0]
            labels[i][:context_index+1] = -100

        b_input_ids = b_input_ids.to(device)
        labels = torch.tensor(labels).to(device)

        with torch.no_grad():        
            outputs  = model(b_input_ids,  
                            attention_mask=b_masks,
                            labels=labels)

            loss = outputs[0]

        batch_loss = loss
        total_eval_loss += batch_loss   

    avg_val_loss = total_eval_loss / len(val_dataloader)
    inference(val_dataloader)

    print(f'Validation loss: {avg_val_loss}.')
    torch.save(model.state_dict(), '/content/' + file_name)
    return model

In [166]:
for epoch_i in range(0, EPOCHS):
    print(f'Epoch {epoch_i + 1} of {EPOCHS}')
    train(epoch_i, train_loader)
    validate(val_loader, '/drive/MyDrive/model.pth')

Epoch 1 of 4


  0%|          | 0/40 [00:00<?, ?it/s]

Average Training Loss: 0.06558763235807419.


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


0
51.0
[' Serge learns that the time research facility Chronopolis created El Nido thousands of years ago after a catastrophic experimental failure drew it to the past.  Lynx travels with Harle, a mysterious, playful girl dressed like a harlequin.  In Los Angeles in November 2019, ex-police officer Rick Deckard is detained by officer Gaff and brought to his former supervisor, Bryant.\nDeckard, whose job as a "Blade Runner" was to track down bioengineered beings known as replicas and "retire" (a euphemism for killing) them, is informed that four have come to Earth illegally.\nAs Tyrell Corporation Nexus-6 models, they have only a four-year lifespan and may have come to Earth to try to extend their lives.\nDeckard watches a video of a Blade Runner named Holden administering the "Voight-Kampff" test designed to distinguish replicas from humans based on their emotional response to questions.\nThe test subject, Holden, shoots Holden after Holden asks about Leon\'s mother.\nBryant wants Deck

  0%|          | 0/40 [00:00<?, ?it/s]

Average Training Loss: 0.048859477043151855.


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


0
119.0
[" The effectiveness of the technique is demonstrated to a group of VIPs, who watch as Alex collapses before a bully and abases himself before a scantily-clad young woman whose presence has aroused his predatory sexual inclinations.  Two years into his term, he has obtained a job in one of the prison chapels playing religious music on the stereo to accompany the Sunday religious services.  Two policemen come to Alex's rescue, but turn out to be Dim and Billyboy, a former rival gang leader.  The technique is a form of aversion therapy, in which Alex is injected with  -of, anesthetic agents and sent to kill him.\nHe is revived by the aid of Erica Burgoyne, daughter of the local police Chief Constable.\nA group of thugs, led by her backup singers, are seen parading through the streets and pushing white residents aside on the sidewalks.\nBlack occupation soldiers are seen parading through the streets and pushing white residents aside on the sidewalks.\nA group of thugs, led by her 

  0%|          | 0/40 [00:00<?, ?it/s]

Average Training Loss: 0.031587790697813034.


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


0
51.0


KeyboardInterrupt: ignored

In [41]:
a = 'Creatures like us that can anticipate possible futures and make contingency plans have an evolutionary advantage. '

In [None]:
torch.save(model, '/content/drive/MyDrive/model.pth')

In [42]:
encodings_dict_outline = tokenizer(a,
                truncation=True,
                max_length=MAX_INPUT_SEQUENCE_LENGTH,
                padding='max_length')

In [43]:
story_ids = model.generate(torch.tensor([encodings_dict_outline['input_ids']]).cuda(),
                            attention_mask = torch.tensor([encodings_dict_outline['attention_mask']]).cuda(),
                            num_beams=20,
                            max_length=800,
                            temperature=0.9,
                            top_k=50,
                            do_sample=True)

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


In [47]:
tokenizer.decode(story_ids[0])

'Creatures like us that can anticipate possible futures and make contingency plans have an evolutionary advantage.  <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <P