# DeepProbLog #

This project uses DeepProbLog, an extension of ProbLog that integrates Probabilistic Logic Programming with Deep Learning.

The old github repository of the project can be found at: https://bitbucket.org/problog/deepproblog.git

We can set the `seeds` for all `modules` used in order to make the results `reproducible`.
In this case we set it at `0`

In [1]:
SEED = 0

In [2]:
import random
import torch
import numpy

numpy.random.seed(SEED)
random.seed(SEED)
torch.manual_seed(SEED)

<torch._C.Generator at 0x235ca86b370>

## MNIST Digit Octal division ##

In this experiment, the task is to classify the octal division of two lists of MNIST digits 
representing multi-digit numbers. 

First, we create a ProbLog file containing the logic part of the program. The file will be saved as `datasetPL/octal.pl`.

```prolog
:- use_module('datasetPL/utils.py').

nn(mnist_net,[X],Y,[0,1,2,3,4,5,6,7,8,9]) :: digit(X,Y).

number([],Result,Result).
number([H|T],Acc,Result) :- 
    digit(H,Nr), 
    Acc2 is Nr+10*Acc,
    number(T,Acc2,Result).
number(X,Y) :- number(X,0,Y).

octal_div(X,Y,Z) :- number(X,X2), number(Y,Y2), oct_div(X2, Y2, Z1), Z is Z1.
```

The `use_module` command in problog allows us to leave some of the logic of the problog program to an external python module.
In this case what it does is pass to a python function 2 integer number in base 10 and return the octal division between them directly.
When the second number is equal to zero an `ERROR` value is returned making the statement automatically false 

Then, we create the queries files for both train and test, connecting MNIST images to instances of the `octal_div` Prolog predicate. 

In [3]:
def oct_div(n1: int, n2: int):
    """
    return the octal division between two number,
    if the second is zero or is not an integer divisor of the first
    None is returned
    """
    try:
        if n1 / n2 == n1 // n2:
            res = oct((n1 // n2))
        else:
            return None
    except ZeroDivisionError:
        return None

    return int(res[2:])

In [4]:
from torchvision.datasets import MNIST


datasets = {
    "train": MNIST(root="data/MNIST", train=True, download=True),
    "test": MNIST(root="data/MNIST", train=False, download=True),
}


# dataset name is train or test
# op is a (lambda) function for the operation to be learned
# length is the number of digits to be used
# out is the output file name
def generate_examples(dataset_name: str, op, length: int, out: str):
    dataset = datasets[dataset_name]
    indices = list(range(len(dataset)))
    random.shuffle(indices)
    i = iter(indices)
    examples = []
    while True:
        try:
            examples.append(next_example(i, dataset, op, length))
            # exception is raised when all digits in dataset have been used
        except StopIteration:
            break
    save_examples(dataset_name, examples, out)


def next_example(i, dataset, op, length):
    nr1, n1 = next_number(i, dataset, length)
    nr2, n2 = next_number(i, dataset, length)
    res = op(n1, n2)
    # We want the second number to be an integer divisor of first and not equal to zero
    while res is None:
        nr1, n1 = next_number(i, dataset, length)
        nr2, n2 = next_number(i, dataset, length)
        res = op(n1, n2)
    return nr1, nr2, op(n1, n2), n1, n2


def next_number(i, dataset, nr_digits):
    n = 0
    nr = []
    for _ in range(nr_digits):
        x = next(i)
        _, c = dataset[x]  # c is the digit that the image represents
        n = (
            n * 10 + c
        )  # the number is incrementally built from the sequence of its digits
        nr.append(str(x))  # nr is the list of ids of the digit images
    return nr, n


def save_examples(dataset_name, examples, out):
    with open(out, "w") as f:
        for example in examples:
            # number encoded as e.g. (test(9150),test(6809),test(1586))
            args1 = tuple("{}({})".format(dataset_name, e) for e in example[0])
            args2 = tuple("{}({})".format(dataset_name, e) for e in example[1])
            # example encoded as e.g.
            # octal_div([test(7215),test(9001)], [test(6072),test(1802)], 22).
            f.write(
                "octal_div([{}], [{}], {}).\n".format(
                    ",".join(args1), ",".join(args2), example[2]
                )
            )

### Dataset Creation ###

If you want to generate a new set of queries, you can do so by running the following commands

In [None]:
generate_examples(
    "train", lambda x, y: oct_div(x, y), 1, "datasetPL/train.txt"
)  # generate queries of one digit numbers
generate_examples(
    "test", lambda x, y: oct_div(x, y), 2, "datasetPL/test.txt"
)  # generate queries of two digit numbers

Train and test queries look like this:

```
octal_div([train(38999)], [train(37649)], 3).
octal_div([train(28340)], [train(58230)], 5).
octal_div([train(35361)], [train(22662)], 1).
...

octal_div([test(4157),test(6134)], [test(8634),test(1227)], 3).
octal_div([test(8553),test(1698)], [test(5902),test(7520)], 4).
octal_div([test(5688),test(8399)], [test(7533),test(6777)], 1).
...
```

We can now define a python class implementing a standard CNN for MNIST images, and a neural predicate connecting the image id (as found in the query) to the corresponding image and the sending it to the neural net. 

In [5]:
import torch.nn as nn
from torch.autograd import Variable


class MNIST_Net(nn.Module):
    def __init__(self, N=10):
        super(MNIST_Net, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 6, 5),
            nn.MaxPool2d(2, 2),  # 6 24 24 -> 6 12 12
            nn.ReLU(True),
            nn.Conv2d(6, 16, 5),  # 6 12 12 -> 16 8 8
            nn.MaxPool2d(2, 2),  # 16 8 8 -> 16 4 4
            nn.BatchNorm2d(16),  # This layer can be safely removed
            nn.ReLU(True),
        )
        self.classifier = nn.Sequential(
            nn.Linear(16 * 4 * 4, 120),
            nn.ReLU(),
            nn.Linear(120, 84),
            nn.ReLU(),
            nn.Linear(84, N),
            nn.Softmax(1),
        )

    def forward(self, x):
        x = self.encoder(x)
        x = x.view(-1, 16 * 4 * 4)
        x = self.classifier(x)
        return x


def neural_predicate(network, i):
    global mnist_train_data
    global mnist_test_data
    # i is something like train(2764) or test(4052)
    dataset = str(i.functor)
    i = int(i.args[0])
    if dataset == "train":
        d, l = mnist_train_data[i]
    elif dataset == "test":
        d, l = mnist_test_data[i]
    d = Variable(d.unsqueeze(0))
    output = network.net(d)
    return output.squeeze(0)

We can now load the mnist data, the queries and the problog file.
Note that `mnist_train_data` and `mnist_test_data` are global variables used inside `neural_predicate`.

In [6]:
from torchvision.datasets import MNIST
import torchvision.transforms as transforms
from deepproblog.data_loader import load

transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
)
mnist_train_data = MNIST(
    root="data/MNIST", train=True, download=True, transform=transform
)
mnist_test_data = MNIST(
    root="data/MNIST", train=False, download=True, transform=transform
)

train_queries = load("datasetPL/train.txt")
test_queries = load("datasetPL/test.txt")

with open("datasetPL/octal.pl") as f:
    problog_string = f.read()

# We print the problog file
print(problog_string)

:- use_module('datasetPL/utils.py').

nn(mnist_net,[X],Y,[0,1,2,3,4,5,6,7,8,9]) :: digit(X,Y).

number([],Result,Result).
number([H|T],Acc,Result) :- 
    digit(H,Nr), 
    Acc2 is Nr+10*Acc,
    number(T,Acc2,Result).
number(X,Y) :- number(X,0,Y).

octal_div(X,Y,Z) :- number(X,X2), number(Y,Y2), oct_div(X2, Y2, Z1), Z is Z1.


Finally, we can create the network and the DeepProbLog model with the network as neural predicate, and train it as a standard torch model.

In [7]:
from deepproblog.train import train_model
from deepproblog.network import Network
from deepproblog.model import Model


network = MNIST_Net()

# Network is a DeepProbLog class that wraps a pytorch networks and interfaces with ProbLog
net = Network(network, "mnist_net", neural_predicate)

We can choose between two different optimizers for the model that handles the prediction on the MNIST dataset

In [8]:
net.optimizer = torch.optim.Adam(network.parameters(), lr=0.001)

In [None]:
net.optimizer = torch.optim.SGD(
    network.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-1
)

In [9]:
from deepproblog.optimizer import Optimizer, SGD


def test(model):
    acc = model.accuracy(test_queries, test=True, verbose=False)
    return [("accuracy", acc)]


# Model is a DeepProbLog class that combines reasoning via the ProbLog code
# and neural processing via a list of Network objects
model = Model(problog_string, [net], caching=False)

We can choose between two different optimizers for the model that handles the resoning via the ProbLog code

In [None]:
optimizer = Optimizer(model, 2)

In [10]:
optimizer = SGD(model, accumulation=2, param_lr=0.001)

We define a name for the log file

In [None]:
log_name = f"test_ADAM_SGD_VERBOSE_{SEED}_"

We train our model and save the log file to the current folder

In [11]:
logger = train_model(
    model, train_queries, 2, optimizer, test_iter=500, test=test, snapshot_iter=5000
)
logger.write_to_file(log_name)

Training for 2 epochs (19494 iterations).
Wrong octal_div([test(620), test(1610)],[test(1638), test(36)],0) vs octal_div([test(620), test(1610)],[test(1638), test(36)],1)
Wrong octal_div([test(1595), test(2190)],[test(1674), test(660)],5) vs octal_div([test(1595), test(2190)],[test(1674), test(660)],1)
Wrong octal_div([test(2714), test(3911)],[test(9316), test(5854)],0) vs octal_div([test(2714), test(3911)],[test(9316), test(5854)],3)
Wrong octal_div([test(8182), test(9922)],[test(2687), test(7447)],136) vs octal_div([test(8182), test(9922)],[test(2687), test(7447)],0)
Wrong octal_div([test(6153), test(4050)],[test(4862), test(6085)],33) vs octal_div([test(6153), test(4050)],[test(4862), test(6085)],0)
Wrong octal_div([test(4429), test(6281)],[test(5075), test(9613)],1) vs octal_div([test(4429), test(6281)],[test(5075), test(9613)],0)
Wrong octal_div([test(8973), test(8175)],[test(7655), test(9283)],11) vs octal_div([test(8973), test(8175)],[test(7655), test(9283)],1)
Wrong octal_div([



Wrong octal_div([test(3242), test(3266)],[test(7751), test(6991)],0) vs octal_div([test(3242), test(3266)],[test(7751), test(6991)],1)
Wrong octal_div([test(4234), test(9779)],[test(5538), test(8625)],1) vs octal_div([test(4234), test(9779)],[test(5538), test(8625)],0)
Wrong octal_div([test(6754), test(9982)],[test(2761), test(2768)],5) vs octal_div([test(6754), test(9982)],[test(2761), test(2768)],0)
Wrong octal_div([test(8338), test(4720)],[test(3514), test(9181)],0) vs octal_div([test(8338), test(4720)],[test(3514), test(9181)],3)
Wrong octal_div([test(5269), test(3395)],[test(9310), test(6144)],30) vs octal_div([test(5269), test(3395)],[test(9310), test(6144)],1)
Wrong octal_div([test(1650), test(1811)],[test(701), test(3772)],1) vs octal_div([test(1650), test(1811)],[test(701), test(3772)],0)
Wrong octal_div([test(9172), test(7791)],[test(9911), test(240)],16) vs octal_div([test(9172), test(7791)],[test(9911), test(240)],0)
Wrong octal_div([test(7632), test(9936)],[test(2942), tes