<a href="https://colab.research.google.com/github/stepan-fukalov/ml/blob/master/gpt2_finetuning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [3]:
!pip install transformers -q

In [4]:
import json
import gdown
from torch.utils.data import Dataset
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import torch
from torch.optim import Adam
from torch.utils.data import DataLoader
from tqdm import tqdm

In [5]:
def default_device():
  '''Indicate availability of GPU, otherwise return CPU'''
  if torch.cuda.is_available():
    return 'cuda'
  else:
    return 'cpu'

device = default_device()
device

'cuda'

In [6]:
def to_device(tensor, device=device):
  '''Move tensor to chosen device'''
  if isinstance(tensor, (list, tuple)):
    return [to_device(x, device) for x in tensor]
  return tensor.to(device, non_blocking=True)

In [7]:
chat_data_url = "https://drive.google.com/file/d/13ARq6cCISHx48dF86nr65rV6HddRERbn/view?usp=sharing"
gdown.download(url=chat_data_url,
               quiet=True,
               fuzzy=True)

'chat_data.json'

In [55]:
def make_prompt(question, answer):
  return "<startofstring> "+question+" <bot>: "+answer+" <endofstring>"

class ChatData(Dataset):
    def __init__(self,
                 path: str,
                 tokenizer):
        self.data = json.load(open(path, "r"))

        self.X = []
        for data in self.data:
          texts = [dial["text"] for dial in data["dialog"]]
          dialog = [make_prompt(texts[idx], texts[idx+1]) for idx in range(len(texts)-1)]
          self.X += dialog

        self.X_encoded = tokenizer(self.X,
                                   max_length=40,
                                   truncation=True,
                                   padding="max_length",
                                   return_tensors="pt")
        self.input_ids = self.X_encoded['input_ids']
        self.attention_mask = self.X_encoded['attention_mask']

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return (self.input_ids[idx], self.attention_mask[idx])

In [56]:
def infer(inp):
    inp = "<startofstring> "+inp+" <bot>: "
    inp = tokenizer(inp, return_tensors="pt")
    X = to_device(inp["input_ids"])
    a = to_device(inp["attention_mask"])
    output = model.generate(X, attention_mask=a )
    output = tokenizer.decode(output[0])
    return output

def train(chatData, model, optim):

    epochs = 10

    for i in tqdm(range(epochs)):
        for X, a in chatData:
            X = to_device(X)
            a = to_device(a)
            loss = model(X, attention_mask=a, labels=X).loss
            optim.zero_grad()
            loss.backward()
            optim.step()
        torch.save(model.state_dict(), "model_state.pt")
        print(infer("Hello"))

tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.add_special_tokens({"pad_token": "<pad>",
                              "bos_token": "<startofstring>",
                              "eos_token": "<endofstring>"})
tokenizer.add_tokens(["<bot>:"])

model = GPT2LMHeadModel.from_pretrained("gpt2")
model.resize_token_embeddings(len(tokenizer))
to_device(model)

chatData = ChatData("chat_data.json", tokenizer)
chatData = DataLoader(chatData, batch_size=64)

model.train()

optim = Adam(model.parameters())

In [57]:
train(chatData, model, optim)

  0%|          | 0/10 [00:00<?, ?it/s]Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
 10%|█         | 1/10 [00:39<05:56, 39.56s/it]

<startofstring> Hello <bot>: <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
 20%|██        | 2/10 [01:15<05:01, 37.65s/it]

<startofstring> Hello <bot>: I am a huge fan of the very good <endofstring> <pad> <pad> <pad> <pad> <pad> <pad> <pad>


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
 30%|███       | 3/10 [01:52<04:20, 37.15s/it]

<startofstring> Hello <bot>: I am a huge gamer <bot>: <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
 40%|████      | 4/10 [02:29<03:41, 36.98s/it]

<startofstring> Hello <bot>: i am a huge gamer <bot>: ok <endofstring> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
 50%|█████     | 5/10 [03:09<03:10, 38.17s/it]

<startofstring> Hello <bot>: Hello, how are you? <endofstring> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
 60%|██████    | 6/10 [03:46<02:30, 37.70s/it]

<startofstring> Hello <bot>: Hi, how are you doing? <endofstring> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
 70%|███████   | 7/10 [04:22<01:52, 37.34s/it]

<startofstring> Hello <bot>: i am a huge gamer <endofstring> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
 80%|████████  | 8/10 [04:59<01:14, 37.09s/it]

<startofstring> Hello <bot>: Hello <endofstring> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
 90%|█████████ | 9/10 [05:39<00:38, 38.08s/it]

<startofstring> Hello <bot>: Hi, how are you? <endofstring> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
100%|██████████| 10/10 [06:16<00:00, 37.68s/it]

<startofstring> Hello <bot>: Hello, how are you? <endofstring> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>





In [61]:
inp = input()
print(inp)
print(infer(inp))

Hello, how are you?


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


Hello, how are you?




<startofstring> Hello, how are you? <bot>: Hi <endofstring> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>
