In [1]:
import torch
import ltn
import numpy as np
from ltn_imp.automation.knowledge_base import KnowledgeBase
from sklearn.metrics import accuracy_score
import numpy as np
from ltn_imp.fuzzy_operators.aggregators import SatAgg

## Data and Model Preparation

In [2]:
class ModelA(torch.nn.Module):
    def __init__(self):
        super(ModelA, self).__init__()
        self.elu = torch.nn.ELU()
        self.layer1 = torch.nn.Linear(2, 16)
        self.layer2 = torch.nn.Linear(16, 16)
        self.layer3 = torch.nn.Linear(16, 2) 

    def forward(self, x):
        x = self.elu(self.layer1(x))
        x = self.elu(self.layer2(x))
        logits = self.layer3(x)
        return logits

class LogitsToPredicate(torch.nn.Module):
    def __init__(self, logits_model):
        super(LogitsToPredicate, self).__init__()
        self.logits_model = logits_model
        self.softmax = torch.nn.Softmax(dim=1)

    def forward(self, x, y, training=True):
        logits = self.logits_model(x)
        probs = self.softmax(logits)
        # y is expected to be a one-hot encoded vector
        out = torch.sum(probs * y, dim=1, keepdim=True)
        return out

In [3]:
nr_samples = 100
dataset = torch.rand((nr_samples, 2))
labels_dataset = torch.sum(torch.square(dataset - torch.tensor([.5, .5])), dim=1) < .09

In [4]:
class DataLoader(object):
    def __init__(self,
                 data,
                 labels,
                 batch_size=1,
                 shuffle=True):
        
        self.data = data
        self.labels = labels
        self.batch_size = batch_size
        self.shuffle = shuffle

    def __len__(self):
        return int(np.ceil(self.data.shape[0] / self.batch_size))

    def __iter__(self):
        n = self.data.shape[0]
        idxlist = list(range(n))
        if self.shuffle:
            np.random.shuffle(idxlist)

        for _, start_idx in enumerate(range(0, n, self.batch_size)):
            end_idx = min(start_idx + self.batch_size, n)
            data = self.data[idxlist[start_idx:end_idx]]
            labels = self.labels[idxlist[start_idx:end_idx]]
            yield data, labels
            
train_loader = DataLoader(dataset[:50], labels_dataset[:50], batch_size=64, shuffle=True)
test_loader = DataLoader(dataset[50:], labels_dataset[50:], batch_size=64, shuffle=False)

In [5]:
def compute_accuracy(loader, model):
    mean_accuracy = 0.0
    for data, labels in loader:
        predictions = model(data).detach().numpy()
        predictions = np.argmax(predictions, axis=1)
        mean_accuracy += accuracy_score(labels, predictions)
    return mean_accuracy / len(loader)

## My Implementation 

In [6]:
a = ModelA() # type: ignore
model = LogitsToPredicate(a) # type: ignore
predicates = {"Classifier": model}
expression_1 = "all x. (Classifier(x,y))"
rules = [expression_1]

In [7]:
rule_to_data_loader_mapping={expression_1: [train_loader]}

In [8]:
loader_to_variable = {train_loader: "x"}

In [9]:
kb = KnowledgeBase(rules=rules, predicates=predicates,
                   rule_to_data_loader_mapping=rule_to_data_loader_mapping,
                   num_classes=2,
                   loader_to_variable=loader_to_variable, # type: ignore
                   target_variable="y",
                   quantifier_impls={"forall" : "pmean_error"})

Shape: torch.Size([16, 2])
Requires Grad: True
--------------------------------------------------
Shape: torch.Size([16])
Requires Grad: True
--------------------------------------------------
Shape: torch.Size([16, 16])
Requires Grad: True
--------------------------------------------------
Shape: torch.Size([16])
Requires Grad: True
--------------------------------------------------
Shape: torch.Size([2, 16])
Requires Grad: True
--------------------------------------------------
Shape: torch.Size([2])
Requires Grad: True
--------------------------------------------------


In [10]:
compute_accuracy(model = a, loader= test_loader)

0.26

In [11]:
compute_accuracy(model = a, loader= train_loader)

0.28

In [12]:
kb.optimize(801, log_steps=200)

Rule <function ExpressionVisitor.visit_QuantifiedExpression.<locals>.<lambda> at 0x135d6fba0> produced output: 0.42747730016708374 for class 0
Rule <function ExpressionVisitor.visit_QuantifiedExpression.<locals>.<lambda> at 0x135d6fba0> produced output: 0.5713341236114502 for class 1
Epoch 1/801, Loss: 0.505735456943512

Rule <function ExpressionVisitor.visit_QuantifiedExpression.<locals>.<lambda> at 0x135d6fba0> produced output: 0.4767915606498718 for class 0
Rule <function ExpressionVisitor.visit_QuantifiedExpression.<locals>.<lambda> at 0x135d6fba0> produced output: 0.6507197618484497 for class 1
Epoch 201/801, Loss: 0.4448279142379761

Rule <function ExpressionVisitor.visit_QuantifiedExpression.<locals>.<lambda> at 0x135d6fba0> produced output: 0.5964967012405396 for class 0
Rule <function ExpressionVisitor.visit_QuantifiedExpression.<locals>.<lambda> at 0x135d6fba0> produced output: 0.6684198975563049 for class 1
Epoch 401/801, Loss: 0.3692967891693115

Rule <function ExpressionVi

In [13]:
compute_accuracy(model = a, loader= train_loader)

1.0

In [14]:
compute_accuracy(model = a, loader= test_loader)

0.88

## LTN

In [15]:
class ModelA(torch.nn.Module):
    def __init__(self):
        super(ModelA, self).__init__()
        self.sigmoid = torch.nn.Sigmoid()
        self.layer1 = torch.nn.Linear(2, 16)
        self.layer2 = torch.nn.Linear(16, 16)
        self.layer3 = torch.nn.Linear(16, 1)
        self.elu = torch.nn.ELU() 

    def forward(self, x):
        x = self.elu(self.layer1(x))
        x = self.elu(self.layer2(x))
        return self.sigmoid(self.layer3(x)) 

In [16]:
A = ltn.Predicate(ModelA()) # type: ignore
Not = ltn.Connective(ltn.fuzzy_ops.NotStandard())
Forall = ltn.Quantifier(ltn.fuzzy_ops.AggregPMeanError(p=2), quantifier="f")
SatAgg = ltn.fuzzy_ops.SatAgg()

In [17]:
def compute_sat_level(loader):
    mean_sat = 0
    for data, labels in loader:
        
        x_A = ltn.Variable("x_A", data[torch.nonzero(labels)])  # positive examples
        x_not_A = ltn.Variable("x_not_A", data[torch.nonzero(torch.logical_not(labels))])  # negative examples

        mean_sat += SatAgg(
            Forall(x_A, A(x_A)),
            Forall(x_not_A, Not(A(x_not_A)))
        )
        
    mean_sat /= len(loader)
    return mean_sat

def compute_accuracy(loader):
    mean_accuracy = 0.0
    for data, labels in loader:
        predictions = A.model(data).detach().numpy()
        predictions = np.where(predictions > 0.5, 1., 0.).flatten()
        mean_accuracy += accuracy_score(labels, predictions)

    return mean_accuracy / len(loader)

In [18]:
optimizer = torch.optim.Adam(A.parameters(), lr=0.001)


for epoch in range(801):
    train_loss = 0.0
    for batch_idx, (data, labels) in enumerate(train_loader):
        optimizer.zero_grad()

        x_A = ltn.Variable("x_A", data[torch.nonzero(labels)]) # positive examples
        x_not_A = ltn.Variable("x_not_A", data[torch.nonzero(torch.logical_not(labels))]) # negative examples

        sat_agg = SatAgg(
            Forall(x_A, A(x_A)),
            Forall(x_not_A, Not(A(x_not_A)))
        )

        loss = 1. - sat_agg
        loss.backward()
        optimizer.step()
        train_loss += loss.item()

    train_loss = train_loss / len(train_loader)

    if epoch % 200 == 0:
        print(" epoch %d | loss %.4f | Train Sat %.3f | Test Sat %.3f | Train Acc %.3f | Test Acc %.3f"
        %(epoch, train_loss, compute_sat_level(train_loader), compute_sat_level(test_loader),
            compute_accuracy(train_loader), compute_accuracy(test_loader)))
        
        print()
        print(f"Positive { Forall(x_A, A(x_A)) }")
        print(f"Negative { Forall(x_not_A, Not(A(x_not_A)))}") 
        print()

 epoch 0 | loss 0.5054 | Train Sat 0.495 | Test Sat 0.496 | Train Acc 0.720 | Test Acc 0.740

Positive LTNObject(value=tensor(0.4311, grad_fn=<RsubBackward1>), free_vars=[])
Negative LTNObject(value=tensor(0.5687, grad_fn=<RsubBackward1>), free_vars=[])

 epoch 200 | loss 0.4411 | Train Sat 0.559 | Test Sat 0.477 | Train Acc 0.720 | Test Acc 0.540

Positive LTNObject(value=tensor(0.6448, grad_fn=<RsubBackward1>), free_vars=[])
Negative LTNObject(value=tensor(0.4875, grad_fn=<RsubBackward1>), free_vars=[])

 epoch 400 | loss 0.3516 | Train Sat 0.649 | Test Sat 0.523 | Train Acc 0.880 | Test Acc 0.600

Positive LTNObject(value=tensor(0.7034, grad_fn=<RsubBackward1>), free_vars=[])
Negative LTNObject(value=tensor(0.6023, grad_fn=<RsubBackward1>), free_vars=[])

 epoch 600 | loss 0.3002 | Train Sat 0.700 | Test Sat 0.525 | Train Acc 0.900 | Test Acc 0.640

Positive LTNObject(value=tensor(0.7955, grad_fn=<RsubBackward1>), free_vars=[])
Negative LTNObject(value=tensor(0.6282, grad_fn=<RsubBa