In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
from torch.utils.data import DataLoader, TensorDataset

import numpy as np
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, auc, precision_recall_curve
from tqdm import tqdm
from tensorboardX import SummaryWriter

import os
import pprint
import argparse
import time
import json
import timm
from functools import partial

   
from PIL import Image
import sys
import cv2

In [2]:
# dataset and models
from dataset import ChexpertSmall, extract_patient_ids
from torchvision.models import densenet121, resnet152
#from models.efficientnet import construct_model
#from models.attn_aug_conv import DenseNet, ResNet, Bottleneck
from vit_pytorch import ViT

CheXpert-v1.0-small


In [3]:
!pip install tensorboardX

Looking in indexes: https://pypi.org/simple, https://pip.repos.neuron.amazonaws.com


In [4]:
def grad_cam(model, x, hooks, cls_idx=None):
    """ cf CheXpert: Test Results / Visualization; visualize final conv layer, using grads of final linear layer as weights,
    and performing a weighted sum of the final feature maps using those weights.
    cf Grad-CAM https://arxiv.org/pdf/1610.02391.pdf """

    model.eval()
    model.zero_grad()

    # register backward hooks
    conv_features, linear_grad = [], []
    forward_handle = hooks['forward'].register_forward_hook(lambda module, in_tensor, out_tensor: conv_features.append(out_tensor))
    backward_handle = hooks['backward'].register_backward_hook(lambda module, grad_input, grad_output: linear_grad.append(grad_input))

    # run model forward and create a one hot output for the given cls_idx or max class
    outputs = model(x)
    if not cls_idx: cls_idx = outputs.argmax(1)
    one_hot = F.one_hot(cls_idx, outputs.shape[1]).float().requires_grad_(True)

    # run model backward
    one_hot.mul(outputs).sum().backward()

    # compute weights; cf. Grad-CAM eq 1 -- gradients flowing back are global-avg-pooled to obtain the neuron importance weights
    weights = linear_grad[0][2].mean(1).view(1, -1, 1, 1)
    # compute weighted combination of forward activation maps; cf Grad-CAM eq 2; linear combination over channels
    cam = F.relu(torch.sum(weights * conv_features[0], dim=1, keepdim=True))

    # normalize each image in the minibatch to [0,1] and upscale to input image size
    cam = cam.clone()  # avoid modifying tensor in-place
    with torch.no_grad():
        def norm_ip(t, min, max):
            t.clamp_(min=min, max=max)
            t.add_(-min).div_(max - min + 1e-5)

        for t in cam:  # loop over mini-batch dim
            norm_ip(t, float(t.min()), float(t.max()))

    cam = F.interpolate(cam, x.shape[2:], mode='bilinear', align_corners=True)

    # cleanup
    forward_handle.remove()
    backward_handle.remove()
    model.zero_grad()

    return cam

In [5]:
def visualize_grad_cam(model, dataloader, grad_cam_hooks, output_dir):
    attr_names = dataloader.dataset.attr_names
    # 1. run through model to compute logits and grad-cam
    imgs, labels, scores, masks, idxs = [], [], [], [], []
    for x, target, idx in dataloader:
        imgs += [x]
        labels += [target]
        idxs += idx.tolist()
        x = x.to(device)
        scores += [model(x).cpu()]
        masks  += [grad_cam(model, x, grad_cam_hooks).cpu()]
    imgs, labels, scores, masks = torch.cat(imgs), torch.cat(labels), torch.cat(scores), torch.cat(masks)

    # 2. renormalize images and convert everything to numpy for matplotlib
    imgs.mul_(0.0349).add_(0.5330)
    imgs = imgs.permute(0,2,3,1).data.numpy()
    labels = labels.data.numpy()
    patient_ids = extract_patient_ids(dataloader.dataset, idxs)
    masks = masks.permute(0,2,3,1).data.numpy()
    probs = scores.sigmoid().data.numpy()

    # 3. make column grid of [model probs table, original image, grad-cam image] for each attr + other categories
    for attr, vis_idxs in zip(dataloader.dataset.vis_attrs, dataloader.dataset.vis_idxs):
        fig, axs = plt.subplots(3, 3, figsize=(4 * imgs.shape[1]/100, 3.3 * imgs.shape[2]/100), dpi=100, frameon=False)
        fig.suptitle(attr)
        for i, idx in enumerate(vis_idxs):
            offset = idxs.index(idx)
            visualize_one(model, imgs[offset], masks[offset], labels[offset], patient_ids[offset], probs[offset], attr_names, axs[i], output_dir)

        filename = 'vis_{}_step_{}.png'.format(attr.replace(' ', '_'), 100)
        plt.savefig(os.path.join(output_dir, 'vis', filename), dpi=100)
        plt.close()

In [6]:
def visualize_att_rollout(model, dataloader, output_dir):
    attr_names = dataloader.dataset.attr_names
    # 1. run through model to compute logits and grad-cam
    imgs, labels, scores, masks, idxs = [], [], [], [], []
    for x, target, idx in dataloader:
        imgs += [x]
        labels += [target]
        idxs += idx.tolist()
        x = x.to(device)
        scores += [model(x).cpu()]
        attention_rollout = VITAttentionRollout(model, head_fusion='max', discard_ratio=0.9)
        masks += [attention_rollout(x)]
    imgs, labels, scores, masks = torch.cat(imgs), torch.cat(labels), torch.cat(scores), torch.cat(masks)

    # 2. renormalize images and convert everything to numpy for matplotlib
    imgs.mul_(0.0349).add_(0.5330)
    imgs = imgs.permute(0,2,3,1).data.numpy()
    labels = labels.data.numpy()
    patient_ids = extract_patient_ids(dataloader.dataset, idxs)
    masks = masks.permute(0,2,3,1).data.numpy()
    probs = scores.sigmoid().data.numpy()

    # 3. make column grid of [model probs table, original image, grad-cam image] for each attr + other categories
    for attr, vis_idxs in zip(dataloader.dataset.vis_attrs, dataloader.dataset.vis_idxs):
        fig, axs = plt.subplots(3, 3, figsize=(4 * imgs.shape[1]/100, 3.3 * imgs.shape[2]/100), dpi=100, frameon=False)
        fig.suptitle(attr)
        for i, idx in enumerate(vis_idxs):
            offset = idxs.index(idx)
            visualize_one(model, imgs[offset], masks[offset], labels[offset], patient_ids[offset], probs[offset], attr_names, axs[i], output_dir)

        filename = 'vis_{}_step_{}.png'.format(attr.replace(' ', '_'), 100)
        plt.savefig(os.path.join(output_dir, 'vis', filename), dpi=100)
        plt.close()

In [7]:
def visualize_one(model, img, mask, label, patient_id, prob, attr_names, axs, output_dir):
    """ display [table of model vs ground truth probs | original image | grad-cam mask image] in a given suplot axs """
    # sort data by prob high to low
    sort_idxs = prob.argsort()[::-1]
    label = label[sort_idxs]
    prob = prob[sort_idxs]
    names = [attr_names[i] for i in sort_idxs]
    # 1. left -- show table of ground truth and predictions, sorted by pred prob high to low
    axs[0].set_title(patient_id)
    data = np.stack([label, prob.round(3)]).T
    axs[0].table(cellText=data, rowLabels=names, colLabels=['Ground truth', 'Pred. prob'],
                 rowColours=plt.cm.Greens(0.5*label),
                 cellColours=plt.cm.Greens(0.5*data), cellLoc='center', loc='center')
    axs[0].axis('tight')
    # 2. middle -- show original image
    axs[1].set_title('Original image', fontsize=10)
    axs[1].imshow(img.squeeze(), cmap='gray')
    # 3. right -- show heatmap over original image with predictions
    axs[2].set_title('Top class activation \n{}: {:.4f}'.format(names[0], prob[0]), fontsize=10)
    axs[2].imshow(img.squeeze(), cmap='gray')
    axs[2].imshow(mask.squeeze(), cmap='jet', alpha=0.5)

    for ax in axs: ax.axis('off')

In [8]:
def vis_attn(x, patient_ids, idxs, attn_layers, output_dir, batch_element=0):
    H, W = x.shape[2:]
    nh = attn_layers[0].nh

    # select which pixels to visualize -- e.g. select virtices of a center square of side 1/3 of the image dims
    pix_to_vis = lambda h, w: [(h//3, w//3), (h//3, int(2*w/3)), (int(2*h/3), w//3), (int(2*h/3), int(2*w/3))]
    window = 30  # take mean attn around the pix_to_vis in a window of size ws

    for j, l in enumerate(attn_layers):
        # visualize attention maps (rows for each head; columns for each pixel)
        fig, axs = plt.subplots(nh+1, 4, figsize=(3,3/4*(1+nh)), frameon=False)
        fig.suptitle(patient_ids[batch_element], fontsize=8)
        # display target image; highlight pixel
        for ax, (ph, pw) in zip(axs[0], pix_to_vis(H,W)):
            image = x.clone().detach().mul_(0.0349).add_(0.5330)  # renormalize
            image[:,:,ph-window:ph+window,pw-window:pw+window] = torch.tensor([1., 215/255, 0]).view(1,3,1,1)   # add yellow pixel on the pix_to_vis for visualization
            ax.imshow(image[batch_element].permute(1,2,0).numpy())
            ax.axis('off')
        # display attention maps
        # get attention weights tensor for the batch element
        attn = l.weights.data[batch_element]
        # reshape attn tensor and select the pixels to visualize
        h = w = int(np.sqrt(attn.shape[-1]))
        ws = max(1, int(window * h/H))  # scale window to feature map size
        attn = attn.reshape(nh, h, w, h, w)
        for i, (ph, pw) in enumerate(pix_to_vis(h,w)):
            for h in range(nh):
                axs[h+1, i].imshow(attn[h, ph-ws:ph+ws, pw-ws:pw+ws, :, :].mean([0,1]).cpu().numpy())
                axs[h+1, i].axis('off')


        filename = 'attn_image_idx_{}_{}_layer_{}.png'.format(idxs[batch_element], batch_element, j)
        fig.subplots_adjust(0,0,1,0.95,0.05,0.05)
        plt.savefig(os.path.join(output_dir, 'vis', filename))
        plt.close()
    

In [9]:
def fetch_dataloader(resize, mode, mini_data):
    assert mode in ['train', 'valid', 'vis']
    data_path = '../'
    batch_size = 1
    transforms = T.Compose([
        T.Resize(resize) if resize else T.Lambda(lambda x: x),
        T.CenterCrop(320 if not resize else resize),
        lambda x: torch.from_numpy(np.array(x, copy=True)).float().div(255).unsqueeze(0),   # tensor in [0,1]
        T.Normalize(mean=[0.5330], std=[0.0349]),                                           # whiten with dataset mean and st
        lambda x: x.expand(3,-1,-1)
#        T.Resize((args.resize, args.resize)),
#        T.RandomHorizontalFlip(),
#        T.ToTensor(),
        ])                                                       # expand to 3 channels

    dataset = ChexpertSmall(data_path, mode, transforms, mini_data)

    return DataLoader(dataset, batch_size, shuffle=(mode=='train'), pin_memory=(device.type=='cuda'),
                      num_workers=0 if mode=='valid' else 16)  # since evaluating the valid_dataloader is called inside the
                                                              # train_dataloader loop, 0 workers for valid_dataloader avoids
                                                              # forking (cf torch dataloader docs); else memory sharing gets clunky


In [10]:
n_classes = len(ChexpertSmall.attr_names)
print(n_classes)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
vis_dataloader = fetch_dataloader(resize=224, mode='vis', mini_data= 200)

5
../CheXpert-v1.0-small/valid.csv


  cpuset_checked))


In [11]:
model = densenet121().to(device)

In [12]:
model.classifier = nn.Linear(model.classifier.in_features, out_features=n_classes).to(device)
grad_cam_hooks = {'forward': model.features, 'backward': model.classifier}


In [14]:
output_dir = './results/'
visualize_grad_cam(model, vis_dataloader, grad_cam_hooks, output_dir)



In [39]:
def rollout(attentions, discard_ratio, head_fusion):
    print(attentions)
    result = torch.eye(attentions[0].size(-1))
    with torch.no_grad():
        for attention in attentions:
            if head_fusion == "mean":
                attention_heads_fused = attention.mean(axis=1)
            elif head_fusion == "max":
                attention_heads_fused = attention.max(axis=1)[0]
            elif head_fusion == "min":
                attention_heads_fused = attention.min(axis=1)[0]
            else:
                raise "Attention head fusion type Not supported"

            # Drop the lowest attentions, but
            # don't drop the class token
            flat = attention_heads_fused.view(attention_heads_fused.size(0), -1)
            _, indices = flat.topk(int(flat.size(-1)*discard_ratio), -1, False)
            indices = indices[indices != 0]
            flat[0, indices] = 0

            I = torch.eye(attention_heads_fused.size(-1))
            a = (attention_heads_fused + 1.0*I)/2
            a = a / a.sum(dim=-1)

            result = torch.matmul(a, result)

    # Look at the total attention between the class token,
    # and the image patches
    mask = result[0, 0 , 1 :]
    # In case of 224x224 image, this brings us from 196 to 14
    width = int(mask.size(-1)**0.5)
    mask = mask.reshape(width, width)
    mask = mask / torch.max(mask)
    return mask    

class VITAttentionRollout:
    def __init__(self, model, attention_layer_name='attend', head_fusion="mean",
        discard_ratio=0.9):
        self.model = model.to(device)
        self.head_fusion = head_fusion
        self.discard_ratio = discard_ratio
        for name, module in self.model.named_modules():
            if attention_layer_name in name:
                module.register_forward_hook(self.get_attention)

        self.attentions = []

    def get_attention(self, module, input, output):
        self.attentions.append(output.cpu())

    def __call__(self, input_tensor):
        self.attentions = []
        with torch.no_grad():
            output = self.model(input_tensor.to(device))

        return rollout(self.attentions, self.discard_ratio, self.head_fusion)

In [16]:
def grad_rollout(attentions, gradients, discard_ratio):
    result = torch.eye(attentions[0].size(-1))
    with torch.no_grad():
        for attention, grad in zip(attentions, gradients):                
            weights = grad
            attention_heads_fused = (attention*weights).mean(axis=1)
            attention_heads_fused[attention_heads_fused < 0] = 0

            # Drop the lowest attentions, but
            # don't drop the class token
            flat = attention_heads_fused.view(attention_heads_fused.size(0), -1)
            _, indices = flat.topk(int(flat.size(-1)*discard_ratio), -1, False)
            #indices = indices[indices != 0]
            flat[0, indices] = 0

            I = torch.eye(attention_heads_fused.size(-1))
            a = (attention_heads_fused + 1.0*I)/2
            a = a / a.sum(dim=-1)
            result = torch.matmul(a, result)
    
    # Look at the total attention between the class token,
    # and the image patches
    mask = result[0, 0 , 1 :]
    # In case of 224x224 image, this brings us from 196 to 14
    width = int(mask.size(-1)**0.5)
    mask = mask.reshape(width, width).numpy()
    mask = mask / np.max(mask)
    return mask    

class VITAttentionGradRollout:
    def __init__(self, model, attention_layer_name='attend', discard_ratio=0.9):
        self.model = model
        self.discard_ratio = discard_ratio
        for name, module in self.model.named_modules():
            if attention_layer_name in name:
                module.register_forward_hook(self.get_attention)
                module.register_backward_hook(self.get_attention_gradient)

        self.attentions = []
        self.attention_gradients = []

    def get_attention(self, module, input, output):
        self.attentions.append(output.cpu())

    def get_attention_gradient(self, module, grad_input, grad_output):
        self.attention_gradients.append(grad_input[0].cpu())

    def __call__(self, input_tensor, category_index):
        self.model.zero_grad()
        output = self.model(input_tensor)
        category_mask = torch.zeros(output.size())
        category_mask[:, category_index] = 1
        loss = (output*category_mask).sum()
        loss.backward()

        return grad_rollout(self.attentions, self.attention_gradients,
            self.discard_ratio)

In [52]:
def fetch_dataloader(resize, mode, mini_data):
    assert mode in ['train', 'valid', 'vis']
    data_path = '../'
    batch_size = 16
    transforms = T.Compose([
        T.Resize(resize) if resize else T.Lambda(lambda x: x),
        T.CenterCrop(320 if not resize else resize),
        lambda x: torch.from_numpy(np.array(x, copy=True)).float().div(255).unsqueeze(0),   # tensor in [0,1]
        T.Normalize(mean=[0.5330], std=[0.0349]),                                           # whiten with dataset mean and st
#         lambda x: x.expand(3,-1,-1)
#        T.Resize((args.resize, args.resize)),
#        T.RandomHorizontalFlip(),
#        T.ToTensor(),
        ])                                                       # expand to 3 channels

    dataset = ChexpertSmall(data_path, mode, transforms, mini_data)

    return DataLoader(dataset, batch_size, shuffle=(mode=='train'), pin_memory=(device.type=='cuda')
                      num_workers=0 if mode=='valid' else 16)  # since evaluating the valid_dataloader is called inside the
                                                              # train_dataloader loop, 0 workers for valid_dataloader avoids
                                                              # forking (cf torch dataloader docs); else memory sharing gets clunky


In [53]:
n_classes = len(ChexpertSmall.attr_names)
print(n_classes)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
vis_dataloader = fetch_dataloader(resize=256, mode='vis', mini_data= 20)

5
cuda
../CheXpert-v1.0-small/valid.csv


  cpuset_checked))


In [57]:
output_dir = './results/2022-05-27_02-52-52'
model = ViT(image_size = 256, patch_size = 16, num_classes = 5, dim = 512, depth = 6, heads = 8, channels = 1, mlp_dim = 1024).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=0)
scheduler = None
visualize_att_rollout(model, vis_dataloader, output_dir)

OSError: [Errno 12] Cannot allocate memory

In [None]:
print(model.transformer.layers[0][0].fn.attend)

In [50]:
class attention_hooks():
    def __init__(self, model, attention_layer_name='attend'):
        self.model = model
        for name, module in self.model.named_modules():
            if attention_layer_name in name:
                module.register_forward_hook(self.get_attention)

        self.attentions = []

    def get_attention(self, module, input, output):
        self.attentions.append(output.cpu())

    def __call__(self):
        self.attentions = []
        return self.attentions

In [56]:
attn_hooks = attention_hooks(model)
for x, _, idxs in vis_dataloader:
    model(x.to(device))
    patient_ids = extract_patient_ids(vis_dataloader.dataset, idxs)
    # visualize stored attention weights for each image
    print(attn_hooks())
    for i in range(len(x)): vis_attn(x, patient_ids, idxs, attn_hooks(), args, i)


OSError: [Errno 12] Cannot allocate memory

In [55]:
torch.cuda.empty_cache()