In [1]:
# default_exp models.conditionalnet

# Implementation of a conditional similarity network

In [53]:
#export
import numpy as np
import torch
import torch.nn as nn
import torchvision.models as models

In [55]:
#export
resnet18 = models.resnet18(pretrained=True)

Downloading: "https://download.pytorch.org/models/resnet18-5c106cde.pth" to /root/.cache/torch/hub/checkpoints/resnet18-5c106cde.pth


HBox(children=(FloatProgress(value=0.0, max=46827520.0), HTML(value='')))




In [6]:
#export
class ConditionalSimNet(nn.Module):
    def __init__(self, emb_extractor, num_sim_condition, embedding_size, learned_mask=True):
        """
        emb_extractor: base model that extracts the embeddings from the network
        num_sim_condition: number of similarity conditions being considered
        embedding_size: size of the mask. Must be that same size as the output from `emb_extractor`
        learned_mask: flag to know if the conditional mask should be learned/adjusted during training
        """
        
        super(ConditionalSimNet, self).__init__()
        self.emb_extractor = emb_extractor
        self.learned_mask = learned_mask
        
        if learned_mask:
            self.masks = nn.Embedding(num_sim_condition, embedding_size)
            #use a normal distribution to init the weights
            self.masks.weight.data.normal_(0.9, 0.7)
            
        else:
            self.masks = nn.Embedding(num_sim_condition, embedding_size)
            
            # use equal spacing in the embedding space to define the masks
            mask_array = np.zeros([n_conditions, embedding_size])
            mask_len = int(embedding_size / n_conditions)
            
            for i in range(n_conditions):
                mask_array[i, i*mask_len:(i+1)*mask_len] = 1
            # no gradients for the masks
            self.masks.weight = torch.nn.Parameter(torch.Tensor(mask_array), requires_grad=False)
    
    def forward(self, inp, sim_condition):
        embedding = self.emb_extractor(inp)
        mask = self.masks(sim_condition)
        if self.learned_mask:
            # set the embeddings to fall between 0 and 1
            masked_embedding = embedding * mask
        return masked_embedding, mask.norm(1), embedding.norm(2), masked_embedding.norm(2)

In [9]:
# hide
c = torch.rand(2, 3); c

tensor([[0.6983, 0.7668, 0.1128],
        [0.6742, 0.4703, 0.9971]])

In [34]:
# hide
from IPython.display import Image
Image(url= "https://miro.medium.com/max/2546/1*zMLv7EHYtjfr94JOBzjqTA.png", width=500, height=500)

In [54]:
# hide
print([c_ for c_ in c.flatten()])

[tensor(0.6983), tensor(0.7668), tensor(0.1128), tensor(0.6742), tensor(0.4703), tensor(0.9971)]


In [49]:
# hide
print("L2 regularization \n")
print(torch.tensor([c_ ** 2 for c_ in c.flatten()]).sum().sqrt())
print(c.norm(2))

L2 regularization 

tensor(1.6608)
tensor(1.6608)


In [52]:
# hide
print("L1 regularization \n")
print(torch.tensor([c_.abs() for c_ in c.flatten()]).sum())
print(c.norm(1))

L1 regularization 

tensor(3.7195)
tensor(3.7195)
