<a href="https://colab.research.google.com/github/pranavkarnani/StoryGenerator/blob/pranav/GPT2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
! pip install transformers



In [2]:
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import numpy as np
import random
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import GPT2Tokenizer, GPT2LMHeadModel, AdamW, get_linear_schedule_with_warmup
import torch.nn.functional as F
import torch.nn as nn
import csv

In [3]:
import pandas as pd
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel, GPT2Config, AdamW, get_linear_schedule_with_warmup
from torch.utils.data import Dataset, random_split, DataLoader, RandomSampler, SequentialSampler

In [4]:
from tqdm.auto import tqdm

In [5]:
device = 'cpu'
if torch.cuda.is_available():
    device = 'cuda'

tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

In [6]:
RANDOM_SEED = 73
BATCH_SIZE = 1

EPOCHS = 4
SAMPLE_EVERY = 100

MAX_INPUT_SEQUENCE_LENGTH = 800

In [7]:
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

special_tokens_dict = {'bos_token': '<BOS>', 'eos_token': '<EOS>', 'pad_token': '<PAD>', 'sep_token': '<SEP>'}
num_added_tokens = tokenizer.add_special_tokens(special_tokens_dict)

In [8]:
data = pd.read_csv("/content/drive/MyDrive/refined.csv")

In [9]:
# data = data.dropna()
# data.to_csv('refined.csv')

In [10]:
len(tokenizer)

50261

In [11]:
class StoryOutlineDataset(Dataset):

    def __init__(self, data, tokenizer, max_input_length):

        self.tokenizer = tokenizer
        self.input_ids = []
        self.attn_masks = []
        self.labels = []
        self.data = data

        for i in tqdm(range(len(self.data))):
            text = self.data.loc[i, 'text']
            outline = self.data.loc[i, 'storyline']

            encodings_dict_story = tokenizer('<BOS> ' + text + ' <EOS>',
                                     truncation=True,
                                     max_length=max_input_length,
                                     padding='max_length'
                                    )
            
            encodings_dict_outline = tokenizer('<BOS> ' + outline + ' <EOS>',
                                     truncation=True,
                                     max_length=max_input_length,
                                     padding='max_length'
                                    )

            self.input_ids.append(torch.tensor(encodings_dict_outline['input_ids']))
            self.attn_masks.append(torch.tensor(encodings_dict_outline['attention_mask']))
            self.labels.append(torch.tensor(encodings_dict_story['input_ids']))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, ind):
        return self.input_ids[ind], self.attn_masks[ind], self.labels[ind]

In [32]:
story_dataset = StoryOutlineDataset(data, tokenizer, MAX_INPUT_SEQUENCE_LENGTH)

  0%|          | 0/108828 [00:00<?, ?it/s]

In [33]:
from torch.utils.data import random_split

In [34]:
def train_val_split(split, dataset):
    train_size = int(split * len(dataset))
    val_size = len(dataset) - train_size
    return train_size, val_size

In [35]:
train_size, val_size = train_val_split(0.8, story_dataset)
train_dataset, val_dataset = random_split(story_dataset, [train_size, val_size])

In [36]:
torch.cuda.manual_seed_all(RANDOM_SEED)
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)

<torch._C.Generator at 0x7fd0615d26f0>

In [37]:
train_loader = DataLoader(train_dataset, batch_size = BATCH_SIZE, shuffle = True)
val_loader = DataLoader(val_dataset, batch_size = BATCH_SIZE, shuffle = True)

In [38]:
learning_rate = 5e-4
eps = 1e-8
warmup_steps = 100

In [39]:
configuration = GPT2Config(vocab_size=len(tokenizer), n_positions = MAX_INPUT_SEQUENCE_LENGTH, 
                           activation_function = "gelu_new", resid_pdrop = 0.1, embd_pdrop = 0.2,
                           attn_pdrop = 0.2, eos_token_id = 50256, pad_token_id = 50256)

In [40]:
model_config = configuration.from_pretrained('gpt2', output_hidden_states=True)

In [41]:
import time
import datetime
scaler = torch.cuda.amp.GradScaler()

In [42]:
def format_time(elapsed):
    return str(datetime.timedelta(seconds=int(round((elapsed)))))

In [43]:
model = GPT2LMHeadModel.from_pretrained('gpt2', config=model_config)
model.resize_token_embeddings(len(tokenizer))

model.cuda()
optimizer = AdamW(model.parameters(), lr=learning_rate, eps=eps)

total_steps = len(train_loader) * EPOCHS
scheduler = get_linear_schedule_with_warmup(optimizer,
                                            num_warmup_steps=warmup_steps,
                                            num_training_steps=total_steps)



In [44]:
def format_out_texts(text):
    t_map = tokenizer.special_tokens_map
    for key in t_map:
        text = text.replace(t_map[key], '')
    return text

def inference(input_id, attn_mask ,tokenizer):
    model.eval()

    story_ids = model.generate(input_id,
                            attention_mask = attn_mask,
                            num_beams=20,
                            max_length=800,
                            temperature=0.9,
                            top_k=50,
                            do_sample=True)
    
    raw_stories = [tokenizer.decode(story) for story in story_ids]
    output_texts = list(map(format_out_texts, raw_stories))
    print(output_texts)
    return output_texts

In [45]:
import ERLoss
from ERLoss import get_er

In [46]:
mse_loss = nn.MSELoss()

In [48]:
def train(train_loader):

    total_train_loss = 0

    for step, batch in enumerate(tqdm(train_loader)):
        
        b_input_ids = batch[0].to(device)
        b_masks = batch[1].to(device)
        b_labels = batch[2].to(device)

        model.zero_grad()
        model.train()        

        with torch.cuda.amp.autocast():
            
            outputs = model(b_input_ids,
                            labels=b_labels, 
                            attention_mask=b_masks,
                            token_type_ids=None)

            
            loss = outputs[0]
            logits = outputs[1]

            story_logits = torch.argmax(logits, dim = 2)
            
            actual_stories = [tokenizer.decode(story) for story in b_labels]
            raw_stories = [tokenizer.decode(story) for story in story_logits]


        loss1 = 0

        for i in range(len(raw_stories)):

            er_target = get_er(actual_stories[i])
            er_generate = get_er(raw_stories[i])

            target = torch.FloatTensor().cuda()
            inp = torch.FloatTensor().cuda()

            for token in tokenizer.encode(er_target):
                target = torch.cat((target, model.transformer.wte.weight[token].unsqueeze(0)), dim = 0)

            for token in tokenizer.encode(er_generate):
                inp = torch.cat((inp, model.transformer.wte.weight[token].unsqueeze(0)), dim = 0)

            if inp.shape[0] < target.shape[0]:
                for i in range(target.shape[0] - inp.shape[0]):
                    inp = torch.cat((inp, model.transformer.wte.weight[50259].unsqueeze(0)), dim = 0)

            else:
                for i in range(inp.shape[0] - target.shape[0]):
                    target = torch.cat((target, model.transformer.wte.weight[50259].unsqueeze(0)), dim = 0)

            loss1 += mse_loss(torch.flatten(inp), torch.flatten(target))

        if i == 1:
            batch_loss = 0.7*loss + 0.3*loss1
        elif i >= 2:
            batch_loss = 0.5*loss + 0.5*loss1
        else:
            batch_loss = loss

        total_train_loss += batch_loss

        if step % SAMPLE_EVERY == 0 and step != 0:
            inference(b_input_ids, b_masks, tokenizer)

        scaler.scale(batch_loss).backward() 
        scaler.step(optimizer) 
        scaler.update()

    avg_train_loss = total_train_loss / len(train_loader)       

    print(f'Average Training Loss: {avg_train_loss}.')


def validate(val_dataloader, file_name):

    model.eval()
    total_eval_loss = 0
    nb_eval_steps = 0

    for batch in val_dataloader:
        b_input_ids = batch[0].to(device)
        b_masks = batch[1].to(device)
        b_labels = batch[2].to(device)

        with torch.no_grad():        

            outputs  = model(b_input_ids,  
                                attention_mask=b_masks,
                                labels=b_labels)

            loss = outputs[0]

        batch_loss = loss.item()
        total_eval_loss += batch_loss        

    avg_val_loss = total_eval_loss / len(val_dataloader)

    print(f'Validation loss: {avg_val_loss}.')
    torch.save(model.state_dict(), '/content/' + file_name)
    return model

In [49]:
for epoch_i in range(0, EPOCHS):
    print(f'Epoch {epoch_i + 1} of {EPOCHS}')
    train(train_loader)
    validate(val_loader, 'model.pth')

Epoch 1 of 4


  0%|          | 0/87062 [00:00<?, ?it/s]

TypeError: ignored

In [123]:
import spacy
from spacy import displacy

from spacy.matcher import Matcher 
from spacy.tokens import Span 

import nltk
nltk.download("punkt")

from nltk.tokenize import word_tokenize

nlp = spacy.load('en_core_web_sm')

def get_entity_pairs(sentences):
    entity_pairs = []
    for i in sentences:
        entity_pairs.append(get_entities(i))
    return entity_pairs

def get_entities(sent):

  ent1 = ""
  ent2 = ""

  prv_tok_dep = ""    
  prv_tok_text = ""   

  prefix = ""
  modifier = ""
  
  for tok in nlp(sent):
    
    if tok.dep_ != "punct":
      
      if tok.dep_ == "compound":
        prefix = tok.text
        
        if prv_tok_dep == "compound":
          prefix = prv_tok_text + " "+ tok.text
      

      if tok.dep_.endswith("mod") == True:
        modifier = tok.text

        if prv_tok_dep == "compound":
          modifier = prv_tok_text + " "+ tok.text
      
      if tok.dep_.find("subj") == True:
        ent1 = modifier +" "+ prefix + " "+ tok.text
        prefix = ""
        modifier = ""
        prv_tok_dep = ""
        prv_tok_text = ""      

      if tok.dep_.find("obj") == True:
        ent2 = modifier +" "+ prefix +" "+ tok.text
        
      prv_tok_dep = tok.dep_
      prv_tok_text = tok.text

  return [ent1.strip(), ent2.strip()]


def get_relation(sent):

  doc = nlp(sent)

  matcher = Matcher(nlp.vocab)

  pattern = [{'DEP':'ROOT'}, 
            {'DEP':'prep','OP':"?"},
            {'DEP':'agent','OP':"?"},  
            {'POS':'ADJ','OP':"?"}] 

  matcher.add("matching_1", None, pattern) 

  matches = matcher(doc)
  k = len(matches) - 1

  print(k, len(doc))
  if len(doc) > k and k > 0:
    span = doc[matches[k][1]:matches[k][2]] 
    return(span.text)
  else:
    return ''

def get_relations(sentences):
    relations = [get_relation(i) for i in sentences]
    return relations


def get_er(story):
    sentences = story.split(".")
    entity_pairs = get_entity_pairs(sentences)
    relations = get_relations(sentences)
    sequence = ""
    for i in range(len(entity_pairs)):
        if relations[i] != '':
            sequence += entity_pairs[i][0] + ' '
            sequence += relations[i] + ' '
            sequence += entity_pairs[i][1] + ' '
    return sequence

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
