In [1]:
import torch
import torchvision
from torch.utils.data import Dataset,DataLoader
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torchvision import transforms, datasets
import time
import os
import copy
import gc
import torch.optim as optim
from torch.optim import lr_scheduler

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [3]:
class EmbeddingGenerator(nn.Module):
    def __init__(self, activation_features_channels=200, activation_features_size=(14, 14)):
        """     
        Parameters
        ----------
        activation_features_channels : INT, optional            DESCRIPTION. The default is 512 as in VGG.
        activation_features_size : TYPE, optional            DESCRIPTION. The default is (14, 14) as in VGG.

        Returns
        -------
        None.
        """
        super().__init__()
        self.embeddings = nn.Sequential( nn.Linear(activation_features_size[0] * activation_features_size[1], 32),
                                        nn.Tanh(),)
        
    def forward(self, x):
        """ input will be in concept map dimensions, so first flatten it then pass through embedding generator """        
        x = torch.flatten(x, 1)        
        x = self.embeddings(x)
        weights = torch.sum(torch.square(x),dim=-1,keepdim=True)       
        x = F.normalize(x)        
        return x, weights   
    
class Relevance(nn.Module):
    def __init__(self, proto_per_class = 10,embed_out_shape=(32,32)):
        super(Relevance , self).__init__()
        self.rel = nn.Linear(embed_out_shape[0] * proto_per_class, 1)

    def forward(self,inputs):
        relevance = self.rel(inputs)
        return relevance     

In [24]:
class MACENetwork(nn.Module):
    """ MACE module """
    def __init__(self, activation_features_channels = 512 , num_classes = 20, proto_per_class = 10, first_dense_dim = 4096):
        super(MACENetwork , self).__init__()

        self.activation_features_channels = activation_features_channels
        self.num_classes = num_classes
        self.proto_per_class = proto_per_class
        self.num_prototypes = num_classes * proto_per_class

        # concept map generator module
        self.feature_extract = nn.Sequential(nn.Conv2d(in_channels=self.activation_features_channels,out_channels=self.num_prototypes,
                                                       kernel_size=(1,1)), nn.ReLU())
    
        # maps a concept map into a local prototype(embedding generator)
        self.protonet = nn.ModuleList([EmbeddingGenerator(activation_features_size=(1, 1)) for _ in range(self.num_classes)])
    
        # concept relevance module
        self.concept_relevance  = nn.ModuleList([Relevance(self.proto_per_class) for _ in range(self.num_classes)])

        self.softmax = nn.Sigmoid()
        
        # Final Dense Layer (output module)
        self.final_dense = nn.Sequential(nn.Linear(32*self.num_prototypes , first_dense_dim), nn.ReLU())

        #------------------------This looks like that----------------------------
        self.prototype_vectors = nn.Parameter(torch.rand(self.num_prototypes, self.activation_features_channels, 1, 1),
                                              requires_grad=True)

        # do not make this just a tensor,
        # since it will not be moved automatically to gpu
        self.ones = nn.Parameter(torch.ones(self.num_prototypes, self.activation_features_channels, 1, 1),
                                 requires_grad=False)

    def _l2_convolution(self, x):
        '''
        apply self.prototype_vectors as l2-convolution filters on input x
        '''
        # (batch, 512, 14, 14)
        # (batch, 512, 1, 1) conv(a) square and add - a^2
        x2 = x ** 2
        x2_patch_sum = F.conv2d(input=x2, weight=self.ones)

        # (batch, 512, 1, 1) prototype vectors (b) - b^2
        p2 = self.prototype_vectors ** 2
        p2 = torch.sum(p2, dim=(1, 2, 3))
        # p2 is a vector of shape (num_prototypes,)
        # then we reshape it to (num_prototypes, 1, 1)
        p2_reshape = p2.view(-1, 1, 1)

        # important (batch, 512, 1, 1) ab
        xp = F.conv2d(input=x, weight=self.prototype_vectors)
        intermediate_result = - 2 * xp + p2_reshape  # use broadcast
        # x2_patch_sum and intermediate_result are of the same shape
        distances = F.relu(x2_patch_sum + intermediate_result)

        return distances

    def distance_2_similarity(self, distances, prototype_activation_function='log', epsilon=1e-4):
        if prototype_activation_function == 'log':
            return torch.log((distances + 1) / (distances + epsilon))
        elif prototype_activation_function == 'linear':
            return -distances

    def similarity_of_concepts(self, conv_features):
        distances = self._l2_convolution(conv_features)
        '''
        we cannot refactor the lines below for similarity scores
        because we need to return min_distances
        '''
        concept_maps = self.distance_2_similarity(distances)
        # global min pooling
        prototype_activations = F.max_pool2d(concept_maps,
                                      kernel_size=(concept_maps.size()[2],
                                                   concept_maps.size()[3]))
        # prototype_activations = prototype_activations.view(-1, self.num_prototypes)
        # prototype_activations = torch.unsqueeze(prototype_activations, -1)
        # prototype_activations = torch.unsqueeze(prototype_activations, -1)
        
        return prototype_activations, concept_maps
        #-------------------------------End--------------------------------------

    def get_relevance_part(self):
        return self.concept_relevance
    
    def forward(self, inputs):
        batch_size = inputs.shape[0]
        # Get concept maps
        concept_maps,_ = self.similarity_of_concepts(inputs)
        concept_maps = torch.unbind(concept_maps, dim = 1) 
        
        # List of lists for prototypes(or embeddings)
        class_prototypes = [[] for _ in range(self.num_classes)]
        class_protoweight = [[] for _ in range(self.num_classes)]
        class_protoweighted = [[] for _ in range(self.num_classes)]
        
        # Get Concepts
        for i, concept_map in enumerate(concept_maps):
            # class index ci
            ci = i//self.proto_per_class
            
            # Pass the concept map to the protonet
            prototype, weight = self.protonet[ci](concept_map)
            class_prototypes[ci].append(prototype)
            class_protoweight[ci].append(weight)

            #if self.prune[ci][i % self.proto_per_class]:
            #    prototype = 0 * prototype
            class_protoweighted[ci].append(prototype)
        
        # relevances and concepts
        relevances, concepts = [], []

        # Get Relevances
        for i, cprototypes in enumerate(class_protoweighted):
            # class prototypes to tensors      # (batch_size, embedding_size * prototypes per class)            
            cps = torch.cat(cprototypes, 1)             
            relev = self.concept_relevance[i](cps)            

            # Add Relevances and concepts
            relevances.append(relev)
            concepts.append(cps);
        
        # (Batch_size, num_classes)
        relevances = torch.cat(relevances,1)   
        relevances = self.softmax(relevances)        
        expanded_relevances = torch.unsqueeze(relevances,-1)  
            
        # (Batch_size, embedding_size * num prototypes)
        concepts = torch.cat(concepts,1) 
       
        temp = torch.ones((batch_size , self.num_classes , 32*self.proto_per_class),requires_grad=False)
        temp = temp.to(device)
        expanded_relevances = expanded_relevances * temp  
        expanded_relevances = torch.reshape(expanded_relevances, concepts.shape)
        
        # Weight the concepts
        weighted_concepts = expanded_relevances * concepts
        dense_layer = self.final_dense(weighted_concepts)
        
        return class_prototypes, class_protoweight, relevances, dense_layer, class_protoweighted   


    def get_vis_local(self,inputs):
        batch_size = inputs.shape[0]
        concept_maps, concept_maps_to_return = self.similarity_of_concepts(inputs)
        concept_maps = torch.unbind(concept_maps, dim = 1)   
        concept_maps_to_return = torch.unbind(concept_maps_to_return, dim = 1) 

        # List of lists for prototypes
        class_prototypes = [[] for i in range(self.num_classes)]
        class_protoweight = [[] for i in range(self.num_classes)]
        class_protoweighted = [[] for i in range(self.num_classes)]
        
        # Get Concepts
        for i, concept_map in enumerate(concept_maps):
            # class index ci
            ci = i//self.proto_per_class        
            # Pass the featuremap to the protonet
            prototype, weight = self.protonet[ci](concept_map)
            class_prototypes[ci].append(prototype)
            class_protoweight[ci].append(weight)

            #if self.prune[ci][i % self.proto_per_class]:
            #    prototype = 0 * prototype
            class_protoweighted[ci].append(prototype)

        # relevances and concepts
        relevances, concepts = [], []
        
        # Get Relevances
        for i, cprototypes in enumerate(class_protoweighted):
            cps = torch.cat(cprototypes, 1)           # class prototypes to tensors (batch_size, embedding_size * prototypes per class)
            relev = self.concept_relevance[i](cps)
        
            # Add Relevances and concepts
            relevances.append(relev)
            concepts.append(cps);

        # (Batch_size, num_classes)
        relevances = torch.cat(relevances,1)    
        relevances = self.softmax(relevances)
     
        return concept_maps_to_return, class_prototypes, class_protoweight, relevances, class_protoweighted            

In [25]:
class ApplyMACE(nn.Module):
    def __init__(self, activation_features_channels = 512,num_classes = 20, proto_per_class = 10, first_dense_dim = 4096):
        super().__init__()
        self.basemodel = load_checkpoint('sgd_vgg16_finetune_checkpoint.pth')
        self.interpret_layer = MACENetwork(activation_features_channels,num_classes, proto_per_class, first_dense_dim)
        
        self.layer_outputs = {}
        def get_activation(name):
            def hook(module, input, output):
                self.layer_outputs[name] = output
            return hook
        self.basemodel.features[29].register_forward_hook(get_activation('convolution_output'))
        self.basemodel.classifier[1].register_forward_hook(get_activation('first_fully_connected_layer_output'))
        
        for param in self.basemodel.parameters():
            param.requires_grad = False
        
        for params in self.interpret_layer.parameters():
            params.requires_grad = True
            
    def get_features(self, inputs):
        final_output = self.basemodel(inputs)
        final_output = F.softmax(final_output,dim=1)
        features =  self.layer_outputs['convolution_output']
        first_dense = self.layer_outputs['first_fully_connected_layer_output']
        return features, first_dense, final_output
    
    #get_softmax only for vgg model
    def get_softmax(self, inputs):
        for i in range(3,7):
            inputs=self.basemodel.classifier[i](inputs)    
        result = F.softmax(inputs,dim=1) 
        #result = inputs
        return result

    def forward(self, inputs):
        features, first_dense, final_output = self.get_features(inputs)
        class_prototypes, class_protoweight, relevances, dense_layer_predict, class_protoweighted = self.interpret_layer(features)
        final_predict = self.get_softmax(dense_layer_predict)
        return class_prototypes, class_protoweight, relevances, dense_layer_predict, first_dense, final_predict, final_output, class_protoweighted
    
    def get_concept_maps(self, inputs):
        features, first_dense, final_output = self.get_features(inputs)
        return self.interpret_layer.get_concept_maps(features)
    
    def get_vis_local_maps(self,inputs):
        features, _, _ = self.get_features(inputs)
        return self.interpret_layer.get_vis_local(features)

In [26]:
def load_checkpoint(filepath):
    checkpoint = torch.load(filepath)
    model = checkpoint['model']
    model.load_state_dict(checkpoint['state_dict'])
    for parameter in model.parameters():
        parameter.requires_grad = False
    
    model.eval()
    return model

In [27]:
def get_anchors(batch_size, num_concepts):
    anchors_indices = torch.from_numpy(np.random.uniform(low = 0, high = batch_size, size=(num_concepts, )))
    anchors_indices = anchors_indices.type(torch.int32)
    index_mapper = batch_size * (torch.range(1,num_concepts))
    return anchors_indices + index_mapper

def pairwise_dist_map(anchors, embeddings):
    # RETURNS PAIRWISE DISTANCE FOR EACH CONCEPT
    fn = lambda x: torch.sum(torch.square(x - embeddings), axis = 1)
    y = torch.stack([fn(ele) for ele in torch.unbind(anchors)])
    return y
    
def sample_negative(avg_positive, masked_negative):
    # GIVEN A PANCHOR, IT GETS THE SEMI-HARD NEGATIVE
    def get_negative(panchor, nanchors):
        mask = torch.gt(nanchors, panchor)                 
        applied_map = torch.masked_select(nanchors, mask)   
        semi_hard_negative = torch.min(applied_map)          
        return semi_hard_negative if torch.isfinite(semi_hard_negative) else torch.tensor(0.0) 
  
    # FOR EACH ELEMENT IN POSITIVE MASK, FIND A SEMI-HARD NEGATIVE
    func = lambda x: get_negative(x[0], x[1])
    yy = torch.stack([func(ele) for ele in zip(torch.unbind(avg_positive),torch.unbind(masked_negative))])
    yy = yy.type(torch.float32)
    return yy                                                

In [28]:
def custom_triplet_loss(local_prototypes, margin = 0.2):
    """ Memory Efficient Triplet Loss """
    batch_size, embedding_size, num_concepts = local_prototypes.shape

    # EMBEDDING: (1st CONCEPT FOR BATCH_SIZE, 2ND CONCEPT FOR BATCH_SIZE)   # SHAPE = (BATCH_SIZE * NUM_CONCEPTS, EMBEDDING_SIZE)
    transposed_prototype = local_prototypes.permute(2,0,1)                                            
    embeddings = torch.reshape(transposed_prototype, (batch_size*num_concepts, embedding_size))

    # MASKS = (NUM_CONCEPTS, NUM_CONCEPT*BATCH_SIZE)
    concepts_range = torch.unsqueeze(torch.range(1,num_concepts),-1)                                  
    column_mask = torch.tensor(np.broadcast_to(concepts_range.numpy(), (num_concepts, batch_size)))   
    column_mask = torch.reshape(column_mask, (num_concepts*batch_size,1))
    column_mask = torch.transpose(torch.tensor(np.broadcast_to(column_mask.numpy(),(num_concepts*batch_size, num_concepts))), 0, 1)  

    positive_mask = torch.eq(concepts_range, column_mask)       
    negative_mask = torch.logical_not(positive_mask)            

    # DIMENSION OF ANCHORS = (NUM_CONCEPTS, EMBEDDING_SIZE)
    anchor_indices = get_anchors(batch_size, num_concepts)
    anchors = torch.gather(embeddings,0, anchor_indices)          

    # PAIRWISE DISTANCE = (NUM_CONCEPTS, NUM_CONCEPTS*BATCH_SIZE)
    pairwise_distance = pairwise_dist_map(anchors, embeddings)

    # FIND HARDEST POSITIVE ANCHOR
    masked_positive = torch.masked_select(pairwise_distance, positive_mask)                   
    masked_positive = torch.reshape(masked_positive, (num_concepts, batch_size))
    avg_positive = torch.mean(masked_positive, dim = 1)                                       

    # SEMI HARD NEGATIVE
    masked_negative = torch.masked_select(pairwise_distance, negative_mask)                         
    masked_negative = torch.reshape(masked_negative, (num_concepts, batch_size*(num_concepts - 1)))
    hardest_negative = torch.unsqueeze(sample_negative(avg_positive, masked_negative), -1)          

    loss = torch.mean(torch.maximum(masked_positive - hardest_negative + margin, 0))           

    return loss

In [29]:
def new_triplet_loss(prototypes, margin = 0.2):
    """ Memory Efficient Triplet Loss      prototypes.shape = (batch_size, number_of_prototypes, embedding_size) """
    # FIND HARDEST POSITIVE ACROSS BATCH
    p_prototypes = prototypes.permute(1,0,2)                 
    p_prototypes = torch.unsqueeze(p_prototypes, -1)         
    p_ = p_prototypes
    p_t = p_prototypes.permute(0, 3, 2, 1)                   
    
    pairwise = torch.sum(torch.square(p_ - p_t), dim = 2)    
    # Diagonals are 0. For each row take the one with the max.
    hardest_p,_ = torch.max(pairwise, dim = -1)               # (number_of_prototypes, batch_size)     
    hardest_p = hardest_p.permute(1,0)                       # (batch_size, number_of_prototypes)     
    
    # FIND HARDEST NEGATIVE ACROSS CONCEPTS
    n_prototypes = torch.unsqueeze(prototypes, -1)           
    n_ = n_prototypes
    n_t = n_prototypes.permute(0, 3, 2, 1)                   
    
    pairwise = torch.sum(torch.square(n_ - n_t), dim = 2)                      
    semi_masked = torch.gt(pairwise, torch.unsqueeze(hardest_p, -1))            
    not_semi_masked =  torch.logical_not(semi_masked)                           
    # Diagonals are 0. Make diagonals 4 + 1e-6 (largest distance) For each row take the one with the min. 
    not_semi_masked = not_semi_masked.type(torch.float32)
    eye = (4 + 1e-6)*(not_semi_masked)                                         
    semi_masked = semi_masked.type(torch.float32)
    pairwise = eye + pairwise*semi_masked                                   
    
    hardest_n , _ = torch.min(pairwise, dim = -1)       # (batch_size, number_of_prototypes)  
    temp = torch.zeros(hardest_p.shape,requires_grad = False).to(device)
    
    x = torch.max(hardest_p - hardest_n + margin,temp)                  #element wise max #(batch_size, number_of_prototypes)
    
    loss = torch.mean(x, dim = -1)     
    return loss

In [30]:
data_transform = transforms.Compose([transforms.RandomResizedCrop(224),
                                     transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor(),
                                     transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])])
train_dataset = datasets.ImageFolder(root='./DATA/output/train', transform=data_transform)
train_dataset_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=2)
print('train_dataset_loader',len(train_dataset_loader))

train_dataset_loader 374


In [31]:
data_transform_test = transforms.Compose([#transforms.RandomResizedCrop(224),
                                         #transforms.RandomHorizontalFlip(),
                                         transforms.Resize((224,224)),
                                         transforms.ToTensor(),
                                         transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])  ])                                          

test_dataset = datasets.ImageFolder(root='./DATA/output/test',transform = data_transform_test )
test_dataset_loader = torch.utils.data.DataLoader(test_dataset, batch_size=32)
print('test_dataset_loader',len(test_dataset_loader))

test_dataset_loader 48


In [32]:
#def sparse_categorical_accuracy(y_true, y_pred):
#    return K.cast(K.equal(K.max(y_true, axis=-1), K.cast(K.argmax(y_pred, axis=-1), K.floatx())), K.floatx())

def sparse_categorical_accuracy(y_true, y_pred):
    #max_y_true,_= torch.max(y_true , dim = -1)
    max_y_true = y_true
    result = torch.eq(max_y_true , torch.argmax(y_pred, dim=-1).float())
    result = result.float()
    return result.sum().item()

In [60]:
meta_model = ApplyMACE()
meta_model.to(device)

ApplyMACE(
  (basemodel): VGG(
    (features): Sequential(
      (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): ReLU(inplace=True)
      (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (3): ReLU(inplace=True)
      (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (6): ReLU(inplace=True)
      (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (8): ReLU(inplace=True)
      (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (11): ReLU(inplace=True)
      (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (13): ReLU(inplace=True)
      (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (15): ReLU(inplace=True)
      

In [61]:
gc.collect()

572

In [62]:
mse = torch.nn.MSELoss()
kl = torch.nn.BCELoss()

In [63]:
optimizer = torch.optim.Adam(meta_model.parameters(),lr=1e-3)
#optimizer = optim.SGD(meta_model.parameters(), lr=1.0, momentum=0.9)
#exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

In [64]:
train_acc = 0
test_acc = 0
rev_acc = 0
train_total = 0
test_total = 0

train_mse_dense = 0
train_kl = 0
train_revelance = 0
train_loss_triplet = 0
num_batch=0

In [65]:
def train_step(images, labels):
    global train_acc, train_total, train_mse_dense, train_kl, train_revelance, train_loss_triplet, num_batch  
    
    class_prototypes, class_protoweight, relevances, dense_layer_predict, first_dense, final_predict, final_output, class_protoweighted = meta_model(images)      
    optimizer.zero_grad()
    
    # MSE with pre dense layer
    mse_loss = mse(first_dense, dense_layer_predict)
    
    # KL with original distribution
    kl_loss = kl(final_predict , final_output)
    revelance_loss = kl( relevances , final_output)

    triplet_loss = 0
    for cp in class_prototypes:
        # (batch_size, embedding_size, num_prototypes) (64,32,10)
        prototype = torch.stack(cp, dim = 2)
        triplet_loss += 0.1*new_triplet_loss(prototype.permute(0,2,1), 1)

    triplet_loss = torch.mean(triplet_loss, dim=0)

    # TOTAL LOSS
    loss = mse_loss + kl_loss + revelance_loss + triplet_loss

    loss.backward()
    optimizer.step()
    
    #exp_lr_scheduler.step()
    
    train_mse_dense += mse_loss.item()
    train_kl += kl_loss.item()
    train_revelance += revelance_loss.item()
    train_loss_triplet += triplet_loss.item()
    num_batch += 1
    
    train_acc += sparse_categorical_accuracy(torch.argmax(final_output,dim=1), final_predict)
    train_total += labels.size(0)
    
def test_step(images, labels):
    global test_acc, rev_acc, test_total 
    
    class_prototypes, class_protoweight, relevances, dense_layer_predict, first_dense, final_predict, final_output, class_protoweighted = meta_model(images)       
    test_acc += sparse_categorical_accuracy(torch.argmax(final_output,dim=1), final_predict)
    rev_acc += sparse_categorical_accuracy(torch.argmax(final_output,dim=1), relevances)
    test_total += labels.size(0)    

In [67]:
train_kl_loss = []
train_mse_dense_loss = []
train_relevance_loss = []
train_loss_triplet_loss = []

best_model_wts = copy.deepcopy(meta_model.state_dict())
best_acc = 0.0

EPOCHS = 30
for epoch in range(EPOCHS):
    for image, label in train_dataset_loader:
        image = image.to(device)
        label = label.to(device)
        train_step(image, label)        
    #print('train step done')
    
    for image, label in test_dataset_loader:
        image = image.to(device)
        label = label.to(device)
        test_step(image, label)    
    #print('test step done')

    train_kl_loss.append(train_kl/num_batch)
    train_mse_dense_loss.append(train_mse_dense/num_batch)
    train_relevance_loss.append(train_revelance/num_batch)
    train_loss_triplet_loss.append(train_loss_triplet/num_batch)
    
    template = 'Eh {}, MSE Loss:{:.3f}, KL Loss:{:.3f}, Triplet Loss:{:.3f}, Relevance Loss:{:.3f}, Train Acc:{:.3f}, Test Acc:{:.3f}, Rev Acc:{:.3f}'
    print(template.format(epoch+1,
                        train_mse_dense/num_batch,
                        train_kl/num_batch,
                        train_loss_triplet/num_batch,
                        train_revelance/num_batch,
                        train_acc/train_total,
                        test_acc/test_total,
                        rev_acc/test_total ))
    
    if (train_acc/train_total) > best_acc:
        best_acc = train_acc/train_total
        best_model_wts = copy.deepcopy(meta_model.state_dict())
        
   # Reset the metrics for the next epoch    
    train_acc = 0
    test_acc = 0
    rev_acc = 0
    train_total = 0
    test_total = 0
    
    train_mse_dense = 0
    train_kl = 0
    train_revelance = 0
    train_loss_triplet = 0
    num_batch=0

Eh 1, MSE Loss:0.210, KL Loss:0.058, Triplet Loss:0.295, Relevance Loss:0.164, Train Acc:0.801, Test Acc:0.846, Rev Acc:0.344
Eh 2, MSE Loss:0.208, KL Loss:0.056, Triplet Loss:0.293, Relevance Loss:0.161, Train Acc:0.803, Test Acc:0.831, Rev Acc:0.384
Eh 3, MSE Loss:0.208, KL Loss:0.055, Triplet Loss:0.292, Relevance Loss:0.158, Train Acc:0.815, Test Acc:0.863, Rev Acc:0.415
Eh 4, MSE Loss:0.207, KL Loss:0.053, Triplet Loss:0.293, Relevance Loss:0.155, Train Acc:0.819, Test Acc:0.853, Rev Acc:0.426
Eh 5, MSE Loss:0.207, KL Loss:0.050, Triplet Loss:0.291, Relevance Loss:0.152, Train Acc:0.834, Test Acc:0.860, Rev Acc:0.452
Eh 6, MSE Loss:0.203, KL Loss:0.050, Triplet Loss:0.288, Relevance Loss:0.150, Train Acc:0.831, Test Acc:0.869, Rev Acc:0.443
Eh 7, MSE Loss:0.206, KL Loss:0.049, Triplet Loss:0.287, Relevance Loss:0.148, Train Acc:0.837, Test Acc:0.865, Rev Acc:0.467
Eh 8, MSE Loss:0.203, KL Loss:0.048, Triplet Loss:0.285, Relevance Loss:0.147, Train Acc:0.842, Test Acc:0.862, Rev Ac

In [None]:
#lr=5e-3
best = best_acc
print(best)
PATH = './this_looks_88acc.pth'
torch.save({'model_state_dict':best_model_wts,'optimizer_state_dict':optimizer.state_dict()}, PATH)

```
concept_input torch.Size([64, 512, 14, 14])
CM torch.Size([64, 200, 14, 14])
CMlen 200
CM_embed_input -> torch.Size([64, 14, 14])
Embed_input -> torch.Size([64, 196])
embed_out -> torch.Size([64, 32])
prototype torch.Size([64, 32])
cprototype shape -> 10 torch.Size([64, 32])
cps shape-> torch.Size([64, 320])
relev--> torch.Size([64, 1])
cprototype shape -> 10 torch.Size([64, 32])
R -> torch.Size([64, 20])
expanded R  -> torch.Size([64, 20, 1])
1 concept (before cat) torch.Size([64, 320])
concepts after cat -> torch.Size([64, 6400])```