In [9]:
import os

import torch
from torch.utils.data import Dataset
from torch.utils.data.dataloader import DataLoader

from mingpt.model import GPT
from mingpt.trainer import Trainer
from mingpt.utils import set_seed, setup_logging, CfgNode as CN
set_seed(3407)

from chargpt import get_config, CharDataset

In [10]:
model_path = os.path.join('./out/chargpt', "model.pt")

config = get_config()
setup_logging(config)
set_seed(config.system.seed)

# construct the training dataset
text = open('train.txt', 'r').read() # don't worry we won't run out of file handles
train_dataset = CharDataset(config.data, text)

# construct the model
config.model.vocab_size = train_dataset.get_vocab_size()
config.model.block_size = train_dataset.get_block_size()
model = GPT(config.model)
model.load_state_dict(torch.load(model_path))

data has 445476 characters, 76 unique.
number of parameters: 2.71M


<All keys matched successfully>

In [11]:
def generate(model, context):
    x = torch.tensor([train_dataset.stoi[s] for s in context], dtype=torch.long)[None,...].to("cpu")
    y = model.generate(x, 25, temperature=1.0, do_sample=True, top_k=1)[0]
    completion = ''.join([train_dataset.itos[int(i)] for i in y])
    return completion.split("\n")[0]

In [12]:
# Example generation
generate(model, "Mollembaum-")

'Mollembaum-Mollembäume'

In [13]:
# Evaluate model performance on test set
pairs = []
with open("test.txt", "r", encoding="utf-8") as f:
    lines = f.readlines()
    for line in lines: 
        try:
            sg, pl = line.strip("\n").split("-")
        except ValueError: 
            # Alpha-Version-Alpha-Versionen
            tokens = line.strip("\n").split("-")
            n = len(tokens) // 2
            sg, pl = "-".join(tokens[:n]), "-".join(tokens[n:])
        pairs.append((sg, pl))

In [14]:
def get_suffix(string1, string2):
    # Generated by ChatGPT
    # Check if string2 is formed by adding a suffix to string1
    if string2.startswith(string1):
        # Return the suffix
        suffix = string2[len(string1):]
        return suffix
    else:
        # Return None if string2 is not formed by adding a suffix to string1
        return None

# TODO remove
test_set = pairs[:100]
N = len(test_set)
correct = 0
incorrect = []
for pair in test_set:
    sg, pl = pair
    context = sg+"-"
    prediction = generate(model, context)
    prediction = get_suffix(context, prediction)
    if prediction == pl:
        correct += 1
    else: 
        incorrect.append((sg, pl, prediction))

print(f"Correct: {correct} out of {N} or {correct/N}")
print(f"Incorrect: {len(incorrect)}")
for sg, pl, prediction in incorrect:
    print(f"SG: {sg}, PL: {pl}, PRED: {prediction}")

Correct: 79 out of 100 or 0.79
Incorrect: 21
SG: Faible, PL: Faibles, PRED: Faiblen
SG: Fußsoldat, PL: Fußsoldaten, PRED: Fußsoldate
SG: Slawin, PL: Slawinnen, PRED: Schelawinnen
SG: Sturzkampfgeschwader, PL: Sturzkampfgeschwader, PRED: Sturzkampfgese
SG: Entgelt, PL: Entgelte, PRED: Entgelt
SG: Nahrung, PL: Nahrungen, PRED: Nachrungen
SG: Kolonialwarenladen, PL: Kolonialwarenläden, PRED: Kolonialwarenlare
SG: Leitungscode, PL: Leitungscodes, PRED: Leitungscoden
SG: Athletenherz, PL: Athletenherzen, PRED: Athethetenherzen
SG: Furie, PL: Furien, PRED: Furier
SG: Vorhut, PL: Vorhuten, PRED: Vorhüte
SG: Penny, PL: Pence, PRED: Penys
SG: Bart, PL: Bärte, PRED: Barten
SG: Lutscher, PL: Lutscher, PRED: Wutscher
SG: Alterthum, PL: Alterthümer, PRED: Alterthume
SG: Lustgarten, PL: Lustgärten, PRED: Lustgarten
SG: Dreißigeck, PL: Dreißigecke, PRED: Dreißigecks
SG: Fußballfan, PL: Fußballfans, PRED: Fußballfäne
SG: Farbintensität, PL: Farbintensitäten, PRED: Farbintensintäten
SG: Futon, PL: Futo