In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
from torch.utils.data import DataLoader, RandomSampler
from torchvision import models
from torchvision.datasets import CIFAR10
from torchvision.utils import make_grid
import torchvision.transforms as transforms
from tensorboardX import SummaryWriter

#[TODO]
from network.modules import get_resnet, freeze_, unfreeze_
from network.modules.transformations import TransformsRelic
from network.modules.sync_batchnorm import convert_model


import os
import click
import time
import numpy as np
import copy

from con_losses import SupConLoss, ReLICLoss, BarlowTwinsLoss
from network import mnist_net,res_net, generator
import data_loader
from main_base import evaluate

import matplotlib.pyplot as plt


HOME = os.environ['HOME']

  from .autonotebook import tqdm as notebook_tqdm


In [76]:
backbone= 'resnet18'
pretrained='True'
ckpt="saved-model/cifar10/base_resnet18_True_128_run0/best.pkl"
projection_dim=128
lr= 1e-3
batchsize= 128
nbatch=100

In [63]:
encoder = get_resnet(backbone, pretrained) # Pretrained Backbone default as False - We will load our model anyway
n_features = encoder.fc.in_features
output_dim = 10 #{TODO}- output - cifar10
src_net= res_net.ConvNet(encoder, projection_dim, n_features, output_dim).cuda() #projection_dim/ n_features/output_dim=10
saved_weight = torch.load(ckpt)
src_net.load_state_dict(saved_weight['cls_net'])
src_opt = optim.Adam(src_net.parameters(), lr=lr)

In [64]:
src_net_copy= copy.deepcopy(src_net)

In [65]:
def freeze(freeze, model):
    '''
    Freeze a Model
    --freeze (Which block to freeze --encoder/heads/all [str])
    --model (which model to freeze [str])
    '''
    if freeze is not None:
        if freeze == "all":
            freeze_(model)
        elif freeze == 'encoder':
            freeze_(model.encoder)
        elif freeze == 'heads':
            freeze_(model.cls_head_src)
            freeze_(model.cls_head_tgt)
            freeze_(model.pro_head)
        else:
            raise ValueError("Please Freeze Either all/encoder/heads")
def unfreeze(unfreeze, model):
    '''
    Unfreeze a Model
    --unfreeze (Which block to unfreeze --encoder/heads/all [str])
    --model (which model to freeze [str])
    '''
    if unfreeze is not None:
        if unfreeze == "all":
            unfreeze_(model)
        elif unfreeze == 'encoder':
            unfreeze_(model.encoder)
        elif unfreeze == 'heads':
            unfreeze_(model.cls_head_src)
            unfreeze_(model.cls_head_tgt)
            unfreeze_(model.pro_head)
        else:
            raise ValueError("Please Unfreeze Either all/encoder/heads")

In [70]:
freeze("encoder", src_net_copy)

In [71]:
src_net_copy.encoder.layer1[0].conv1.weight.requires_grad

False

In [73]:
src_net_copy.encoder.layer1[0].bn1.weight.requires_grad

False

In [72]:
src_net_copy.pro_head[0].weight.requires_grad

True

# Comparing with Source Net

In [74]:
trset = data_loader.load_cifar10(split='train', autoaug=None) #Autoaug set as None
teset = data_loader.load_cifar10(split='test')

In [77]:
trloader = DataLoader(trset, batch_size=batchsize, num_workers=8, sampler=RandomSampler(trset, True, nbatch*batchsize))
teloader = DataLoader(teset, batch_size=batchsize, num_workers=8, shuffle=False, drop_last=True)

In [79]:
for i, (x, y) in enumerate(trloader):  
    x, y = x.cuda(), y.cuda()

In [92]:
p,z= src_net(x, mode= 'train')
p_oracle,z_oracle= src_net_copy(x, mode= 'train')
