In [2]:
from pdb import Restart
import sys
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import torchvision
import torchvision.transforms as transforms
import os
import argparse
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
from tqdm import tqdm
import numpy as np

from captum.attr import LayerGradCam

sys.path.append('../../../dre')
from datasets import make_dataset
from mixuploss import MixupLoss
from networks import ResNet

In [3]:
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
net_base = ResNet()
net_base = nn.DataParallel(net_base)
net_base.module.network.fc = nn.Linear(net_base.module.network.fc.in_features, 10)
net_base.load_state_dict(torch.load('../../ckpts/baseline_model.pth'))
net_base = net_base.to(device)
net_base.eval()

In [None]:
net_ours = ResNet()
net_ours = nn.DataParallel(net_ours)
net_ours.module.network.fc = nn.Linear(net_ours.module.network.fc.in_features, 10)
net_ours.load_state_dict(torch.load('../../ckpts/dre_model.pth'))
net_ours = net_ours.to(device)
net_ours.eval()

In [82]:
transform = transforms.Compose([
            # transforms.Resize((224,224)),
            transforms.RandomResizedCrop(224, scale=(0.7, 1.0)),
            transforms.RandomHorizontalFlip(),
            transforms.ColorJitter(0.3, 0.3, 0.3, 0.3),
            transforms.RandomGrayscale(),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])

transform_orig = transforms.Compose([
            transforms.Resize((224,224)),
            transforms.ToTensor()
        ])

In [83]:
dset38 = torchvision.datasets.ImageFolder('../../../data/terra_incognita/location_38/', transform=transform)
dset43 = torchvision.datasets.ImageFolder('../../../data/terra_incognita/location_43/', transform=transform)
dset46 = torchvision.datasets.ImageFolder('../../../data/terra_incognita/location_46/', transform=transform)
dset100 = torchvision.datasets.ImageFolder('../../../data/terra_incognita/location_100/', transform=transform)

dset38_orig = torchvision.datasets.ImageFolder('../../../data/terra_incognita/location_38/', transform=transform_orig)
dset43_orig = torchvision.datasets.ImageFolder('../../../data/terra_incognita/location_43/', transform=transform_orig)
dset46_orig = torchvision.datasets.ImageFolder('../../../data/terra_incognita/location_46/', transform=transform_orig)
dset100_orig = torchvision.datasets.ImageFolder('../../../data/terra_incognita/location_100/', transform=transform_orig)

In [84]:
class_names = np.sort(dset38.classes)
class_names

array(['bird', 'bobcat', 'cat', 'coyote', 'dog', 'empty', 'opossum',
       'rabbit', 'raccoon', 'squirrel'], dtype='<U8')

In [85]:
bs = 16
imgsz = 224

loader38 = torch.utils.data.DataLoader(dset38, batch_size=bs, shuffle=False, num_workers=8)
loader43 = torch.utils.data.DataLoader(dset43, batch_size=bs, shuffle=False, num_workers=8)
loader46 = torch.utils.data.DataLoader(dset46, batch_size=bs, shuffle=False, num_workers=8)
loader100 = torch.utils.data.DataLoader(dset100, batch_size=bs, shuffle=False, num_workers=8)

loader38_orig = torch.utils.data.DataLoader(dset38_orig, batch_size=bs, shuffle=False, num_workers=8)
loader43_orig = torch.utils.data.DataLoader(dset43_orig, batch_size=bs, shuffle=False, num_workers=8)
loader46_orig = torch.utils.data.DataLoader(dset46_orig, batch_size=bs, shuffle=False, num_workers=8)
loader100_orig = torch.utils.data.DataLoader(dset100_orig, batch_size=bs, shuffle=False, num_workers=8)

In [87]:
layer_gc_base = LayerGradCam(net_base, net_base.module.network.layer4)
layer_gc_ours = LayerGradCam(net_ours, net_ours.module.network.layer4)

In [88]:
def attr_scale(x):
    return (x-x.min()) / (x.max()-x.min())

def explain(image, image_orig, pred_base, pred_ours, count, imgsz=224, target_env='L38', mode='ood'):
    image = image.reshape(1, 3, imgsz, imgsz).to(device)
    image_orig = image_orig.reshape(1, 3, imgsz, imgsz).to(device)

    attr_base = layer_gc_base.attribute(image, target=pred_base)
    attr_ours = layer_gc_ours.attribute(image, target=pred_ours)

    attribution_base = F.interpolate(attr_base, size=imgsz, mode='bilinear').squeeze()
    attribution_ours = F.interpolate(attr_ours, size=imgsz, mode='bilinear').squeeze()

    attribution_base = attr_scale(attribution_base)
    attribution_ours = attr_scale(attribution_ours)

    cmap = mpl.cm.get_cmap('jet', 256)
    heatmap_base = cmap(attribution_base.cpu().detach().numpy(), alpha = 0.5)
    heatmap_ours = cmap(attribution_ours.cpu().detach().numpy(), alpha = 0.5)

    fig, ax = plt.subplots(1, 3, figsize=(4, 2), dpi=200)
    # fig.suptitle('Grad-CAM', fontsize=20)

    ax[0].imshow(image_orig.squeeze().cpu().detach().numpy().transpose(1, 2, 0))
    ax[0].set_title(str(mode) + ' ' + str(target_env))
    ax[0].axis('off')

    ax[1].imshow(image_orig.squeeze().cpu().detach().numpy().transpose(1, 2, 0))
    ax[1].set_title('baseline')
    ax[1].axis('off')
    ax[2].imshow(image_orig.squeeze().cpu().detach().numpy().transpose(1, 2, 0))
    ax[2].set_title('ours')
    ax[2].axis('off')

    ax[1].imshow(heatmap_base)
    ax[1].axis('off')
    ax[2].imshow(heatmap_ours)
    ax[2].axis('off')

    fig.savefig('./erm/{}/{}/{}/{}_{}.jpeg'.format(mode, target_env, str(class_names[int(pred_ours.cpu().detach().numpy())]), str(count), str(class_names[int(pred_ours.cpu().detach().numpy())])))
    plt.close(fig)

In [None]:
# visualize single class

count = 0
it = iter(loader46)
it_orig = iter(loader46_orig)

for batch_idx in tqdm(range(len(it))):
    inputs, targets = next(it)
    inputs_orig, _ = next(it_orig)

    inputs = inputs.to(device)
    targets = targets.to(device)

    outputs_base = net_base(inputs)
    outputs_ours = net_ours(inputs)

    preds_base = outputs_base.argmax(dim=1)
    preds_ours = outputs_ours.argmax(dim=1)

    for i in range(len(inputs)):
        if targets[i] == 2:
            if preds_ours[i] == targets[i]:
                explain(inputs[i], inputs_orig[i], preds_base[i], preds_ours[i], count, target_env='L38', mode='ood')
        count += 1