In [1]:
%load_ext autoreload
%autoreload 2
import os
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from model_manager import ModelManager
from model_mixing import ModelMixing
from model_utils import get_model
from config import Config

In [2]:
saved_model_path = os.path.join("models", "awsw_main")
tokenizer = AutoTokenizer.from_pretrained('EleutherAI/gpt-neo-125M')
main_model = AutoModelForCausalLM.from_pretrained(saved_model_path)
base_model, _ = get_model("EleutherAI/gpt-neo-125M")
target_model, _ = get_model("EleutherAI/gpt-neo-125M")
device = torch.device('cpu')
model_manager = ModelManager(model=target_model, tokenizer=tokenizer, device=device)
print(f"Loaded base and main model to CPU")

Loaded base and main model to CPU


In [3]:
prompts = [
    ('<p><msg>c "Hey Remy!"<d><scn>park2<msg>Ry "Hey!"', "How are you?"),
    ('<p><msg>c "I was with Lorem today."<d><scn>park2<msg>Ad "Very nice."', "What do you think of Lorem?"),
    ('<p><msg>m "In Tatsu park, Adine and I sat down."', "Oh my god, Adine. What is this?"),
    ('<p><msg>m "I sat down on a chair in Anna\'s lab."', "What will we do here?"),
]

def sample_test(model_manager):
    for (past, prompt) in prompts:
        print(f"Prompt: {prompt}")
        reply = model_manager.say(past, prompt)
        print(f"Reply: {reply}")
        reply = model_manager.say(past, prompt, top_k = 50, top_p = 0.7)
        print(f"Reply [sampled]: {reply}")
        print("-" * 10)
sample_test(model_manager)

Prompt: How are you?
Reply: park2<msg>Ry "How are you?"<p><msg>c "How are you?"<d><scn>park2<msg>Ry "How are you?"<p><msg>c "How are you?"<d><scn>park2<msg>Ry "How are you?"<p><msg>c "How are you?"<d><scn>park2
Reply [sampled]: park2<msg>Remy<msg>C<scn>Remy<scn>Remy<scn>C<scn>Remy<scn>C<scn>Remy<scn>C<scn>Remy<scn>C<scn>C<scn>Remy<scn>C<scn>C<scn>
----------
Prompt: What do you think of Lorem?
Reply: park2<msg>Ad "Very nice."<p><msg>c "What do you think of Lorem?"<d><scn>park2<msg>Ad "Very nice."<p><msg>c "What do you think of Lorem?"<d><scn>park2<msg>Ad "Very nice."<p><msg>
Reply [sampled]: park2<msg>Ad "Very nice."<p><msg>c "What do you think of Lorem?"<d><scn>park2<msg>Ad "Very nice."<p><msg>c "What do you think of Lorem?"<d><scn>park2<msg>Ad "Very nice."<p><msg>
----------
Prompt: Oh my god, Adine. What is this?
Reply: c "Oh my god, Adine. What is this?"<d><scn>c "Oh my god, Adine. What is this?"<d><scn>c "Oh my god, Adine. What is this?"<d><scn>c "Oh my god, Adine. What is this?"<

In [4]:
def is_good(model):
    is_good = False
    for (past, prompt) in prompts:
        reply = model_manager.say(past, prompt)
        print(reply)
        return reply.startswith('park2<msg>Ry "I\'m fine')
    return is_good
model_mixing = ModelMixing(base_model, main_model, target_model, is_good, seed = 3443)
model_mixing.start_mixing()

park2<msg>Ry "How are you?"<p><msg>c "How are you?"<d><scn>park2<msg>Ry "How are you?"<p><msg>c "How are you?"<d><scn>park2<msg>Ry "How are you?"<p><msg>c "How are you?"<d><scn>park2
park2<msg>Remy<msg>c "Hey!"<p><msg>c "Hey!"<d><scn>park2<msg>Remy<msg>c "Hey!"<p><msg>c "Hey!"<d><scn>park2<msg>Remy<msg>c "Hey!"<p><msg>c "Hey!"<d><scn
park2<msg>Remy "Good, I'm fine."<p><msg>c "You're fine, I'm fine."<d><scn>park2<msg>Remy "You're fine, I'm fine."<p><msg>c "You're fine, I'm fine."<d><scn>park2<msg>Remy "You're fine, I'm
park2<msg>Remy "Good, I'm fine."<p><msg>c "You're fine, I'm fine."<d><scn>park2<msg>Remy "You're fine, I'm fine."<p><msg>c "You're fine, I'm fine."<d><scn>park2<msg>Remy "You're fine, I'm
park2<msg>R "Good, I'm fine."<p><msg>c "You're fine, I'm fine."<p><msg>c "You're fine, I'm fine."<p><msg>c "You're fine, I'm fine."<p><msg>c "You're fine, I'm fine."<p><msg>c "You're fine, I'm
park2<msg>R "Good, thanks."<p><msg>c "You're welcome."<d><scn>park2<msg>R "You're welcome."<p><

In [5]:
test_rps = [
    "Visit Lorem",
    "Meet with Lorem",
    "Visit Adine",
    "Fight Maverick",
    "Fight Adine",
    "Attack Adine"
]
for rp in test_rps:
    print(f'[Pytorch] {rp} -> {model_manager.say("", rp, top_k = 50, top_p = 0.7)}')
    print("-" * 10)
    
sample_test(model_manager)

[Pytorch] Visit Lorem -> lorem<msg>c "This is Lorem. I'm afraid of the wild, but I'm so afraid of the wild that I'll just have to share it with you."<|endoftext|>
----------
[Pytorch] Meet with Lorem -> lorem<msg>c "This is Lorem. We have just about two years of knowledge."<d><scn>lorem<msg>c "This is Lorem. We have just about two years of knowledge."<d><scn>lorem<msg>c "This is Lorem. We have just about two years of knowledge."<d><scn>lorem<msg>c "This is Lorem. We have just about two years of knowledge."<d><scn>lorem
----------
[Pytorch] Visit Adine -> loremapt<msg>Ad "Hey, [player_name]. I'm with you."<p><msg>c "Hey, [player_name]. I'm with you."<p><msg>c "Hey, [player_name]. I'm with you."<|endoftext|>
----------
[Pytorch] Fight Maverick -> o<msg>m "He's just out here, and it's a nice place."<p><msg>c "You're right. I'm getting too excited about this. I'm going to ruin you."<d><scn>o<msg>c "This is an interesting time for me."<p><msg>c "I'm leaving."<p><msg>c "You're right. I'm lea