In [1]:
import torch, numpy as np
import torch.nn as nn
from tqdm import tqdm
from torch.amp import autocast
import matplotlib.pyplot as plt

from Training.Trainer.BaseTrainer import BaseTrainer
from Utils.configs import conf
from Utils.model_utils import * 
from Training.Data.Modules.custom_loader import CustomLoader 

In [None]:
os.makedirs(os.getcwd() + f"/Results/Experiments/01_logit_evolution", exist_ok=True)
config = conf['simpleDNNTMnist']

data_loader = CustomLoader(config["dataset"], True, config["batch_size"], shuffle_test=True)
num_samples = 1000

In [3]:
def track_model_activations(model, samples, labels):

    labels = labels.long()
    out_model = model(samples, apply_softmax=False)
    best_score = out_model[torch.arange(out_model.shape[0]), labels]
    
    mask = torch.ones_like(out_model, dtype=torch.bool) 
    mask[torch.arange(out_model.shape[0]), labels] = False 
    other_scores = out_model[mask].reshape((out_model.shape[0], out_model.shape[1]-1))

    gather_tensor = torch.zeros_like(out_model)
    gather_tensor[:, 0] = best_score
    gather_tensor[:, 1:] = other_scores

    return gather_tensor.detach().cpu()

def accuracy(Y_hat, Y, averaged=True):
    """Compute the number of correct predictions."""
    
    Y_hat = Y_hat.reshape((-1, Y_hat.shape[-1]))
    preds = Y_hat.argmax(axis=1).type(Y.dtype)
    compare = (preds == Y.reshape(-1)).type(torch.float32)
    return compare.mean() if averaged else compare

def activation_evolution(activations_tracker):

    cmap = plt.get_cmap("Oranges")

    shp = activations_tracker.shape[2]
    colors = [cmap(0.3 + 0.5 * j / shp) for j in range(1, shp)]

    for i in range(10):

        sample_i = activations_tracker[:, i, :]

        # BEST
        plt.plot(np.arange(sample_i.shape[0]), 
                 sample_i[:, 0].numpy(), 
                 color='lightblue',
                 linewidth=4,
                 label="Target neuron")
        
        plt.plot(np.arange(sample_i.shape[0]), 
                 torch.mean(sample_i[:, 1:], dim=1).numpy(), 
                 color="coral", 
                 linewidth=4, 
                 label="Non-target neurons (mean)"
            )

        lower = torch.min(sample_i[:, 1:], dim=1).values.numpy()
        upper = torch.max(sample_i[:, 1:], dim=1).values.numpy()
        plt.fill_between(
            np.arange(sample_i.shape[0]), 
            lower, 
            upper, 
            color=colors[0], 
            alpha=0.2,  # Transparency (0=invisible, 1=opaque)
            label="Non-target neurons (variation area)"
        )

        for j in range(1, sample_i.shape[1]):
            plt.xlabel("Epochs")
            plt.ylabel("Logits")

            plt.plot(np.arange(sample_i.shape[0]), 
                     sample_i[:, j].numpy(),
                     color=colors[j-1], 
                     alpha=0.7,  # Slight transparency for lines
                     linewidth=1.5
            )
        
        plt.xlabel("Epochs")
        plt.ylabel("Logits")
        plt.legend()

        plt.savefig(os.getcwd() + f"/Results/Experiments/01_logit_evolution/sample_{i}.pdf", 
                    bbox_inches='tight', format='pdf')        
        plt.close()

def activation_evolution_statistics(activations_tracker, mean_acc):

    next_activs = torch.zeros_like(activations_tracker)
    for i in range(next_activs.shape[0]-1):
        next_activs[i, :, :] = activations_tracker[i+1, :, :]
    
    activ_diff = next_activs - activations_tracker
    activ_diff = activ_diff[:-1, :, :]

    with open(os.getcwd() + "/Results/Experiments/01_logit_evolution/logits_evolution.txt", "a+") as f:

        f.write(f"Model Accuracy: {mean_acc} \n")
        f.write(f'For all samples, the mean of best value changes as: {torch.mean(activ_diff[:, :, 0])} \n')
        f.write(f'For all samples, the mean of other values change as: {torch.mean(activ_diff[:, :, 1:])} \n')
        f.write("\n")

In [None]:
class SimpleTrainer(BaseTrainer):

    def __init__(self, model, data, config, num_samples):

        super().__init__(model, config, data, num_samples, save_results=False)
        
        self.dev = 'cuda' if torch.cuda.is_available() else "cpu"
        self.loss = nn.CrossEntropyLoss()

    def train(self):
            
        activ_changes = torch.zeros((self.epochs+1, self.X.shape[0], 10))
        activ_changes[0, :, :] = track_model_activations(self.model, 
            self.X, self.labels).detach().cpu()
        
        for epoch in range(1, self.epochs + 1):
            
            train_bar = tqdm(enumerate(self.train_load), 
                        desc=f"Epoch: {epoch}/{self.epochs}",
                        total=len(self.train_load)
                    )
            
            self.model.train()
            for batch in train_bar:
                l = self.train_on_batch(batch[1])

            self.scheduler.step()
            activ_changes[epoch, :, :] = track_model_activations(self.model, 
                self.X, self.labels).detach().cpu()

            self.model.eval()
            metrics = np.zeros(len(self.test_load.dataset))
            batch_size = 0

            test_bar = tqdm(enumerate(self.test_load), 
                        desc=f"Epoch: {epoch}/{self.epochs}",
                        total=len(self.test_load)
                    )

            for batch in test_bar:
                l = self.validate_on_batch(batch[1])

                bs = batch[1][0].size(0)
                metrics[batch_size: batch_size+bs] = l.cpu().numpy()
                batch_size += batch[1][0].size(0)

            print(f"Epoch {epoch}, Accuracy: {np.mean(metrics)}")        
            self.train_logs.append(np.mean(metrics))
        
        # activation_evolution(activ_changes)
        activation_evolution_statistics(activ_changes, np.mean(metrics))

    def apply_loss(self, y_pred, y):

        y_pred = y_pred.reshape((-1, y_pred.shape[-1]))
        y = y.reshape((-1,))

        if y[0].dtype == torch.int or torch.float32: y = y.long()

        return self.loss(y_pred, y)

    def train_on_batch(self, batch):

        # Get data from batch, calculate loss and update
        images, labels = batch[0], batch[1]
        images = images.to(self.device); labels = labels.to(self.device)

        self.optimizer.zero_grad()
        with autocast(device_type=self.dev):  # Enable mixed-precision

            output = self.forward(images)
            l = self.apply_loss(output, labels)

        self.scaler.scale(l).backward()
        self.scaler.step(self.optimizer)
        self.scaler.update()
        
        return l

    def validate_on_batch(self, batch):

        images, labels = batch[0], batch[1]
        images = images.to(self.device)
        labels = labels.to(self.device)

        with torch.no_grad():
            output = self.forward(images)

        acc = accuracy(output, labels, averaged=False)
        return acc

In [None]:
layers = config['model_layers']
# layers = [784, 512, 256, 64, 10]

for i in range(10):
    model = get_model_architecture(config['model_name'], layers + [config["num_classes"]])
    trainer = SimpleTrainer(model, data_loader, config, num_samples)
    trainer.train()

In [3]:
file = os.getcwd() + "/Results/Experiments/01_logit_evolution/logits_evolution.txt"
with open(file, 'r') as f:
    
    lines = [l for l in f.readlines()]

    acc_scores = [float(l.rsplit(":")[1][1:-2]) for l in lines if "Model Accuracy" in l]
    targ_values = [float(l.rsplit(":")[1][1:-2]) for l in lines if "best value" in l]
    notarg_values = [float(l.rsplit(":")[1][1:-2]) for l in lines if "other values" in l]

In [4]:
print("Mean Accuracy: ", np.mean(acc_scores))
print("STD Accuracy: ", np.std(acc_scores, ddof=1))
print("Mean Logit Target Update: ", np.mean(targ_values))
print("STD Logit Target Update: ", np.std(targ_values, ddof=1))
print("Mean Logit Non-Target Update: ", np.mean(notarg_values))
print("STD Logit Non-Target Update: ", np.std(notarg_values, ddof=1))

Mean Accuracy:  0.8831605351170568
STD Accuracy:  0.053661651794216995
Mean Logit Target Update:  0.653751540184021
STD Logit Target Update:  0.03661135019573068
Mean Logit Non-Target Update:  -0.06500565763562918
STD Logit Non-Target Update:  0.029020444148339045
