In [1]:
import yaml
import torch
import nltk
from glob import glob
from transformers import GPT2Tokenizer, GPT2LMHeadModel

from chatbot_files.data import Dialogues
from chatbot_files.utils import set_seed

nltk.download('wordnet')
nltk.download('omw-1.4')
nltk.download('punkt')
nltk.download('punkt_tab')

### Make Sure the seed is imported
# from utils import set_seed

args = yaml.safe_load(open('config.yml'))
set_seed(args['seed']) 
print(args)


def load_tokenizer(args):
    tokenizer = GPT2Tokenizer.from_pretrained(args['model_name'])
    special_tokens = ['<speaker1>', '<speaker2>']
    tokenizer.add_special_tokens({
        'bos_token': '<bos>',
        'additional_special_tokens': special_tokens
    })

    # add new token ids to args
    special_tokens += ['<bos>', '<eos>']
    sp1_id, sp2_id, bos_id, eos_id = tokenizer.encode(special_tokens)
    args['sp1_id'] = sp1_id
    args['sp2_id'] = sp2_id
    args['bos_id'] = bos_id
    args['eos_id'] = eos_id

    return tokenizer

def load_model(args, tokenizer, device):
    model = GPT2LMHeadModel.from_pretrained(args["model_name"]).to(device)
    model.resize_token_embeddings(len(tokenizer))
    return model

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
args['device'] = device

print("--"*50)
print(f'Using device: {device}')
print("--"*50)

tokenizer = load_tokenizer(args)
model = load_model(args, tokenizer, device)

  from .autonotebook import tqdm as notebook_tqdm
[nltk_data] Downloading package wordnet to
[nltk_data]     /home/rahulbharti/Preojects/college-chatbot-
[nltk_data]     gpt2/venv/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package omw-1.4 to
[nltk_data]     /home/rahulbharti/Preojects/college-chatbot-
[nltk_data]     gpt2/venv/nltk_data...
[nltk_data]   Package omw-1.4 is already up-to-date!
[nltk_data] Downloading package punkt to
[nltk_data]     /home/rahulbharti/Preojects/college-chatbot-
[nltk_data]     gpt2/venv/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package punkt_tab to
[nltk_data]     /home/rahulbharti/Preojects/college-chatbot-
[nltk_data]     gpt2/venv/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!


{'structure_dataset_dir': './process_data/structred_data', 'corpus_dataset_dir': './process_data/corpus_data', 'train_frac': 0.85, 'model_name': 'gpt2', 'seed': 8459, 'lr': 2e-05, 'warmup_ratio': 0.1, 'batch_size': 1, 'num_epochs': 10, 'max_len': 100, 'max_history': 5, 'models_dir': './models', 'stop_command': 'bye', 'top_p': 0.9, 'top_k': 50, 'temperature': 0.9, 'mode': 'train', 'checkpoint': 'None', 'model_dir': './models'}
----------------------------------------------------------------------------------------------------
Using device: cpu
----------------------------------------------------------------------------------------------------
