In [None]:
!nvidia-smi

In [None]:
!pip install transformers

In [None]:
import torch
import random
import numpy as np
import time
import datetime
import seaborn as sns
import pandas as pd
import os
import pathlib
from random import randrange

import matplotlib.pyplot as plt
%matplotlib inline

from transformers import GPT2LMHeadModel, GPT2Tokenizer, GPT2Config, GPT2LMHeadModel
from transformers import AdamW, get_linear_schedule_with_warmup
from transformers import Trainer, TrainingArguments

In [None]:
seed = 29384
# Tell pytorch to run this model on the GPU.
device = torch.device("cuda")

In [None]:
if os.path.isdir("/opt/awsw"):
  # In case we run this locally (in Docker)
  work_dir = os.path.join("/opt", "awsw")
else:
  from google.colab import drive
  drive.mount('/content/drive')
  work_dir = os.path.join("/content", "drive", "MyDrive", "endless_awsw")

models_dir = os.path.join(work_dir, "models")

if not os.path.isdir(models_dir):
    pathlib.Path(models_dir).mkdir(parents=True, exist_ok=True)
    
tokenizer = GPT2Tokenizer.from_pretrained('distilgpt2', bos_token='<|startoftext|>', eos_token='<|endoftext|>', pad_token='<|pad|>') #gpt2-medium
model = GPT2LMHeadModel.from_pretrained('distilgpt2', pad_token_id = tokenizer.eos_token_id)
print(f"Loading empty, pre-trained model.")

model.to(device)
model.resize_token_embeddings(len(tokenizer))
print("Model attached to GPU")

# Split data

In [None]:
with open("awsw_story_input.txt") as f:
    data = f.read()
lines = data.split("\n")
lines_by_command = {}
for line in lines:
    line = line.strip()
    line_split = line.split(" ")
    if len(line_split) <= 1:
        continue
    if not line_split[0] in lines_by_command:
        lines_by_command[line_split[0]] = []
        
    if line_split[0] == "c":
        line = "PlayerReply " + line
    else:
        line = "DragonReply " + line
    lines_by_command[line_split[0]].append(line)
    
train_lines = []
eval_lines = []

for command in lines_by_command.keys():
    if command != "c":
        for i in range(50):
            if len(lines_by_command[command]) > 2:
                random_idx = randrange(len(lines_by_command[command]))
                eval_lines.append(lines_by_command[command][random_idx])
                del lines_by_command[command][random_idx]
    train_lines.extend(lines_by_command[command])

joined_eval_lines = "\n".join(eval_lines[:5])
print(f"eval_lines: {joined_eval_lines}")
joined_train_lines = "\n".join(train_lines[:5])
print(f"train_lines: {joined_train_lines}")

if not os.path.isfile("data_train.txt"):
    with open("data_train.txt", "w") as f:
        for l in train_lines:
            f.write(l + "\n")
            
if not os.path.isfile("data_test.txt"):
    with open("data_test.txt", "w") as f:
        for l in eval_lines:
            f.write(l + "\n")

In [None]:
from datasets import load_dataset
dataset = load_dataset('text', data_files={'train': os.path.join(work_dir, "data_train.txt"), 'test': os.path.join(work_dir, "data_test.txt")})
def encode(batch):
    encoded = tokenizer([f"{text}<|endoftext|>" for text in batch['text']])
    return encoded

def group_texts(examples):
    # Make a max size
    block_size = 128
    # Concatenate all texts.
    concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
    total_length = len(concatenated_examples[list(examples.keys())[0]])
    # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
    # customize this part to your needs.
    total_length = (total_length // block_size) * block_size
    # Split by chunks of block_size.
    result = {
        k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
        for k, t in concatenated_examples.items()
    }
    result["labels"] = result["input_ids"].copy()
    return result

dataset = dataset.map(
    encode,
    batched=True,
    batch_size=1000,
    remove_columns=["text"],
    num_proc=4
)

dataset = dataset.map(
    group_texts,
    batched=True,
    batch_size=1000,
    num_proc=4
)

dataset

In [None]:
class AWSWTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        if self.label_smoother is not None and "labels" in inputs:
            labels = inputs.pop("labels")
        else:
            labels = None
        outputs = model(**inputs)
        
        # Save past state if it exists
        # TODO: this needs to be fixed and made cleaner later.
        if self.args.past_index >= 0:
            self._past = outputs[self.args.past_index]

        if labels is not None:
            loss = self.label_smoother(outputs, labels)
        else:
            # We don't use .loss here since the model may return tuples instead of ModelOutput.
            loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]

        return (loss, outputs) if return_outputs else loss


def train(model):
    training_args = TrainingArguments(
        models_dir,
        seed=seed,
        per_device_train_batch_size=2,
        per_device_eval_batch_size=2,
        num_train_epochs=100,
        save_total_limit=2,
        save_steps=1000
    )
    trainer = Trainer(
        model=model, 
        args=training_args, 
        train_dataset=dataset['train'], 
        eval_dataset=dataset['test']
    )
    checkpoint_dirs = [os.path.join(models_dir, d) for d in os.listdir(models_dir) if os.path.isdir(os.path.join(models_dir, d))]
    if len(checkpoint_dirs) > 0:
        latest_checkpoint = max(checkpoint_dirs, key=os.path.getmtime)
        trainer.train(latest_checkpoint)
    else:
        trainer.train()

train(model)

In [None]:
def generate_reply(prompt):
  model.eval()
  prompt = f'c "{prompt}"{tokenizer.eos_token}DragonReply'
  generated = torch.tensor(tokenizer.encode(prompt)).unsqueeze(0)
  generated = generated.to(device)
  print(prompt, generated)

  sample_outputs = model.generate(
    generated, 
    do_sample=True,   
    eos_token_id=tokenizer.eos_token_id,
    top_k=50, 
    max_length = 128,
    top_p=0.95, 
    num_return_sequences=3
  )

  for i, sample_output in enumerate(sample_outputs):
    print("{}: {}\n\n".format(i, tokenizer.decode(sample_output, skip_special_tokens=False)))

print("What to say?")
print(generate_reply(input()))