Consider moving notebook to the main directory to avoid path errors.

In [None]:
import torch
import matplotlib.pyplot as plt
import torch.optim as optim
import torch.nn as nn
from torch.amp import autocast, GradScaler
import os

from Training.Data.Modules.custom_loader import CustomLoader 
from Utils.model_utils import *; 
from Utils.configs import conf
from XAI_Method.rand_samples import RandomSamples
from XAI_Method.causal_effect_effic import Causal

In [None]:
if True:

    os.makedirs(os.getcwd() + f"/Results/Experiments/02_breaking_linearity", exist_ok=True)
    
    config = conf['simpleDNNTMnist']
    def init_model(rand=True):
        return get_model_architecture(config['model_name'], config['model_layers'] + [config["num_classes"]], rand)

    model = init_model()   
    data_loader = CustomLoader(config["dataset"], True, config["batch_size"], shuffle_test=True)
    train_loader, test_loader = data_loader.train_load, data_loader.evalu_load

    optimizer = optim.SGD(model.parameters(), lr=1)
    loss = nn.CrossEntropyLoss()
    scaler = GradScaler(enabled=True)

    rand_samples = RandomSamples(model, test_loader, 6, num_samples=64, classes=10, save_r_scores='last')
    R = rand_samples.artificial_step()

    X, labels, const_model = rand_samples.X, rand_samples.lbls, rand_samples.const_model
    device = torch.device('cuda') 
    model = model.to(device)
    model.train()

In [3]:
class simpleDNNTrainerOld():

    def __init__(self, model, X, labels):

        self.model = model
        self.dev = "cuda" if torch.cuda.is_available() else "cpu"
        self.device = torch.device(self.dev)

        self.loss_red = nn.CrossEntropyLoss()
        self.optimizer = optim.SGD(model.parameters(), lr=1)
        self.scaler = GradScaler(enabled=True)

        self.X, self.labels = X, labels

    def apply_loss(self, y_pred, y):

        y_pred = y_pred.reshape((-1, y_pred.shape[-1]))
        y = y.reshape((-1,))

        if y[0].dtype == torch.int: y = y.long()

        return self.loss_red(y_pred, y)

    def train_on_batch(self, batch):

        self.model_state = {k: v.detach().clone() for k, v in self.model.state_dict().items()}

        # Get data from batch, calculate loss and update
        images, labels = batch[0], batch[1]
        images = images.to(self.device); labels = labels.to(self.device)

        self.optimizer.zero_grad()
        with autocast(device_type=self.dev):  # Enable mixed-precision

            output = self.model(images)
            l = self.apply_loss(output, labels)

        self.scaler.scale(l).backward()
        self.scaler.step(self.optimizer)
        self.scaler.update()
        
        self.upd_model_state = {k: v.detach().clone() for k, v in self.model.state_dict().items()}

        R = self.update_R()   
        self.model.load_state_dict(self.upd_model_state)   

        return self.model, R
    
    def update_R(self):
        
        causal = Causal(self.model, self.model_state, self.X, self.labels)
        return causal.update()
    
def update(model, model_state, X, labels):
    
    causal = Causal(model, model_state, X, labels)
    return causal.update()

def check_differences(model, R, X, labels, break_pt=2):

    # Apply two steps for the trainer, get models
    trainer = simpleDNNTrainerOld(model, X, labels)

    # First run
    model_non_upd_state = {k: v.detach().clone() for k, v in model.state_dict().items()}
    R1 = R.clone()

    i = 0
    for batch in train_loader:

        model, R_upd = trainer.train_on_batch(batch)
        R1 += R_upd

        i += 1
        if i == break_pt: break

    # Second run
    R2 = update(model, model_non_upd_state, X, labels)

    return R1, R2

def plot(X, samples, R1, R2):

    for i in range(len(samples)):
        for j, x in enumerate([X[samples[i], :], R1[samples[i], :], R2[samples[i], :]]):

            tit = f'sample_{i}' if j==0 else f'non_linear_{i}' if j==1 else f'linear_{i}'
            plt.imshow(x.reshape(28, 28)); plt.axis('off')
            plt.savefig(os.getcwd() + f"/Results/Experiments/02_breaking_linearity/{tit}.pdf", 
                            bbox_inches='tight', format='pdf')
            plt.close()

In [4]:
R1, R2 = check_differences(model, R, X, labels, break_pt=100)
plot(X, [1, 3, 4, 6], R1, R2)