In [1]:
import torch
import torchvision
from torch.utils.data import Dataset,DataLoader
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torchvision import transforms, datasets
import time
import os
import copy
import gc
import torch.optim as optim
from torch.optim import lr_scheduler
import PIL
import matplotlib.pyplot as plt

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [3]:
class EmbeddingGenerator(nn.Module):
    def __init__(self, activation_features_channels=200, activation_features_size=(14, 14)):
        """     
        Parameters
        ----------
        activation_features_channels : INT, optional            DESCRIPTION. The default is 512 as in VGG.
        activation_features_size : TYPE, optional            DESCRIPTION. The default is (14, 14) as in VGG.

        Returns
        -------
        None.
        """
        super().__init__()             
        self.embeddings = nn.Sequential( nn.Linear(activation_features_size[0] * activation_features_size[1], 128),
                                        nn.ReLU(),
                                        nn.Linear(128, 64),
                                        nn.ReLU(),
                                        nn.Linear(64, 32),
                                        nn.Tanh(),)
        self.embeddings = nn.Sequential( nn.Linear(activation_features_size[0] * activation_features_size[1], 32),
                                        nn.Tanh(),)
        
    def forward(self, x):
        """ input will be in concept map dimensions, so first flatten it then pass through embedding generator """        
        x = torch.flatten(x, 1)        
        x = self.embeddings(x)
        weights = torch.sum(torch.square(x),dim=-1,keepdim=True)       
        x = F.normalize(x)        
        return x, weights   
    
class Relevance(nn.Module):
    def __init__(self, proto_per_class = 10,embed_out_shape=(32,32)):
        super(Relevance , self).__init__()
        self.rel = nn.Linear(embed_out_shape[0] * proto_per_class, 1)

    def forward(self,inputs):
        relevance = self.rel(inputs)
        return relevance     

In [4]:
class MACENetwork(nn.Module):
    """ MACE module """
    def __init__(self, activation_features_channels = 512 , num_classes = 20, proto_per_class = 10, first_dense_dim = 4096):
        super(MACENetwork , self).__init__()

        self.activation_features_channels = activation_features_channels
        self.num_classes = num_classes
        self.proto_per_class = proto_per_class
        self.num_prototypes = num_classes * proto_per_class

        # concept map generator module
        self.feature_extract = nn.Sequential(nn.Conv2d(in_channels=self.activation_features_channels,out_channels=self.num_prototypes,
                                                       kernel_size=(1,1)), nn.ReLU())
    
        # maps a concept map into a local prototype(embedding generator)
        self.protonet = nn.ModuleList([EmbeddingGenerator(activation_features_size=(1, 1)) for _ in range(self.num_classes)])
    
        # concept relevance module
        self.concept_relevance  = nn.ModuleList([Relevance(self.proto_per_class) for _ in range(self.num_classes)])

        self.softmax = nn.Sigmoid()
        
        # Final Dense Layer (output module)
        self.final_dense = nn.Sequential(nn.Linear(32*self.num_prototypes , first_dense_dim), nn.ReLU())

    def get_relevance_part(self):
        return self.concept_relevance
    
    def forward(self, inputs):
        batch_size = inputs.shape[0]
        # Get concept maps
        concept_maps = self.feature_extract(inputs)
         # for maxpool
        concept_maps = F.max_pool2d(concept_maps, kernel_size=(concept_maps.size()[2], concept_maps.size()[3]))
        concept_maps = torch.unsqueeze(concept_maps, -1)
        concept_maps = torch.unsqueeze(concept_maps, -1)
        concept_maps = torch.unbind(concept_maps, dim = 1) 
        
        # List of lists for prototypes(or embeddings)
        class_prototypes = [[] for _ in range(self.num_classes)]
        class_protoweight = [[] for _ in range(self.num_classes)]
        class_protoweighted = [[] for _ in range(self.num_classes)]
        
        # Get Concepts
        for i, concept_map in enumerate(concept_maps):
            # class index ci
            ci = i//self.proto_per_class
            
            # Pass the concept map to the protonet
            prototype, weight = self.protonet[ci](concept_map)
            class_prototypes[ci].append(prototype)
            class_protoweight[ci].append(weight)

            #if self.prune[ci][i % self.proto_per_class]:
            #    prototype = 0 * prototype
            class_protoweighted[ci].append(prototype)
        
        # relevances and concepts
        relevances, concepts = [], []

        # Get Relevances
        for i, cprototypes in enumerate(class_protoweighted):
            # class prototypes to tensors      # (batch_size, embedding_size * prototypes per class)      
            cps = torch.cat(cprototypes, 1)             
            relev = self.concept_relevance[i](cps)            

            # Add Relevances and concepts
            relevances.append(relev)
            concepts.append(cps);
        
        # (Batch_size, num_classes)
        relevances = torch.cat(relevances,1)   
        relevances = self.softmax(relevances)        
        expanded_relevances = torch.unsqueeze(relevances,-1)  
            
        # (Batch_size, embedding_size * num prototypes)
        concepts = torch.cat(concepts,1) 
       
        temp = torch.ones((batch_size , self.num_classes , 32*self.proto_per_class),requires_grad=False)
        temp = temp.to(device)
        expanded_relevances = expanded_relevances * temp  
        expanded_relevances = torch.reshape(expanded_relevances, concepts.shape)
        
        # Weight the concepts
        weighted_concepts = expanded_relevances * concepts
        dense_layer = self.final_dense(weighted_concepts)
        
        return class_prototypes, class_protoweight, relevances, dense_layer, class_protoweighted   

    def get_vis_local(self,inputs):
        batch_size = inputs.shape[0]
        concept_maps = self.feature_extract(inputs)
        concept_maps_to_return = torch.unbind(concept_maps, dim = 1)
        concept_maps = F.max_pool2d(concept_maps, kernel_size=(concept_maps.size()[2], concept_maps.size()[3]))
        concept_maps = torch.unsqueeze(concept_maps, -1)
        concept_maps = torch.unsqueeze(concept_maps, -1)
        concept_maps = torch.unbind(concept_maps, dim = 1)    

        # List of lists for prototypes
        class_prototypes = [[] for i in range(self.num_classes)]
        class_protoweight = [[] for i in range(self.num_classes)]
        class_protoweighted = [[] for i in range(self.num_classes)]
        
        # Get Concepts
        for i, concept_map in enumerate(concept_maps):
            # class index ci
            ci = i//self.proto_per_class        
            # Pass the featuremap to the protonet
            prototype, weight = self.protonet[ci](concept_map)
            class_prototypes[ci].append(prototype)
            class_protoweight[ci].append(weight)

            #if self.prune[ci][i % self.proto_per_class]:
            #    prototype = 0 * prototype
            class_protoweighted[ci].append(prototype)

        # relevances and concepts
        relevances, concepts = [], []
        
        # Get Relevances
        for i, cprototypes in enumerate(class_protoweighted):
            cps = torch.cat(cprototypes, 1)           # class prototypes to tensors (batch_size, embedding_size * prototypes per class)
            relev = self.concept_relevance[i](cps)
        
            # Add Relevances and concepts
            relevances.append(relev)
            concepts.append(cps);

        # (Batch_size, num_classes)
        relevances = torch.cat(relevances,1)    
        relevances = self.softmax(relevances)
     
        return concept_maps_to_return, class_prototypes, class_protoweight, relevances, class_protoweighted            

In [5]:
class ApplyMACE(nn.Module):
    def __init__(self, activation_features_channels = 512,num_classes = 20, proto_per_class = 10, first_dense_dim = 4096):
        super().__init__()
        self.basemodel = load_checkpoint('sgd_vgg16_finetune_checkpoint.pth')
        self.interpret_layer = MACENetwork(activation_features_channels,num_classes, proto_per_class, first_dense_dim)
        
        self.layer_outputs = {}
        def get_activation(name):
            def hook(module, input, output):
                self.layer_outputs[name] = output
            return hook
        self.basemodel.features[30].register_forward_hook(get_activation('convolution_output'))
        self.basemodel.classifier[1].register_forward_hook(get_activation('first_fully_connected_layer_output'))
        
        for param in self.basemodel.parameters():
            param.requires_grad = False
        
        for params in self.interpret_layer.parameters():
            params.requires_grad = True
            
    def get_features(self, inputs):
        final_output = self.basemodel(inputs)
        final_output = F.softmax(final_output,dim=1)
        features =  self.layer_outputs['convolution_output']
        first_dense = self.layer_outputs['first_fully_connected_layer_output']
        return features, first_dense, final_output
    
    #get_softmax only for vgg model
    def get_softmax(self, inputs):
        for i in range(3,7):
            inputs=self.basemodel.classifier[i](inputs)    
        result = F.softmax(inputs,dim=1) 
        #result = inputs
        return result

    def forward(self, inputs):
        features, first_dense, final_output = self.get_features(inputs)
        class_prototypes, class_protoweight, relevances, dense_layer_predict, class_protoweighted = self.interpret_layer(features)
        final_predict = self.get_softmax(dense_layer_predict)
        return class_prototypes, class_protoweight, relevances, dense_layer_predict, first_dense, final_predict, final_output, class_protoweighted
    
    def get_concept_maps(self, inputs):
        features, first_dense, final_output = self.get_features(inputs)
        return self.interpret_layer.get_concept_maps(features)
    
    def get_vis_local_maps(self,inputs):
        features, _, _ = self.get_features(inputs)
        return self.interpret_layer.get_vis_local(features)

In [6]:
def load_checkpoint(filepath):
    checkpoint = torch.load(filepath)
    model = checkpoint['model']
    model.load_state_dict(checkpoint['state_dict'])
    for parameter in model.parameters():
        parameter.requires_grad = False
    
    model.eval()
    return model

In [8]:
meta_model = ApplyMACE()
meta_model.to(device)

ch_pt = torch.load('./max_pool_92acc.pth')
meta_model.load_state_dict(ch_pt['model_state_dict'])

<All keys matched successfully>

In [9]:
gc.collect()

20

In [10]:
data_transform = transforms.Compose([transforms.Resize((224,224)),
                                     transforms.ToTensor(),
                                     transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])])

train_dataset = datasets.ImageFolder(root='./DATA/output/splitted_train_data/val', transform=data_transform)
train_dataset_loader = torch.utils.data.DataLoader(train_dataset, batch_size = 64, shuffle=True, num_workers=2)

print('train_dataset_loader',len(train_dataset_loader))

train_dataset_loader 57


In [11]:
def get_maps(train_set_ims):
    all_outputs = [meta_model.get_vis_local_maps(image.to(device)) for image, label in train_set_ims]
    print(len(all_outputs))
    print(len(all_outputs[0]))
    print(len(all_outputs[0][1]))
    print(len(all_outputs[0][1][0]))
    print((all_outputs[0][1][0][0].shape))
    for i in all_outputs[0][1][0]:
        print(i.shape)
   
    concept_map = torch.cat([torch.stack(x[0]) for x in all_outputs], 1)
    print('concept_map',concept_map.size())
    
    class_prototype = torch.cat([torch.stack([torch.stack(i) for i in x[1]]) for x in all_outputs], 2)
    print('class_prototype',class_prototype.shape)
    
    class_protoweight = torch.cat([torch.stack([torch.stack(i) for i in x[2]]) for x in all_outputs], 2)
    print('class_protoweight',class_protoweight.shape)
    
    class_protoweight = torch.squeeze(class_protoweight)
    print('class_protoweight',class_protoweight.shape)
    
    relevance = torch.cat([x[3] for x in all_outputs])
    print('relevance',relevance.shape)
    return concept_map, class_prototype, class_protoweight, relevance

concept_maps, class_prototypes, class_protoweights, relevances = get_maps(train_dataset_loader)

torch.Size([64, 3, 224, 224])
57
5
20
10
torch.Size([64, 32])
torch.Size([64, 32])
torch.Size([64, 32])
torch.Size([64, 32])
torch.Size([64, 32])
torch.Size([64, 32])
torch.Size([64, 32])
torch.Size([64, 32])
torch.Size([64, 32])
torch.Size([64, 32])
torch.Size([64, 32])
concept_map torch.Size([200, 3592, 7, 7])
class_prototype torch.Size([20, 10, 3592, 32])
class_protoweight torch.Size([20, 10, 3592, 1])
class_protoweight torch.Size([20, 10, 3592])
relevance torch.Size([3592, 20])


In [12]:
DIRECTORY = 'project'
dir_v = 'visualizations_maxpool_after_clip/'
if not os.path.exists(dir_v):
    os.makedirs(dir_v)

In [None]:
def get_weights(layer_weights, prototype, i ,j):
    # print(layer_weights[i].rel)
    weights = layer_weights[i].rel.weight[0]
    # print(weights.shape)
    a = torch.split(weights, [32 for _ in range(10)], 0)
    # print(a[j].shape)
    # print('torch.matmul(prototype, a[j])',torch.matmul(prototype, a[j]).shape)
    return torch.matmul(prototype, a[j]).squeeze()

In [None]:
def visualize(num_classes=20, proto_per_class=10, topK=10,S=5): # train_dataset_loader
    mean = torch.tensor([0.485, 0.456, 0.406], dtype=torch.float32)
    std = torch.tensor([0.229, 0.224, 0.225], dtype=torch.float32)
    denormalize = transforms.Normalize((-mean / std).tolist(), (1.0 / std).tolist())
    compose = transforms.Compose([transforms.Resize((224,224))])
    trans = transforms.ToPILImage()
    layer_weights = meta_model.interpret_layer.get_relevance_part()
   
    for i in range(num_classes):
        for j in range(proto_per_class):
            #print('concept_maps',concept_maps.shape)  torch.Size([200, 3592, 7, 7])
            c_m = concept_maps[i*10 + j]
            
            c_pw = class_protoweights[i][j]
            c_prototype = class_prototypes[i][j]
            new_c_pw = get_weights(layer_weights, c_prototype, i ,j)
            weight_rel = new_c_pw
            
            #values,indices = torch.topk( weight_rel , dim=0, k = topK )
            indices = torch.nonzero(torch.eq(torch.tensor(train_dataset.targets),i).float()).flatten()

            cm = [c_m[index]  for index in indices]
            names = [train_dataset[x][0] for x in indices]
            targets = sum([1 for _i in [train_dataset[x][1] for x in indices] if _i==i])
            if targets < 5:
                continue
                
            for k in range(topK):
                image = names[k].to(device)
                cmi = cm[k]
                #print(cmi)
                cmi = cmi.eq(torch.max(cmi)).float()
                #print(cmi)
                cmi = torch.clamp(cmi,0.38,1)
                #print(cmi)
                
                cmi = torch.unsqueeze(cmi, 0)                
                map_img = transforms.Resize((224, 224), interpolation=PIL.Image.NEAREST)(cmi)
                image = denormalize(image)
                image = torch.clamp(image, 0, 1)

                map_img = map_img*image
                # map_img = map_img/torch.max(map_img)
                
                map_name = dir_v + '/class' + str(i) + '_concept' + str(j) + '_' + str(k) + '_m_' + '.png'
                im_name  = dir_v + '/class' + str(i) + '_concept' + str(j) + '_' + str(k) + '_i_' + '.png'
                
                torchvision.utils.save_image(map_img, map_name)
                torchvision.utils.save_image(image, im_name)                

visualize(num_classes=20, topK=10)