In [4]:
from collections import OrderedDict

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.models as models
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

import matplotlib.pyplot as plt
from tqdm import tqdm

from utils.config import opt
from data.dataset import TestDataset

# 1. Model

In [3]:
class ConvNet(nn.Module):
    def __init__(self):
        super(ConvNet, self).__init__()
        self.extraction_point = [15, 22, 29]
        self.extractor = load_vgg16_extractor()
        
    def forward(self, x):
        features, indices, in_shapes = {}, {}, {}
        for i, layer in enumerate(self.extractor):
            if i in self.extraction_point:
                x = layer(x)
                features[i] = x
            elif isinstance(layer, nn.MaxPool2d):
                in_shapes[i] = x.shape[2:]
                x, idx = layer(x)
                indices[i] = idx
            else:
                x = layer(x)
        
        return features, indices, in_shapes

def load_vgg16_extractor():
    model = models.vgg16()
    features = list(model.features)[:-1]
    for i in range(len(features)):
        if isinstance(features[i], nn.MaxPool2d):
            features[i] = nn.MaxPool2d(2, stride=2, return_indices=True)
        for p in features[i].parameters():
            p.requires_grad = False
    return nn.Sequential(*features)

class DeconvNet(nn.Module):
    def __init__(self):
        super(DeconvNet, self).__init__()
        self.extractor = load_vgg16_extractor_reverse()
    
    def forward(self, features, indices, in_shapes):
        remapped = {}
        for key, feature in features.items():
            for i, layer in reversed(list(enumerate(self.extractor[:key + 1]))):
                if isinstance(layer, nn.MaxUnpool2d):
                    feature = layer(feature, indices[i], output_size=in_shapes[i])
                else:
                    feature = layer(feature)
            remapped[key] = feature
        return remapped

def load_vgg16_extractor_reverse():
    model = models.vgg16()
    features = list(model.features)[:-1]
    for i in range(len(features)):
        if isinstance(features[i], nn.MaxPool2d):
            features[i] = nn.MaxUnpool2d(2, stride=2)
        elif isinstance(features[i], nn.Conv2d):
            features[i] = nn.ConvTranspose2d(features[i].out_channels,
                                            features[i].in_channels,
                                            3, 1, padding=1, bias=False)
        for p in features[i].parameters():
            p.requires_grad = False
    return nn.Sequential(*features)

# 2. Utility functions

In [5]:
def inverse_normalize(img):
    tensor = img.clone()
    mean = [0.485, 0.456, 0.406]
    std = [0.229, 0.224, 0.225]
    for t, m, s in zip(tensor[0], mean, std):
        t.mul_(s).add_(m)
    return tensor

def get_most_activated(feature, n_pixel):
    dim3_idx = feature.flatten(start_dim=2).argsort(descending=True)
    dim3_idx = dim3_idx[0, :, n_pixel:]
    dim2_idx = [[i] * dim3_idx.shape[1] for i in range(feature.shape[1])]
    feature.flatten(start_dim=2)[0, dim2_idx, dim3_idx] = 0
    return feature

# 3. Load pretrained model & test data

## 3.1. Faster R-CNN

In [6]:
# load state dictionary for convnet
state_dict = torch.load('./frcnn.pth')
temp = OrderedDict()
for key, item in state_dict.items():
    lst = key.split('.')
    if 'extractor' in lst:
        temp[key] = item
        
# construct vgg16
convnet = ConvNet()
convnet.load_state_dict(temp)

temp = OrderedDict()
for key, item in state_dict.items():
    lst = key.split('.')
    if 'extractor' in lst and 'weight' in lst:
        temp[key] = item

deconvnet = DeconvNet()
deconvnet.load_state_dict(temp)

<All keys matched successfully>

## 3.2. Test dataset

In [7]:
test_set = TestDataset(opt)
test_loader = DataLoader(test_set, batch_size=1, num_workers=2, shuffle=False, pin_memory=True)

# 4. Feature extraction

In [9]:
writer = SummaryWriter()
for i, data in tqdm(enumerate(test_loader)):
    if i == 29:
        break
    
    if i not in [0, 6, 16, 21, 28]:
        continue
    
    img, _, _, _, _ = data
    outputs = [inverse_normalize(img)]
    features, indices, in_shapes = convnet(img)
    
#     n_pixel = 1000
#     # get most activated 500 pixels per feature
#     for key in features.keys():
#         features[key] = get_most_activated(features[key], n_pixel)
    features = deconvnet(features, indices, in_shapes)
    # min-max normalization
    for feature in features.values():
        feature = (feature - feature.min()) / (feature.max() - feature.min())
        outputs.append(feature)
    # concatenate all features
    outputs = torch.cat(outputs)
    outputs = torchvision.utils.make_grid(outputs, nrow=2)
    writer.add_image(f'frcnn {i + 1}-th feature all pixel', outputs)

29it [01:20,  2.77s/it]
