# AlexNet Places 365 Neuron Disentanglement

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import copy

import numpy as np
import pandas as pd
import matplotlib.pylab as plt
import seaborn as sns

import torch
import torchvision
from torch import nn

from torch.nn import functional as F
from torchvision import transforms
from torch.utils.data import Dataset, Subset, DataLoader

from sklearn.decomposition import PCA

In [None]:
import os

os.chdir('/notebooks/compexp/vision/')
print(os.getcwd())

import sys

# conexp
sys.path.append('')
import settings
from loader.model_loader import loadmodel


sys.path.append('./my')
from rosettastone import maxact
from rosettastone.disentanglement import disentanglenet, splitting_operation, succeeding_operation
from rosettastone.utils import InfiniteDataLoader, freezer, forward_up_to

In [None]:
import wandb

wandb.login()

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

## Data & Model

In [None]:
model = loadmodel(None)

def normalize_image(rgb_image):
    img = np.array(rgb_image, dtype=np.float32).copy()
    img = img[:, ::-1]
    img -= [109.5388, 118.6897, 124.6901]
    img = img.transpose((2, 0, 1))
    
    # accuracy is better without rgb -> bgr, not clear why it is necessery
    # perhaps it is necesseray for ADE20K? Don't know right now.
    # img = torch.from_numpy(img[::-1, :, :].copy())
    
    img = torch.from_numpy(img.copy())
    img = img.div_(255.0 * 0.224)
    
    return img

place365_dataset = torchvision.datasets.Places365(root='/notebooks/compexp/vision/my/data', split='val', small=True,
                                         transform=transforms.Lambda(normalize_image))
place365_dataloader = torch.utils.data.DataLoader(place365_dataset, batch_size=32, num_workers=8, shuffle=False)

model

## Neuron on the Table: Last Conv2D layer @ 31
## Polysemantic: Buildings & People

In [None]:
NEURON = 31
TOP_K = 50

In [None]:
with torch.no_grad():
    orig_activations = maxact.harvest_activations(model, place365_dataloader, {'last_conv': lambda m: m.features[10]})

maxact.visualize_neuron('last_conv', NEURON, model, orig_activations, place365_dataset,
                        with_cropped=False, with_lucent=False, top_k=TOP_K)
    
(top_img_acts,
 top_img_indice,
 _,
 top_reprs_last_conv,
 _, _) = maxact.get_neuron_max_activations(orig_activations, place365_dataset, 'last_conv', NEURON, top_k=TOP_K)

top_last_conv_repers = (PCA(n_components=2, random_state=0)
                        .fit_transform(orig_activations['last_conv'][top_img_indice, :]
                                       .mean(dim=(2, 3))))

del orig_activations

### Manual labeling of the Top-50 max activiations images to **concepts**.

In [None]:
# manually extracted
y_concept = np.array(['building'] * 9 + ['people'] * 2
                     + ['building'] + ['people'] + ['building']
                     + ['people']*3 + ['building'] * 3
                     + ['other'] + ['people'] + ['building'] * 3
                     + ['people'] * 7 + ['building'] * 3 + ['people']
                     + ['building'] * 3 + ['people'] +['building'] * 2
                     + ['people'] * 2 + ['building'] * 3 + ['people'] * 2 + ['building'])


first_concept = 'building'
second_concept = 'people'

concepts = {first_concept, second_concept}

concept_datasets = {c: Subset(place365_dataset, top_img_indice[y_concept == c])
                    for c in concepts}

concept_dataloaders = {c: InfiniteDataLoader(ds, batch_size=8)
                      for c, ds in concept_datasets.items()}

concept_probs = {c: DataLoader(ds, batch_size=len(ds))
                for c, ds in concept_datasets.items()}

### Representations and Seperation

#### Are the concepts linearly separated in the last conv layer representation? Yes!

In [None]:
sns.scatterplot(x=top_last_conv_repers[:, 0], y=top_last_conv_repers[:, 1], hue=y_concept);

#### Are the concepts linearly separated in the penultimate conv layer representation? Yes!

In [None]:
with torch.no_grad():
    orig_activations = maxact.harvest_activations(model, place365_dataloader, {'penultimate_conv': lambda m: m.features[9]},
                                                 apply_fn=lambda x:x.mean(dim=(2,3)))
    
top_penultimate_conv_repers = PCA(n_components=2, random_state=0).fit_transform(orig_activations['penultimate_conv'][top_img_indice, :])

del orig_activations

sns.scatterplot(x=top_penultimate_conv_repers[:, 0], y=top_penultimate_conv_repers[:, 1], hue=y_concept)

#### Are the concepts linearly separated in the first conv layer representation? Yes!

In [None]:
with torch.no_grad():
    orig_activations = maxact.harvest_activations(model, place365_dataloader, {'first_conv': lambda m: m.features[0]},
                                                  apply_fn=lambda x:x.mean(dim=(2,3)))
    
top_first_conv_repers = PCA(n_components=2, random_state=0).fit_transform(orig_activations['first_conv'][top_img_indice, :])

del orig_activations

sns.scatterplot(x=top_first_conv_repers[:, 0], y=top_first_conv_repers[:, 1], hue=y_concept)

## Define `DisentangleNet` Class

In [None]:
class AlexDisentangleNet(nn.Module):
    def __init__(self, net, neuron, new_neuron_noise_std, device):
        super().__init__()
        
        self.orig_net = copy.deepcopy(net)
        self.splitted_net = copy.deepcopy(net)
        
        self.neuron = neuron
        self.device = device
        self.new_neuron_noise_std = new_neuron_noise_std
        
        self.orig_layer_forward_fn = None
        self.splitted_layer_forward_fn = None
        
        self.layer_getter = None

        self.active_net = 'orig'
        
        self._prepare_splitted()

        
    def _prepare_splitted(self):
        
        freezer(self.orig_net)
        freezer(self.splitted_net)

        self.splitted_net.features[10], _ = splitting_operation(self.orig_net.features[10],
                                                             self.neuron,
                                                             self.new_neuron_noise_std,
                                                             self.device)
        
        self.splitted_net.classifier[1], _ = succeeding_operation(self.orig_net.classifier[1],
                                                             self.neuron,
                                                             self.orig_net.features[10].out_channels,
                                                             self.device)

        self.orig_layer_forward_fn = forward_up_to(self.orig_net.features, 10)
        self.splitted_layer_forward_fn = forward_up_to(self.splitted_net.features, 10)
        self.layer_getter = lambda m: m.features[10]

        self.eval()  # turn off dropout, but also BN and it is not good

        
    def forward(self, x, active_net=None):
        if active_net is None:
            active_net = self.active_net
        assert active_net in ('orig', 'splitted')
        return self.orig_net(x) if active_net == 'orig' else self.splitted_net(x)

    
    def generate_losses(self, x_indff, x_cnpt1, x_cnpt2):
        x_indff_orig = self.orig_net.features(x_indff)
        x_indff_orig = self.orig_net.avgpool(x_indff_orig)
        x_indff_orig = torch.flatten(x_indff_orig, 1)
        x_indff_orig = self.orig_net.classifier[1](x_indff_orig)  # skip dropout

        x_indff_splitted = self.splitted_net.features(x_indff)
        x_indff_splitted = self.splitted_net.avgpool(x_indff_splitted)
        x_indff_splitted = torch.flatten(x_indff_splitted, 1)
        x_indff_splitted = self.splitted_net.classifier[1](x_indff_splitted) # skip dropout

        indff_loss = (torch.flatten(F.mse_loss(x_indff_orig,
                                         x_indff_splitted,
                                                 reduction='none'), 1))    
        x_cnpt1 = self.splitted_layer_forward_fn(x_cnpt1)
        x_cnpt2 = self.splitted_layer_forward_fn(x_cnpt2)
        
        spc11_loss = (x_cnpt1[:, self.neuron, :, :]**2).mean(dim=(1, 2))
        spc12_loss = (x_cnpt1[:, -1, :, :]**2).mean(dim=(1, 2))

        spc21_loss = (x_cnpt2[:, self.neuron, :, :]**2).mean(dim=(1, 2))
        spc22_loss = (x_cnpt2[:, -1, :, :]**2).mean(dim=(1, 2))

        wd1_loss = (self.splitted_net.features[10].weight[self.neuron, :, :, :]**2).mean()
        wd2_loss = (self.splitted_net.features[10].weight[-1, :, :, :]**2).mean()
        
        return (indff_loss,
                spc11_loss,
                spc22_loss,
                spc12_loss,
                spc21_loss,
                wd1_loss,
                wd2_loss)

    def equality_report(self):
        print()
        print('==', (self.orig_net.features[10].bias[:self.neuron] == self.splitted_net.features[10].bias[:self.neuron]).all())
        print('==', (self.orig_net.features[10].bias[self.neuron+1:] == self.splitted_net.features[10].bias[self.neuron+1:-1]).all())
        print('==', (self.orig_net.features[10].bias[self.neuron] == self.splitted_net.features[10].bias[self.neuron]).all())
        print('==', (self.orig_net.features[10].bias[self.neuron] == self.splitted_net.features[10].bias[-1]).all())
        print()
        if self.neuron:
            print('==', (self.orig_net.classifier[1].weight[:, :36*self.neuron] ==  self.splitted_net.classifier[1].weight[:, :36*self.neuron]).any())
        print('==', (self.orig_net.classifier[1].weight[:, 36*(self.neuron+1):] ==  self.splitted_net.classifier[1].weight[:, 36*(self.neuron+1):-36]).any())
        print('!=', (self.orig_net.classifier[1].weight[:, 36*self.neuron:36*(self.neuron+1)] ==  self.splitted_net.classifier[1].weight[:, 36*self.neuron:36*(self.neuron+1)]).any())
        print('!=', (self.splitted_net.classifier[1].weight[:, 36*self.neuron:36*(self.neuron+1)] ==  self.splitted_net.classifier[1].weight[:, -36:]).any())
        print('==', (self.orig_net.classifier[1].bias ==  self.splitted_net.classifier[1].bias).any())
        print()

## Run the Disentanglement!

In [None]:
disentanglenet(AlexDisentangleNet, model, NEURON,
               place365_dataset, place365_dataloader, place365_dataloader,
               first_concept, second_concept, concept_probs, concept_dataloaders,
               epochs=1, lr=1e-3, alpha=1e-3, beta=1e-2, new_neuron_noise_std=1e-4,
               device=device)

### Hyperparameter Tuning

In [None]:
sweep_config = {
    'method': 'random', #grid, random
    'metric': {
      'name': 'test_accuracy_splitted',
      'goal': 'maximize'   
    },
    'parameters': {
        'epochs': {
            'values': [1, 2, 5, 10]
        },
        'alpha': {
            'values': [1e-2, 1e-3, 1e-4, 3e-4, 3e-5, 1e-5]
        },
        'beta': {
            'values': [1e-2, 1e-3, 1e-4, 3e-4, 3e-5, 1e-5]
        },
        'lr': {
            'values': [1e-2, 1e-3, 1e-4, 3e-4, 3e-5, 1e-5]
        },
        'new_neuron_noise_std': {
            'values': [1e-2, 1e-3, 1e-4, 3e-4, 3e-5, 1e-5, 0]
        }
    }
}

sweep_id = wandb.sweep(sweep_config, project="disentanglement")

In [None]:
def sweep_trainer():
    disentanglenet(AlexDisentangleNet, model, NEURON,
                   place365_dataset, place365_dataloader, place365_dataloader,
                   first_concept, second_concept, concept_probs, concept_dataloaders,
                   epochs=1, lr=1e-3, alpha=1e-3, beta=1e-2, new_neuron_noise_std=1e-4,
                   verbose=False,
                   device=device)

In [None]:
wandb.agent(sweep_id, sweep_trainer)