In [1]:
!pip install lightly==1.4.7

Collecting lightly==1.4.7
  Downloading lightly-1.4.7-py3-none-any.whl.metadata (5.0 kB)
Collecting hydra-core>=1.0.0 (from lightly==1.4.7)
  Downloading hydra_core-1.3.2-py3-none-any.whl.metadata (5.5 kB)
Collecting lightly-utils~=0.0.0 (from lightly==1.4.7)
  Downloading lightly_utils-0.0.2-py3-none-any.whl (6.4 kB)
Collecting omegaconf<2.4,>=2.2 (from hydra-core>=1.0.0->lightly==1.4.7)
  Downloading omegaconf-2.3.0-py3-none-any.whl (79 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m79.5/79.5 kB[0m [31m3.6 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting antlr4-python3-runtime==4.9.* (from hydra-core>=1.0.0->lightly==1.4.7)
  Downloading antlr4-python3-runtime-4.9.3.tar.gz (117 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m117.0/117.0 kB[0m [31m5.9 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25ldone
Downloading lightly-1.4.7-py3-none-any.whl (647 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [3

In [2]:
import os
import torch
import numpy as np
import matplotlib.pyplot as plt

In [3]:
import torch
import torchvision
import torch.nn as nn
from lightly.models.modules.heads import SimCLRProjectionHead

class SimCLRModel(nn.Module):
    def __init__(self):    
        super(SimCLRModel, self).__init__()
        # create a ResNet backbone and remove the classification head
        # The dilation parameters are given, as backbone will be reused with deeplabv3+ model
        backbone = torchvision.models.resnet50(pretrained=True, replace_stride_with_dilation =[False, True, True])
        self.initial = nn.Sequential(*list(backbone.children())[:4])
        self.layer1 = backbone.layer1
        self.layer2 = backbone.layer2
        self.layer3 = backbone.layer3
        self.layer4 = backbone.layer4
        self.avgpool = backbone.avgpool

        hidden_dim = backbone.fc.in_features
        self.projection_head = SimCLRProjectionHead(hidden_dim, hidden_dim, 128)
    
    def backbone_forward(self, x):
        # Pass input through Backbone layers
        x = self.initial(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        return x
        
    def forward(self, x):
        x = self.backbone_forward(x)
        x = self.avgpool(x).flatten(start_dim=1)
        return self.projection_head(x)         

In [4]:
# dino pretrained model wieghts
dino_resnet50 = torch.hub.load('facebookresearch/dino:main', 'dino_resnet50')

# custom model
model = SimCLRModel()

# number of backbone layers  = 318
from collections import OrderedDict
test_soln = OrderedDict(zip(list(model.state_dict().keys())[:318], list(dino_resnet50.state_dict().values())[:318]))
model.load_state_dict(test_soln, strict=False)

Downloading: "https://github.com/facebookresearch/dino/zipball/main" to /root/.cache/torch/hub/main.zip
Downloading: "https://dl.fbaipublicfiles.com/dino/dino_resnet50_pretrain/dino_resnet50_pretrain.pth" to /root/.cache/torch/hub/checkpoints/dino_resnet50_pretrain.pth
100%|██████████| 90.0M/90.0M [00:00<00:00, 164MB/s] 
Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
100%|██████████| 97.8M/97.8M [00:00<00:00, 119MB/s] 


_IncompatibleKeys(missing_keys=['projection_head.layers.0.weight', 'projection_head.layers.1.weight', 'projection_head.layers.1.bias', 'projection_head.layers.1.running_mean', 'projection_head.layers.1.running_var', 'projection_head.layers.3.weight', 'projection_head.layers.4.weight', 'projection_head.layers.4.bias', 'projection_head.layers.4.running_mean', 'projection_head.layers.4.running_var'], unexpected_keys=[])

In [5]:
torch.save(model.state_dict(), "dino_resnet50.pt")