In [None]:
cd ../..

In [None]:
from pathlib import Path
base_dir = Path('local/out/elpp/exp1')

base_dir.mkdir(parents=True, exist_ok=True)
base_dir

In [None]:
import lzma
import dill

with lzma.open(base_dir / 'reasoners_400_concepts_4_roles_1_proofs.dill.xz', 'rb') as f:
    reasoners = dill.load(f)

In [None]:
import numpy as np
from src.reasoner import ReasonerHead, EmbeddingLayer, train
from src.utils import timestr, paramcount
import torch as T
from src.elpp.gen import split_dataset

seed = 2022
ts = timestr()

emb_size = 32
hidden_size = 16
epoch_count = 15
test_epoch_count = 10
batch_size = 32

artifacts = {}

for complexity_threshold in range(2, 21):

    print("Complexity threshold", complexity_threshold)

    training, validation, test = split_dataset(reasoners, np.random.default_rng(seed=0xbeef), complexity_threshold=complexity_threshold)

    T.manual_seed(seed)
    trained_reasoner = ReasonerHead(emb_size=emb_size, hidden_size=hidden_size)
    encoders = [EmbeddingLayer(emb_size=emb_size, n_concepts=reasoner.n_concepts, n_roles=reasoner.n_roles) for reasoner in
                reasoners]

    print(f'created reasoner with {paramcount(trained_reasoner)} parameters')
    print(f'created {len(encoders)} encoders with {paramcount(encoders[0])} parameters each')

    train_logger = train(training, validation, trained_reasoner, encoders, epoch_count=epoch_count, batch_size=batch_size)

    artifacts[complexity_threshold] = {
        'reasoner': trained_reasoner,
        'encoders': encoders,
        'training': training,
        'validation': validation,
        'test': test
    }

In [None]:
tmp = {key: {'reasoner': value['reasoner'].state_dict(), 'encoders': [e.state_dict() for e in value['encoders']], 'training': value['training'], 'validation': value['validation'], 'test': value['test']} for key, value in artifacts.items()}

with lzma.open(base_dir / 'exp1.dill.xz', 'wb') as f:
    dill.dump(tmp, f)

In [None]:
from tqdm import tqdm
from src.reasoner import eval_batch

import pandas as pd

rows = []

for complexity_threshold, components in tqdm(artifacts.items()):
    with T.no_grad():
        idx_te, X_te, y_te = components['test']
        _, _, Y_te_good = eval_batch(components['reasoner'], components['encoders'], X_te, y_te, idx_te)
    for i in range(len(idx_te)):
        idx = idx_te[i]
        axiom = X_te[i]
        expected = y_te[i]
        predicted = Y_te_good[i]
        complexity = len(reasoners[idx].decode_shortest_proof(axiom[1], axiom[2]))
        rows.append([complexity_threshold, idx, complexity, axiom, expected, int(predicted >= .5), predicted])

In [None]:
df = pd.DataFrame(rows, columns=["Complexity threshold", "KB", "Complexity", "Axiom", "Expected", "Predicted", "Raw predicted"])
df.to_feather(base_dir / 'exp1.feather')
df
