In [2]:
cd ..

/home/smaug/ownCloud/praca/reasonable-embeddings/src


In [3]:
from pathlib import Path
base_dir = Path('../local/out/elpp/')

base_dir.mkdir(parents=True, exist_ok=True)

In [4]:
import lzma
import dill

with lzma.open(base_dir / 'reasoners.dill.xz', 'rb') as f:
    reasoners = dill.load(f)

In [5]:
import numpy as np
from src.reasoner import ReasonerHead, EmbeddingLayer, train
from src.utils import timestr, paramcount
import torch as T
from src.elpp.gen import split_dataset

seed = 2022
ts = timestr()

emb_size = 10
hidden_size = 16
epoch_count = 15
test_epoch_count = 10
batch_size = 32

artifacts = {}

for complexity_threshold in range(2, 21):

    print("Complexity threshold", complexity_threshold)

    training, validation, test = split_dataset(reasoners, np.random.default_rng(seed=0xbeef), complexity_threshold=complexity_threshold)

    T.manual_seed(seed)
    trained_reasoner = ReasonerHead(emb_size=emb_size, hidden_size=hidden_size)
    encoders = [EmbeddingLayer(emb_size=emb_size, n_concepts=reasoner.n_concepts, n_roles=reasoner.n_roles) for reasoner in
                reasoners]

    print(f'created reasoner with {paramcount(trained_reasoner)} parameters')
    print(f'created {len(encoders)} encoders with {paramcount(encoders[0])} parameters each')

    train_logger = train(training, validation, trained_reasoner, encoders, epoch_count=epoch_count, batch_size=batch_size)

    artifacts[complexity_threshold] = {
        'reasoner': trained_reasoner,
        'encoders': encoders,
        'training': training,
        'validation': validation,
        'test': test
    }

Complexity threshold 2
Training 20064 #pos 10032
Validation 4976 #pos 2488
Test 32982 #pos 32982
created reasoner with 3293 parameters
created 40 encoders with 1440 parameters each
train epoch 00/15 | batch 628/627 | loss 0.6962 | val loss 0.6963 | acc 0.4956 | f1 0.6627 | prec 0.4978 | recall 0.9912 | roc auc 0.3924 | pr auc 0.4452 | elapsed 3.13s
train epoch 01/15 | batch 628/627 | loss 0.6763 | val loss 0.6502 | acc 0.5860 | f1 0.6532 | prec 0.5620 | recall 0.7797 | roc auc 0.6974 | pr auc 0.7622 | elapsed 7.38s
train epoch 02/15 | batch 628/627 | loss 0.6204 | val loss 0.5904 | acc 0.6698 | f1 0.6185 | prec 0.7323 | recall 0.5354 | roc auc 0.7203 | pr auc 0.7826 | elapsed 7.65s
train epoch 03/15 | batch 628/627 | loss 0.5646 | val loss 0.5552 | acc 0.6845 | f1 0.6085 | prec 0.8016 | recall 0.4904 | roc auc 0.7321 | pr auc 0.7947 | elapsed 7.56s
train epoch 04/15 | batch 628/627 | loss 0.5276 | val loss 0.5392 | acc 0.6923 | f1 0.6202 | prec 0.8101 | recall 0.5024 | roc auc 0.7412 |

In [6]:
tmp = {key: {'reasoner': value['reasoner'].state_dict(), 'encoders': [e.state_dict() for e in value['encoders']], 'training': value['training'], 'validation': value['validation'], 'test': value['test']} for key, value in artifacts.items()}

with lzma.open(base_dir / 'exp1.dill.xz', 'wb') as f:
    dill.dump(tmp, f)

In [7]:
from tqdm import tqdm
from src.reasoner import eval_batch

import pandas as pd

rows = []

for complexity_threshold, components in tqdm(artifacts.items()):
    with T.no_grad():
        idx_te, X_te, y_te = components['test']
        _, _, Y_te_good = eval_batch(components['reasoner'], components['encoders'], X_te, y_te, idx_te)
    for i in range(len(idx_te)):
        idx = idx_te[i]
        axiom = X_te[i]
        expected = y_te[i]
        predicted = Y_te_good[i]
        complexity = len(reasoners[idx].decode_shortest_proof(axiom[1], axiom[2]))
        rows.append([complexity_threshold, idx, complexity, axiom, expected, int(predicted >= .5), predicted])

100%|██████████| 19/19 [00:34<00:00,  1.80s/it]


In [8]:
df = pd.DataFrame(rows, columns=["Complexity threshold", "KB", "Complexity", "Axiom", "Expected", "Predicted", "Raw predicted"])
df.to_feather(base_dir / 'exp1.feather')
df


Unnamed: 0,Complexity threshold,KB,Complexity,Axiom,Expected,Predicted,Raw predicted
0,2,0,4,"(0, 1, 9)",1,0,0.211107
1,2,0,6,"(0, 1, 50)",1,1,0.842637
2,2,0,5,"(0, 1, 55)",1,1,0.691905
3,2,0,7,"(0, 1, 63)",1,0,0.357871
4,2,0,5,"(0, 1, 65)",1,1,0.593294
...,...,...,...,...,...,...,...
445205,20,39,24,"(0, 89, 44)",1,0,0.015694
445206,20,39,22,"(0, -2, 49)",1,1,0.569123
445207,20,39,24,"(0, -2, 56)",1,1,0.506917
445208,20,39,25,"(0, -2, 86)",1,0,0.094137
