In [1]:
%load_ext autoreload
%autoreload 2

from model_utils import train_model, split_data, split_branches, get_model, set_pretrained_model_dropout, get_dataset
from config import Config
import json
import matplotlib.pyplot as plt
%matplotlib inline
import math
import random
import os
import datasets
import torch
import transformers
from transformers import AutoTokenizer, AutoModelForCausalLM
from model_manager import ModelManager

In [2]:
# seed = random.randint(0, 2 ** 32 - 1)
seed = 3218885689
random.seed(seed)
datasets.logging.set_verbosity(datasets.logging.ERROR)
# Tell pytorch to run this model on the GPU.
device_name = "cuda:0" if torch.cuda.is_available() else "cpu"
# device_name = "cpu"
device = torch.device(device_name)
print(f"Will use {device_name} for training with seed: {seed}")

Will use cuda:0 for training with seed: 3218885689


In [3]:
split_data(os.path.join(Config.work_dir, "awsw_story_input.txt"))

In [4]:
config = {
    "lr": 6e-4,
    "warmup_factor": 0,
    "scheduler": "polynomial_decay_schedule_with_warmup",
    "lr_end": 2e-6,
    "power": 0.6,
    #"freeze_layer_rate": 1e-4,
    "freeze_from_steps": -1,
    "seed": seed,
    "num_epoch": 50
}

optuna_result_attachement = {
    'lr': 0.001,
    'scheduler': 'cosine_schedule_with_warmup',
    'to_freeze_count': 0,
    #"to_freeze_gpt_blocks": 11,
    'warmup_factor': 1
}
config.update(optuna_result_attachement)

In [5]:
saved_model_path = os.path.join("models", "awsw_main")
if os.path.exists(os.path.join(saved_model_path, "pytorch_model.bin")):
    print("Pretrained model loaded")
    tokenizer = AutoTokenizer.from_pretrained('EleutherAI/gpt-neo-125M')
    model = AutoModelForCausalLM.from_pretrained(saved_model_path)
else:
    print("Loaded empty model")
    model, tokenizer = get_model("EleutherAI/gpt-neo-125M")
model.to(device)
# set_pretrained_model_dropout(model.transformer.h[-1:], 0.05)

Loaded empty model


GPTNeoForCausalLM(
  (transformer): GPTNeoModel(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(2048, 768)
    (drop): Dropout(p=0, inplace=False)
    (h): ModuleList(
      (0): GPTNeoBlock(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPTNeoAttention(
          (attention): GPTNeoSelfAttention(
            (attn_dropout): Dropout(p=0, inplace=False)
            (resid_dropout): Dropout(p=0, inplace=False)
            (k_proj): Linear(in_features=768, out_features=768, bias=False)
            (v_proj): Linear(in_features=768, out_features=768, bias=False)
            (q_proj): Linear(in_features=768, out_features=768, bias=False)
            (out_proj): Linear(in_features=768, out_features=768, bias=True)
          )
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPTNeoMLP(
          (c_fc): Linear(in_features=768, out_features=3072, bias=True)
          (c_proj): Linear(in_features=3072, o

# Test before training on a pretrained model!

In [6]:
model.eval()
model_manager = ModelManager(model=model, tokenizer=tokenizer, device=device)
def test_regular_sampler():
    print(model_manager.say_raw("In my dreams, I'm a dragon", 50, 0.7))
test_regular_sampler()

In my dreams, I'm a dragon. I'm not afraid of dragons, or dragons that aren't dragons. I don't want to be scared of them. I don't want to be scared of the dragon that I love. I don't want to be scared of the dragon that I love.

I know I can't make it through the day without falling asleep on my bed, and I don't want to go to sleep, and I don't want to fall asleep with the world around me. I don't want to fall asleep with the world around me. I don't want to fall asleep with the world around


# Reviewing our dataset!

In [7]:
dataset = get_dataset(tokenizer)
print("Dataset demo snapshot:")
demo_idx = 0
for item in dataset['train']:
    print(tokenizer.decode(item['input_ids']))
    if demo_idx > 0:
        break
    demo_idx += 1

print("RP review!")
has_seen_rp = False
for item in dataset['train']:
    decoded = tokenizer.decode(item['input_ids'])
    if 'c "Fight ' in decoded: 
        print(decoded)
        has_seen_rp = True
        continue        
    if has_seen_rp:
        print(decoded)
        break
        
del demo_idx, has_seen_rp

  0%|          | 0/2 [00:00<?, ?it/s]

Dataset demo snapshot:


Token indices sequence length is longer than the specified maximum sequence length for this model (2159 > 2048). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (3376 > 2048). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (2085 > 2048). Running this sequence through the model will result in indexing errors


<p><msg>c "I imagine being smaller than the rest of the population would come with its own challenges."<d><scn>loremapt<msg>Lo "It's not that big of deal. If something is unreachable for me, I can fly!"<d><scn>loremapt<msg>Ip "This apartment was actually intended to house one dragon of a bigger size. That not only makes it fairly cheap, but it's also big enough for both of us."<|endoftext|><!DOCTYPE HTML PUBLIC "-//W3C//DTD HTML 4.01 Trans
itional//EN" "http://www.w3.org/TR/html4/loose.dtd">
<!--NewPage-->
<HTML>
<HEAD>
<!-- Generated by javadoc (build 1.6.0_24) on Tue Oct  3 02:22:50 CEST 2010 -->
<META http-equiv="Content-Type" content="text/html; charset=US-ASCII" />
<TITLE>
<p><msg>c "What do you propose?"<d><scn>cafe<
RP review!


Token indices sequence length is longer than the specified maximum sequence length for this model (2085 > 2048). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (2083 > 2048). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (2159 > 2048). Running this sequence through the model will result in indexing errors


/./../machines/mysql/lib/mysql-mysql-connector-php/mysqli.c:31
msgid "No database found"
msgstr ""

#:../.././././././../../machines/mysql/lib/mysql-<p><msg>c "Fight Anna"<d><scn>cafe<msg>m "Anna barely avoids my attack and fell, but managed to get up and quickly punch me in the face, a soaring pain quickly came over my face"<|endoftext|> up to 1-3
 hours.<|endoftext|><p><msg>c "Well, no matter what you might think about humans, I can assure you that I am in no way special or supernatural."<d><scn>o2<msg>Ad "I disagree with that."<|endoftext|>, in the course of its activity with a wide range of biochemical and physicochemical stimuli, was found to be affected by its effect on gene expression. This hypothesis was supported by the observation that in transfected cells the DNA synthesis and gene activity of the RNA-dependent ribonucleotide exchange factor (RREF) mRNA both increased during the in vitro incubation


# Training

Model is put in training mode and we begin training. The `train_results` will contain all data after training is completed.

In [8]:
train_results = {}
model.train()
train_model(model, tokenizer, dataset, config, train_results)

[0] set freeze_part_layers: True (freezing 0 out of 160 layers.)


Step,Training Loss
54,2.3561
108,2.1586
162,2.1144
216,2.0532
270,1.9319
324,1.9575
378,1.7713
432,1.8334
486,1.7072
540,1.7858


KeyboardInterrupt: 

In [9]:
model.eval()
model.save_pretrained("models")

In [None]:
fig, axs = plt.subplots(2)
fig.suptitle('Learning rate and loss')
axs[0].plot(train_results['learning_rate_history'])
axs[1].plot(train_results['loss_history'])

# Testing

We created a few past (for context) + present prompts (player input) and see the different reactions. This way, we can test the models across different iterations.
The first test involves a old prompt to compare the pre-trained model with the one trained on AWSW. Did it manage to store it's data well? Is it able to write down things that have nothing to do with AWSW? (So we know we didn't overfit).

In [10]:
test_regular_sampler()

In my dreams, I'm a dragon. How does that sound?"<d><scn>black<msg>Lo "It sounds pretty simple. We take turns drawing cards and asking questions until we've both asked a number that we agree on beforehand. Whoever gets more right in the end wins."<p><msg>c "How many are we talking about?"<d><scn>black<msg>Lo "I don't know, exactly. It's probably a lot less than I expected."<p><msg>c "How many are we talking about?"<d><scn>black<msg>


**This test generates boring and repetetive** replies! It's because we use no good sampling algorithm, but it does give us a indication of what the model has learned!

In [11]:
prompts = [
    ('<p><msg>c "Hey Remy!"<d><scn>park2<msg>Ry "Hey!"', "How are you?"),
    ('<p><msg>c "I was with Lorem today."<d><scn>park2<msg>Ad "Very nice."', "What do you think of Lorem?"),
    ('<p><msg>m "In Tatsu park, Adine and I sat down."', "Oh my god, Adine. What is this?"),
    ('<p><msg>m "I sat down on a chair in Anna\'s lab."', "What will we do here?"),
]

for (past, prompt) in prompts:
    reply = model_manager.say(past, prompt)
    print(f"Prompt: {prompt}\nReply: {reply}\n\n")

Prompt: How are you?
Reply: park2<msg>Ry "Well, I've got some pretty good ice cream on my hands, and I haven't heard that one before."<d><scn>park2<msg>Ry "Oh, [player_name], I wasn't expecting visitors."<|endoftext|>


Prompt: What do you think of Lorem?
Reply: park2<msg>Lo "It's not exactly the same thing, but I heard that the author is very good at that sort of thing."<p><msg>c "I was with Lorem today."<d><scn>park2<msg>Ad "Very nice."<|endoftext|>


Prompt: Oh my god, Adine. What is this?
Reply: o2<msg>Ad "It's the most amazing time in human history. Why would I visit?"<p><msg>c "It would be a neat ability to have."<d><scn>o2<msg>Ad "But not all the time travel is done for you."<p><msg>c "It would be a neat ability to have, but I don't think so."


Prompt: What will we do here?
Reply: facin3<msg>An "I'll probably just leave them here. No point in doing anything with them."<p><msg>c "What will we do?"<d><scn>facin3<msg>An "I'll probably just leave them here. No point in doing anythi

# Sampling test

This is gonna be interesting!

In [12]:
for i in range(10):
    for (past, prompt) in prompts:
        reply = model_manager.say(past, prompt, top_k = 50, top_p = 0.7)
        print(f"[Test {i + 1}] -> Prompt: {prompt}\nReply: {reply}\n")
    print("-------------")

[Test 1] -> Prompt: How are you?
Reply: park2<msg>Ry "Well, I've got some really good ice cream ideas for you."<d><scn>park2<msg>Ry "Oh, I see."<d><scn>park2<msg>Ry "I must have heard that."<d><scn>park2<msg>Ry "I know what you mean."<d><scn>park2<

[Test 1] -> Prompt: What do you think of Lorem?
Reply: park2<msg>Ad "I don't know. It's not hard for me to guess."<d><scn>park2<msg>Ad "Alright, what about him?"<p><msg>c "He sounds good."<p><msg>c "That sounds good."<p><msg>c "That sounds good."<p><msg

[Test 1] -> Prompt: Oh my god, Adine. What is this?
Reply: o2<msg>Ad "It's the most amazing time in human history, because it's the only time I ever had a chance to see the sun."<|endoftext|>

[Test 1] -> Prompt: What will we do here?
Reply: facin3<msg>An "I'll probably just leave them here. I'd love to talk to her, but I'm not sure if I can do that."<p><msg>c "That wouldn't be so bad."<d><scn>facin3<msg>An "W-Well, I do think you're not a big fan of "friendly", then."<p><msg>c "What are

-

# RP test
Testing out the injected roleplay actions

In [13]:
test_rps = [
    "Visit Lorem",
    "Meet with Lorem",
    "Visit Adine",
    "Fight Maverick",
    "Fight Adine",
    "Attack Adine"
]
for rp in test_rps:
    print(f'{rp} -> {model_manager.say("", rp, top_k = 50, top_p = 0.7)}')

Visit Lorem -> loremapt<msg>Lo "Oh, [player_name], I wasn't expecting visitors."<|endoftext|>
Meet with Lorem -> loremapt<msg>Lo "Oh, [player_name], I wasn't expecting visitors."<|endoftext|>
Visit Adine -> adineapt<msg>Ad "Oh, [player_name], I wasn't expecting visitors."<|endoftext|>
Fight Maverick -> park3<msg>m "Mv "I didn't hesitate to tell her [player_name]!"<|endoftext|>
Fight Adine -> beach<msg>Ad "Fight Adine"<d><scn>beach<msg>Ad "I didn't kill anyone, I'm just doing what I can. That's what I'm trying to find out."<d><scn>beach<msg>Ad "What do you mean?"<|endoftext|>
Attack Adine -> beach<msg>Ad "Don't be. I can see that."<d><scn>beach<msg>Ad "I see. That sounds kinda complicated."<d><scn>beach<msg>Ad "Does that mean we shouldn't stay here for too long?"<p><msg>c "It won't affect me much, if at all."<p><msg>c "Pretty much, though that isn't necessarily true for all of us. Depending on the skin tone, people can be more
