In [1]:
import torch
import torch.nn as nn
from tqdm import tqdm
import torch.nn.functional as F
from generalized_contrastive_loss.datasets import *
from torch.utils.data import DataLoader
from torchvision import models

  warn(f"Failed to load image Python extension: {e}")


In [2]:
class GeM(nn.Module):
    def __init__(self, p=3, eps=1e-6, requires_grad=False):
        super(GeM,self).__init__()
        self.p = nn.Parameter(torch.ones(1)*p, requires_grad=requires_grad)
        self.eps = eps

    def forward(self, x):
        return self.gem(x, p=self.p, eps=self.eps)

    def gem(self, x, p=3, eps=1e-6):
        return F.avg_pool2d(x.clamp(min=eps).pow(p), (x.size(-2), x.size(-1))).pow(1./p)

    def __repr__(self):
        return self.__class__.__name__ + '(' + 'p=' + '{:.4f}'.format(self.p.data.tolist()[0]) + ', ' + 'eps=' + str(self.eps) + ')'


class BaseNet(nn.Module):
    def __init__(self, backbone, global_pool=None, poolkernel=7, norm=None, p=3, num_clusters=64):
        super(BaseNet, self).__init__()
        self.backbone = backbone
        for name, param in self.backbone.named_parameters():
                n=param.size()[0]
        self.num_features=n
        self.pretrained_cfg = {}
        self.num_classes=0
        if global_pool == "max":
            self.pool = nn.AdaptiveMaxPool2d(output_size=(1, 1))
        elif global_pool == "avg":
            self.pool = nn.AdaptiveAvgPool2d(output_size=(1, 1))
        elif global_pool == "GeM":
            self.pool=GeM(p=p)
        else:
            self.pool = None
        self.norm=norm

    # This function returns both local and global features
    def forward(self, x0):

        # conv1
        x0 = self.backbone[0](x0)
        # bn1
        x0 = self.backbone[1](x0)
        # relu
        x0 = self.backbone[2](x0)
        # max0pool
        x0 = self.backbone[3](x0)
        # layer1
        x0 = self.backbone[4](x0)
        # layer2
        x0 = self.backbone[5](x0)
        # layer3
        local_features = self.backbone[6](x0)
        # print(f"local feature size {local_features.size()}")
        
        # layer4. This is equivalent to do out = self.backbone.forward(x0)    
        global_features = self.backbone[7](local_features)
        # print(f"global feature size before squeezing {global_features.size()}")
        
        # Apply GeM pooling on the features from layer 4 
        global_features = self.pool.forward(global_features).squeeze(-1).squeeze(-1)
        # print(f"global feature size after squeezing and pooling {global_features.size()}")
        
        # Apply GeM pooling on the local features 
        local_features = self.pool.forward(local_features).squeeze(-1).squeeze(-1)
        # print(f"local feature size after pooling {local_features.size()}")
        if self.norm == "L2":
            global_features=nn.functional.normalize(global_features)
        return local_features, global_features


class SiameseNet(BaseNet):
    def __init__(self, backbone, global_pool=None, poolkernel=7,norm=None, p=3,num_clusters=64):
        super(SiameseNet, self).__init__(backbone, global_pool, poolkernel, norm=norm, p=p,num_clusters=num_clusters)

    def forward(self, x0, x1):
        out0 = super(SiameseNet, self).forward(x0)
        out1 = super(SiameseNet, self).forward(x1)
        return out0, out1

In [None]:
def create_dataloader(dataset, root_dir, idx_file, gt_file, image_t, batch_size):
    # Create dataset
    if dataset=="test":
        ds = TestDataSet(root_dir, idx_file, transform=image_t)
        return DataLoader(ds, batch_size=batch_size, num_workers=4)

    if dataset == "soft_siamese":
        ds = SiameseDataSet(root_dir, idx_file, gt_file, ds_key="fov", transform=image_t)
    elif dataset == "binary_siamese":
        ds = SiameseDataSet(root_dir, idx_file, gt_file, ds_key="sim", transform=image_t)
    return DataLoader(ds, batch_size=batch_size, num_workers=4, shuffle=True)


def get_backbone(name):
    if name == "resnet18":
        backbone = models.resnet18(pretrained=True)
    elif name == "resnet34":
        backbone = models.resnet34(pretrained=True)
    elif name == "resnet152":
        backbone = models.resnet152(pretrained=True)
    elif name == "resnet50":
        backbone = models.resnet50(pretrained=True)
    if name == "densenet161":
        backbone = models.densenet161(pretrained=True).features
        output_dim=2208
    elif name == "densenet121":
        backbone = models.densenet121(pretrained=True).features
        output_dim=2208
    elif name == "vgg16":
        backbone = models.vgg16(pretrained=True).features
        output_dim=512
    elif name == "resnext":
        backbone = torch.hub.load('facebookresearch/WSL-Images', 'resnext101_32x8d_wsl')
        # Supposed to be ['conv1', 'bn1', 'relu', 'maxpool', 'layer1', 'layer2', 'layer3', 'layer4', 'avgpool', 'fc']
        print(f" the layers of the resnext101_32x8d_wsl are: {backbone._modules.keys()}")
    if "resne" in name:
        backbone = torch.nn.Sequential(*(list(backbone.children())[:-2]))
        output_dim = 2048
        print(f" the layers of the resnext101_32x8d_wsl are after removing the last two layers (avgpool and fc): {backbone._modules.keys()}")
    return backbone, output_dim


def create_model(name, pool, last_layer=None, norm=None, p_gem=3, num_clusters=64, mode="siamese"):
    backbone, output_dim = get_backbone(name)
    layers = len(list(backbone.children()))
    print(f"Number of layers: {layers}")

    if last_layer is None:
        last_layer = layers
    elif "densenet" in name:
        last_layer=last_layer*2
    elif "vgg" in name:
    	last_layer=last_layer*8-2
    aux = 0
    for c in backbone.children():

        if aux < layers - last_layer:
            print(aux, c._get_name(), "IS FROZEN")
            for p in c.parameters():
                p.requires_grad = False
        else:
            print(aux, c._get_name(), "IS TRAINED")
        aux += 1
    if mode=="siamese":
        return SiameseNet(backbone, pool, norm=norm, p=p_gem, num_clusters=num_clusters)
    elif mode=="triplet":
        return TripletNet(backbone, pool, norm=norm, p=p_gem, num_clusters=num_clusters)
    else:
        return BaseNet(backbone, pool, norm=norm, p=p_gem)