In [1]:
import os
import torch
import matplotlib.pyplot as plt
import torch.optim as optim
import torch.nn as nn
from torch.amp import autocast, GradScaler
import numpy as np
from tqdm import tqdm

from Training.Data.Modules.custom_loader import CustomLoader 
from Utils.model_utils import *; 
from Utils.configs import conf
from XAI_Method.rand_samples import RandomSamples
from XAI_Method.causal_effect import Causal
from XAI_Applications.evaluation import deletion_auc
from XAI_Applications.evaluation import benchmark_on_batch
from XAI_Applications.average_drop import evaluate_pixel_erasure
from Training.Utils.train_utils import accuracy

In [2]:
if True:
    os.makedirs(os.getcwd() + f"/Results/Experiments/02_breaking_linearity", exist_ok=True)
    
    config = conf['simpleDNNTMnist']
    def init_model(rand=True):
        return get_model_architecture(config['model_name'], config['model_layers'] + [config["num_classes"]], rand)

    model = init_model()   
    data_loader = CustomLoader(config["dataset"], True, config["batch_size"], shuffle_test=True)
    train_loader, test_loader = data_loader.train_load, data_loader.evalu_load

    rand_samples = RandomSamples(model, test_loader, 6, num_samples=64, classes=10, save_r_scores='last')
    R = rand_samples.artificial_step()

    X, labels, const_model = rand_samples.X, rand_samples.lbls, rand_samples.const_model
    device = torch.device('cuda') 
    model = model.to(device)
    model.train()

In [None]:
class simpleDNNTrainerOld():

    def __init__(self, model, X, labels, lr=.01):

        self.model = model
        self.dev = "cuda" if torch.cuda.is_available() else "cpu"
        self.device = torch.device(self.dev)

        self.loss_red = nn.CrossEntropyLoss()
        self.optimizer = optim.SGD(model.parameters(), lr=lr)
        # self.scaler = GradScaler(enabled=True)

        self.X, self.labels = X, labels

    def apply_loss(self, y_pred, y):

        y_pred = y_pred.reshape((-1, y_pred.shape[-1]))
        y = y.reshape((-1,))

        if y[0].dtype == torch.int: y = y.long()

        return self.loss_red(y_pred, y)

    def train_on_batch(self, batch):

        self.model_state = {k: v.detach().clone() for k, v in self.model.state_dict().items()}

        # Get data from batch, calculate loss and update
        images, labels = batch[0], batch[1]
        images = images.to(self.device); labels = labels.to(self.device)

        self.optimizer.zero_grad()
        with autocast(device_type=self.dev):  # Enable mixed-precision

            output = self.model(images)
            l = self.apply_loss(output, labels)

        self.scaler.scale(l).backward()
        self.scaler.step(self.optimizer)
        self.scaler.update()
        
        self.upd_model_state = {k: v.detach().clone() for k, v in self.model.state_dict().items()}

        R = self.update_R()   
        self.model.load_state_dict(self.upd_model_state)   

        return R

    def validate_on_batch(self, batch):

        images, labels = batch[0], batch[1]
        images = images.to(self.device)
        labels = labels.to(self.device)

        with torch.no_grad():
            output = self.model(images)

        acc = accuracy(output, labels, averaged=False)
        return acc
    
    def update_R(self):
        
        causal = Causal(self.model, self.model_state, self.X, self.labels)
        return causal.update()
    
def update(model, model_state, X, labels):
    
    causal = Causal(model, model_state, X, labels)
    return causal.update()

def check_differences(model, R, X, labels, lr=0.1, epochs=5):

    # Apply two steps for the trainer, get models
    trainer = simpleDNNTrainerOld(model, X, labels, lr)

    R1 = R.clone()
    R2 = R.clone()
    R3 = R.clone()

    start_model = {k: v.detach().clone() for k, v in trainer.model.state_dict().items()}

    for epoch in range(epochs):

        train_bar = tqdm(enumerate(train_loader), 
                        desc=f"Epoch: {epoch+1}/{epochs}",
                        total=len(train_loader)
                    )
        model_before = {k: v.detach().clone() for k, v in trainer.model.state_dict().items()}

        for batch in train_bar:
            
            R_upd = trainer.train_on_batch(batch[1])
            R1 += R_upd

        R2 += update(trainer.model, model_before, X, labels)
        trainer.model.load_state_dict(trainer.upd_model_state)

        test_bar = tqdm(enumerate(test_loader), 
                        desc=f"Epoch: {epoch+1}/{epochs}",
                        total=len(test_loader)
                    )

        metrics = np.zeros(len(test_loader.dataset))
        batch_size = 0
        
        for batch in test_bar:
            l = trainer.validate_on_batch(batch[1])

            bs = batch[1][0].size(0)
            metrics[batch_size: batch_size+bs] = l.cpu().numpy()
            batch_size += batch[1][0].size(0)

        print(f"Epoch {epoch+1}, Accuracy: {np.mean(metrics)}") 

    R3 = update(trainer.model, start_model, X, labels)

    return R1, R2, R3

In [12]:
def calc_deletion_AUC(model, X, labels, R1, R2, R3, k):

    model.eval()
    results = {"compl": [], "skip_epoch": [], "skip_all": []}
    
    for j, mth in enumerate([R1, R2, R3]):
        for i, x in enumerate(X):  

            method = "compl" if j==0 else "skip_epoch" if j==1 else "skip_all" 
            auc_value, _ = deletion_auc(model, x.to('cuda'), mth[i, :].to('cuda'), 
                                        labels[i], baseline_val=config['baseline_val'], steps=200)

            try:
                results[method].append(auc_value)
            except Exception as e:
                results[method].append(np.nan)
            
    res_pth = os.getcwd() + f"/Results/Experiments/02_breaking_linearity/run_{k}"
    os.makedirs(f"{res_pth}", exist_ok=True)
    with open(f"{res_pth}/auc_compare.txt", "a+") as f:

        for method, auc_list in results.items():
            avg_auc = np.nanmean(auc_list)
            f.write(f"Average Deletion AUC for {method:20s}: {avg_auc:.4f} \n")

def calc_avg_drop(model, X, labels, R1, R2, R3, k):

    model.eval()
    results = {"compl": [], "skip_epoch": [], "skip_all": []}

    for j, r in enumerate([R1, R2, R3]):

        method = "compl" if j==0 else "skip_epoch" if j==1 else "skip_all" 

        score = evaluate_pixel_erasure(model, X, labels, r)
        results[method] = score
    
    res_pth = os.getcwd() + f"/Results/Experiments/02_breaking_linearity/run_{k}"
    
    os.makedirs(f"{res_pth}", exist_ok=True)
    with open(f"{res_pth}/ad_compare.txt", "a+") as f:

        for method, ad_score in results.items():
            f.write(f"Average Drop for {method:20s}: {ad_score:.4f} \n")

def plot(X, samples, R1, R2, R3, k):

    for i in range(len(samples)):
        for j, x in enumerate([X[samples[i], :], R1[samples[i], :], R2[samples[i], :]]): #, R3[samples[i], :]]):

            title = f'sample_{i}' if j==0 else f'complete_{i}' if j==1 else f'skip_ep_{i}' if j==2 else f'skip_all_{i}'
            plt.imshow(x.reshape(28, 28)); plt.axis('off')
            plt.savefig(os.getcwd() + f"/Results/Experiments/02_breaking_linearity/run_{k}/{title}.pdf", 
                            bbox_inches='tight', format='pdf')
            plt.close()

In [None]:
def run(k, epochs=10):

    model = init_model().to(device).train()
    lr = 0.05
    R1, R2, R3 = check_differences(model, R, X, labels, lr=lr, epochs=epochs)

    calc_deletion_AUC(model, X, labels, R1, R2, R3, k)
    _ = benchmark_on_batch(model, X, labels, 
                        ['GS', 'IG'], 0, 
                        os.getcwd()+f"/Results/Experiments/02_breaking_linearity/run_{k}",
                        0.6)

    calc_avg_drop(model, X, labels, R1, R2, R3, k)

    with open(os.getcwd() + f"/Results/Experiments/02_breaking_linearity/run_{k}/setup.txt", "w") as f:
        f.write(f"Learning Rate: {lr}")

    if k==10:
        plot(X, range(20), R1, R2, R3, k)

# Calculate mean scores for Deletion AUC
for k in range(10):
    run(k)

# Test resulting attribution maps for one epoch
run(10, epochs=2)

Epoch: 1/2: 100%|██████████| 374/374 [00:04<00:00, 87.95it/s]
Epoch: 1/2: 100%|██████████| 94/94 [00:00<00:00, 127.94it/s]

Epoch 1, Accuracy: 0.2834448160535117



Epoch: 2/2: 100%|██████████| 374/374 [00:04<00:00, 88.20it/s]
Epoch: 2/2: 100%|██████████| 94/94 [00:00<00:00, 126.51it/s]


Epoch 2, Accuracy: 0.3933110367892977


In [None]:
xt_auc, xep_auc, xl_auc = [], [], []
xt_ad, xep_ad, xl_ad = [], [], []

for i in range(10):

    file = os.getcwd() + f"/Results/Experiments/02_breaking_linearity/run_{i}/auc_compare.txt"
    with open(file, 'r') as f:
        
        lines = [l for l in f.readlines()]

        xt_auc.append(float(lines[0].rsplit(":")[1][1:]))
        xep_auc.append(float(lines[1].rsplit(":")[1][1:]))
        xl_auc.append(float(lines[2].rsplit(":")[1][1:]))

    file = os.getcwd() + f"/Results/Experiments/02_breaking_linearity/run_{i}/ad_compare.txt"
    with open(file, 'r') as f:
        
        lines = [l for l in f.readlines()]

        xt_ad.append(float(lines[0].rsplit(":")[1][1:]))
        xep_ad.append(float(lines[1].rsplit(":")[1][1:]))
        xl_ad.append(float(lines[2].rsplit(":")[1][1:]))

with open(os.getcwd() + "/Results/Experiments/02_breaking_linearity/results.txt", "w+") as f:

    f.write('------------------ Deletion AUC ------------------ \n\n')
    f.write(f"Average Deletion AUC for X_train: {np.mean(xt_auc)} \n")
    f.write(f"Std for Deletion AUC for X_train: {np.std(xt_auc, ddof=1)} \n\n")
    f.write(f"Average Deletion AUC for X_steps: {np.mean(xep_auc)} \n")
    f.write(f"Std for Deletion AUC for X_steps: {np.std(xep_auc, ddof=1)} \n\n")
    f.write(f"Average Deletion AUC for X_linear: {np.mean(xl_auc)} \n")
    f.write(f"Std for Deletion AUC for X_linear: {np.std(xl_auc, ddof=1)} \n\n")

    f.write('------------------ Average_drop ------------------ \n\n')
    f.write(f"Average Drop for X_train: {np.mean(xt_ad)} \n")
    f.write(f"Average Drop for X_steps: {np.mean(xep_ad)} \n")
    f.write(f"Average Drop for X_linear': {np.mean(xl_ad)} \n")
