In [1]:
import numpy as np

In [41]:
import torch
import torch.nn as nn
import torchvision.models as models
import torch.nn.functional as F

In [42]:
class simCLR_FROD(nn.Module):
    def __init__(self, base_encoder, AE, projection_dim=128):
        super().__init__()
        self.enc = base_encoder(pretrained=False)  # load model from torchvision.models without pretrained weights.
        self.feature_dim = self.enc.fc.in_features

        # Customize for CIFAR10. Replace conv 7x7 with conv 3x3, and remove first max pooling.
        # See Section B.9 of SimCLR paper.
        self.enc.conv1 = nn.Conv2d(3, 64, 3, 1, 1, bias=False)
        self.enc.maxpool = nn.Identity()
        self.enc.fc = nn.Identity()  # remove final fully connected layer.

        # Add MLP projection.
        self.projection_dim = projection_dim
        self.projector = nn.Sequential(nn.Linear(self.feature_dim, 2048),
                                       nn.ReLU(),
                                       nn.Linear(2048, projection_dim))
        
        self.enc_midlayers=nn.Sequential(nn.Sequential(self.enc.conv1,self.enc.bn1,self.enc.relu,self.enc.maxpool),self.enc.layer1[0],self.enc.layer1[1],self.enc.layer2[0],self.enc.layer2[1],self.enc.layer3[0],self.enc.layer3[1],self.enc.layer4[0],self.enc.layer4[1],nn.Sequential(self.enc.avgpool,self.enc.fc))
        
        self.midlayers_num=len(self.enc_midlayers)
        self.AE = nn.Sequential(AE(64, 32, 16, 8,4,0,0),
                               AE(64, 32, 16, 8,4,0,0),
                               AE(64, 32, 16, 8,4,0,0),
                               AE(128, 64, 32, 16,8,4,0),
                                AE(128, 64, 32, 16,8,4,0),
                                AE(256, 128, 64, 32, 16, 8,4),
                                AE(256, 128, 64, 32, 16, 8,4),
                                AE(512,256,128,64,32,8,4),
                                AE(512,256,128,64,32,8,4)
                               )
#     def recon_error(self, x):
#         z = self.encoder(x)
#         x_recon = self.decoder(z)
#         return torch.norm((x_recon - x), dim=1)
    
#     def forward(self, x):
#         z = self.encoder(x)
#         return self.decoder(z)

    def forward(self, x):
        feature = self.enc(x)
        projection = self.projector(feature)
        return feature, projection
    
    def intermediate_features(self,x,index):
        out_features=self.enc_midlayers[:index](x)
        out_features = out_features.view(out_features.size(0), out_features.size(1), -1)
        out_features = torch.mean(out_features, 2)

        return out_features
    
    def recon_error(self,x,index):
        return self.AE[index].recon_error(self.intermediate_features(x,index))

In [43]:
device = torch.device('cuda:7')

network = simCLR_FROD(models.resnet18, AE).to(device)
print(network)

simCLR_FROD(
  (enc): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): Identity()
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (con

In [50]:
network.intermediate_features(torch.randn(128,3,32,32).to(device),1).shape

torch.Size([128, 64])

In [45]:
network.midlayers_num

10

In [48]:
network.recon_error(torch.randn(128,3,32,32).to(device),8)

tensor([13.9729, 13.5033, 13.3804, 12.9721, 12.9899, 12.8227, 13.3897, 13.1998,
        13.7892, 12.7826, 13.3194, 13.3273, 13.0100, 12.9278, 14.2812, 13.8439,
        13.5959, 13.2702, 13.7046, 12.9605, 13.3336, 12.9959, 12.9971, 13.1047,
        14.1510, 13.4758, 14.1476, 13.5983, 13.4270, 13.0585, 13.2651, 13.7201,
        14.1078, 13.6648, 13.0152, 14.2640, 14.3270, 12.4420, 12.5989, 12.9132,
        12.9987, 13.7687, 12.9021, 14.3158, 13.5300, 13.2175, 13.5919, 14.0193,
        13.2096, 12.8661, 14.1083, 13.1256, 14.3405, 13.4736, 14.1788, 13.3970,
        13.7272, 13.1166, 13.3756, 13.5783, 12.7398, 13.4358, 13.3844, 12.8670,
        13.6791, 14.1437, 13.2376, 13.5435, 14.1774, 13.0412, 13.4329, 14.0647,
        12.9474, 13.5686, 13.1864, 13.5179, 14.1768, 13.7969, 14.1786, 12.8989,
        14.1730, 13.1356, 13.3913, 13.0977, 13.8695, 13.8544, 13.7209, 13.5823,
        13.9350, 13.4777, 13.4713, 12.7683, 13.6696, 13.8028, 13.5839, 13.3189,
        13.0051, 13.2099, 14.2653, 13.17