In [1]:
import os 
# Optional: change working directory
os.chdir('/app/dpl/')

from data.dataset import DPLDataset, TorchDataset

# just static information for saving the results
method = "DPL"
name = "nurse_enters_{}".format(method)
save_path = r"results/"

# set seeds and other things for reproducability 
seed = 12345
batch_size = 20
epochs = 10

# import data into dataset for DeepProbLog
dpl_set = DPLDataset(r'data/metadata/one_sample_per_class.csv', r'data/metadata/test_80.csv', 'difference')

# get training and test data
dpl_train = dpl_set.train_set
dpl_test = dpl_set.test_set

In [2]:
import torch

from torch.utils.data import DataLoader

from deepproblog.dataset import DataLoader
from deepproblog.engines import ExactEngine
from deepproblog.evaluate import get_confusion_matrix
from deepproblog.model import Model
from deepproblog.network import Network
from json import dumps

from dpl.lenet import LeNet
from dpl.utils.train import train_model

torch.manual_seed(seed)

print("Number of itertions: {}".format(len(dpl_train) // batch_size))

# load PyTorch net
network = LeNet()

# initialize dpl NN and set optimizer to torch optimizer
net = Network(network, "nop_net", batching=True)
net.optimizer = torch.optim.Adam(network.parameters(), lr=1e-3)

# initialize prolog model and use exact engine
model = Model(r"dpl/model.pl", [net])
model.set_engine(ExactEngine(model), cache=True)

# add data source for NN component
model.add_tensor_source("train", dpl_train)
model.add_tensor_source("test", dpl_test)

# transform datasets to DPL datasets
loader = DataLoader(dpl_train, batch_size=batch_size, shuffle=True, seed=seed)

# train the model and save it
kwargs = {'test_set': dpl_test}
train = train_model(model, loader, epochs, log_iter=10, profile=0, save_path=save_path, with_negatives=False, test_set=dpl_test)
model.save_state("snapshot/" + name + ".pth")
#train.logger.comment(dumps(model.get_hyperparameters()))
train.logger.comment(
    "Accuracy {}".format(get_confusion_matrix(model, dpl_test, verbose=1).accuracy())
)
train.logger.write_to_file(save_path + name)



Number of itertions: 576
Caching ACs
Training  for 10 epoch(s)
Epoch 1
Iteration:  10 	s:5.5546 	Average Loss:  2.7206279881298543
Iteration:  20 	s:6.1575 	Average Loss:  2.67634539604187
Iteration:  30 	s:5.9753 	Average Loss:  2.430575928092003
Iteration:  40 	s:6.2564 	Average Loss:  2.072080947831273
Iteration:  50 	s:6.3043 	Average Loss:  1.9073792410083115
Iteration:  60 	s:5.8843 	Average Loss:  1.6041076024062932
Iteration:  70 	s:6.4464 	Average Loss:  1.5120235323905944
Iteration:  80 	s:5.9467 	Average Loss:  1.3133365706307814
Iteration:  90 	s:5.9242 	Average Loss:  1.1421344115398824
Iteration:  100 	s:6.4632 	Average Loss:  0.978945555887185
Iteration:  110 	s:5.8880 	Average Loss:  0.949367146345321
Iteration:  120 	s:6.5955 	Average Loss:  0.9839849069248885
Iteration:  130 	s:5.9486 	Average Loss:  0.9928636368480511
Iteration:  140 	s:5.9061 	Average Loss:  0.8048189422814176
Iteration:  150 	s:5.8995 	Average Loss:  0.7131362963700667
Iteration:  160 	s:6.7649 	Av

100%|███████████████████████████████████████████████████████████| 1998/1998 [03:41<00:00,  9.03it/s]


Test: 	 Accuracy: 86.4% 	 Avg loss: 0.477355
         	  	   	   	  	  	   	  	   	  	  	Actual	  	  	   	   	  	  	  	  	  
         	  	 -3	 -2	 7	 6	  2	 5	  0	-6	-4	     3	-8	-9	 -1	  1	 4	-7	-5	 9	 8
         	-3	108	  2	 0	 0	  2	 0	  0	 1	 0	     1	 1	 0	  3	  3	 0	 1	 1	 0	 0
         	-2	  1	106	 0	 0	  0	 1	  1	 1	 1	     0	 2	 0	  3	  6	 1	 3	 1	 0	 0
         	 7	  0	  0	47	 1	  5	 1	  0	 0	 0	     2	 0	 0	  0	  1	 0	 0	 0	 1	 0
         	 6	  0	  1	 1	60	  1	 2	  1	 0	 0	     1	 0	 0	  1	  5	 3	 0	 1	 0	 0
         	 2	  2	  0	 1	 3	121	 0	  0	 0	 1	     2	 0	 0	  0	  5	 3	 0	 0	 0	 0
         	 5	  0	  1	 0	 1	  0	71	  3	 0	 0	     0	 0	 0	  0	  0	 2	 0	 0	 0	 0
         	 0	  1	  2	 1	 0	  1	 5	147	 2	 0	     1	 0	 1	  4	  4	 1	 0	 3	 0	 1
         	-6	  1	  2	 0	 0	  0	 0	  4	62	 1	     0	 2	 0	  0	  2	 0	 1	 0	 0	 0
Predicted	-4	  2	  2	 0	 0	  0	 0	  1	 3	91	     0	 0	 0	  0	  6	 0	 1	 2	 0	 0
         	 3	  2	  2	 2	 1	  1	 0	  1	 1	 0	   107	 0	 0	  3	  3	 3	 0	 0	 

100%|███████████████████████████████████████████████████████████| 1998/1998 [02:18<00:00, 14.44it/s]


Test: 	 Accuracy: 90.9% 	 Avg loss: 0.295184
         	  	   	   	  	  	   	  	   	  	   	Actual	   	  	  	   	  	  	  	  	  
         	  	 -3	 -2	 7	 6	  2	 5	  1	-6	  0	    -4	  3	-8	-9	 -1	 4	-7	-5	 9	 8
         	-3	112	  1	 0	 0	  1	 0	  1	 1	  0	     0	  1	 0	 0	  3	 0	 1	 3	 0	 0
         	-2	  1	110	 0	 0	  0	 1	  0	 1	  2	     1	  1	 1	 0	  1	 1	 3	 0	 0	 0
         	 7	  0	  0	52	 1	  3	 0	  1	 0	  0	     0	  2	 0	 0	  0	 0	 0	 0	 1	 0
         	 6	  0	  1	 1	60	  0	 0	  1	 0	  0	     0	  1	 0	 0	  0	 2	 0	 0	 0	 0
         	 2	  1	  0	 0	 4	128	 0	  3	 0	  1	     0	  4	 0	 0	  0	 3	 0	 0	 0	 0
         	 5	  0	  1	 0	 1	  0	74	  2	 0	  1	     0	  0	 0	 0	  0	 1	 0	 0	 0	 0
         	 1	  0	  0	 0	 0	  0	 2	365	 0	  2	     3	  1	 0	 0	  1	 1	 1	 0	 0	 0
         	-6	  1	  3	 0	 0	  0	 0	  0	69	  1	     1	  0	 1	 0	  0	 0	 1	 1	 0	 0
Predicted	 0	  0	  2	 0	 0	  0	 3	  2	 1	153	     0	  1	 0	 0	  3	 0	 0	 2	 0	 0
         	-4	  3	  3	 0	 0	  0	 0	  4	 3	  2	    95	  0	 0	 0	  

100%|███████████████████████████████████████████████████████████| 1998/1998 [02:12<00:00, 15.13it/s]


Test: 	 Accuracy: 94.1% 	 Avg loss: 0.210218
         	  	   	   	  	  	   	  	  	   	  	Actual	  	  	   	  	  	   	  	  	  
         	  	 -3	 -2	 7	 6	  2	 5	-6	  0	-4	     3	-8	-9	 -1	 4	-7	  1	-5	 9	 8
         	-3	117	  0	 0	 0	  1	 0	 0	  0	 0	     0	 0	 0	  2	 0	 0	  0	 2	 0	 0
         	-2	  1	118	 0	 0	  0	 0	 1	  1	 0	     1	 1	 0	  0	 0	 2	  0	 0	 0	 0
         	 7	  0	  0	54	 0	  1	 2	 0	  0	 0	     1	 0	 0	  0	 0	 0	  1	 0	 0	 0
         	 6	  0	  0	 0	65	  0	 0	 0	  0	 0	     0	 0	 0	  0	 4	 0	  5	 0	 0	 0
         	 2	  0	  0	 0	 3	131	 0	 0	  0	 1	     2	 0	 0	  0	 0	 0	  1	 0	 0	 0
         	 5	  0	  1	 0	 0	  0	75	 0	  1	 0	     0	 0	 0	  0	 1	 0	  0	 0	 0	 0
         	-6	  0	  1	 0	 0	  0	 0	74	  1	 1	     0	 1	 0	  0	 0	 0	  0	 1	 0	 0
         	 0	  0	  1	 0	 0	  0	 1	 0	163	 0	     0	 0	 1	  2	 1	 0	  2	 4	 0	 0
Predicted	-4	  1	  1	 0	 0	  0	 0	 2	  0	95	     0	 0	 0	  1	 0	 0	  2	 0	 0	 0
         	 3	  1	  2	 0	 0	  0	 0	 0	  0	 0	   117	 0	 0	  1	 0	 0	  2	 0	 

100%|███████████████████████████████████████████████████████████| 1998/1998 [02:07<00:00, 15.64it/s]


Test: 	 Accuracy: 95.2% 	 Avg loss: 0.171488
         	  	   	   	  	  	   	  	  	   	  	Actual	  	  	   	  	  	   	  	  	  
         	  	 -3	 -2	 7	 6	  2	 5	-6	  0	-4	     3	-8	-9	 -1	 4	-7	  1	-5	 9	 8
         	-3	121	  0	 0	 0	  1	 0	 1	  0	 0	     1	 0	 0	  2	 0	 0	  1	 1	 0	 0
         	-2	  0	118	 0	 0	  1	 0	 0	  0	 0	     1	 1	 0	  0	 1	 2	  0	 0	 0	 0
         	 7	  0	  0	54	 0	  2	 0	 0	  0	 0	     1	 0	 0	  0	 0	 0	  1	 0	 0	 0
         	 6	  0	  0	 0	65	  0	 0	 0	  1	 0	     1	 0	 0	  0	 1	 0	  2	 0	 0	 0
         	 2	  0	  0	 0	 2	131	 0	 0	  0	 0	     2	 0	 0	  0	 0	 0	  2	 0	 0	 0
         	 5	  0	  1	 0	 0	  0	78	 0	  0	 0	     1	 0	 0	  0	 1	 0	  1	 0	 0	 0
         	-6	  0	  2	 0	 0	  0	 0	74	  0	 0	     0	 1	 0	  1	 0	 1	  0	 0	 0	 0
         	 0	  0	  1	 0	 0	  0	 2	 1	166	 0	     0	 0	 1	  2	 0	 0	  2	 1	 0	 0
Predicted	-4	  1	  1	 0	 0	  1	 0	 2	  0	98	     0	 0	 0	  0	 0	 0	  4	 1	 0	 0
         	 3	  1	  1	 0	 0	  0	 0	 0	  0	 0	   115	 0	 0	  1	 0	 0	  1	 0	 

100%|███████████████████████████████████████████████████████████| 1998/1998 [02:02<00:00, 16.29it/s]


Test: 	 Accuracy: 95.3% 	 Avg loss: 0.158021
         	  	   	   	  	  	   	  	  	   	  	Actual	  	  	   	  	  	   	  	  	  
         	  	 -3	 -2	 7	 6	  2	 5	-6	  0	-4	     3	-8	-9	 -1	 4	-7	  1	-5	 9	 8
         	-3	122	  0	 0	 0	  1	 0	 0	  0	 0	     0	 0	 0	  1	 0	 0	  0	 2	 0	 0
         	-2	  0	120	 0	 0	  0	 0	 0	  0	 0	     1	 1	 0	  1	 1	 2	  0	 0	 0	 0
         	 7	  0	  0	54	 0	  1	 0	 0	  0	 0	     1	 0	 0	  0	 0	 0	  1	 0	 0	 0
         	 6	  0	  0	 0	65	  0	 0	 0	  1	 0	     0	 0	 0	  0	 1	 0	  1	 0	 0	 0
         	 2	  0	  0	 0	 3	131	 0	 0	  0	 1	     3	 0	 0	  0	 0	 0	  3	 0	 0	 0
         	 5	  0	  1	 0	 0	  0	77	 0	  1	 0	     1	 0	 0	  0	 1	 0	  0	 0	 0	 0
         	-6	  0	  0	 0	 0	  0	 0	75	  1	 0	     0	 2	 0	  0	 0	 0	  0	 1	 0	 0
         	 0	  0	  0	 0	 0	  1	 2	 0	161	 0	     0	 0	 0	  1	 0	 0	  2	 2	 0	 0
Predicted	-4	  0	  2	 0	 0	  1	 0	 2	  1	96	     0	 0	 0	  0	 0	 0	  2	 0	 0	 0
         	 3	  1	  1	 0	 0	  0	 0	 0	  0	 0	   116	 0	 0	  2	 0	 0	  1	 0	 

100%|███████████████████████████████████████████████████████████| 1998/1998 [01:57<00:00, 16.95it/s]


Test: 	 Accuracy: 95.7% 	 Avg loss: 0.142931
         	  	   	   	  	  	   	  	  	   	  	Actual	  	  	   	  	  	   	  	  	  
         	  	 -3	 -2	 7	 6	  2	 5	-6	  0	-4	     3	-8	-9	 -1	 4	-7	  1	-5	 9	 8
         	-3	121	  0	 0	 0	  1	 0	 0	  0	 0	     1	 0	 0	  1	 0	 0	  1	 0	 0	 0
         	-2	  0	121	 0	 0	  1	 1	 0	  1	 0	     1	 0	 0	  1	 0	 3	  0	 0	 0	 0
         	 7	  0	  0	54	 0	  2	 0	 0	  0	 0	     1	 0	 0	  0	 0	 0	  1	 0	 0	 0
         	 6	  0	  0	 0	65	  0	 0	 0	  2	 0	     1	 0	 0	  0	 1	 0	  2	 0	 0	 0
         	 2	  0	  0	 0	 2	132	 0	 0	  0	 1	     2	 0	 0	  0	 0	 0	  3	 0	 0	 0
         	 5	  0	  1	 0	 0	  0	78	 0	  0	 0	     0	 0	 0	  0	 1	 0	  1	 0	 0	 0
         	-6	  0	  0	 0	 0	  0	 0	77	  1	 0	     0	 1	 0	  0	 0	 0	  0	 0	 0	 0
         	 0	  0	  0	 0	 0	  0	 0	 0	163	 0	     1	 0	 0	  1	 0	 0	  1	 1	 0	 0
Predicted	-4	  0	  1	 0	 0	  0	 0	 0	  0	96	     0	 0	 0	  1	 0	 0	  1	 0	 0	 0
         	 3	  1	  2	 0	 0	  1	 0	 0	  0	 0	   116	 0	 0	  1	 0	 0	  1	 0	 

100%|███████████████████████████████████████████████████████████| 1998/1998 [02:01<00:00, 16.39it/s]


Test: 	 Accuracy: 95.8% 	 Avg loss: 0.142190
         	  	   	   	  	  	   	  	  	   	  	Actual	  	  	   	  	  	   	  	  	  
         	  	 -3	 -2	 7	 6	  2	 5	-6	  0	-4	     3	-8	-9	 -1	 4	-7	  1	-5	 9	 8
         	-3	121	  0	 0	 0	  1	 0	 2	  0	 0	     1	 0	 0	  2	 0	 0	  0	 0	 0	 0
         	-2	  0	122	 0	 0	  0	 1	 0	  2	 0	     1	 0	 0	  0	 0	 2	  0	 0	 0	 0
         	 7	  0	  0	54	 0	  2	 0	 0	  0	 0	     1	 0	 0	  0	 0	 0	  1	 0	 0	 0
         	 6	  0	  0	 0	65	  0	 0	 0	  0	 0	     1	 0	 0	  0	 1	 0	  1	 0	 0	 0
         	 2	  1	  0	 0	 2	133	 0	 0	  0	 0	     3	 0	 0	  0	 2	 0	  1	 0	 0	 0
         	 5	  0	  1	 0	 0	  0	77	 0	  1	 0	     1	 0	 0	  0	 1	 0	  0	 0	 0	 0
         	-6	  0	  0	 0	 0	  0	 0	72	  0	 0	     0	 1	 0	  0	 0	 1	  0	 0	 0	 0
         	 0	  0	  1	 0	 0	  0	 1	 0	162	 0	     1	 0	 1	  1	 0	 0	  1	 1	 0	 0
Predicted	-4	  0	  1	 0	 0	  0	 0	 2	  0	97	     0	 0	 0	  0	 0	 0	  2	 0	 0	 0
         	 3	  1	  1	 0	 0	  0	 0	 0	  0	 0	   114	 0	 0	  1	 0	 0	  2	 0	 

100%|███████████████████████████████████████████████████████████| 1998/1998 [01:55<00:00, 17.34it/s]


Test: 	 Accuracy: 96.5% 	 Avg loss: 0.126783
         	  	   	   	  	  	   	  	  	   	  	Actual	  	  	   	  	  	   	  	  	  
         	  	 -3	 -2	 7	 6	  2	 5	-6	  0	-4	     3	-8	-9	 -1	 4	-7	  1	-5	 9	 8
         	-3	121	  0	 0	 0	  1	 0	 0	  0	 0	     0	 0	 0	  1	 0	 0	  0	 0	 0	 0
         	-2	  0	122	 0	 0	  1	 0	 0	  1	 0	     0	 1	 0	  0	 0	 2	  0	 0	 0	 0
         	 7	  0	  0	54	 0	  1	 0	 0	  0	 0	     1	 0	 0	  0	 0	 0	  1	 0	 1	 0
         	 6	  0	  0	 0	66	  0	 0	 0	  1	 0	     0	 0	 0	  0	 1	 0	  2	 0	 0	 0
         	 2	  0	  0	 0	 2	134	 0	 0	  0	 1	     2	 0	 0	  0	 0	 0	  1	 0	 0	 0
         	 5	  0	  1	 0	 0	  0	79	 0	  0	 0	     0	 0	 0	  0	 1	 0	  0	 0	 0	 0
         	-6	  0	  0	 0	 0	  0	 0	77	  0	 0	     0	 1	 0	  0	 0	 0	  0	 0	 0	 0
         	 0	  0	  0	 0	 0	  0	 0	 0	163	 0	     1	 0	 1	  2	 0	 1	  1	 1	 0	 0
Predicted	-4	  0	  1	 0	 0	  0	 0	 1	  0	96	     0	 0	 0	  1	 0	 0	  2	 0	 0	 0
         	 3	  1	  1	 0	 0	  0	 0	 0	  0	 0	   118	 0	 0	  1	 0	 0	  1	 0	 

100%|███████████████████████████████████████████████████████████| 1998/1998 [01:52<00:00, 17.73it/s]


Test: 	 Accuracy: 96.0% 	 Avg loss: 0.135782
         	  	   	   	  	  	   	  	  	   	  	Actual	  	  	   	  	  	   	  	  	  
         	  	 -3	 -2	 7	 6	  2	 5	-6	  0	-4	     3	-8	-9	 -1	 4	-7	  1	-5	 9	 8
         	-3	120	  0	 0	 0	  0	 0	 0	  0	 0	     0	 0	 0	  1	 0	 0	  0	 0	 0	 0
         	-2	  0	122	 0	 0	  0	 0	 0	  0	 0	     0	 0	 0	  0	 0	 2	  0	 0	 0	 0
         	 7	  0	  0	54	 0	  2	 0	 0	  0	 0	     1	 0	 0	  0	 0	 0	  1	 0	 0	 0
         	 6	  0	  1	 0	65	  0	 0	 0	  1	 0	     1	 0	 0	  0	 1	 0	  1	 0	 0	 0
         	 2	  1	  0	 0	 2	131	 0	 0	  0	 1	     2	 0	 0	  0	 0	 0	  3	 0	 0	 0
         	 5	  0	  1	 0	 0	  0	78	 0	  0	 0	     1	 0	 0	  0	 1	 0	  0	 0	 0	 0
         	-6	  0	  0	 0	 0	  0	 0	75	  0	 0	     0	 1	 0	  0	 0	 0	  0	 1	 0	 0
         	 0	  0	  0	 0	 0	  0	 1	 1	163	 0	     0	 0	 0	  2	 1	 0	  1	 1	 0	 0
Predicted	-4	  0	  1	 0	 0	  2	 0	 1	  0	96	     0	 0	 0	  1	 0	 0	  2	 0	 0	 0
         	 3	  1	  0	 0	 0	  0	 0	 0	  0	 0	   115	 0	 0	  1	 0	 0	  0	 0	 

100%|███████████████████████████████████████████████████████████| 1998/1998 [01:50<00:00, 18.00it/s]


Test: 	 Accuracy: 95.7% 	 Avg loss: 0.136372
         	  	   	   	  	  	   	  	  	   	  	Actual	  	  	   	  	  	   	  	  	  
         	  	 -3	 -2	 7	 6	  2	 5	-6	  0	-4	     3	-8	-9	 -1	 4	-7	  1	-5	 9	 8
         	-3	120	  0	 0	 0	  0	 0	 0	  0	 0	     0	 0	 0	  1	 0	 0	  0	 1	 0	 0
         	-2	  0	121	 0	 0	  0	 0	 0	  0	 0	     0	 1	 0	  0	 0	 3	  0	 0	 0	 0
         	 7	  0	  0	54	 0	  2	 0	 0	  0	 0	     0	 0	 0	  0	 0	 0	  1	 0	 0	 0
         	 6	  0	  0	 0	66	  0	 0	 0	  1	 0	     0	 0	 0	  0	 1	 0	  2	 0	 0	 0
         	 2	  1	  0	 0	 1	134	 0	 0	  0	 1	     2	 0	 0	  0	 0	 0	  2	 0	 0	 0
         	 5	  0	  0	 0	 0	  0	77	 0	  0	 0	     0	 0	 0	  0	 1	 0	  1	 0	 0	 0
         	-6	  0	  0	 0	 0	  0	 0	73	  2	 1	     0	 1	 0	  0	 0	 0	  0	 0	 0	 0
         	 0	  0	  2	 0	 0	  1	 1	 1	162	 0	     1	 0	 0	  1	 1	 1	  1	 2	 0	 0
Predicted	-4	  0	  1	 0	 0	  0	 0	 1	  1	95	     0	 0	 0	  1	 0	 0	  1	 0	 0	 0
         	 3	  2	  0	 0	 0	  0	 1	 0	  0	 0	   119	 0	 0	  1	 0	 0	  1	 0	 