In [3]:
import os

import torch
from torch.utils.data import Dataset
from torch.utils.data.dataloader import DataLoader

from mingpt.model import GPT
from mingpt.trainer import Trainer
from mingpt.utils import set_seed, setup_logging, CfgNode as CN
set_seed(3407)

from chargpt import get_config, CharDataset

In [4]:
model_path = os.path.join('./out/chargpt', "model.pt")

config = get_config()
setup_logging(config)
set_seed(config.system.seed)

# construct the training dataset
text = open('data/gpt_train.txt', 'r').read() # don't worry we won't run out of file handles
train_dataset = CharDataset(config.data, text)

# construct the model
config.model.vocab_size = train_dataset.get_vocab_size()
config.model.block_size = train_dataset.get_block_size()
model = GPT(config.model)
model.load_state_dict(torch.load(model_path))

data has 447013 characters, 79 unique.
number of parameters: 2.70M


<All keys matched successfully>

In [5]:
def generate(model, context):
    x = torch.tensor([train_dataset.stoi[s] for s in context], dtype=torch.long)[None,...].to("cpu")
    y = model.generate(x, 25, temperature=1.0, do_sample=True, top_k=1)[0]
    completion = ''.join([train_dataset.itos[int(i)] for i in y])
    return completion.split("\n")[0]

In [6]:
# Example generation
generate(model, "Mollembaum-")

'Mollembaum-en'

In [7]:
# Evaluate model performance on test set
pairs = []
with open("data/gpt_test.txt", "r", encoding="utf-8") as f:
    lines = f.readlines()
    for line in lines: 
        sg, pl = line.strip("\n").split(":")
        pairs.append((sg, pl))

In [8]:
def find_suffix(singular, plural): 
    last_char = singular[-1]
    suffix = []
    for i in range(len(plural)-1, -1, -1):
        if plural[i] == last_char:
            break
        else:
            suffix.append(plural[i])
    return "".join(suffix[::-1])

test_set = pairs[:100] # TODO remove
N = len(test_set)
correct = 0
incorrect = []
for pair in test_set:
    sg, pl = pair
    context = sg+":"
    prediction = generate(model, context)
    prediction = find_suffix(context, prediction)
    if prediction == pl:
        correct += 1
    else: 
        incorrect.append((sg, pl, prediction))

print(f"Correct: {correct} out of {N} or {correct/N}")
print(f"Incorrect: {len(incorrect)}")
for sg, pl, prediction in incorrect:
    print(f"SG: {sg}, PL: {pl}, PRED: {prediction}")

Correct: 0 out of 100 or 0.0
Incorrect: 100
SG: Urfassung, PL: Urfassungen, PRED: en
SG: Küchenreibe, PL: Küchenreiben, PRED: en
SG: Regelung, PL: Regelungen, PRED: en
SG: Volksverhetzung, PL: Volksverhetzungen, PRED: enen
SG: Locher, PL: Locher, PRED: en
SG: Mannheimer, PL: Mannheimer, PRED: en
SG: Gasplanet, PL: Gasplaneten, PRED: en
SG: Etappe, PL: Etappen, PRED: enen
SG: Malaysierin, PL: Malaysierinnen, PRED: en
SG: Doktorfisch, PL: Doktorfische, PRED: en
SG: Ölförderung, PL: Ölförderungen, PRED: en
SG: Theorbe, PL: Theorben, PRED: enen
SG: Baskenmütze, PL: Baskenmützen, PRED: en
SG: Landschaft, PL: Landschaften, PRED: en
SG: Patientin, PL: Patientinnen, PRED: en
SG: Thiophosphat, PL: Thiophosphate, PRED: en
SG: Abgasklappe, PL: Abgasklappen, PRED: en
SG: Antike, PL: Antiken, PRED: en
SG: Code, PL: Codes, PRED: en
SG: Reling, PL: Relings, PRED: n
SG: Lederschlaufe, PL: Lederschlaufen, PRED: en
SG: Mundschenk, PL: Mundschenke, PRED: 
SG: Summierung, PL: Summierungen, PRED: en
SG: En