In [1]:
cd ../..

/home/wtaisner/PycharmProjects/reasonable-embeddings


In [2]:
from pathlib import Path
base_dir = Path('local/out/elpp/exp1')

base_dir.mkdir(parents=True, exist_ok=True)
base_dir

PosixPath('local/out/elpp/exp1')

In [4]:
import lzma
import dill

with lzma.open(base_dir / 'reasoners_400_concepts_4_roles_1_proofs.dill.xz', 'rb') as f:
    reasoners = dill.load(f)

In [5]:
import numpy as np
from src.reasoner import ReasonerHead, EmbeddingLayer, train
from src.utils import timestr, paramcount
import torch as T
from src.elpp.gen import split_dataset

seed = 2022
ts = timestr()

emb_size = 32
hidden_size = 16
epoch_count = 15
test_epoch_count = 10
batch_size = 32

artifacts = {}

for complexity_threshold in range(2, 21):

    print("Complexity threshold", complexity_threshold)

    training, validation, test = split_dataset(reasoners, np.random.default_rng(seed=0xbeef), complexity_threshold=complexity_threshold)

    T.manual_seed(seed)
    trained_reasoner = ReasonerHead(emb_size=emb_size, hidden_size=hidden_size)
    encoders = [EmbeddingLayer(emb_size=emb_size, n_concepts=reasoner.n_concepts, n_roles=reasoner.n_roles) for reasoner in
                reasoners]

    print(f'created reasoner with {paramcount(trained_reasoner)} parameters')
    print(f'created {len(encoders)} encoders with {paramcount(encoders[0])} parameters each')

    train_logger = train(training, validation, trained_reasoner, encoders, epoch_count=epoch_count, batch_size=batch_size)

    artifacts[complexity_threshold] = {
        'reasoner': trained_reasoner,
        'encoders': encoders,
        'training': training,
        'validation': validation,
        'test': test
    }

  setattr(self, word, getattr(machar, word).flat[0])
  return self._float_to_str(self.smallest_subnormal)
  setattr(self, word, getattr(machar, word).flat[0])
  return self._float_to_str(self.smallest_subnormal)


Complexity threshold 2
Training 56596 #pos 28298
Validation 14108 #pos 7054
Test 203 #pos 203
created reasoner with 53409 parameters
created 40 encoders with 17024 parameters each
train epoch 00/15 | batch 1770/1769 | loss 0.6961 | val loss 0.6960 | acc 0.5000 | f1 0.6667 | prec 0.5000 | recall 1.0000 | roc auc 0.6820 | pr auc 0.7176 | elapsed 9.17s
train epoch 01/15 | batch 1770/1769 | loss 0.5658 | val loss 0.4943 | acc 0.7353 | f1 0.6415 | prec 0.9935 | recall 0.4736 | roc auc 0.7847 | pr auc 0.8354 | elapsed 25.39s
train epoch 02/15 | batch 1308/1769 | loss 0.4694 | elapsed 19.14s

KeyboardInterrupt: 

In [None]:
tmp = {key: {'reasoner': value['reasoner'].state_dict(), 'encoders': [e.state_dict() for e in value['encoders']], 'training': value['training'], 'validation': value['validation'], 'test': value['test']} for key, value in artifacts.items()}

with lzma.open(base_dir / 'exp1.dill.xz', 'wb') as f:
    dill.dump(tmp, f)

In [None]:
from tqdm import tqdm
from src.reasoner import eval_batch

import pandas as pd

rows = []

for complexity_threshold, components in tqdm(artifacts.items()):
    with T.no_grad():
        idx_te, X_te, y_te = components['test']
        _, _, Y_te_good = eval_batch(components['reasoner'], components['encoders'], X_te, y_te, idx_te)
    for i in range(len(idx_te)):
        idx = idx_te[i]
        axiom = X_te[i]
        expected = y_te[i]
        predicted = Y_te_good[i]
        complexity = len(reasoners[idx].decode_shortest_proof(axiom[1], axiom[2]))
        rows.append([complexity_threshold, idx, complexity, axiom, expected, int(predicted >= .5), predicted])

100%|██████████| 19/19 [00:23<00:00,  1.26s/it]


In [None]:
df = pd.DataFrame(rows, columns=["Complexity threshold", "KB", "Complexity", "Axiom", "Expected", "Predicted", "Raw predicted"])
df.to_feather(base_dir / 'exp1.feather')
df


Unnamed: 0,Complexity threshold,KB,Complexity,Axiom,Expected,Predicted,Raw predicted
0,2,0,4,"(0, 1, 9)",1,0,0.306261
1,2,0,6,"(0, 1, 50)",1,1,0.924735
2,2,0,5,"(0, 1, 55)",1,0,0.333867
3,2,0,7,"(0, 1, 63)",1,1,0.729814
4,2,0,5,"(0, 1, 65)",1,0,0.283852
...,...,...,...,...,...,...,...
445205,20,39,24,"(0, 89, 44)",1,0,0.003385
445206,20,39,22,"(0, -2, 49)",1,0,0.003953
445207,20,39,24,"(0, -2, 56)",1,0,0.055731
445208,20,39,25,"(0, -2, 86)",1,0,0.024811
