<a href="https://colab.research.google.com/github/yongsun-yoon/deep-learning-paper-implementation/blob/main/03-natural-language-process/DialoGPT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# DialoGPT

## 0. Info

## Paper
* title: DialoGPT: Large-Scale Generative Pre-training for Conversational Response Generation
* author: Yizhe Zhang et al.
* url: https://arxiv.org/abs/1911.00536

## Feats
* data: AI Hub

## Refs
* https://huggingface.co/microsoft/DialoGPT-medium
* https://github.com/xcapt0/gpt2_chatbot

## 1. Setup

In [None]:
import easydict
from glob import glob
from tqdm.auto import tqdm

import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import Adafactor
from datasets import load_dataset

In [None]:
cfg = easydict.EasyDict(
    model_name = 'EleutherAI/polyglot-ko-1.3b',
    device = 'cuda:2',
    num_training_steps = 50000,
)

## 2. Data

In [None]:
class Dataset(torch.utils.data.Dataset):
    def __init__(self, data, tokenizer):
        self.data = data
        self.tokenizer = tokenizer
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        text = self.data[idx]['text']
        dialog = text.split(chr(1000))

        src, tgt = dialog[0], dialog[1:]
        src += self.tokenizer.eos_token
        tgt = self.tokenizer.eos_token.join(tgt)

        src_input_ids = self.tokenizer(src)['input_ids']
        tgt_input_ids = self.tokenizer(tgt)['input_ids']

        input_ids = src_input_ids + tgt_input_ids
        labels = [-100] * len(src_input_ids) + tgt_input_ids
        
        return input_ids, labels

    
def pad_seq(seq, value, max_length):
    seq = seq[:max_length]
    seq += [value] * (max_length - len(seq))
    return seq
    
    
def collate_fn(batch):
    input_ids, labels = list(zip(*batch))

    _max_length = max([len(i) for i in input_ids])
    max_length = min(_max_length, 256)

    input_ids = [pad_seq(i, 2, max_length) for i in input_ids]
    input_ids = torch.tensor(input_ids)

    labels = [pad_seq(l, -100, max_length) for l in labels]
    labels = torch.tensor(labels)
    
    return input_ids, labels

In [None]:
tokenizer = AutoTokenizer.from_pretrained(cfg.model_name)

In [None]:
files = glob('/mnt/dialog-ko/*.txt')
data = load_dataset('text', data_files=files)['train']

In [None]:
dataset = Dataset(data, tokenizer)

In [None]:
dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, shuffle=True, collate_fn=collate_fn)

In [None]:
input_ids, labels = next(iter(dataloader))
input_ids.shape, labels.shape 

## 3. Train

In [None]:
model = AutoModelForCausalLM.from_pretrained(cfg.model_name)
_ = model.train().to(cfg.device)

In [None]:
# optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-3)
optimizer = Adafactor(model.parameters())

In [None]:
dataiter = iter(dataloader)
pbar = tqdm(range(1, cfg.num_training_steps+1))
for st in pbar:
    try: 
        input_ids, labels = next(dataiter)
    except StopIteration:
        dataiter = iter(dataloader)
        input_ids, labels = next(dataiter)
    input_ids, labels = input_ids.to(cfg.device), labels.to(cfg.device)
    
    outputs = model(input_ids, labels=labels)
    loss = outputs.loss

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    pbar.set_postfix({'loss': loss.item()})
    if st % 1000 == 0:
        tokenizer.save_pretrained('dialogpt')
        model.save_pretrained('dialogpt')

## 4. Test

In [None]:
tokenizer = AutoTokenizer.from_pretrained('dialogpt')
model = AutoModelForCausalLM.from_pretrained('dialogpt')
_ = model.eval().requires_grad_(False)

In [None]:
for step in range(5):
    new_user_input_ids = tokenizer.encode(input(">> User:") + tokenizer.eos_token, return_tensors='pt')
    bot_input_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1) if step > 0 else new_user_input_ids
    chat_history_ids = model.generate(bot_input_ids, max_new_tokens=32, repetition_penalty=4.0)
    print("DialoGPT: {}".format(tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)))