In [None]:
import os
import torch
import torch.nn as nn
import torch.cuda as cuda
import torch.optim as optim
from torchvision import datasets, transforms
import torchvision.utils as vutils

import numpy as np
from PIL import Image
import imageio
import pickle
import matplotlib.pyplot as plt
%matplotlib inline

from discriminator import Discriminator
from generator import Generator
from trainer import Trainer
from config import Config
import utils

In [None]:
# import cv2
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
# Set yo seed.
SEED = 42069

# Set NumPy seed.
np.random.seed(SEED)

# Set PyTorch seed.
torch.manual_seed(SEED)
cuda.manual_seed_all(SEED)

In [None]:
# Create a config object.
config = Config()

In [None]:
generator = Generator(z_dim=config.z_dim, num_classes=config.num_classes, 
                      base_width=config.base_width, 
                      base_filters=config.base_filters, 
                      use_attention=config.use_attention)
    
if config.pretrained:
    generator.load_state_dict(torch.load(config.checkpoint_path 
                                         + 'models/generator_{}.pth'
                                         .format(1499)))
    
generator = generator.to(config.device)
    
discriminator = Discriminator(config.num_classes, 
                              base_filters=config.base_filters, 
                              use_attention=config.use_attention, 
                              use_dropout=config.use_dropout)

if config.pretrained:
    discriminator.load_state_dict(torch.load(config.checkpoint_path 
                                         + 'models/discriminator_{}.pth'
                                         .format(1499)))

discriminator = discriminator.to(config.device)

if config.data_parallel:
    generator = nn.DataParallel(generator)
    discriminator = nn.DataParallel(discriminator)

In [None]:
# Get the dataloaders.
train_dataloader, test_dataloader = utils.get_dataloaders(config.train_root,
                                                          config.test_root,
                                                          batch_size=config.batch_size)

In [None]:
for iteration, (X, y) in enumerate(test_dataloader):
                
    batch_size = int(X.size()[0])

    # Real
    X = torch.FloatTensor(X).to(config.device)
    y = torch.LongTensor(y).to(config.device)
    
    real_score, attn_map = discriminator(X, y, visualize=True)
    
    break

In [None]:
attn_map.size() # B X (64x64) X (32x32)

In [None]:
def deprocess_img(img_tensor):
    
    img = img_tensor.cpu().numpy()
    
    img = np.transpose(img, (1, 2, 0))
    
    img = np.clip((img * 0.5) + 0.5, 0., 1.)
    
    return (img * 255.).astype(np.uint8)

In [None]:
for i in range(batch_size):
    
    image = deprocess_img(fake_X[i, :, :, :])
    
    attention = attn_map.detach().cpu()[i, :, :]
    
    attention = attention.view(64, 64, 32, 32).numpy()
    
    query_locations = [[16, 16], [32, 32], [50, 50]]
    
    f = plt.figure(figsize=(10, 10))
    for j, location in enumerate(query_locations):
        amap = attention[location[0], location[1], :, :]
        amap_i = Image.fromarray(amap).resize((256, 256))
        amap = np.array(amap_i)
        
        amap_i.close()
        
        plt.subplot(1, 3, j+1)
        plt.imshow(image)
        plt.imshow(amap, interpolation="bicubic", cmap='gray', alpha=0.6)
    plt.show()

In [None]:
for i in range(batch_size):
    
    image = deprocess_img(X[i, :, :, :])
    
    attention = attn_map.detach().cpu()[i, :, :]
    
    attention = attention.view(64, 64, 32, 32).numpy()
    
    query_locations = [[16, 16], [32, 32], [50, 50]]
    
    f = plt.figure(figsize=(10, 10))
    for j, location in enumerate(query_locations):
        amap = attention[location[0], location[1], :, :]
        amap_i = Image.fromarray(amap).resize((256, 256))
        amap = np.array(amap_i)
        
        amap_i.close()
        
        plt.subplot(1, 3, j+1)
        plt.imshow(image)
        plt.imshow(amap, interpolation="bicubic", cmap='gray', alpha=0.6)
    plt.show()

In [None]:
idx_to_class = {idx: cls for cls, idx in test_dataloader.dataset.class_to_idx.items()}

In [None]:
for i in range(batch_size):
    
    image = deprocess_img(X[i, :, :, :])
    
    attention = attn_map.detach().cpu()[i, :, :]
    
    attention = attention.view(64, 64, 32, 32).numpy()
    
    query_locations = [[16, 16], [32, 32], [50, 50]]
    
    print(idx_to_class[int(y[i].cpu().numpy())])
    f = plt.figure(figsize=(10, 10))
    for j, location in enumerate(query_locations):
        amap = attention[location[0], location[1], :, :]
        amap_i = Image.fromarray(amap).resize((256, 256))
        amap = np.array(amap_i)
        
        amap_i.close()
        
        plt.subplot(1, 3, j+1)
        plt.imshow(image)
        plt.imshow(amap, interpolation="bicubic", cmap='Greens', alpha=0.6)
        plt.scatter(location[1]*4, location[0]*4)
    plt.show()