In [None]:
import torch
import pandas as pd
import numpy as np

from ltn_imp.automation.knowledge_base import KnowledgeBase
from ltn_imp.automation.data_loaders import LoaderWrapper

## Data Preparation

In [None]:
!poetry run poe download-class-datasets

In [None]:
def compute_accuracy(loader, model, num_classes):
    class_correct = np.zeros(num_classes)
    class_total = np.zeros(num_classes)

    for data, labels in loader:
        predictions = model(data).detach().numpy()
        predictions = np.argmax(predictions, axis=1)
        labels = labels.numpy()  # Convert labels to numpy array

        for i in range(num_classes):
            class_mask = (labels == i)
            class_correct[i] += np.sum(predictions[class_mask] == labels[class_mask])
            class_total[i] += np.sum(class_mask)

    class_accuracy = class_correct / class_total

    # Print accuracy for each class
    for i in range(num_classes):
        print(f'Accuracy for class {i}: {class_accuracy[i]:.2f}')

    # Compute and print general accuracy
    overall_accuracy = np.sum(class_correct) / np.sum(class_total)
    print(f'Overall accuracy: {overall_accuracy:.2f}')

In [None]:
kb = KnowledgeBase("multi_config.yaml")

## My Implementation

In [None]:
compute_accuracy(kb.loaders[0], model = kb.predicates["Iris"], num_classes=3)

In [None]:
kb.optimize(num_epochs=801, log_steps=200)

In [None]:
compute_accuracy(kb.loaders[0], model = kb.predicates["Iris"], num_classes=3)

## Computation Graph 

In [None]:
batch = next(iter(kb.loaders[0]))

var_mapping = {}

kb.partition_data(var_mapping,batch, kb.loaders[0] )

kb.rules[0].comp_graph(var_mapping)