In [15]:
"""Imports"""
import os, sys, shutil, json
import pandas as pd
import numpy as np
import pickle

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision

import matplotlib.pyplot as plt
import matplotlib.cm as cm

!pip install opencv-python
import cv2
!pip install nopdb
import nopdb

# Custom
import custom_modules as CXR

Defaulting to user installation because normal site-packages is not writeable


In [92]:
"""Set GPU"""
os.environ['CUDA_DEVICE_ORDER']='PCI_BUS_ID'
os.environ['CUDA_VISIBLE_DEVICES']='0' # <------ Be sure to set the right GPU!!!
device='cuda'

In [72]:
"""Config"""
with open('/cis/home/zmurphy/code/transformer-radiographs/cfg.json'.replace('~', os.path.expanduser('~')), 'r') as f:
    cfg = json.load(f)
    
model_args = {
    'model_state': '/cis/home/zmurphy/code/data/results/final/DeiT_lr0.05_bs16_optSGD_wd1e-05_sch_step_pp3_bp5_trtrain_all.txt_vatest.txt_tfhflip_nllayer_do0.0_1624113464_model.pt',
    'labels_set': 'chexnet-14-standard',
    'labels': cfg['labels_chexnet_14_standard'],
    'n_labels': len(cfg['labels_chexnet_14_standard']),
    'batch_size': 16,
    'data_dir': cfg['data_dir'],
    'dataset': 'nihcxr14',
    'test_file': 'test_correct.txt',
    'use_parallel': 'y',
    'num_workers': 12,
    'img_size': 224,
    'print_batches': False,
    'scratch_dir':'/export/gaon1/data/zmurphy/transformer-cxr',
    'results_dir':'/export/gaon1/data/zmurphy/transformer-cxr/results/final'
}

# ImageNet mean, std
mean = torch.tensor([0.485, 0.456, 0.406])
std = torch.tensor([0.229, 0.224, 0.225])


In [401]:
"""Get model and set state"""
torch.hub.set_dir('base_model_states')
model = torch.hub.load('facebookresearch/deit:main', 'deit_base_patch16_224', pretrained=True)
model.head = nn.Sequential(nn.Linear(in_features=768, out_features=14), nn.Sigmoid())
model.load_state_dict(torch.load(model_args['model_state'], map_location=torch.device('cpu')))

# Test
model.eval()

# Model to GPU
model = model.to(device)

Using cache found in ./export/gaon1/data/zmurphy/transformer-cxr/results/facebookresearch_deit_main





In [361]:
"""Load data"""
dataset_root = os.path.join(model_args['data_dir'], model_args['dataset'])
test_data = CXR.CXRDataset(images_list=os.path.join(dataset_root, model_args['test_file']),
                            dataset=model_args['dataset'],
                            images_dir=os.path.join(dataset_root, 'images'),
                            image_paths=os.path.join(dataset_root, 'image_paths.txt'),
                            labels_file=os.path.join(dataset_root, 'labels.csv'),
                            labels=model_args['labels'],
                            transform='none',
                            op='test',
                            img_size=model_args['img_size'])
testLoader = DataLoader(test_data, batch_size=model_args['batch_size'],
                         pin_memory=True, shuffle=True,
                         num_workers=model_args['num_workers'])


Test set: starting load
Using image path file
Using no transforms
Loaded 271 images


In [402]:
"""Get attention map"""
def predict(x, model):
  with torch.no_grad():
    yhat = model(x)
  return yhat

def get_attention_map(x, model, layer_to_get=-1, cmap=cm.get_cmap('jet',256)):
  # To GPU
  x = x.to(device)
  
  # Get attention wts from each layer
  attn_wts = []
  for layer in range(12):
    with nopdb.capture_call(model.blocks[layer].attn.forward) as attn_call:
      yhat = predict(x,model)
    attn_wts.append(attn_call.locals['attn'][0])
  attn_wts = torch.stack(attn_wts).squeeze(1)
  x = x.to('cpu')

  # Get mean over attention heads for each layer
  attn_wts_mean = torch.mean(attn_wts, 1)
  attn_wts_mean = attn_wts_mean.to('cpu')

  # Attention rollout
  eye = torch.eye(attn_wts_mean.shape[-1])
  attn_rollout = [0.5*attn_wts_mean[0,:,:] + 0.5*eye]
  for layer in range(1,attn_wts_mean.shape[0]):
    attn_rollout_layer = torch.matmul(0.5*attn_wts_mean[layer,:,:] + 0.5*eye, attn_rollout[layer-1])
    attn_rollout.append(attn_rollout_layer)

  # Get attention map for given layer
  grid_size = int(np.sqrt(attn_rollout[layer_to_get].shape[-1]))
  attn_map = attn_rollout[layer_to_get][0,1:].reshape(grid_size, grid_size).numpy() # Get matrix from top row without position token (first position)

  # Scale values and resize
  attn_map = attn_map - attn_map.min()
  attn_map = attn_map/attn_map.max()
  attn_map = cmap(cv2.resize(attn_map, dsize=(224, 224)),alpha = 1)
  
  return attn_map, yhat

In [None]:
"""Loop through input images"""
target_dir='attn_maps'
if os.path.exists(target_dir):
  shutil.rmtree(target_dir)
os.mkdir(target_dir)

for x, y, file in testLoader:
  # For each image in batch
  for i in range(x.shape[0]):
    
    # Get attention map and yhat
    attn_map, yhat = get_attention_map(x[i,:,:].unsqueeze(0),model)

    # Reverse ImageNet mean & std for CXR
    im = x[i,:,:]
    im = im * std[:, None, None] + mean[:, None, None]

    # Plot
    fig, ax = plt.subplots(1,2)
    fig.patch.set_facecolor('white')
    ax[0].imshow(im.numpy().transpose(1, 2, 0))
    #ax[1].imshow(im.numpy().transpose(1, 2, 0))
    ax[1].imshow(attn_map)
    ax[0].axis('off')
    ax[1].axis('off')

    # Get labels for image
    labs = []
    for l in range(14):
      if y[i][l] == 1:
        labs.append(model_args['labels'][l])
    if len(labs) == 0:
      plt.suptitle(file[i][file[i].rfind('/')+1:] + '\n' + 'No findings', y=0.9)
    else:
      plt.suptitle(file[i][file[i].rfind('/')+1:] + '\n' + ', '.join(labs), y=0.9)
      
    # Show plot
    fig.tight_layout()
    plt.savefig(os.path.join(target_dir, file[i][file[i].rfind('/')+1:]))

  # Only do one batch
  break

In [None]:
"""Check ImageNet model"""
from torchvision import transforms
from PIL import Image

torch.hub.set_dir('base_model_states')
#model = torch.hub.load('facebookresearch/deit:main', 'deit_base_patch16_224', pretrained=True) # DeiT-B
#model = torch.hub.load('facebookresearch/deit:main', 'deit_base_distilled_patch16_224', pretrained=True) # Distilled DeiT-B. Change [0,1:] to [0,2:] since distillation adds another token.
#model = torch.hub.load('facebookresearch/dino:main', 'dino_vitb16') # DINO, performs best

tfms = transforms.Compose([
  transforms.ToTensor(),
  transforms.Resize((224,224), transforms.functional.InterpolationMode.BILINEAR),
  transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])

file_name = 'zebra2'
img = Image.open('google_images/'+file_name+'.jpeg')
img = tfms(img)

model = model.to(device)
attn_map, yhat = get_attention_map(img.unsqueeze(0),model, layer_to_get=-1)

# Normalize image
im = img
im = im * std[:, None, None] + mean[:, None, None]

# Plot
fig, ax = plt.subplots(1,2)
fig.patch.set_facecolor('white')
ax[0].imshow(im.numpy().transpose(1, 2, 0))
#ax[1].imshow(im.numpy().transpose(1, 2, 0))
ax[1].imshow(attn_map)
ax[0].axis('off')
ax[1].axis('off')

# Show plot
fig.tight_layout()
plt.savefig(os.path.join('deit_distilled',file_name+'.png'))
s