![fig](../../img/kde_learning_single_representation.png)

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

import torchvision.datasets as datasets
import torchvision.transforms as transforms

from torch.utils.data import DataLoader

import numpy as np
from sklearn.neighbors import KernelDensity

In [None]:
def train(model, train_loader, criterion, optimizer, device):
    model.train()
    train_loss = 0.0
    correct = 0
    total = 0
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        train_loss += loss.item() * images.size(0)
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

    train_loss /= len(train_loader.dataset)
    accuracy = 100. * correct / total
    return train_loss, accuracy


def evaluate(device, model, dataloader):
    correct = 0
    total = 0
    model.eval()
    with torch.no_grad():
        for images, labels in dataloader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = 100 * correct / total
    return accuracy

In this experiment we use the representations of each sample extracted from different layer of the network. Each as point in an high-dimensional space. Each sample is figuratively represented as a point in an high-dimensional spa

In [None]:

def extract_activations_layers(layers):
    """ Extract for each layer the activations

    Args:
        layers (np.array): shape (layer_activation, batch_size, number_of_neurons)

    Returns:
        np.array: shape (layer_activation, batch_size, number_of_activations)
    """

    return np.array([np.array([np.array(h) for h in l]) for l in layers])

def extract_activations_per_sample(layers, mask = False):
    """ Extract for each sample the activations 
    for each layer and store them in a list.

    Args:
        layers (np.array): shape (layer_activation, batch_size, number_of_neurons)

    Returns:
        np.array: shape (batch_size, number_of_activations)
    """

    if mask == True:
        # mask the activations to remove zeros
        mask = layers != 0
        layers = [[np.array(h[m]) for h, m in zip(l,sm)] 
                for l, sm in zip(layers, mask)]
        
    return np.array([layers[:,i,:].flatten().reshape(-1, 1) for i in range(layers.shape[1])])


def get_sampled_activations(activations, bandwidth = 0.2):
    """ Sample the activations using KDE

    Args:
        activations (np.array): shape (batch_size, number_of_activations)

    Returns:
        np.array: shape (batch_size, number_of_activations)
    """

    return torch.from_numpy(np.array([KernelDensity(kernel="gaussian", bandwidth=bandwidth).fit(a).sample([n_neurons]) for a in activations], dtype="float32")).squeeze(2)

def wd(layers: list()):
    """ Compute the weight decay for each layer

    Args:
        layers (list): list of layers

    Returns:
        torch.tensor: weight decay

    """
    return get_sampled_activations(
                list(
                    extract_activations_per_sample(
                            extract_activations_layers(layers), 
                            mask=False
                        )
                ), 
                bandwidth=0.2
            )

n_neurons = 64
class MLPWD(nn.Module):
    def __init__(self):
        super(MLPWD, self).__init__()
        self.flatten = nn.Flatten()
        self.l1 = nn.Linear(784, n_neurons)
        self.l2 = LinW(in_features=n_neurons, out_features=n_neurons, depth=0)
        self.l3 = LinW(in_features=n_neurons, out_features=n_neurons, depth=1, layers=[self.l2])
        self.l4 = nn.Linear(n_neurons, 10)
        self.gelu = nn.GELU()
        self.layers = [self.l2, self.l3]

    def forward(self, x):
        repr = []
        x = self.flatten(x)
        x = self.gelu(self.l1(x))
        repr.append(x.detach().cpu().numpy())
        x = self.gelu(self.l2(x, repr))
        repr.append(x.detach().cpu().numpy())
        x = self.gelu(self.l3(x, repr))
        x = self.l4(x)
        return x
    
    def __getitem__(self, idx):
        return self.layers[idx]
    
    def __len__(self):
        return len(self.layers)
    

class LinW(nn.Linear):
    def __init__(self, in_features, out_features, depth, layers=[]):
        super(LinW, self).__init__(in_features=in_features, out_features=out_features)
        self.depth = depth
        self.layers = layers[:self.depth] if len(layers)>0 else layers

    def forward(self, input, prev=[]):
        # use wd as a shift 
        return F.linear(input + wd(prev).to('cuda:0'), self.weight, self.bias)
        # use wd as transformation
        return F.linear(input * wd(prev).to('cuda:0'), self.weight, self.bias)
        # use wd as the new activations
        return F.linear(wd(prev).to('cuda:0'), self.weight, self.bias)

EPOCHS = 10
BATCH_SIZE = 64

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transforms.ToTensor())
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transforms.ToTensor())

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

model = MLPWD().to(device)
optimizer = optim.AdamW(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)

print("LinW layers:", "\n".join([f"Depth {model[i].depth}: {model[i]}" for i in range(len(model))]), sep="\n\n")

for epoch in range(EPOCHS):
    train_loss, train_acc = train(model, train_loader, criterion, optimizer, device)
    test_accuracy = evaluate(device, model, test_loader)
    print(f'Epoch {epoch + 1}/{EPOCHS}, Training Loss: {train_loss:.4f}, Training Accuracy: {train_acc:.2f}%, Test accuracy: {test_accuracy:.2f}%')
    lr_scheduler.step()
