<a href="https://colab.research.google.com/github/sky1515/Rick-Bot/blob/main/Rick_Bot.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
! pip -q install transformers

In [2]:
import glob
import logging
import os
import pickle
import random
import re
import shutil
from tqdm.notebook import tqdm, trange
from pathlib import Path

import pandas as pd
import numpy as np

import torch
from sklearn.model_selection import train_test_split
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler
from torch.utils.data.distributed import DistributedSampler
from transformers import ( MODEL_WITH_LM_HEAD_MAPPING,
    WEIGHTS_NAME,
    AdamW,
    AutoConfig,
    AutoModelWithLMHead,
    AutoTokenizer,
    PreTrainedModel,
    PreTrainedTokenizer,
    get_linear_schedule_with_warmup,
)

try:
    from torch.utils.tensorboard import SummaryWriter
except ImportError:
    from tensorboardX import SummaryWriter

# Configuration
logger = logging.getLogger(__name__)
MODEL_CONFIG_CLASSES = list(MODEL_WITH_LM_HEAD_MAPPING.keys())
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)

In [3]:
class Args():
  def __init__(self):
     self.output_dir='output-small'
     self.model_type='gpt2'
     self.model_name_or_path='microsoft/DialoGPT-small'
     self.config_name='microsoft/DialoGPT-small'
     self.tokenizer_name='microsoft/DialoGPT-small'
     self.cache_dir='cached'
     self.block_size= 512
     self.do_train= True
     self.do_eval= True
     self.evaluate_during_training= False
     self.per_gpu_train_batch_size = 4
     self.per_gpu_eval_batch_size = 4
     self.gradient_accumulation_steps = 1
     self.learning_rate = 5e-5
     self.weight_decay = 0.0
     self.adam_epsilon = 1e-8
     self.max_grad_norm = 1
     self.num_train_epochs = 3
     self.max_steps = -1
     self.warmup_steps = 0
     self.logging_steps = 1000
     self.save_steps = 3500
     self.save_total_limit = None
     self.eval_all_checkpoints = False
     self.no_cuda = False
     self.overwrite_output_dir = True
     self.overwrite_cache = True
     self.should_continue = False
     self.seed = 42
     self.local_rank = -1
     self.fp16 = False
     self.fp16_opt_level = '01'

args= Args()


In [4]:
rick_df=pd.read_csv('/content/drive/MyDrive/Colab Notebooks/Dataset/RickAndMortyScripts.csv')
rick_df.head()

Unnamed: 0,index,season no.,episode no.,episode name,name,line
0,0,1,1,Pilot,Rick,Morty! You gotta come on. Jus'... you gotta co...
1,1,1,1,Pilot,Morty,"What, Rick? What’s going on?"
2,2,1,1,Pilot,Rick,"I got a surprise for you, Morty."
3,3,1,1,Pilot,Morty,It's the middle of the night. What are you tal...
4,4,1,1,Pilot,Rick,"Come on, I got a surprise for you. Come on, h..."


In [5]:
contexted = []
n=7

for i in range(n, len(rick_df['line'])):
  row=[]
  prev=i-1-n # to contain all prev 7 rows

  for j in range(i, prev, -1):
    row.append(rick_df['line'][j])
  contexted.append(row)

columns=['response','context']
columns = columns + ['context/'+str(i) for i in range(n-1)]
df = pd.DataFrame.from_records(contexted, columns=columns)
df.head()

Unnamed: 0,response,context,context/0,context/1,context/2,context/3,context/4,context/5
0,"What do you think of this... flying vehicle, M...","We gotta go, gotta get outta here, come on. Go...",Ow! Ow! You're tugging me too hard!,"Come on, I got a surprise for you. Come on, h...",It's the middle of the night. What are you tal...,"I got a surprise for you, Morty.","What, Rick? What’s going on?",Morty! You gotta come on. Jus'... you gotta co...
1,"Yeah, Rick... I-it's great. Is this the surprise?","What do you think of this... flying vehicle, M...","We gotta go, gotta get outta here, come on. Go...",Ow! Ow! You're tugging me too hard!,"Come on, I got a surprise for you. Come on, h...",It's the middle of the night. What are you tal...,"I got a surprise for you, Morty.","What, Rick? What’s going on?"
2,Morty. I had to... I had to do it. I had— I ha...,"Yeah, Rick... I-it's great. Is this the surprise?","What do you think of this... flying vehicle, M...","We gotta go, gotta get outta here, come on. Go...",Ow! Ow! You're tugging me too hard!,"Come on, I got a surprise for you. Come on, h...",It's the middle of the night. What are you tal...,"I got a surprise for you, Morty."
3,What?! A bomb?!,Morty. I had to... I had to do it. I had— I ha...,"Yeah, Rick... I-it's great. Is this the surprise?","What do you think of this... flying vehicle, M...","We gotta go, gotta get outta here, come on. Go...",Ow! Ow! You're tugging me too hard!,"Come on, I got a surprise for you. Come on, h...",It's the middle of the night. What are you tal...
4,We're gonna drop it down there just get a whol...,What?! A bomb?!,Morty. I had to... I had to do it. I had— I ha...,"Yeah, Rick... I-it's great. Is this the surprise?","What do you think of this... flying vehicle, M...","We gotta go, gotta get outta here, come on. Go...",Ow! Ow! You're tugging me too hard!,"Come on, I got a surprise for you. Come on, h..."


In [7]:
train_df, test_df = train_test_split(df, test_size=0.1)

def construct_conv(row, tokenizer, eos=True):
  flatten = lambda l: [item for sublist in l for item in sublist]
  conv = list(reversed([tokenizer.encode(x) + [tokenizer.eos_token_id] for x in row]))
  conv = flatten(conv)
  return conv

class ConvData(Dataset):
  def __init__(self, tokenizer: PreTrainedTokenizer, args, df, block_size=512):
    block_size=block_size - (tokenizer.max_len - tokenizer.max_len_single_sentence)
    directory = args.cache_dir
    cached_features_file = os.path.join(directory, args.model_type + "_cached_lm_"+ str(block_size))

    if os.path.exists(cached_features_file) and not args.overwrite_cache:
      logging.info("Loading features from cached file %s", cached_features_file)
      with open(cached_features_file, 'rb') as handle:
        self.examples = pickle.load(handle)

    else:
      logging.info("Creating features from dataset file at %s", directory)
      self.examples=[]
      for _, row in df.iterrows():
        conv=construct_conv(row, tokenizer)
        self.examples.append(conv)

      logging.info("Saving features into cached file %s", cached_features_file)
      with open(cached_features_file, 'wb') as handle:
        pickle.dump(self.examples, handle, protocol=pickle.HIGHEST_PROTOCOL)


  def __len__(self):
    return len(self.examples)


  def __getitem__(self, item):
    return torch.tensor(self.examples[item], dtype=torch.long)


In [None]:
# Caching and storing of data checkpoints
