In [1]:
cd ../../

/home/wtaisner/PycharmProjects/reasonable-embeddings


In [2]:
from pathlib import Path

base_dir = Path('local/out/elpp/exp1/')

base_dir.mkdir(parents=True, exist_ok=True)

In [3]:
import lzma
import dill

with lzma.open(base_dir / 'reasoners_200_concepts.dill.xz', 'rb') as f:
    reasoners = dill.load(f)

In [4]:
from src.reasoner import ReasonerHead
from src.reasoner import EmbeddingLayer

with lzma.open(base_dir / 'exp1.dill.xz', 'rb') as f:
    artifacts = dill.load(f)

emb_size = 32
hidden_size = 16

for key, components in artifacts.items():
    neural_reasoner = ReasonerHead(emb_size=emb_size, hidden_size=hidden_size)
    neural_reasoner.load_state_dict(components['reasoner'])
    components['reasoner'] = neural_reasoner
    encoders = [EmbeddingLayer(emb_size=emb_size, n_concepts=reasoner.n_concepts, n_roles=reasoner.n_roles) for reasoner
                in
                reasoners]
    for sd, e in zip(components['encoders'], encoders):
        e.load_state_dict(sd)
    components['encoders'] = encoders

  setattr(self, word, getattr(machar, word).flat[0])
  return self._float_to_str(self.smallest_subnormal)
  setattr(self, word, getattr(machar, word).flat[0])
  return self._float_to_str(self.smallest_subnormal)


In [5]:
with lzma.open(base_dir / 'test_reasoners.dill.xz', 'rb') as f:
    test_reasoners = dill.load(f)

In [6]:
from src.elpp.gen import split_dataset
import numpy as np

splits = {complexity_threshold_k: split_dataset(test_reasoners, np.random.default_rng(seed=0xbeef),
                                                complexity_threshold=complexity_threshold_k) for complexity_threshold_k
          in range(2, 21)}

Training 10202 #pos 5101
Validation 2528 #pos 1264
Test 22480 #pos 22480
Training 11504 #pos 5752
Validation 2856 #pos 1428
Test 21665 #pos 21665
Training 12822 #pos 6411
Validation 3184 #pos 1592
Test 20842 #pos 20842
Training 13946 #pos 6973
Validation 3464 #pos 1732
Test 20140 #pos 20140
Training 14814 #pos 7407
Validation 3686 #pos 1843
Test 19595 #pos 19595
Training 15724 #pos 7862
Validation 3910 #pos 1955
Test 19028 #pos 19028
Training 16662 #pos 8331
Validation 4146 #pos 2073
Test 18441 #pos 18441
Training 17546 #pos 8773
Validation 4368 #pos 2184
Test 17888 #pos 17888
Training 18428 #pos 9214
Validation 4584 #pos 2292
Test 17339 #pos 17339
Training 19454 #pos 9727
Validation 4844 #pos 2422
Test 16696 #pos 16696
Training 20449 #pos 10229
Validation 5092 #pos 2547
Test 16069 #pos 16069
Training 21712 #pos 10868
Validation 5408 #pos 2707
Test 15270 #pos 15270
Training 22743 #pos 11460
Validation 5664 #pos 2854
Test 14531 #pos 14531
Training 23756 #pos 12146
Validation 5920 #pos 3

In [7]:
from joblib import Parallel, delayed
from src.reasoner import EmbeddingLayer, train, eval_batch
from src.utils import timestr
import torch as T

seed = 2022
ts = timestr()

emb_size = 32
hidden_size = 16
epoch_count = 15
test_epoch_count = 10
batch_size = 128

encoders = {}


def train_helper(complexity_threshold_j, complexity_threshold_k):
    neural_reasoner = artifacts[complexity_threshold_j]["reasoner"]
    training, validation, test = splits[complexity_threshold_k]
    T.manual_seed(seed)
    my_encoders = [EmbeddingLayer(emb_size=emb_size, n_concepts=reasoner.n_concepts, n_roles=reasoner.n_roles) for
                   reasoner in test_reasoners]

    train_logger = train(training, validation, neural_reasoner, my_encoders, epoch_count=epoch_count,
                         batch_size=batch_size,
                         freeze_reasoner=True, validate=False)

    with T.no_grad():
        idx_te, X_te, y_te = test
        _, _, Y_te_good = eval_batch(neural_reasoner, my_encoders, X_te, y_te, idx_te)

    rows = []
    for i in range(len(idx_te)):
        idx = idx_te[i]
        axiom = X_te[i]
        expected = y_te[i]
        predicted = Y_te_good[i]
        complexity = len(test_reasoners[idx].decode_shortest_proof(axiom[1], axiom[2]))
        rows.append([complexity_threshold_j, complexity_threshold_k, idx, complexity, axiom, expected, int(predicted >= .5), predicted])
    print(f"({complexity_threshold_j}, {complexity_threshold_k}) completed")

    return complexity_threshold_j, complexity_threshold_k, my_encoders, rows

results = Parallel(n_jobs=-1)(
    delayed(train_helper)(complexity_threshold_j, complexity_threshold_k) for complexity_threshold_j in range(2, 21) for
    complexity_threshold_k in range(2, 21))

train epoch 00/15 | batch 80/80 | loss 0.5711 | elapsed 1.38s
train epoch 00/15 | batch 90/90 | loss 0.6133 | elapsed 1.90spoch 00/15 | batch 89/90 | loss 0.6139 | elapsed 1.88s
train epoch 01/15 | batch 80/80 | loss 0.5607 | elapsed 4.06s9s
train epoch 00/15 | batch 101/101 | loss 0.6607 | elapsed 3.84s
train epoch 01/15 | batch 90/90 | loss 0.5983 | elapsed 7.19s0s
train epoch 00/15 | batch 109/109 | loss 0.7040 | elapsed 4.51s
train epoch 02/15 | batch 80/80 | loss 0.5180 | elapsed 6.76s8s
train epoch 00/15 | batch 116/116 | loss 0.7052 | elapsed 4.89s
train epoch 01/15 | batch 101/101 | loss 0.6399 | elapsed 7.89s
train epoch 02/15 | batch 90/90 | loss 0.5475 | elapsed 8.37s1s
train epoch 03/15 | batch 80/80 | loss 0.4796 | elapsed 7.07s5s
train epoch 01/15 | batch 109/109 | loss 0.6781 | elapsed 9.82s
train epoch 00/15 | batch 123/123 | loss 0.7370 | elapsed 6.14s
train epoch 00/15 | batch 131/131 | loss 0.7455 | elapsed 5.08s
train epoch 02/15 | batch 101/101 | loss 0.5793 | elap

In [8]:
encoders = {}
rows = []

for j, k, jk_encoders, some_rows in results:
    encoders[(j, k)] = jk_encoders
    rows += some_rows

In [9]:
tmp = {
    'splits': splits,
    'encoders': {key: [e.state_dict() for e in encs] for key, encs in encoders.items()}
}

with lzma.open(base_dir / 'exp3.dill.xz', 'wb') as f:
    dill.dump(tmp, f)

In [10]:
import pandas as pd

df = pd.DataFrame(rows, columns=["Complexity threshold j","Complexity threshold k", "KB", "Complexity", "Axiom", "Expected", "Predicted",
                                 "Raw predicted"])
df.to_feather(base_dir / 'exp3.feather')
df


Unnamed: 0,Complexity threshold j,Complexity threshold k,KB,Complexity,Axiom,Expected,Predicted,Raw predicted
0,2,2,0,9,"(0, 4, 53)",1,0,0.437073
1,2,2,0,3,"(0, 7, 53)",1,0,0.275827
2,2,2,0,8,"(0, 9, 0)",1,0,0.043837
3,2,2,0,10,"(0, 9, 46)",1,0,0.168075
4,2,2,0,7,"(0, 9, 53)",1,1,0.511046
...,...,...,...,...,...,...,...,...
5909147,20,20,17,22,"(0, 96, 66)",1,1,0.917326
5909148,20,20,17,23,"(0, 98, 5)",1,1,0.822247
5909149,20,20,17,21,"(0, 98, 10)",1,0,0.407074
5909150,20,20,17,22,"(0, 98, 48)",1,1,0.878611
