# Model and data load

In [1]:
import os
os.chdir('..')
os.getcwd()

'/Users/pavelpopov/mlp_nn'

In [83]:
from omegaconf import OmegaConf, DictConfig
import torch
import pandas as pd

from captum.attr import IntegratedGradients, NoiseTunnel, Saliency, visualization as viz
import matplotlib.pyplot as plt
import numpy as np

from src.utils import set_project_name, set_run_name, validate_config, get_resume_params
from src.data import data_factory, data_postfactory
from src.dataloader import dataloader_factory, cross_validation_split
from src.model import model_config_factory, model_factory

In [39]:
def load_model_and_data(path):
    cfg = OmegaConf.load(path+'/general_config.yaml')

    df = pd.read_csv(path+'/runs.csv')
    best_idx = df["test_score"].idxmax()
    k = best_idx // cfg.mode.n_trials
    trial = best_idx - cfg.mode.n_trials * k
    print("k ", k)
    print("trial ", trial)

    original_data = data_factory(cfg)
    model_cfg = model_config_factory(cfg, k)
    data = data_postfactory(
            cfg,
            model_cfg,
            original_data,
        )
    dataloaders = dataloader_factory(cfg, data, k=k, trial=trial)

    model = model_factory(cfg, model_cfg)

    model_logpath = path+f"/k_{k:02d}/trial_{trial:04d}/best_model.pt"
    checkpoint = torch.load(
        model_logpath, map_location=lambda storage, loc: storage
    )
    model.load_state_dict(checkpoint)

    return model, dataloaders

# Introspection

In [174]:
class Introspector:
    """Basic introspector"""

    def __init__(self, model, features, labels, methods, save_path) -> None:
        self.methods = methods
        self.save_path = save_path
        self.model = model

        self.features = features
        self.labels = labels

        if "saliency" in self.methods:
            os.makedirs(f"{self.save_path}saliency/colormap", exist_ok=True)
            os.makedirs(f"{self.save_path}saliency/barchart", exist_ok=True)
        if "ig" in self.methods:
            os.makedirs(f"{self.save_path}ig/colormap", exist_ok=True)
            os.makedirs(f"{self.save_path}ig/barchart", exist_ok=True)
        if "ignt" in self.methods:
            os.makedirs(f"{self.save_path}ignt/colormap", exist_ok=True)
            os.makedirs(f"{self.save_path}ignt/barchart", exist_ok=True)


    def run(self, cutoff=1, percentile=0.9):
        """Run introspection, save results"""
        targets = torch.unique(self.labels)
        for method in self.methods:
            for target in targets:
                filter_array = self.labels == target
                features = self.features[filter_array]
                features.requires_grad = True

                grads = self.get_grads(method, features, target)
                grads = grads.cpu().detach().numpy()
                detached_features = features.cpu().detach().numpy()
                
                print(f"\tPlotting generalized saliency maps using '{method}' with target {target}")
                self.plot_colormaps(
                    grads,
                    detached_features, 
                    filepath=f"{self.save_path}{method}/colormap/general_{target}.png",
                    color="blue" if target == 0 else "red"
                )

                self.plot_histograms(
                    grads, 
                    filepath=f"{self.save_path}{method}/barchart/general_{target}.png",
                    percentile=percentile
                )

                print(f"\tPlotting single sample saliency maps using '{method}' with target {target}")
                for i in range(cutoff):
                    feature = detached_features[i][np.newaxis, :, :]
                    grad = grads[i][np.newaxis, :, :]

                    self.plot_colormaps(
                        grad,
                        feature, 
                        filepath=f"{self.save_path}{method}/colormap/target_{target}_idx_{i:04d}.png",
                        color="blue" if target == 0 else "red"
                    )

                    self.plot_histograms(
                        grad, 
                        filepath=f"{self.save_path}{method}/barchart/target_{target}_idx_{i:04d}.png",
                        percentile=percentile
                    )

    def plot_colormaps(self, grads, features, filepath, color):

        # plot colormaps
        fig, axs = plt.subplots(1, 1, figsize=(13, 5))
        _ = viz.visualize_image_attr(
            np.transpose(grads, (2, 1, 0)),
            np.transpose(features, (2, 1, 0)),
            method="heat_map",
            cmap="inferno",
            show_colorbar=False,
            plt_fig_axis=(fig, axs),
            use_pyplot=False,
        )
        plt.savefig(
            filepath,
            format="png",
            dpi=300,
            bbox_inches="tight",
        )
        plt.close()

    def plot_histograms(self, grads, filepath, percentile=0.9, color="blue"):
        # bar charts: summarizes gradients at each component
        data = np.sum(grads, axis=(0, 1))
        sorted_indices = np.argsort(np.abs(data))
        percentile = int(sorted_indices.shape[0] * percentile)
        data[sorted_indices[:-percentile]] = 0
        plt.bar(
            range(grads.shape[2]),
            data,
            align="center",
            color=color,
        )
        plt.xlim([0, grads.shape[2]])
        plt.grid(False)

        plt.savefig(
            filepath,
            format="png",
            dpi=300,
            bbox_inches="tight",
        )
        plt.close()
    
    def get_grads(self, method, features, target):
        """Returns gradients according to method"""
        if method == "saliency":
            saliency = Saliency(self.model)
            self.model.zero_grad()
            grads = saliency.attribute(features, target=target)
        elif method == "ig":
            ig = IntegratedGradients(self.model)
            self.model.zero_grad()
            grads, _ = ig.attribute(
                inputs=features,
                target=target,
                baselines=torch.zeros_like(features),
                return_convergence_delta=True,
            )
        elif method == "ignt":
            ig = IntegratedGradients(self.model)
            nt = NoiseTunnel(ig)
            self.model.zero_grad()
            grads, _ = nt.attribute(
                inputs=features,
                target=target,
                baselines=torch.zeros_like(features),
                return_convergence_delta=True,
                nt_type="smoothgrad_sq",
                nt_samples=5,
                stdevs=0.2,
            )
        else:
            raise ValueError(f"'{method}' methods is not recognized")

        return grads

# Do stuff

In [156]:
path = "./assets/logs/rerun_all-exp-mlp_defHP-fbirn"
model, data = load_model_and_data(path)

dataset = data["test"].dataset
features = [sample[0] for sample in dataset]
labels = [sample[1] for sample in dataset]

load_only_test = True
if not load_only_test:
    dataset = data["test"].dataset
    for sample in dataset:
        features.append(sample[0])
        labels.append(sample[1])

features = torch.stack(features)
labels = torch.stack(labels)

k  3
trial  7
Loaded model config:
dropout: 0.11
hidden_size: 150
num_layers: 0
lr: 0.00027
input_size: 53
output_size: 2



In [175]:
methods = ["saliency", "ig", "ignt"]
introspector = Introspector(model=model, features=features, labels=labels, methods=methods, save_path="./introspection/")
introspector.run()

	Plotting generalized saliency maps using 'saliency' with target 0
	Plotting single sample saliency maps using 'saliency' with target 0
	Plotting generalized saliency maps using 'saliency' with target 1
	Plotting single sample saliency maps using 'saliency' with target 1
	Plotting generalized saliency maps using 'ig' with target 0
	Plotting single sample saliency maps using 'ig' with target 0
	Plotting generalized saliency maps using 'ig' with target 1
	Plotting single sample saliency maps using 'ig' with target 1
	Plotting generalized saliency maps using 'ignt' with target 0
	Plotting single sample saliency maps using 'ignt' with target 0
	Plotting generalized saliency maps using 'ignt' with target 1
	Plotting single sample saliency maps using 'ignt' with target 1
