In [1]:
import sys
sys.path.append('..')
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sympy import simplify_logic

from deep_logic.utils.base import validate_network
from deep_logic.utils.relunn import get_reduced_model, prune_features
from deep_logic import fol
import deep_logic as dl

torch.manual_seed(0)
np.random.seed(0)

In [2]:
x_train = pd.read_csv('dsprites/dsprites_c_train.csv', index_col=0)
y_train = pd.read_csv('dsprites/dsprites_y_train.csv', index_col=0)
x_test = pd.read_csv('dsprites/dsprites_c_test.csv', index_col=0)
y_test = pd.read_csv('dsprites/dsprites_y_test.csv', index_col=0)

In [3]:
base_concepts = ['color', 'shape', 'scale', 'rotation', 'x_pos', 'y_pos']
base_concepts

['color', 'shape', 'scale', 'rotation', 'x_pos', 'y_pos']

In [4]:
colors = ['white']
shapes = ['square', 'ellipse', 'heart']
scale = ['very small', 'small', 's-medium', 'b-medium', 'big', 'very big']
rotation = ['0°', '5°', '10°', '15°', '20°', '25°', '30°', '35°']
x_pos = ['x0', 'x2', 'x4', 'x6', 'x8', 'x10', 'x12', 'x14', 'x16', 'x18', 'x20', 'x22', 'x24', 'x26', 'x28', 'x30']
y_pos = ['y0', 'y2', 'y4', 'y6', 'y8', 'y10', 'y12', 'y14', 'y16', 'y18', 'y20', 'y22', 'y24', 'y26', 'y28', 'y30']
concepts = colors + shapes + scale + rotation + x_pos + y_pos

In [5]:
x_train = torch.tensor(x_train.values, dtype=torch.float)
x_test = torch.tensor(x_test.values, dtype=torch.float)
y_test = torch.tensor(y_test.values, dtype=torch.float)
print(x_train.shape)
print(x_test.shape)
print(y_test.shape)

torch.Size([20993, 50])
torch.Size([5530, 50])
torch.Size([5530, 18])


In [6]:
# y_train = torch.zeros((y_train.shape[0], y_train.shape[1]), dtype=torch.float)
y_train = torch.tensor(y_train.values, dtype=torch.float)
print(y_train.shape)
n_classes = y_train.size(1)
print(n_classes)
y_train.sum(dim=0)

torch.Size([20993, 18])
18


tensor([1167., 1173., 1163., 1170., 1179., 1161., 1164., 1194., 1189., 1164.,
        1149., 1199., 1194., 1159., 1121., 1140., 1112., 1195.])

In [7]:
torch.manual_seed(0)
np.random.seed(0)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
x_train = x_train.to(device)
y_train = y_train.to(device)

layers = [
    torch.nn.Linear(x_train.size(1), 20),
    torch.nn.ReLU(),
    torch.nn.Linear(20, 10),
    torch.nn.ReLU(),
    torch.nn.Linear(10, 5),
    torch.nn.ReLU(),
    torch.nn.Linear(5, n_classes),
    torch.nn.Softmax(),
]
model = torch.nn.Sequential(*layers).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
loss_form = torch.nn.BCELoss()
model.train()
need_pruning = False
for epoch in range(6000):
    # forward pass
    optimizer.zero_grad()
    y_pred = model(x_train)
    # Compute Loss
    loss = loss_form(y_pred, y_train)

    for module in model.children():
        if isinstance(module, torch.nn.Linear):
            loss += 0.0001 * torch.norm(module.weight, 1)
            loss += 0.0001 * torch.norm(module.bias, 1)
            break

    # backward pass
    loss.backward()
    optimizer.step()
    
    if epoch > 3000 and need_pruning:
        prune_features(model, n_classes, device)
        need_pruning = False

    # compute accuracy
    if epoch % 500 == 0:
        y_pred_d = torch.argmax(y_pred, dim=1)
        y_train_d = torch.argmax(y_train, dim=1)
        accuracy = y_pred_d.eq(y_train_d).sum().item() / y_train.size(0)
        print(f'Epoch {epoch}: train accuracy: {accuracy:.4f}')


  input = module(input)


Epoch 0: train accuracy: 0.0534
Epoch 500: train accuracy: 1.0000
Epoch 1000: train accuracy: 1.0000
Epoch 1500: train accuracy: 1.0000
Epoch 2000: train accuracy: 1.0000
Epoch 2500: train accuracy: 1.0000
Epoch 3000: train accuracy: 1.0000
Epoch 3500: train accuracy: 1.0000
Epoch 4000: train accuracy: 1.0000
Epoch 4500: train accuracy: 1.0000
Epoch 5000: train accuracy: 1.0000
Epoch 5500: train accuracy: 1.0000


# Local explanations

In [13]:
x_test = x_test.to(device)
y_test = y_test.to(device)
np.set_printoptions(precision=2, suppress=True)
outputs = []
for i, (xin, yin) in enumerate(zip(x_test, y_test)):
    model_reduced = get_reduced_model(model, xin).to(device)
    for module in model_reduced.children():
        if isinstance(module, torch.nn.Linear):
            wa = module.weight.cpu().detach().numpy()
            break
    output = model_reduced(xin)
    
    pred_class = torch.argmax(output)
    true_class = torch.argmax(y_test[i])

    # generate local explanation only if the prediction is correct
    #if pred_class.eq(true_class):
    local_explanation = fol.relunn.explain_local(model, x_test, y_test, xin, true_class, True,
                                                 False, concepts, device)
    print(f'Input {(i+1)}')
    print(f'\tx={xin.cpu().detach().numpy()}')
    print(f'\ty={y_train[i].cpu().detach().numpy()}')
    print(f'\ty={output.cpu().detach().numpy()}')
    #print(f'\tw={wa}')
    print(f'\tExplanation: {local_explanation}')
    print()
    outputs.append(output)
    if i > 1:
        break

Python 3.8.5 (default, Sep  3 2020, 21:29:08) [MSC v.1916 64 bit (AMD64)]
Type 'copyright', 'credits' or 'license' for more information
IPython 7.19.0 -- An enhanced Interactive Python. Type '?' for help.
PyDev console: using IPython 7.19.0

Python 3.8.5 (default, Sep  3 2020, 21:29:08) [MSC v.1916 64 bit (AMD64)] on win32
ERROR! Session/line number was not unique in database. History logging moved to new session 127


tensor([[2.4799e-01, 1.5274e+01, 8.3626e-03, 2.9867e+00, 6.3794e+00, 1.5824e+00,
         1.6263e-02, 8.2743e-01, 2.2112e+01, 2.1042e-01, 3.2466e-03, 1.2069e-02,
         2.8885e-02, 1.3216e-03, 1.2234e-02, 7.8642e-04, 7.6051e-03, 4.8397e-03,
         3.4731e-03, 6.3829e-03, 2.2688e-03, 2.1170e-04, 3.5201e-03, 6.7903e-03,
         4.0503e-03, 2.4064e-02, 3.8355e-03, 5.1106e-03, 8.1686e-03, 3.5161e-03,
         9.8357e-03, 1.2965e-02, 6.4965e-03, 1.9232e-02, 1.4675e-04, 1.3629e-02,
         1.1343e-02, 1.9440e-03, 3.4923e-03, 3.4671e-04, 5.2135e-03, 6.2713e-03,
         9.9784e-04, 1.7147e-02, 6.2565e-03, 1.3670e-02, 8.8014e-03, 6.9079e-03,
         8.0553e-03, 4.0583e-03],
        [3.7659e-01, 2.7694e+01, 9.6622e-03, 2.1661e+01, 1.4869e+01, 2.7241e+01,
         1.6233e-02, 1.5196e+00, 1.7714e-01, 7.9669e-01, 1.3957e-02, 6.8783e-03,
         8.3452e-03, 5.3155e-03, 8.7228e-03, 7.7247e-04, 9.7830e-04, 2.8356e-02,
         1.6911e-02, 2.4140e-03, 1.1366e-02, 2.3551e-03, 4.0370e-02, 3.8906

tensor([[1., 1., 0., 0., 1., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.,
         1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.]],
       device='cuda:0')

KeyboardInterrupt: 

# Combine local explanations

In [None]:
global_explanation, predictions, counter = fol.combine_local_explanations(model, 
                                                                          x=x_train, y=y_train, 
                                                                          target_class=0,
                                                                          topk_explanations=1,
                                                                          device=device)

In [None]:
accuracy, preds = fol.base.test_explanation(global_explanation, 0, x_test, y_test)
final_formula = fol.base.replace_names(global_explanation, concepts)
print(f'Accuracy when using the formula "{final_formula}": {accuracy:.4f}\n')

In [None]:
y_train_d = torch.argmax(y_train, dim=1)
for i, target_class in enumerate(range(n_classes)):
    global_explanation, predictions, counter = fol.combine_local_explanations(model, x=x_train, y=y_train, 
                                                                          target_class=target_class,
                                                                          topk_explanations=2,
                                                                          device=device)
    print(i, counter)
    if global_explanation:
        accuracy, preds = fol.base.test_explanation(global_explanation, target_class, x_test, y_test)
        final_formula = fol.base.replace_names(global_explanation, concepts)
        print(f'Class {target_class} - Global explanation: "{final_formula}" - Accuracy: {accuracy:.4f}')

In [None]:
y_train_d = torch.argmax(y_train, dim=1)
for target_class in range(n_classes):
    global_explanation = fol.explain_global(model, n_classes, 
                                            target_class=target_class, 
                                            concept_names=concepts, device=device)

    explanation = fol.relunn.explain_global(model, n_classes, target_class, device=device)
    if explanation not in ['False', 'True', 'The formula is too complex!']:
        accuracy, _ = fol.relunn.test_explanation(explanation, target_class, x_train.cpu(), y_train.cpu())
        print(f'Class {target_class} - Global explanation: "{global_explanation}" - Accuracy: {accuracy:.4f}')