In [36]:
import os
import os.path as osp
import pickle
import sklearn
from tqdm import tqdm
from ipywidgets import interact 
from IPython.display import clear_output

import numpy as np
import pandas as pd
import torch 
import torch.nn as nn
import matplotlib.pyplot as plt
import plotly.express as px
import seaborn as sns
import clip
from  torchvision import transforms as T

from artemis.emotions import ARTEMIS_EMOTIONS
import EmotionPredictor.data_tools as dt
import  EmotionPredictor.visual as visual
from EmotionPredictor.data_tools import get_loaders
from EmotionPredictor.training import Trainer, SLP
from EmotionPredictor.visual import save_fig

%config InlineBackend.figure_format ='retina'

# Layer analysis

## Feature extraction
In this section a hook is attached to the botle necks of each layer of clip's ResNet50 to save their averaged output. 

In [38]:
device = "cuda" if torch.cuda.is_available() else "cpu"
resnet_clip, _ = clip.load("RN50")
resnet_clip = resnet_clip.eval()
resnet_clip50 = resnet_clip.visual
resnet50_dict = dict(resnet_clip50.named_children())

In [39]:
def retrieve_backup(backup_id):
    return [x for x in globals().values() if id(x)==backup_id][0]

class HooksBackup:
    """Class made to backup the hooks to keep track of the hooks."""
    def __init__(self):
        self._backup = {}
        self._hook_managers = []
        print(f"backup at id : {id(self)}")
        
    @property
    def hook_managers(self):
        return self._hook_managers
    
    @property
    def backup(self):
        return self._backup
    
    def flush(self):
        for hook_manager in self._hook_managers:
            hook_manager.remove_hooks()
        self._hook_managers = []
            
    
    def add_hook(self, layer, hook):
        if layer in self._backup.keys():
            self._backup[layer].append(hook)
        else :
            self._backup[layer] = [hook]

        
    def add_hook_manager(self, manager):
        self._hook_managers.append(manager)
        
            
_hooks_backup = HooksBackup()

backup at id : 140314852755728


In [40]:
class HooksManager():
    def __init__(self, layer_name, layer, hook_functions = None ):
        self.outputs = {}
        self._reset_hooks()
        self.name = layer_name
        self.layer = layer
        self._hooks = []
        self._hook_functions = hook_functions if hook_functions is not None else []
        self.set_hooks()
        _hooks_backup.add_hook_manager(self)
        
    @property
    def hook_functions(self):
        return self._hook_functions
            
    @hook_functions.setter      
    def hook_functions(self, new_hook_functions):
        self._hook_functions = new_hook_functions
        self.remove_hooks()
        self.set_hooks()
        
    def set_hooks(self):
        for func in self._hook_functions:
            new_hook = self.layer.register_forward_hook(func)
            self._hooks.append(new_hook)
            _hooks_backup.add_hook(self.layer, new_hook)
    
    def remove_hooks(self):
        for hook in self._hooks :
            hook.remove()
        self._reset_hooks()

    def _reset_hooks(self):
        self._hooks = []

def store_transformed_outputs(function, target):
    """Takes a hooking function and store its output in a target.
    The target is a a dictionary of type {"target": target_destination} """
    def hooked(layer, input, output):
        tensor = function(layer, input, output)
        target["image"] = tensor
    return hooked
    
def average2d(tensor):
    assert len(tensor.shape) == 4, f"expected 4d input got {tenser.shape}"
    return tensor.mean((2,3))

def average_my_output(layer, inp, output):
    return average2d(output)

def save_batch(batch_number, batch, save_path):
    if not os.path.exists(save_path): os.mkdir(save_path)
    with open(f"{save_path}/batch{batch_number}.bin","wb") as f:
            pickle.dump(batch, f)
   
    


In [41]:
stored_forward_pass = {}
layer_names = [f"layer{i}" for i in range(1,5)]
data_path = "../data/wikiart_embeddings/clip_training/RN50_layers/"
if not os.path.exists(data_path): os.mkdir(data_path)
_hooks_backup.flush()
for layer_name in layer_names: 
    for bottle_name, bottle_neck in resnet50_dict[layer_name].named_children():
        name = layer_name + "_" + bottle_name
        path = osp.join(data_path, name )
        if not os.path.exists(path): os.mkdir(path)
        stored_forward_pass[name] = {"data_path" : path, 
                                           "image": None}
        hook_functions = [store_transformed_outputs(average_my_output, stored_forward_pass[name])]
        HooksManager(name, bottle_neck, hook_functions)

In [42]:
target_storage = dict(zip(layer_names, [{} for _ in layer_names]))
_hooks_backup.flush()
for name in layer_names:
    hook_functions = [store_transformed_outputs(average_my_output, target_storage[name])]
    HooksManager(name, resnet50_dict[name], hook_functions)

In [43]:
for subset in ["val","test", "rest", "train"]:
    print(f"creating {subset} set")
    dataloader = dt.Pickle_data_loader(f"../_data/preprocessed/img_size_224/{subset}/")
    for i, batch in tqdm(enumerate(dataloader)):
        resnet_clip50(batch["image"])
        for layer in stored_forward_pass.keys() :
            batch["image"] = stored_forward_pass[layer]["image"]
            save_batch(i, batch, stored_forward_pass[layer]["data_path"] + f"/{subset}")        

creating val set


100it [00:23,  4.29it/s]


creating test set


199it [00:35,  5.53it/s]


creating rest set


18it [00:02,  6.44it/s]


creating train set


3372it [09:27,  5.94it/s]


## Layer visualisation

In [13]:
dataset_path = "../data/wikiart_embeddings/clip_training/RN50_layers/"
checkpoint_path = "../neural_checkpoints/clip_training/RN50_layers/"
models = {}
base_layers = layer_names = [f"layer{i}" for i in range(1,5)] # ["layer1","layer2","layer3","layer4"]
layers = os.listdir(resnet_layer_path)
for base_layer in base_layers : layers.remove(base_layer)
def train_on_dataset(path):
    """Fix the parameters for training on the ArtEmis features"""
    loaders = get_loaders(path)
    input_shape = (loaders["train"].load_batch(0))["image"].shape[1]
    return Trainer(model = SLP(input_size = input_shape, output_size=9).to(device),
                        loss_fn = nn.BCEWithLogitsLoss(),
                        optimizer_fn = torch.optim.Adam,
                        lr = 10**-2,
                        data_loaders = loaders,
                        device = device)

for layer in layers :
    models[layer] = train_on_dataset(osp.join(dataset_path, layer))
    print(layer + " created.")
    if layer in os.listdir(checkpoint_path):
        models[layer].model.load_state_dict(torch.load(osp.join(checkpoint_path, layer)))


layer3_4 created.
layer1_2 created.
layer2_1 created.
layer3_1 created.
layer2_0 created.
layer2_3 created.
layer2_2 created.
layer1_0 created.
layer4_2 created.
layer3_0 created.
layer3_2 created.
layer4_0 created.
layer3_5 created.
layer3_3 created.
layer1_1 created.
layer4_1 created.


In [None]:
for model in models.values():
    model.train_eval(4, lrs = [10**-2,10**-3,10**-4,10**-5])
    clear_output()

In [45]:
for model in models.values() : model.create_report(ARTEMIS_EMOTIONS, show_fig = False)
clear_output()

In [18]:
for key, model in models.items():
    model.save_model(osp.join(checkpoint_path, key))

In [21]:
a = [[key, model.metrics()] for key, model in models.items()]
index, vals = list(zip(*a))
res_layers = pd.DataFrame(vals, index = index).sort_index()
res_layers.head(2)

Unnamed: 0,agreement_threshold,confusion_matrix,precision_recall_fscore_support,recall,f1_score,accuracy
layer1_0,0.5,amusement awe contentment e...,"([0.3786407766990291, 0.36065573770491804, 0.5...",0.193657,0.203543,0.532584
layer1_1,0.5,amusement awe contentment e...,"([0.391304347826087, 0.36649214659685864, 0.57...",0.195471,0.205955,0.53523


In [34]:
prf2index = dict(list(zip(["precision", "recall", "fscore"], range(3))))

@interact(metric = prf2index.keys())
def plot_layers(metric):
    f1s = res_layers["precision_recall_fscore_support"].map(lambda y : y[prf2index[metric]])
    f1s = pd.DataFrame(f1s)
    for i, emotion in enumerate(ARTEMIS_EMOTIONS):
        f1s[emotion] = f1s.precision_recall_fscore_support.map(lambda y : y[i])
    f1s = f1s.drop(columns = "precision_recall_fscore_support")
    
    ax = f1s.plot(figsize = (16,9));
    colors_emotions = visual.emotion_colors()
    markers = ['>', '+', '.', ',', 'o', 'v', 'x', 'X', 'D', '|']
    for i, line in enumerate(ax.get_lines()):
        line.set_marker(markers[i])
        line.set_color(colors_emotions[i])
    ax.legend(ax.get_lines(), f1s.columns, loc='upper left')
    plt.grid(axis = "y", ls = ":")
    plt.ylabel("f1 score");

interactive(children=(Dropdown(description='metric', options=('precision', 'recall', 'fscore'), value='precisi…