import matplotlib matplotlib.use('agg') import os import os.path import argparse import torch import torchvision import torch.nn as nn from torchvision import datasets, transforms from torchvision.utils import make_grid from torch.autograd import Variable from torch.utils.data.sampler import Sampler import torch.backends.cudnn as cudnn from tqdm import tqdm, trange import numpy as np import pdb import torch.nn.functional as F #import visdom import csv from model.resnet import ResNet18, ResNet34, ResNet50, ResNet101, ResNet152 from model.wide_resnet import WideResNet from utils.lr_scheduler import ReduceLROnPlateau, ExponentialLR, StepLR, MultiStepLR from utils.validation_set_split import validation_split from utils.transforms import HolePunch import matplotlib.pyplot as plt import seaborn as sns model_names = ['resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', 'wideresnet'] dataset_names = ['cifar10', 'cifar100', 'svhn'] parser = argparse.ArgumentParser(description='CNN') parser.add_argument('--dataset', '-d', metavar='D', default='cifar10', choices=dataset_names) parser.add_argument('--batch_size', type=int, default=128, metavar='N', help='input batch size for training (default: 128)') parser.add_argument('--no-cuda', action='store_true', default=False, help='enables CUDA training') parser.add_argument('--seed', type=int, default=0, metavar='S', help='random seed (default: 1)') parser.add_argument('--sample', type=int, default=0, metavar='S', help='sample to plot') args = parser.parse_args() args.cuda = not args.no_cuda and torch.cuda.is_available() cudnn.benchmark = True # Should make training should go faster for large models torch.manual_seed(args.seed) if args.cuda: torch.cuda.manual_seed(args.seed) def conv3x3(in_planes, out_planes, stride=1): return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) class BasicBlock(nn.Module): expansion = 1 def __init__(self, in_planes, planes, stride=1): super(BasicBlock, self).__init__() self.conv1 = conv3x3(in_planes, planes, stride) self.bn1 = nn.BatchNorm2d(planes) self.conv2 = conv3x3(planes, planes) self.bn2 = nn.BatchNorm2d(planes) self.shortcut = nn.Sequential() if stride != 1 or in_planes != self.expansion*planes: self.shortcut = nn.Sequential( nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), nn.BatchNorm2d(self.expansion*planes) ) def forward(self, x): out = F.relu(self.bn1(self.conv1(x))) out = self.bn2(self.conv2(out)) out += self.shortcut(x) out = F.relu(out) return out class ResNet(nn.Module): def __init__(self, block, num_blocks, num_classes=10): super(ResNet, self).__init__() self.in_planes = 64 self.conv1 = conv3x3(3,64) self.bn1 = nn.BatchNorm2d(64) self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) self.linear = nn.Linear(512*block.expansion, num_classes) def _make_layer(self, block, planes, num_blocks, stride): strides = [stride] + [1]*(num_blocks-1) layers = [] for stride in strides: layers.append(block(self.in_planes, planes, stride)) self.in_planes = planes * block.expansion return nn.Sequential(*layers) def forward(self, x): out = F.relu(self.bn1(self.conv1(x))) out1 = self.layer1(out) out2 = self.layer2(out1) out3 = self.layer3(out2) out4 = self.layer4(out3) out5 = F.avg_pool2d(out4, 4) out = out5.view(out5.size(0), -1) out = self.linear(out) return [out1, out2, out3, out4, out5] # Image Preprocessing if args.dataset == 'svhn': normalize = transforms.Normalize(mean=[x / 255.0 for x in[109.9, 109.7, 113.8]], std=[x / 255.0 for x in [50.1, 50.6, 50.8]]) else: normalize = transforms.Normalize(mean=[x / 255.0 for x in [125.3, 123.0, 113.9]], std=[x / 255.0 for x in [63.0, 62.1, 66.7]]) test_transform = transforms.Compose([ transforms.ToTensor(), normalize]) if args.dataset == 'cifar10': num_classes = 10 test_dataset = datasets.CIFAR10(root='../../data/', train=False, transform=test_transform, download=True) elif args.dataset == 'cifar100': num_classes = 100 test_dataset = datasets.CIFAR100(root='../../data/', train=False, transform=test_transform, download=True) elif args.dataset == 'svhn': num_classes = 10 test_dataset = datasets.SVHN(root='../../data/', split='test', transform=test_transform, download=True) # Data Loader (Input Pipeline) test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=2) def test(loader, sample=7): cnn.eval() # Change model to 'eval' mode (BN uses moving mean/var). max_activations = [[], [], [], [], []] for images, labels in loader: if args.dataset == 'svhn': labels = labels.type_as(torch.LongTensor()).view(-1) - 1 images = Variable(images, volatile=True).cuda() feature_maps = cnn(images) for i, feature_map in enumerate(feature_maps): max_activations[i].append(feature_map.max(2)[0].max(3)[0].view(feature_map.size(0), -1)[sample: sample + 1]) # shape = [batch, n_feature_maps] #max_activations[i].append(feature_map.mean(2).mean(3).view(feature_map.size(0), -1)) break for i in range(len(max_activations)): max_activations[i] = torch.cat(max_activations[i]) # shape = [batch, n_feature_maps] max_activations[i] = torch.sort(max_activations[i], descending=True, dim=1)[0] max_activations[i] = torch.mean(max_activations[i], dim=0).view(-1) # shape = [n_feature_maps] max_activations[i] = torch.sort(max_activations[i], descending=True, dim=0)[0] max_activations[i] = max_activations[i].data.cpu().numpy() cnn.train() return max_activations cnn = ResNet(BasicBlock, [2, 2, 2, 2], 10) cnn = torch.nn.DataParallel(cnn) cnn.load_state_dict(torch.load('checkpoints/baseline_86.pt')) cnn = cnn.cuda() max_activations = test(test_loader, args.sample) for i in range(1, 4): plt.figure(str(i)) #ax = sns.distplot(max_activations[i], label='Cutout', color=sns.color_palette('deep')[0], hist=False, kde_kws={"shade": True}) #ax = sns.barplot(x=range(len(max_activations[i])), y=max_activations[i], label='Cutout', color=sns.color_palette('deep')[0]) ax = plt.bar(left=range(len(max_activations[i])), height=max_activations[i], label='Cutout', color=sns.color_palette('deep')[0], width=1.0, alpha=0.8) cnn.load_state_dict(torch.load('checkpoints/baseline_84.pt')) cnn = cnn.cuda() max_activations = test(test_loader, args.sample) for i in range(1, 4): plt.figure(str(i)) #ax = sns.distplot(max_activations[i], label='Baseline', color=sns.color_palette('deep')[2], hist=False, kde_kws={"shade": True}) #ax = sns.barplot(x=range(len(max_activations[i])), y=max_activations[i], label='Baseline', color=sns.color_palette('deep')[2]) ax = plt.bar(left=range(len(max_activations[i])), height=max_activations[i], label='Baseline', color=sns.color_palette('deep')[2], width=1.0, alpha=0.8) plt.tick_params(labelsize=12) plt.xlim(xmin=0) plt.ylabel("Magnitude of activation", fontsize=16) plt.xlabel("Feature activations (sorted by magnitude)", fontsize=16) plt.legend(fontsize=14) plt.savefig('conv_block_' + str(i + 1) + '_activations_' + str(args.sample) + '.pdf') plt.show()