In [1]:
datasets = ['CUB', 'Derm7pt', 'RIVAL10']
use_dataset = datasets[1]

In [2]:
import os
import sys

notebook_dir = os.getcwd()
project_root_path = os.path.dirname(notebook_dir)
sys.path.insert(0, project_root_path)

from src.config import CUB_CONFIG, DERM7PT_CONFIG, RIVAL10_CONFIG  # noqa: E402
from src.config import PROJECT_ROOT  # noqa: E402
import numpy as np  # noqa: E402

In [3]:
if use_dataset == 'CUB':
    config_dict = CUB_CONFIG
    DATASET_PATH =  os.path.join(PROJECT_ROOT, 'output', 'CUB')
elif use_dataset == 'Derm7pt':
    config_dict = DERM7PT_CONFIG
    DATASET_PATH =  os.path.join(PROJECT_ROOT, 'output', 'Derm7pt')
else:
    config_dict = RIVAL10_CONFIG
    DATASET_PATH =  os.path.join(PROJECT_ROOT, 'output', 'RIVAL10')

# Load and Transform Data

In [4]:
# INSTANCE-BASED CUB MODEL

# C_train = np.load(os.path.join(PROJECT_ROOT, 'output', 'CUB', 'C_train_instance.npy'))
# C_hat_train = np.load(os.path.join(PROJECT_ROOT, 'output', 'CUB', 'C_hat_sigmoid_train_instance.npy'))
# one_hot_Y_train = np.load(os.path.join(PROJECT_ROOT, 'output', 'CUB', 'Y_train_instance.npy'))

# C_test = np.load(os.path.join(PROJECT_ROOT, 'output', 'CUB', 'C_test_instance.npy'))
# C_hat_test = np.load(os.path.join(PROJECT_ROOT, 'output', 'CUB', 'C_hat_sigmoid_test_instance.npy'))
# one_hot_Y_test = np.load(os.path.join(PROJECT_ROOT, 'output', 'CUB', 'Y_test_instance.npy'))

In [5]:
C_hat_train = np.load(os.path.join(DATASET_PATH, 'C_hat_sigmoid_train.npy'))
one_hot_Y_train = np.load(os.path.join(DATASET_PATH, 'Y_train.npy'))

C_hat_test = np.load(os.path.join(DATASET_PATH, 'C_hat_sigmoid_test.npy'))
one_hot_Y_test = np.load(os.path.join(DATASET_PATH, 'Y_test.npy'))

if use_dataset == 'Derm7pt':
    C_hat_val = np.load(os.path.join(DATASET_PATH, 'C_hat_sigmoid_val.npy'))
    one_hot_Y_val = np.load(os.path.join(DATASET_PATH, 'Y_val.npy'))

    C_hat_train = np.concatenate((C_hat_train, C_hat_val), axis=0)
    one_hot_Y_train = np.concatenate((one_hot_Y_train, one_hot_Y_val), axis=0)

class_level_concepts = np.load(os.path.join(DATASET_PATH, 'class_level_concepts.npy'))

In [6]:
Y_train = np.argmax(one_hot_Y_train, axis=1)
Y_test = np.argmax(one_hot_Y_test, axis=1)

In [7]:
C_train = []
for y in Y_train:
    C_train.append(class_level_concepts[y])

C_train = np.array(C_train)

In [8]:
from sklearn.utils import shuffle

C_hat_train, C_train, one_hot_Y_train, Y_train = shuffle(C_hat_train, C_train, one_hot_Y_train, Y_train, random_state=42)

In [9]:
# unique, counts = np.unique(Y_train, return_counts=True)
# for label, count in zip(unique, counts):
#     print(f"Label {label}: {count} instances")

In [10]:
# unique, counts = np.unique(Y_test, return_counts=True)
# for label, count in zip(unique, counts):
#     print(f"Label {label}: {count} instances")

# Classic Models

## Logistic Regression

In [11]:
from sklearn.linear_model import LogisticRegression

lin_model = LogisticRegression(max_iter=1000)
lin_model.fit(C_hat_train, Y_train)
print(f"Logistic Regression Test accuracy: {lin_model.score(C_hat_test, Y_test)}")

Logistic Regression Test accuracy: 0.6632911392405063


In [12]:
np.unique(lin_model.predict(C_hat_test))

array([0, 1, 2, 3, 4])

Save incorrectly classified instances for intervention experiments.

In [13]:
Y_pred = lin_model.predict(C_hat_test)
wrong_indices = np.where(Y_test != Y_pred)[0]

C_hat_wrong = C_hat_test[wrong_indices]
Y_wrong = Y_test[wrong_indices]

print(C_hat_wrong.shape)
print(Y_wrong.shape)

output_dir = os.path.join(PROJECT_ROOT, 'output', 'intervention', use_dataset)
os.makedirs(output_dir, exist_ok=True)
np.save(os.path.join(output_dir, 'C_hat_linear.npy'), C_hat_wrong)
np.save(os.path.join(output_dir, 'Y_linear.npy'), Y_wrong)

(133, 19)
(133,)


In [14]:
import joblib
model_dir = os.path.join(PROJECT_ROOT, 'models', use_dataset)
os.makedirs(model_dir, exist_ok=True)

joblib.dump(lin_model, os.path.join(model_dir, 'lin_model.joblib'))

['/Users/pb/Documents/career/lab_ujm/hybrid-cbm-prototype-model/models/Derm7pt/lin_model.joblib']

In [None]:
# np.mean(np.abs(lin_model.coef_), axis=1)

array([0.41021713, 0.40787597, 0.40497543, 0.3868018 , 0.42114473])

## k-NN

In [15]:
from sklearn.neighbors import KNeighborsClassifier

model = KNeighborsClassifier()
model.fit(C_hat_train, Y_train)
print(f"k-NN Test accuracy: {model.score(C_hat_test, Y_test)}")

k-NN Test accuracy: 0.6


In [16]:
np.unique(model.predict(C_hat_test))

array([0, 1, 2, 3, 4])

## Decision Tree

In [17]:
from sklearn.tree import DecisionTreeClassifier

model = DecisionTreeClassifier()
model.fit(C_hat_train, Y_train)
print(f"Decision Tree Test accuracy: {model.score(C_hat_test, Y_test)}")

Decision Tree Test accuracy: 0.5670886075949367


In [18]:
np.unique(model.predict(C_hat_test))

array([0, 1, 2, 3, 4])

## MLP

In [19]:
from sklearn.neural_network import MLPClassifier

mlp = MLPClassifier(hidden_layer_sizes=(512,256, 128), max_iter=1000)
mlp.fit(C_hat_train, Y_train)
print(f"MLP Test accuracy: {mlp.score(C_hat_test, Y_test)}")

MLP Test accuracy: 0.660759493670886


In [20]:
np.unique(model.predict(C_hat_test))

array([0, 1, 2, 3, 4])

# Accuracy Using Class-Level Concepts

In [21]:
# Calculate differences between each test instance and all class-level concepts
distances = []

for i, test_instance in enumerate(C_hat_test):
    # Calculate absolute differences to each class concept
    # This gives element-wise differences between probabilities and binary values
    instance_diffs = np.abs(test_instance - class_level_concepts)

    # Sum the differences along concept dimension to get total deviation for each class
    instance_distances = np.sum(instance_diffs, axis=1)

    # Find the minimum distance
    min_distance = np.min(instance_distances)
    # Find which class has the minimum distance
    min_class = np.argmin(instance_distances)
    distances.append((min_distance, min_class))

# Convert to numpy arrays for easier analysis
min_distances = np.array([d[0] for d in distances])
predicted_classes = np.array([d[1] for d in distances])

# Calculate accuracy
accuracy = np.mean(predicted_classes == Y_test)
print(f"Accuracy using absolute difference to class concepts: {accuracy:.4f}")

Accuracy using absolute difference to class concepts: 0.6481


# Prototype-Based Model


In [22]:
import torch
from torch.utils.data import TensorDataset, DataLoader

In [23]:
device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")

Using device: mps


## Create Dataloaders

In [24]:
val_split_ratio = 0.2
random_seed = 42

if use_dataset == 'Derm7pt':
    X_train = torch.tensor(C_hat_train, dtype=torch.float32)
    Y_train = torch.tensor(one_hot_Y_train, dtype=torch.float32)
else:
    # C_hat_train, C_hat_val, Y_train_np, Y_val_np = train_test_split(C_hat_train, one_hot_Y_train, test_size=val_split_ratio, random_state=random_seed)
    X_train = torch.tensor(C_hat_train, dtype=torch.float32)
    Y_train = torch.tensor(one_hot_Y_train, dtype=torch.float32)

X_test = torch.tensor(C_hat_test, dtype=torch.float32, device=device)
Y_test = torch.tensor(one_hot_Y_test, dtype=torch.float32, device=device)

# DATALOADERS
batch_size = 64
train_dataset = TensorDataset(X_train, Y_train)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

test_dataset = TensorDataset(X_test, Y_test)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

## Learn Prototypes

In [25]:
from src.models import PrototypeClassifier

num_concepts = config_dict['N_TRIMMED_CONCEPTS']
num_classes = config_dict['N_CLASSES']

model = PrototypeClassifier(num_concepts, num_classes).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
lambda_binary = 0.01
lambda_L1 = 0.001

In [26]:
# train and test
from tqdm import tqdm
from src.training import train_epoch

num_epochs = 50
best_acc, best_epoch = 0, 0

tqdm_loader = tqdm(range(num_epochs), desc="Training Prototypes", leave=True)
for epoch in tqdm_loader:
    train_loss, train_accuracy = train_epoch(model, train_loader, optimizer, lambda_binary, lambda_L1, device=device)
    if train_accuracy > best_acc:
        best_acc = train_accuracy
        best_epoch = epoch
    tqdm_loader.set_postfix({"Train Acc": f"{train_accuracy:.2f}%", "Train Loss": f"{train_loss:.4f}"})

print(f"Best accuracy of {best_acc:.2f}% achieved at epoch {best_epoch}")

KeyboardInterrupt: 

In [None]:
real_labels = Y_test.argmax(dim=1)
predictions = model.predict(X_test)
(predictions == real_labels).sum().item()/len(predictions)

0.9895951570185395

In [None]:
wrong_indices = (predictions != real_labels).nonzero(as_tuple=True)[0]

C_hat_wrong = X_test[wrong_indices]
Y_wrong = real_labels[wrong_indices]

print(C_hat_wrong.shape)
print(Y_wrong.shape)

output_dir = os.path.join(PROJECT_ROOT, 'output', 'intervention', use_dataset)

torch.save(C_hat_wrong, os.path.join(output_dir, 'C_hat_clc.npy'))
torch.save(Y_wrong, os.path.join(output_dir, 'Y_clc.npy'))

torch.Size([55, 18])
torch.Size([55])


In [None]:
# save the model
torch.save(model, os.path.join(model_dir, 'clc_model.pth'))

In [None]:
# np.unique(predictions.cpu().numpy())

In [None]:
from sklearn.metrics import classification_report

y_true = real_labels.cpu().numpy()
y_pred = predictions.cpu().numpy()

print(classification_report(y_true, y_pred))

              precision    recall  f1-score   support

           0       1.00      1.00      1.00       533
           1       0.99      0.99      0.99       534
           2       1.00      0.99      0.99       532
           3       0.99      0.99      0.99       530
           4       0.99      1.00      0.99       531
           5       0.99      0.99      0.99       533
           6       0.99      1.00      1.00       530
           7       0.99      0.98      0.98       529
           8       0.99      1.00      0.99       532
           9       0.97      0.98      0.98       502

    accuracy                           0.99      5286
   macro avg       0.99      0.99      0.99      5286
weighted avg       0.99      0.99      0.99      5286



In [None]:
close_to_zero = (torch.sum((model.get_sigmoid_prototypes() < 0.1) | (model.get_sigmoid_prototypes() > 0.9)) / (200*112)).cpu().numpy()
print(f"{close_to_zero*100}% of the values are close to 0 or 1")

0.8035714626312256% of the values are close to 0 or 1


# Class-level vs Learned

In [None]:
# print(class_level_concepts)

# Overall sparsity (fraction of zeros)
overall_sparsity = np.mean(class_level_concepts == 0)
print(f"Overall sparsity: {overall_sparsity:.4f}")

# # Sparsity per row (fraction of zeros in each row)
# row_sparsity = np.mean(class_level_concepts == 0, axis=1)
# print("Sparsity per row:", row_sparsity)

Overall sparsity: 0.7500


In [None]:
Prototypes = model.get_binary_prototypes()
Prototypes = Prototypes.cpu().detach().numpy()
# print(Prototypes)

# Overall sparsity (fraction of zeros)
overall_sparsity = np.mean(Prototypes == 0)
print(f"Overall sparsity: {overall_sparsity:.4f}")

# Sparsity per row (fraction of zeros in each row)
# row_sparsity = np.mean(Prototypes == 0, axis=1)
# print("Sparsity per row:", row_sparsity)

Overall sparsity: 0.7500


In [None]:
# Prototypes = model.get_sigmoid_prototypes()
# Prototypes = Prototypes.cpu().detach().numpy()
# print(Prototypes)

In [None]:
close_to_half = (torch.sum((model.get_sigmoid_prototypes() > 0.4) & (model.get_sigmoid_prototypes() < 0.6)) / (200*112)).cpu().numpy()
print(f"{close_to_half*100}% of the values are close to 0.5")

0.0% of the values are close to 0.5


In [None]:
class_level_concepts

array([[0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1],
       [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1],
       [1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1],
       [1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
       [1, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 1],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1],
       [0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1, 1],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0],
       [0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 1, 0]])

In [None]:
print(model.get_binary_prototypes().cpu().detach().numpy())

[[0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 1.]
 [0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 1. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 1. 1. 0. 0. 0. 1. 0. 0. 0. 0. 0. 1.]
 [1. 0. 0. 0. 1. 0. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 1.]
 [1. 0. 0. 0. 0. 1. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]
 [1. 0. 0. 0. 0. 0. 1. 0. 1. 1. 0. 1. 0. 0. 0. 0. 0. 1.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 1.]
 [0. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 1. 1. 1.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 1. 1. 1. 1. 0.]
 [0. 0. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 1. 1. 0. 1. 1. 0.]]


In [None]:
def inverse_sigmoid(x):
    epsilon = 1e-7
    x = np.clip(x, epsilon, 1 - epsilon)
    return np.log(x / (1 - x))

In [None]:
if use_dataset == 'Derm7pt':
    new_prototypes = model.get_sigmoid_prototypes().cpu().detach().numpy()
    # Apply inverse sigmoid to convert binary class-level concepts to logits
    for i in range(5):
        new_prototypes[i] = inverse_sigmoid(class_level_concepts[i])

    new_prototypes = torch.tensor(new_prototypes, device=device, dtype=torch.float32)
    print(new_prototypes)

    model.modify_prototypes(new_prototypes)

In [None]:
model.get_binary_prototypes().cpu().detach().numpy()

array([[0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.,
        0., 1.],
       [0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 0., 0.,
        0., 0.],
       [0., 0., 0., 0., 0., 0., 1., 1., 0., 0., 0., 1., 0., 0., 0., 0.,
        0., 1.],
       [1., 0., 0., 0., 1., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0.,
        0., 1.],
       [1., 0., 0., 0., 0., 1., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0.,
        0., 0.],
       [1., 0., 0., 0., 0., 0., 1., 0., 1., 1., 0., 1., 0., 0., 0., 0.,
        0., 1.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.,
        0., 1.],
       [0., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 1.,
        1., 1.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1.,
        1., 0.],
       [0., 0., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 0., 1.,
        1., 0.]], dtype=float32)

In [None]:
# save the model
torch.save(model, os.path.join(model_dir, 'clc_model.pth'))

# MY OLD CODE

In [None]:
# # --- Plotting ---
# from matplotlib import pyplot as plt

# plt.figure(figsize=(10, 5))
# epochs_range = range(1, epochs + 1)
# plt.plot(epochs_range, train_losses, label='Training Loss', marker='o', linestyle='-')
# plt.plot(epochs_range, val_losses, label='Validation Loss', marker='x', linestyle='--')
# plt.title('Training and Validation Loss Over Epochs')
# plt.xlabel('Epoch')
# plt.ylabel('Average Loss')
# plt.legend()
# plt.grid(True)
# plt.show()

# # Optional: Plot validation accuracy as well
# plt.figure(figsize=(10, 5))
# plt.plot(epochs_range, val_accuracies, label='Validation Accuracy', marker='s', linestyle='-', color='green')
# plt.title('Validation Accuracy Over Epochs')
# plt.xlabel('Epoch')
# plt.ylabel('Accuracy (%)')
# plt.legend()
# plt.grid(True)
# plt.show()

In [None]:
# prototypes = []
# for y in Y_train:
#     prototypes.append(final_binary_prototypes[y])

# prototypes = np.array(prototypes)

In [None]:
# # Function to find the closest concept vector and predict the label
# def predict_nearest_concept(instance, reference_concepts, reference_labels):
#     distances = np.sum(np.abs(reference_concepts - instance), axis=1)
#     min_idx = np.argmin(distances)
#     return reference_labels[min_idx]

# # Use prototypes as reference concepts and evaluate on C_hat_test
# correct_predictions = 0
# total_predictions = len(C_hat_test)

# for i, test_instance in enumerate(C_hat_test):
#     predicted_label = predict_nearest_concept(test_instance, prototypes, Y_train)
#     true_label = Y_test[i]

#     if predicted_label == true_label:
#         correct_predictions += 1

# # Calculate and print accuracy
# accuracy = correct_predictions / total_predictions
# print(f"\nOverall accuracy using prototype-based nearest neighbor: {accuracy:.4f}")