In [1]:
import os
import cv2
import numpy as np
import torch
from PIL import Image
from torch.autograd import Variable
from torchvision import transforms
import torch.nn.functional as F
import torchvision.models as models
from torchvision import transforms
from torchvision.utils import make_grid, save_image
from network.studentNet import CNN_RIS
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

device = 'cuda' if torch.cuda.is_available() else 'cpu'

class GradCam:
    def __init__(self, model):
        self.model = model.eval()
        self.feature = None
        self.gradient = None

    def save_gradient(self, grad):
        self.gradient = grad

    def __call__(self, x):
        image_size = (x.size(-1), x.size(-2))
        datas = Variable(x)
        heat_maps = []
        for i in range(datas.size(0)):
            img = datas[i].data.cpu().numpy()
            img = img - np.min(img)
            if np.max(img) != 0:
                img = img / np.max(img)

            feature = datas[i].unsqueeze(0)
            for name, module in self.model.named_children():
                feature = module(feature)
                if name == 'dense3':
                    feature.register_hook(self.save_gradient)
                    self.feature = feature
                    break
            classes = F.sigmoid(feature)
            classes = F.avg_pool2d(classes, kernel_size=5).view(classes.size(0), -1)
            one_hot, _ = classes.max(dim=-1)
            self.model.zero_grad()
            one_hot.backward()

            weight = self.gradient.mean(dim=-1, keepdim=True).mean(dim=-2, keepdim=True)
            mask = F.relu((weight * self.feature).sum(dim=1)).squeeze(0)
            mask = cv2.resize(mask.data.cpu().numpy(), image_size)
            mask = mask - np.min(mask)
            if np.max(mask) != 0:
                mask = mask / np.max(mask)
            heat_map = np.float32(cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET))
            cam = heat_map + np.float32((np.uint8(img.transpose((1, 2, 0)) * 255)))
            cam = cam - np.min(cam)
            if np.max(cam) != 0:
                cam = cam / np.max(cam)
            heat_maps.append(transforms.ToTensor()(cv2.cvtColor(np.uint8(255 * cam), cv2.COLOR_BGR2RGB)))
        heat_maps = torch.stack(heat_maps)
        return heat_maps

def load_pretrained_model(model, pretrained_dict):
	model_dict = model.state_dict()
	# 1. filter out unnecessary keys
	pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
	# 2. overwrite entries in the existing state dict
	model_dict.update(pretrained_dict) 
	# 3. load the new state dict
	model.load_state_dict(model_dict)
    
NUM_CLASSES = 7
snet = CNN_RIS(num_classes=NUM_CLASSES).cuda()
files = os.listdir('picture/')
files.sort()

In [2]:
scheckpoint = torch.load('results/RAF_MultiTeacher_OurDiversity_0.0_0.0_KD3/Student_Test_model.t7')
load_pretrained_model(snet, scheckpoint['snet'])

for filename in files:
    IMAGE_NAME = os.path.join('picture/',filename)
    
    pil_img = Image.open(IMAGE_NAME)
    torch_img = transforms.Compose([
        transforms.Resize((44, 44)),
        transforms.ToTensor()
    ])(pil_img).to(device)
    test_image = transforms.Normalize([0.59003043, 0.4573948, 0.40749523], [0.2465465, 0.22635746, 0.22564183])(torch_img)[None]
    
    if torch.cuda.is_available():
        test_image = test_image.cuda()
        snet.cuda()
    grad_cam = GradCam(snet)
    feature_image = grad_cam(test_image).squeeze(dim=0)
    feature_image = transforms.ToPILImage()(feature_image)
    rb1_s, rb2_s, rb3_s, mimic_s, out_s = snet(test_image)
    feature_image.save(os.path.join(IMAGE_NAME.split('.')[0]+'_0000.jpg'), quality = 100)

In [3]:
scheckpoint = torch.load('results/RAF_MultiTeacher_Few-Shot_KD3/Student_Test_model.t7')
load_pretrained_model(snet, scheckpoint['snet'])

for filename in files:
    IMAGE_NAME = os.path.join('picture/',filename)
    
    pil_img = Image.open(IMAGE_NAME)
    torch_img = transforms.Compose([
        transforms.Resize((44, 44)),
        transforms.ToTensor()
    ])(pil_img).to(device)
    test_image = transforms.Normalize([0.59003043, 0.4573948, 0.40749523], [0.2465465, 0.22635746, 0.22564183])(torch_img)[None]
    
    if torch.cuda.is_available():
        test_image = test_image.cuda()
        snet.cuda()
    grad_cam = GradCam(snet)
    feature_image = grad_cam(test_image).squeeze(dim=0)
    feature_image = transforms.ToPILImage()(feature_image)
    rb1_s, rb2_s, rb3_s, mimic_s, out_s = snet(test_image)
    feature_image.save(os.path.join(IMAGE_NAME.split('.')[0]+'_FewShot.jpg'), quality = 100)

In [4]:
scheckpoint = torch.load('results/RAF_MultiTeacher_OurDiversity_0.8_9.0_KD3/Student_Test_model.t7')
load_pretrained_model(snet, scheckpoint['snet'])

for filename in files:
    IMAGE_NAME = os.path.join('picture/',filename)
    
    pil_img = Image.open(IMAGE_NAME)
    torch_img = transforms.Compose([
        transforms.Resize((44, 44)),
        transforms.ToTensor()
    ])(pil_img).to(device)
    test_image = transforms.Normalize([0.59003043, 0.4573948, 0.40749523], [0.2465465, 0.22635746, 0.22564183])(torch_img)[None]
    
    if torch.cuda.is_available():
        test_image = test_image.cuda()
        snet.cuda()
    grad_cam = GradCam(snet)
    feature_image = grad_cam(test_image).squeeze(dim=0)
    feature_image = transforms.ToPILImage()(feature_image)
    rb1_s, rb2_s, rb3_s, mimic_s, out_s = snet(test_image)
    feature_image.save(os.path.join(IMAGE_NAME.split('.')[0]+'_0890.jpg'), quality = 100)