In [1]:
import torch
import pandas as pd
import numpy as np
from sklearn.metrics import accuracy_score
from ltn_imp.automation.knowledge_base import KnowledgeBase
from torch.utils.data import Dataset, DataLoader
from ltn_imp.automation.data_loaders import LoaderWrapper

## Data Prep

In [2]:
def create_dataset(n_samples, vector_size):
    assert n_samples % 2 == 0, "n_samples must be even"
    
    # Generate random vectors
    instances = np.random.rand(n_samples, vector_size)
    
    # Initialize lists to store pairs and labels
    pairs = []
    labels = []
    
    # Create same pairs
    for i in range(n_samples // 2):
        idx = np.random.choice(n_samples)
        pair = (instances[idx], instances[idx])
        pairs.append(pair)
        labels.append(1)
    
    # Create different pairs
    for _ in range(n_samples // 2):
        idx1, idx2 = np.random.choice(n_samples, 2, replace=False)
        pair = (instances[idx1], instances[idx2])
        pairs.append(pair)
        labels.append(0)
    
    # Convert to numpy arrays
    pairs = np.array(pairs)
    labels = np.array(labels)
    
    return pairs, labels

In [3]:
class SameDataset(Dataset):
    def __init__(self, pairs, labels):
        self.pairs = pairs
        self.labels = labels
    
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx):
        x = self.pairs[idx][0]
        y = self.pairs[idx][1]
        label = self.labels[idx]
        return torch.tensor(x, dtype=torch.float32), torch.tensor(y, dtype=torch.float32), torch.tensor(label, dtype=torch.float32)

In [4]:
# Parameters
n_samples = 1000  # Must be even
vector_size = 10

# Generate the dataset
pairs, labels = create_dataset(n_samples, vector_size)

# Create the dataset
dataset = SameDataset(pairs, labels)

# Create the DataLoader
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

## Model and Methods

In [5]:
import torch.nn as nn

class SameModel(nn.Module):
    def __init__(self, input_size):
        super(SameModel, self).__init__()
        self.fc1 = nn.Linear(input_size * 2, 64)
        self.fc2 = nn.Linear(64, 32)
        self.fc3 = nn.Linear(32, 2)
    
    def forward(self, x1, x2):
        x = torch.cat((x1, x2), dim=1)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x
    
class LogitsToPredicate(torch.nn.Module):
    def __init__(self, logits_model):
        super(LogitsToPredicate, self).__init__()
        self.logits_model = logits_model
        self.softmax = torch.nn.Softmax(dim=1)

    def forward(self, x, z, y):
        logits = self.logits_model(x,z)
        probs = self.softmax(logits)
        # y is expected to be a one-hot encoded vector
        out = torch.sum(probs * y, dim=1, keepdim=True)
        return out

In [6]:
def compute_accuracy(loader, model):
    mean_accuracy = 0.0
    for x,y, labels in loader:
        predictions = model.logits_model(x,y).detach().numpy()
        predictions = np.argmax(predictions, axis=1)
        mean_accuracy += accuracy_score(labels, predictions)
    return mean_accuracy / len(loader)

## Optimize 

In [7]:
a = SameModel(input_size=vector_size) # type: ignore
model = LogitsToPredicate(a) # type: ignore
predicates = {"Classifier": model}
expression_1 = "all a. all b. (Classifier(a,b,c) )"
rules = [expression_1]

In [8]:
loader = LoaderWrapper(variables=["a", "b"], num_classes=2, target="c", loader=dataloader)

In [9]:
rule_to_data_loader_mapping={expression_1: [loader]} # type: ignore

In [10]:
kb = KnowledgeBase(rules=rules, predicates=predicates,
                   rule_to_data_loader_mapping=rule_to_data_loader_mapping,
                   quantifier_impls={"forall" : "pmean_error"})

In [11]:
compute_accuracy(model = model, loader= dataloader)

0.5087890625

In [28]:
kb.optimize(num_epochs=15, log_steps=5)

Epoch 1/15, Loss: 0.0034572482109069824

Epoch 6/15, Loss: 0.00011670589447021484

Epoch 11/15, Loss: 8.428096771240234e-05



In [29]:
compute_accuracy(model = model, loader= dataloader)

0.904296875

In [None]:
temp =