In [9]:
import sys
sys.path.append('..')
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sympy import simplify_logic

from deep_logic.utils.base import validate_network
from deep_logic.utils.relunn import get_reduced_model
from deep_logic import fol

torch.manual_seed(0)
np.random.seed(0)

In [10]:
x = pd.read_csv('dsprites_c_train.csv', index_col=0)
y = pd.read_csv('dsprites_y_train.csv', index_col=0)

In [11]:
base_concepts = ['color', 'shape', 'scale', 'rotation', 'x_pos', 'y_pos']
base_concepts

['color', 'shape', 'scale', 'rotation', 'x_pos', 'y_pos']

In [12]:
colors = ['white']
shapes = ['square', 'ellipse', 'heart']
scale = ['very small', 'small', 's-medium', 'b-medium', 'big', 'very big']
rotation = ['0°', '5°', '10°', '15°', '20°', '25°', '30°', '35°']
x_pos = ['x0', 'x2', 'x4', 'x6', 'x8', 'x10', 'x12', 'x14', 'x16', 'x18', 'x20', 'x22', 'x24', 'x26', 'x28', 'x30']
y_pos = ['y0', 'y2', 'y4', 'y6', 'y8', 'y10', 'y12', 'y14', 'y16', 'y18', 'y20', 'y22', 'y24', 'y26', 'y28', 'y30']
concepts = colors + shapes + scale + rotation + x_pos + y_pos

In [13]:
x_train = torch.tensor(x.values, dtype=torch.float)
y_train = torch.tensor(y.values[:, 0] + y.values[:, 7], dtype=torch.float).unsqueeze(1)
x_test = x_train
print(x_train.shape)
print(y_train.shape)
x

torch.Size([5530, 50])
torch.Size([5530, 1])


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,40,41,42,43,44,45,46,47,48,49
0,1.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,...,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1,1.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0
2,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
3,1.0,0.0,0.0,1.0,1.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0
4,1.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
5525,1.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
5526,1.0,0.0,0.0,1.0,1.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
5527,1.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
5528,1.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0


In [14]:
layers = [
    torch.nn.Linear(x_train.size(1), 50),
    torch.nn.ReLU(),
    torch.nn.Linear(50, 30),
    torch.nn.ReLU(),
    torch.nn.Linear(30, 1),
    torch.nn.Sigmoid(),
]
model = torch.nn.Sequential(*layers)

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
model.train()
need_pruning = True
for epoch in range(1000):
    # forward pass
    optimizer.zero_grad()
    y_pred = model(x_train)
    # Compute Loss
    loss = torch.nn.functional.mse_loss(y_pred, y_train)

    for module in model.children():
        if isinstance(module, torch.nn.Linear):
            loss += 0.002 * torch.norm(module.weight, 1)

    # backward pass
    loss.backward()
    optimizer.step()

    # compute accuracy
    if epoch % 100 == 0:
        y_pred_d = (y_pred > 0.5)
        accuracy = (y_pred_d.eq(y_train).sum(dim=1) == y_train.size(1)).sum().item() / y_train.size(0)
        print(f'Epoch {epoch}: train accuracy: {accuracy:.4f}')

Epoch 0: train accuracy: 0.8873
Epoch 100: train accuracy: 0.8873
Epoch 200: train accuracy: 1.0000
Epoch 300: train accuracy: 1.0000
Epoch 400: train accuracy: 1.0000
Epoch 500: train accuracy: 1.0000
Epoch 600: train accuracy: 1.0000
Epoch 700: train accuracy: 1.0000
Epoch 800: train accuracy: 1.0000
Epoch 900: train accuracy: 1.0000


# Local explanations

In [15]:
np.set_printoptions(precision=2, suppress=True)
outputs = []
for i, (xin, yin) in enumerate(zip(x_train, y_train)):
    model_reduced = get_reduced_model(model, xin)
    for module in model_reduced.children():
        if isinstance(module, torch.nn.Linear):
            wa = module.weight.detach().numpy()
            break
    output = model_reduced(xin)
    if output > 0.5 and (output>0.5) == yin:
        local_explanation = fol.relunn.explain_semi_local(model, x_train, y_train, xin, 
                                                          concept_names=concepts)
        print(f'Input {(i+1)}')
        print(f'\tx={xin.detach().numpy()}')
        print(f'\ty={output.detach().numpy()}')
        print(f'\tw={wa}')
        print(f'\tExplanation: {local_explanation}')
        print()
    outputs.append(output)
    if i > 50:
        break

Input 1
	x=[1. 1. 0. 0. 1. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0.
 0. 0.]
	y=[0.91]
	w=[[-0.  0.  0. -0.  0.  0. -0. -0.  0. -0. -0.  0.  0.  0. -0. -0. -0. -0.
  -0.  0. -0.  0.  0.  0.  0. -0.  0. -0.  0.  0. -0.  0.  0. -0. -0. -0.
   0. -0.  0. -0. -0.  0.  0.  0. -0. -0. -0. -0. -0. -0.]]
	Explanation: square & very small

Input 7
	x=[1. 1. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0.]
	y=[0.91]
	w=[[-0.  0.  0. -0.  0.  0. -0. -0.  0. -0. -0.  0.  0.  0. -0. -0. -0. -0.
  -0.  0. -0.  0.  0.  0.  0. -0.  0. -0.  0.  0. -0.  0.  0. -0. -0. -0.
   0. -0.  0. -0. -0.  0.  0.  0. -0. -0. -0. -0. -0. -0.]]
	Explanation: square & very small

Input 14
	x=[1. 1. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 1. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.


# Combine local explanations

In [16]:
global_explanation, predictions, counter = fol.combine_local_explanations(model, x_train, y_train, 
                                                                          topk_explanations=2,
                                                                          concept_names=concepts)

ynp = y_train.detach().numpy()[:, 0]
accuracy = np.sum(predictions == ynp) / len(ynp)
print(f'Accuracy using the formula "{global_explanation}": {accuracy:.4f}')

Accuracy using the formula "(square & very small) | (ellipse & small)": 1.0000


In [17]:
counter.most_common(4)

[('small & ellipse', 249),
 ('square & very small', 228),
 ('very small & square', 57),
 ('ellipse & small', 53)]

In [22]:
counter.most_common(7)

[('small & ellipse', 249),
 ('square & very small', 228),
 ('very small & square', 57),
 ('ellipse & small', 53),
 ('y18 & ~heart & 15° & very small', 4),
 ('y2 & ~heart & x26 & very small', 2),
 ('~heart & x20 & 15° & very small', 2)]

In [19]:
pd.DataFrame({
    'predictions': predictions.ravel(),
    'labels': y_train.detach().numpy().ravel(),
})

Unnamed: 0,predictions,labels
0,True,1.0
1,False,0.0
2,False,0.0
3,False,0.0
4,False,0.0
...,...,...
5525,False,0.0
5526,False,0.0
5527,False,0.0
5528,False,0.0
