# Interpret CNN model

In [None]:
# default_exp interp.visual

In [None]:
# export
from pathlib import Path
from PIL import Image
from typing import Callable, List, Dict
import torch
from matplotlib import pyplot as plt
import numpy as np
import pandas as pd
import json
from ipywidgets import interact_manual, interact

from unpackai.utils import url_2_text

In [None]:
from torchvision.models import resnet18
from torchvision import transforms as tfm

In [None]:
IMAGENET_CLASSES_TEXT = url_2_text(
    "https://gist.githubusercontent.com/yrevar/942d3a0ac09ec9e5eb3a/raw/238f720ff059c1f82f368259d1ca4ffa5dd8f9f5/imagenet1000_clsidx_to_labels.txt")

In [None]:
IMAGENET_CLASSES = eval(IMAGENET_CLASSES_TEXT)

In [None]:
IMAGENET_CLASSES[463]

'bucket, pail'

In [None]:
model = resnet18(pretrained=True, progress=True)

Downloading: "https://download.pytorch.org/models/resnet18-5c106cde.pth" to /Users/salvor/.cache/torch/checkpoints/resnet18-5c106cde.pth


HBox(children=(FloatProgress(value=0.0, max=46827520.0), HTML(value='')))




In [None]:
IMAGES = list(Path("../test/img/Nature/").iterdir())

In [None]:
model(basic_trans(Image.open(IMAGES[0]))[None,...])

torch.Size([1, 1000])

In [None]:
def simple_to_tensor(
        size: int = 224,
        mean_: List[float] = [0.485, 0.456, 0.406],
        std: List[float] = [0.229, 0.224, 0.225],
        img_transforms: List = [],
        tensor_tranforms: List = [],
        return_batch: bool = False, 
    ) -> Callable:
    trans = tfm.Compose([
        tfm.Resize(size),
        *img_transforms,
        tfm.ToTensor(),
        *tensor_tranforms,
        tfm.Normalize(mean=mean_, std=std)
    ])
    if return_batch:
        def to_tensor(path):
            with Image.open(str(path)) as img:
                return trans(img.convert('RGB'))[None,...]
    else:
        def to_tensor(path):
            """
            """
            with Image.open(str(path)) as img:
                return trans(img.convert('RGB'))
    return to_tensor

In [None]:
basic_trans = simple_to_tensor(return_batch=True)

In [None]:
def get_features(self, x):
    # See note [TorchScript super()]
    with torch.no_grad():
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        return x

In [None]:
def to_8b(x):
    min_ = x.min()
    max_ = x.max()
    return (256*(x-min_)/(max_-min_)).astype(np.int8)

In [None]:
def visualize(model, image_path):
    model.cpu()
    img = Image.open(image_path)
    display(img.resize((224,224)))
    tensor = basic_trans(image_path)
    
    with torch.no_grad():
        y_ = model(tensor)[0]
    features = get_features(model, tensor)[0]
    
    significance = y_.argsort().flip(dims=(0,)).numpy()
    
    significance_df = pd.DataFrame(dict(
        name = list(map(IMAGENET_CLASSES.get,significance)),
        idx = significance,
        pred = y_.cpu().reshape(-1).numpy()[significance]
    ))
    display(significance_df)
    
    @interact
    def search_kw(kw = ""):
        sub_df = significance_df[significance_df.name.str.contains(kw)]
        if len(sub_df)<0:
            print(f"no such class '{kw}'")
        elif len(sub_df)==1:
            cls = sub_df.to_dict(orient="records")[0]
            cls_name = cls['name']
            idx = cls['idx']
            print(f"{cls_name}({idx})\tselected")
            with torch.no_grad():
                feature_importance = model.fc.weight.data[
                    idx][:,None,None]*features
                
                feature_rank = feature_importance.sum(-1).sum(-1).argsort().flip(dims=(0,))
            
            colormap = 'plasma'

            fig,(ax1, ax2) = plt.subplots(1,2, figsize=(12,6))
            ax2.imshow(img,alpha=.5,)

            hm = ax2.imshow(Image.fromarray(to_8b(feature_importance[feature_rank[0],...].numpy()),
                                            mode="L").resize((img.height,img.width)),
                            alpha=.5, cmap=colormap)

            ax1.imshow(img,)
            plt.show()
                
        else:
            display(sub_df.head())