In [1]:
from ltn_imp.automation.knowledge_base import KnowledgeBase
import torch
import pandas as pd

In [2]:
!poetry run poe download-medical-datasets

[37mPoe =>[0m [94mmkdir -p examples/medical/datasets[0m
[37mPoe =>[0m [94mcurl -L -o examples/medical/datasets/pima_indians_imputed.csv https://raw.githubusercontent.com/ChristelSirocchi/hybrid-ML/main/pima_indians_imputed.csv[0m
  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100 33428  100 33428    0     0   331k      0 --:--:-- --:--:-- --:--:--  333k


In [3]:
import random
import numpy as np
import torch

# Set the random seed for reproducibility
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
torch.cuda.manual_seed(42)
torch.backends.cudnn.deterministic = True

In [4]:
from sklearn.model_selection import train_test_split
test_data = pd.read_csv('datasets/pima_indians_imputed.csv').astype(float)
x_train, x_test = train_test_split(test_data, test_size=0.5, random_state=42)
x_train.to_csv('datasets/train.csv')
x_test.to_csv('datasets/test.csv')

In [5]:
import torch

def predict(model, x):
    model.eval()  # Ensure the model is in evaluation mode
    with torch.no_grad():  # No need to track gradients
        # Ensure x is a tensor and has the right dtype
        if not isinstance(x, torch.Tensor):
            x = torch.tensor(x, dtype=torch.float32)
        elif x.dtype != torch.float32:
            x = x.float()
        
        # Forward pass through the model
        probs = model(x)
        
        # Apply binary classification threshold at 0.5
        preds = (probs > 0.5).float()
    return preds

def compute_metrics(model, data_loader):
    correct = 0
    total = 0
    
    true_positives = 0
    false_positives = 0
    false_negatives = 0
    true_negatives = 0
    
    model.eval()  # Set the model to evaluation mode
    with torch.no_grad():  # Disable gradient computation
        for data, labels in data_loader:
            # Ensure data and labels are the correct dtype
            if not isinstance(data, torch.Tensor):
                data = torch.tensor(data, dtype=torch.float32)
            elif data.dtype != torch.float32:
                data = data.float()
            
            if not isinstance(labels, torch.Tensor):
                labels = torch.tensor(labels, dtype=torch.float32)
            elif labels.dtype != torch.float32:
                labels = labels.float()
            
            # Get predictions
            preds = predict(model, data)
            
            # Squeeze predictions and labels to remove dimensions of size 1
            predicted_labels = preds.squeeze()
            true_labels = labels.squeeze()

            # Ensure the shapes match before comparison
            if predicted_labels.shape != true_labels.shape:
                true_labels = true_labels.view_as(predicted_labels)
            
            # Count correct predictions
            correct += (predicted_labels == true_labels).sum().item()
            total += true_labels.size(0)
            
            # Calculate TP, FP, FN, TN
            true_positives += ((predicted_labels == 1) & (true_labels == 1)).sum().item()
            false_positives += ((predicted_labels == 1) & (true_labels == 0)).sum().item()
            false_negatives += ((predicted_labels == 0) & (true_labels == 1)).sum().item()
            true_negatives += ((predicted_labels == 0) & (true_labels == 0)).sum().item()
    
    accuracy = correct / total if total > 0 else 0
    precision = true_positives / (true_positives + false_positives) if (true_positives + false_positives) > 0 else 0
    recall = true_positives / (true_positives + false_negatives) if (true_positives + false_negatives) > 0 else 0
    specificity = true_negatives / (true_negatives + false_positives) if (true_negatives + false_positives) > 0 else 0

    print(f"True Positives: {true_positives}, False Positives: {false_positives}, False Negatives: {false_negatives}, True Negatives: {true_negatives}")
    print()
    print(f"Accuracy: {accuracy:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, Specificity: {specificity:.4f}")

In [6]:
kb = KnowledgeBase("medical_config.yaml")

In [7]:
x, y = next(iter(kb.loaders[0]))
test_data = pd.DataFrame(x)
high = test_data[(test_data[5] > 29) & (test_data[1] > 125)]
low = test_data[(test_data[5] <= 25) & (test_data[1] <= 100)]

In [8]:
kb.predicates["Diabetic"](torch.tensor(high.values, dtype=torch.float32)).mean()

tensor(0.0191, grad_fn=<MeanBackward0>)

In [9]:
kb.predicates["Diabetic"](torch.tensor(low.values, dtype=torch.float32)).mean()

tensor(0.0353, grad_fn=<MeanBackward0>)

In [10]:
compute_metrics(kb.predicates["Diabetic"], kb.loaders[0])

True Positives: 0, False Positives: 0, False Negatives: 138, True Negatives: 246

Accuracy: 0.6406, Precision: 0.0000, Recall: 0.0000, Specificity: 1.0000


In [11]:
compute_metrics(kb.predicates["Diabetic"], kb.loaders[1])

True Positives: 0, False Positives: 0, False Negatives: 130, True Negatives: 254

Accuracy: 0.6615, Precision: 0.0000, Recall: 0.0000, Specificity: 1.0000


In [12]:
kb.optimize(num_epochs=1001, log_steps=500, lr=0.001)

['∀ person.(((y == diabetes) -> Diabetic(person)))', '∀ person.(((y == healthy) -> ~(Diabetic(person))))', '∀ person.((((person[BMI] < 26) & (person[Glucose] < 101)) -> ~(Diabetic(person))))', '∀ person.((((person[BMI] > 29) & (person[Glucose] > 125)) -> Diabetic(person)))', '∀ person.((((person[Glucose] > 143.5) & (person[DiabetesPedigreeFunction] > 0.32)) -> Diabetic(person)))', '∀ person.((((person[Glucose] > 143.5) & ((person[DiabetesPedigreeFunction] <= 0.32) & (person[BMI] <= 31.40))) -> ~(Diabetic(person))))', '∀ person.((((person[Glucose] > 143.5) & ((person[DiabetesPedigreeFunction] <= 0.32) & (person[BMI] > 31.40))) -> Diabetic(person)))', '∀ person.((((person[Glucose] <= 143.5) & ((person[Pregnancies] <= 7.5) & (person[BMI] <= 45.44))) -> ~(Diabetic(person))))', '∀ person.((((person[Glucose] <= 143.5) & ((person[Pregnancies] <= 7.5) & (person[BMI] > 45.44))) -> Diabetic(person)))', '∀ person.((((person[Glucose] <= 143.5) & ((person[Pregnancies] > 7.5) & (person[DiabetesPedig

In [13]:
compute_metrics(kb.predicates["Diabetic"], kb.loaders[0])

True Positives: 42, False Positives: 21, False Negatives: 96, True Negatives: 225

Accuracy: 0.6953, Precision: 0.6667, Recall: 0.3043, Specificity: 0.9146


In [14]:
compute_metrics(kb.predicates["Diabetic"], kb.loaders[1])

True Positives: 43, False Positives: 24, False Negatives: 87, True Negatives: 230

Accuracy: 0.7109, Precision: 0.6418, Recall: 0.3308, Specificity: 0.9055


In [15]:
kb.predicates["Diabetic"](torch.tensor(high.values, dtype=torch.float32)).mean()

tensor(0.4788, grad_fn=<MeanBackward0>)

In [16]:
kb.predicates["Diabetic"](torch.tensor(low.values, dtype=torch.float32)).mean().round()

tensor(0., grad_fn=<RoundBackward0>)