In [1]:
import sys
sys.path.append('..')
import os
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.preprocessing import LabelEncoder, MinMaxScaler, KBinsDiscretizer
from sklearn.impute import SimpleImputer
from sklearn.model_selection import cross_val_score
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import StratifiedKFold
from torch.utils.data import TensorDataset, DataLoader
from torchvision.models import resnet18
import torch.nn.functional as F
import time
from sklearn.metrics import accuracy_score
from sympy import simplify_logic
from sklearn.model_selection import train_test_split

import lens
from lens.utils.base import validate_network, set_seed, tree_to_formula
from lens.utils.layer import prune_logic_layers
from lens import logic

from dSprites.dSprites_loader import load_dsprites, concept_filters, get_shape_scale
from dSprites.dSprites_style_I2C import i2c_style

results_dir = './results_ll/dsprites'
if not os.path.isdir(results_dir):
    os.makedirs(results_dir)

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
n_rep = 10
tot_epochs = 10001
prune_epochs = 60001

In [2]:
def get_data(path):
    # Load dataset specified by schema
    dataset_schema = 'small_skip'
    # dataset_schema = 'big_skip'

    # Get filtered concept values
    shape_range, scale_range, rot_range, x_pos_range, y_pos_range = concept_filters(dataset_schema)

    # Define function for filtering out specified concept values only
    def c_filter_fn(concepts):
        in_shape_range = (concepts[1] in shape_range)
        in_scale_range = (concepts[2] in scale_range)
        in_rot_range = (concepts[3] in rot_range)
        in_x_range = (concepts[4] in x_pos_range)
        in_y_range = (concepts[5] in y_pos_range)

        return (in_shape_range and in_scale_range and in_rot_range and in_x_range and in_y_range)

    label_fn = get_shape_scale(shape_range, scale_range)
    
    # Load dataset
    x_train, y_train, x_val, y_val, x_test, y_test, c_train, c_val, c_test, c_names = load_dsprites(path,
                                                                               c_filter_fn=c_filter_fn,
                                                                               label_fn=label_fn,
                                                                               train_test_split_flag=True)


    return x_train, y_train, x_val, y_val, x_test, y_test, c_train, c_val, c_test, c_names

# Load the model, as well as input, label, and concept data
data = get_data(path='./dSprites/data/dsprites.npz')
x_train, y_train, x_val, y_val, x_test, y_test, c_train, c_val, c_test, c_names = data
x_train = torch.FloatTensor(x_train)
y_train = torch.FloatTensor(y_train)
x_val = torch.FloatTensor(x_val)
y_val = torch.FloatTensor(y_val)
x_test = torch.FloatTensor(x_test)
y_test = torch.FloatTensor(y_test)
c_train = torch.FloatTensor(c_train)
c_val = torch.FloatTensor(c_val)
c_test = torch.FloatTensor(c_test)
print("Data loaded successfully...")

x_train shape: (20993, 3, 64, 64)
c_train shape: (20993, 50)
y_train shape: (20993, 18)
Number of images in x_train 20993
Number of images in x_val 10341
Number of images in x_test 5530
Data loaded successfully...


In [3]:
args = {
    'models_dir': './models_ll/dSprites/', 
    'model_style': 'CBM',
    'seed': 0,
    'batch_size': 256,
}

if not os.path.isdir(args['models_dir']):
    os.makedirs(args['models_dir'])

train_dataset = TensorDataset(x_train, c_train, y_train)
train_loader = DataLoader(train_dataset, batch_size=args['batch_size'], shuffle=False)
val_dataset = TensorDataset(x_val, c_val, y_val)
valid_loader = DataLoader(val_dataset, batch_size=args['batch_size'], shuffle=False)
test_dataset = TensorDataset(x_test, c_test, y_test)
test_loader = DataLoader(test_dataset, batch_size=args['batch_size'], shuffle=False)

In [4]:
n_concepts = c_train.size(1)
n_outputs = y_train.size(1)

In [5]:
colors = ['white']
shapes = ['square', 'ellipse', 'heart']
scale = ['very small', 'small', 's-medium', 'b-medium', 'big', 'very big']
rotation = ['0°', '5°', '10°', '15°', '20°', '25°', '30°', '35°']
x_pos = ['x0', 'x2', 'x4', 'x6', 'x8', 'x10', 'x12', 'x14', 'x16', 'x18', 'x20', 'x22', 'x24', 'x26', 'x28', 'x30']
y_pos = ['y0', 'y2', 'y4', 'y6', 'y8', 'y10', 'y12', 'y14', 'y16', 'y18', 'y20', 'y22', 'y24', 'y26', 'y28', 'y30']
concepts = colors + shapes + scale + rotation + x_pos + y_pos
len(concepts)

50

In [6]:
def l1_loss(model):
    loss = 0
    for module in model.children():
        if isinstance(module, lens.nn.XLogic):
            loss += torch.norm(module.weight, 1) + torch.norm(module.bias, 1)
            break
    return loss

In [15]:
def cbm_learning(net, dataloader, device, optimizer, criterion, epoch, n_epochs, train=False, base=False):
    running_loss = 0.0
    y_correct = 0
    c_correct = 0
    c_total = 0
    y_total = 0
    need_pruning = True
    prune_epochs = 30
    total_step = len(dataloader)
    for batch_idx, (data_, concepts_, target_) in enumerate(dataloader):
        data_, concepts_, target_ = data_.to(device), concepts_.to(device), target_.to(device)

        if train:
            optimizer.zero_grad()
        
        if base:
            c_preds = net[0](data_)
            loss = criterion(c_preds, concepts_)
        else:
            c_preds = net[0](data_)
            y_preds = net[1](c_preds)
            loss = criterion(y_preds, target_) + 0.00001 * l1_loss(net_top)

        if train:
            loss.backward()
            optimizer.step()
        
            if epoch > prune_epochs and need_pruning and not base:
                lens.utils.layer.prune_logic_layers(net[1].to(device), fan_in=9, device=device)
                need_pruning = False

        running_loss += loss.item()
        
        c_correct += torch.sum((c_preds > 0.5).eq(concepts_)).item()
        c_total += concepts_.size(0) * concepts_.size(1)
        if not base:
            y_correct += torch.sum((y_preds.argmax(dim=1)).eq(target_.argmax(dim=1))).item()
            y_total += target_.size(0)
            
        if (batch_idx) % 20 == 0 and train:
            print(f'Epoch [{epoch}/{n_epochs}], Step [{batch_idx}/{total_step}], Loss: {loss.item():.4f}')

    c_accuracy = c_correct / c_total
    if base:
        y_accuracy = 0
    else:
        y_accuracy = y_correct / y_total
    loss = running_loss / total_step

    return net, c_accuracy, y_accuracy, loss

In [16]:
net_base = resnet18(pretrained=False)
num_ftrs = net_base.fc.in_features
net_base.fc = torch.nn.Linear(num_ftrs, n_concepts)
net_top = torch.nn.Sequential(*[
    lens.nn.XLogic(n_concepts, 32, first=True),
    torch.nn.LeakyReLU(),
    torch.nn.Linear(32, 32),
    torch.nn.LeakyReLU(),
    torch.nn.Linear(32, 32),
    torch.nn.LeakyReLU(),
    torch.nn.Linear(32, n_outputs),
    lens.nn.XLogic(n_outputs, n_outputs, top=True),
])
net = torch.nn.Sequential(*[net_base, net_top]).to(device)

file_name = './models_ll/dsprites/trained_model_l1'

In [18]:
for base in [True, False]:
    print(base)
    net_suffix = 'base' if base else 'top'
    if os.path.isfile(f'{file_name}_{net_suffix}.pt'):
        # load trained model
        net.load_state_dict(torch.load(f'{file_name}_{net_suffix}.pt'))

    else:
        set_seed(0)
        need_pruning = True
        
        criterion = torch.nn.MSELoss()

        if base:
            n_epochs = 20
            optimizer = torch.optim.AdamW(net.parameters(), lr=0.0001)
        else:
            n_epochs = 50
            optimizer = torch.optim.AdamW(net.parameters(), lr=0.001)
            for param in net[0].parameters():
                param.requires_grad = False
                
        valid_loss_min = np.Inf
        c_train_accuracy = 0
        val_loss = []
        val_acc = []
        val_c_acc = []
        test_loss = []
        test_acc = []
        test_c_acc = []
        train_loss = []
        train_acc = []
        train_c_acc = []
        epoch = 1
        for epoch in range(1, n_epochs + 1):
            print(f'Epoch {epoch}\n')

            net, c_accuracy, y_accuracy, loss = cbm_learning(net, train_loader,
                                                             device, optimizer, criterion,
                                                             epoch, n_epochs, train=True,
                                                             base=base)
            c_train_accuracy = c_accuracy
            train_c_acc.append(c_accuracy)
            train_acc.append(y_accuracy)
            train_loss.append(loss)
            print(f'\ntrain-loss: {loss:.4f}, train-c-acc: {c_accuracy:.4f}, train-cy-acc: {y_accuracy:.4f}')

            with torch.no_grad():
                net.eval()

                net, c_accuracy, y_accuracy, loss = cbm_learning(net, valid_loader,
                                                                 device, optimizer, criterion,
                                                                 epoch, n_epochs, train=False, 
                                                                 base=base)
                val_c_acc.append(c_accuracy)
                val_acc.append(y_accuracy)
                val_loss.append(loss)
                print(f'validation loss: {loss:.4f}, validation-c-acc: {c_accuracy:.4f}, validation-y-acc: {y_accuracy:.4f}')

                if loss < valid_loss_min:
                    valid_loss_min = loss
                    torch.save(net.state_dict(), f'{file_name}_{net_suffix}.pt')
                    print('Improvement-Detected, save-model')

                net, c_accuracy, y_accuracy, loss = cbm_learning(net, test_loader,
                                                                 device, optimizer, criterion,
                                                                 epoch, n_epochs, train=False, 
                                                                 base=base)
                test_c_acc.append(c_accuracy)
                test_acc.append(y_accuracy)
                test_loss.append(loss)
                print(f'test loss: {loss:.4f}, test-c-acc: {c_accuracy:.4f}, test-y-acc: {y_accuracy:.4f}\n')

            net.train()

        results = pd.DataFrame({
            'test_acc': test_acc,
            'test_c_acc': test_c_acc,
            'test_loss': test_loss,
            'val_acc': val_acc,
            'val_c_acc': val_c_acc,
            'val_loss': val_loss,
            'train_acc': train_acc,
            'train_c_acc': train_c_acc,
            'train_loss': train_loss,
        })

True
False
Epoch 1

Epoch [1/50], Step [0/83], Loss: 0.2497
Epoch [1/50], Step [20/83], Loss: 0.2346
Epoch [1/50], Step [40/83], Loss: 0.2015
Epoch [1/50], Step [60/83], Loss: 0.0924
Epoch [1/50], Step [80/83], Loss: 0.0545

train-loss: 0.1670, train-c-acc: 0.9944, train-cy-acc: 0.0499
validation loss: 0.0547, validation-c-acc: 0.9786, validation-y-acc: 0.0586
Improvement-Detected, save-model
test loss: 0.0547, test-c-acc: 0.9791, test-y-acc: 0.0530

Epoch 2

Epoch [2/50], Step [0/83], Loss: 0.0548
Epoch [2/50], Step [20/83], Loss: 0.0547
Epoch [2/50], Step [40/83], Loss: 0.0546
Epoch [2/50], Step [60/83], Loss: 0.0541
Epoch [2/50], Step [80/83], Loss: 0.0537

train-loss: 0.0544, train-c-acc: 0.9944, train-cy-acc: 0.0547
validation loss: 0.0538, validation-c-acc: 0.9786, validation-y-acc: 0.0586
Improvement-Detected, save-model
test loss: 0.0539, test-c-acc: 0.9791, test-y-acc: 0.0530

Epoch 3

Epoch [3/50], Step [0/83], Loss: 0.0539
Epoch [3/50], Step [20/83], Loss: 0.0537
Epoch [3/50

Epoch [19/50], Step [80/83], Loss: 0.0017

train-loss: 0.0025, train-c-acc: 0.9944, train-cy-acc: 0.9962
validation loss: 0.0193, validation-c-acc: 0.9786, validation-y-acc: 0.8070
test loss: 0.0187, test-c-acc: 0.9791, test-y-acc: 0.8166

Epoch 20

Epoch [20/50], Step [0/83], Loss: 0.0017
Epoch [20/50], Step [20/83], Loss: 0.0022
Epoch [20/50], Step [40/83], Loss: 0.0022
Epoch [20/50], Step [60/83], Loss: 0.0018
Epoch [20/50], Step [80/83], Loss: 0.0016

train-loss: 0.0022, train-c-acc: 0.9944, train-cy-acc: 0.9966
validation loss: 0.0194, validation-c-acc: 0.9786, validation-y-acc: 0.8029
test loss: 0.0188, test-c-acc: 0.9791, test-y-acc: 0.8139

Epoch 21

Epoch [21/50], Step [0/83], Loss: 0.0016
Epoch [21/50], Step [20/83], Loss: 0.0020
Epoch [21/50], Step [40/83], Loss: 0.0020
Epoch [21/50], Step [60/83], Loss: 0.0017
Epoch [21/50], Step [80/83], Loss: 0.0015

train-loss: 0.0021, train-c-acc: 0.9944, train-cy-acc: 0.9969
validation loss: 0.0193, validation-c-acc: 0.9786, validation

Epoch [38/50], Step [20/83], Loss: 0.0011
Epoch [38/50], Step [40/83], Loss: 0.0013
Epoch [38/50], Step [60/83], Loss: 0.0010
Epoch [38/50], Step [80/83], Loss: 0.0008

train-loss: 0.0013, train-c-acc: 0.9944, train-cy-acc: 0.9964
validation loss: 0.0166, validation-c-acc: 0.9786, validation-y-acc: 0.8254
Improvement-Detected, save-model
test loss: 0.0158, test-c-acc: 0.9791, test-y-acc: 0.8331

Epoch 39

Epoch [39/50], Step [0/83], Loss: 0.0009
Epoch [39/50], Step [20/83], Loss: 0.0011
Epoch [39/50], Step [40/83], Loss: 0.0012
Epoch [39/50], Step [60/83], Loss: 0.0010
Epoch [39/50], Step [80/83], Loss: 0.0008

train-loss: 0.0012, train-c-acc: 0.9944, train-cy-acc: 0.9965
validation loss: 0.0166, validation-c-acc: 0.9786, validation-y-acc: 0.8249
Improvement-Detected, save-model
test loss: 0.0158, test-c-acc: 0.9791, test-y-acc: 0.8329

Epoch 40

Epoch [40/50], Step [0/83], Loss: 0.0009
Epoch [40/50], Step [20/83], Loss: 0.0010
Epoch [40/50], Step [40/83], Loss: 0.0012
Epoch [40/50], S

In [22]:
n = 1000
c_preds = net[0].cpu()(x_test[:n].cpu()).detach()
y_preds = net[1].cpu()(c_preds.cpu()).detach()
model_accuracy = accuracy_score(y_test[:n].cpu().detach().argmax(axis=1).numpy(), y_preds.argmax(axis=1).numpy())
print(f'Model\'s accuracy: {model_accuracy:.4f}')

Model's accuracy: 0.9500


In [20]:
# positive class
start = time.time()
class_explanation, class_explanations = lens.logic.explain_class(net[1].cpu(), 
                                                               c_preds.cpu(), 
                                                               y_val[:n].cpu(), 
                                                               binary=False, target_class=0,
                                                               topk_explanations=1)
elapsed_time = time.time() - start
explanation = logic.base.replace_names(class_explanation, concepts)
print(f'Class 1 - Global explanation: "{explanation}"')

Class 1 - Global explanation: "square & very small & ~ellipse & ~heart & ~small & ~s-medium & ~b-medium & ~big & ~very big"


In [21]:
if class_explanation:
    explanation = logic.base.replace_names(class_explanation, concepts)
    explanation_accuracy, y_formula = logic.base.test_explanation(class_explanation, 
                                                                  target_class=0, 
                                                                  x=c_test[:n], y=y_test[:n],
                                                                  metric=accuracy_score)
    
    c_preds = net[0].cpu()(x_test[:n].cpu()).detach()
    y_preds = net[1].cpu()(c_preds.cpu()).detach()
    explanation_fidelity = lens.logic.fidelity(y_formula, y_preds.argmax(axis=1).eq(0).numpy())
    explanation_complexity = lens.logic.complexity(class_explanation)

    print(f'Class 1 - Global explanation: "{explanation}" - Accuracy: {explanation_accuracy:.4f}')
    print(f'Fidelity: "{explanation_fidelity:.4f}" - Complexity: "{explanation_complexity}"')
    print(f'Elapsed time {elapsed_time}')

Class 1 - Global explanation: "square & very small & ~ellipse & ~heart & ~small & ~s-medium & ~b-medium & ~big & ~very big" - Accuracy: 1.0000
Fidelity: "0.9980" - Complexity: "9"
Elapsed time 0.2170257568359375
