In [1]:
import pickle
import collections as col
import numpy as np
import random
import cv2
import pandas as pd
import matplotlib.pyplot as plt
import math
import pdb
from sklearn import metrics
import time
import os
import PIL

import torch
import torch.nn as nn
from sklearn.metrics import roc_auc_score
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils, datasets
import torch.optim as optim
import torchvision
import torch.nn.functional as F
from pytorch_datasets import DiagnosticInpainted
import models
import layers
import utilities.reading_images as reading_images
from utilities.loading import get_single_image
from torchvision.utils import save_image

%matplotlib inline

In [2]:
class Identity(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return x

In [3]:
class PixelCNN(nn.Module):
    def __init__(self, latent_dim):
        super(PixelCNN, self).__init__()
        
        # Conv2d: (input_channels, output_channels, kernel_size, padding)
        
        self.relu = nn.ReLU()
        
        self.model = nn.Sequential(
            nn.Conv2d(latent_dim, latent_dim, (1, 1)),
            nn.ReLU(),
            nn.ConstantPad2d((1, 1, 0, 0), 0),
            nn.Conv2d(latent_dim, latent_dim, (1, 3)),
            nn.ConstantPad2d((0, 0, 0, 1), 0),
            nn.Conv2d(latent_dim, latent_dim, (2, 1)),
            nn.ReLU(),
            nn.Conv2d(latent_dim, latent_dim, (1, 1))
        )

    def forward(self, latents):
        
        # latents: [B, C, H, W]
        cres = latents
        
        for _ in range(5):
            c = self.model(cres)
            cres = cres + c
        cres = self.relu(cres)
        return cres      

In [4]:
boom = torch.tensor(np.arange(20).reshape(2, 10))
boom

tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9],
        [10, 11, 12, 13, 14, 15, 16, 17, 18, 19]])

In [5]:
boom.reshape(2, 2, 5)

tensor([[[ 0,  1,  2,  3,  4],
         [ 5,  6,  7,  8,  9]],

        [[10, 11, 12, 13, 14],
         [15, 16, 17, 18, 19]]])

In [6]:
np.repeat(6, 10)

array([6, 6, 6, 6, 6, 6, 6, 6, 6, 6])

In [7]:
class CPC_loss(nn.Module):
    
    def __init__(self):
        super(CPC_loss, self).__init__()
        self.pixel_cnn = PixelCNN(512)
#         self.conv = nn.Conv2d(512, 64, kernel_size = (1, 1))
        self.conv_1 = nn.Conv2d(512, 512, kernel_size = (1, 1))
        self.conv_2 = nn.Conv2d(512, 512, kernel_size = (1, 1))
        self.conv_3 = nn.Conv2d(512, 512, kernel_size = (1, 1))
        self.loss_func = nn.CrossEntropyLoss()
    
    def forward(self, latents, device, target_dim = 64, steps_to_ignore = 2, steps_to_predict = 3, emb_scale = 0.1):
        # latents: [B, D, H, W]
        # aka:     [B, 512, 6, 6]
        loss = 0.0
        latents = latents.to(device)
        context = self.pixel_cnn(latents) # These are the c's (apply pixelCNN to Z's)
#         targets = self.conv(latents)
        targets = latents
        
        batch_dim, target_dim, col_dim, row_dim = targets.shape
        targets = targets.reshape(-1, target_dim)
        
        # Trying to do the arbitrary context vector
        index = np.random.choice(a = [0, 1, 2])
        context = context[:, :, index, :].unsqueeze(3) # [2, 512, 6, 1]
        
        
        preds_1 = self.conv_1(context).reshape(-1, 512) * emb_scale
        preds_2 = self.conv_2(context).reshape(-1, 512) * emb_scale
        preds_3 = self.conv_3(context).reshape(-1, 512) * emb_scale
        
        logits_1 = torch.matmul(preds_1, targets.permute(1, 0)) # 12 by 512, 512 by 72 --> 12 by 72
        logits_2 = torch.matmul(preds_2, targets.permute(1, 0))
        logits_3 = torch.matmul(preds_3, targets.permute(1, 0))
        
        total_elements = batch_dim * row_dim
        b = np.arange(total_elements) / (row_dim)
        b = b.astype(int)
        col = np.arange(total_elements) % (row_dim)

        labels_1 = b * col_dim * row_dim + 3 * row_dim + col
        labels_2 = labels_1 + 6
        labels_3 = labels_2 + 6
        
        loss += self.loss_func(logits_1, torch.LongTensor(labels_1).to(device))
        loss += self.loss_func(logits_2, torch.LongTensor(labels_2).to(device))
        loss += self.loss_func(logits_3, torch.LongTensor(labels_3).to(device))
        
        
        
        
#         for i in range(steps_to_ignore, steps_to_predict):
# #             pdb.set_trace()
#             col_dim_i = col_dim - i - 1  # 6 - 2 - 1 = 3
#             total_elements = batch_dim * col_dim_i * row_dim
#             preds_i = self.conv_preds(context)
#             preds_i = preds_i[:, :, :(i+1), :] * emb_scale   # [B, 64, 6, 6] ---> [B, 64, 3, 6]
#             preds_i = preds_i.reshape(-1, target_dim)
            
#             logits = torch.matmul(preds_i, targets.permute(1, 0)) # 18 by 64, 64 by 36 ---> 18 by 36
            
#             b = np.arange(total_elements) / (col_dim_i * row_dim)
#             b = b.astype(int)
#             col = np.arange(total_elements) % (col_dim_i * row_dim)
#             labels = b * col_dim * row_dim + (i + 1) * row_dim + col
#             labels = torch.LongTensor(labels).to(device)
#             logits = logits.to(device)
            
#             rand = np.random.choice(a=[False, True], size = (logits.shape[0],))
#             logits = logits[rand, :]
#             labels = labels[rand]
#             loss += self.loss_func(logits, labels)
        return loss
            
            
        
        
        

In [8]:
np.arange(36)

array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
       17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33,
       34, 35])

In [9]:
total_elements = 1 * 3 * 6
b = np.arange(18) / (3 * 6)
b = b.astype(int)
b

array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])

In [10]:
col = np.arange(18) % (3 * 6)
col

array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
       17])

In [11]:
labels = b * 6 * 6 + 3*6 + col
labels

array([18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34,
       35])

In [12]:
# Implementation in the paper is unclear.
# I'm going to go with WF. 

# NCE Loss
# Questions: Is the dimension of Z (B*patches) or (B). 
#            I think it's (B, 6, 6, 4096)

class CPCLossNCE(nn.Module):
    
    def nce_loss(self, z_hat, pos_scores, negative_samples, mask_mat, device, epoch_num, batch_num):
        
        z_hat = z_hat.to(device)
        pos_scores = pos_scores.to(device)
        negative_samples = negative_samples.to(device)
        mask_mat = mask_mat.to(device)
                
        # (b, 1)
        pos_scores = pos_scores.float()
        batch_size, emb_dim = z_hat.size()
        nb_feat_vectors = negative_samples.size(1) // batch_size # 36 of them, if 6 by 6 wireframes. 
        
        # (b, b) -> (b, b, nb_feat_vectors)
        # all zeros with ones in diagonal tensor... (ie: b1 b1 are all 1s, b1 b2 are all zeros)
        mask_pos = mask_mat.unsqueeze(dim=2).expand(-1, -1, nb_feat_vectors).float()
        
        # negative mask
        mask_neg = 1. - mask_pos
        
        # ----------------------
        # ALL SCORES computation
        # (visualize in your mind a batch size of 2, 36-length segments) 
        # (b, dim) x (dim, nb_feats*b) -> (b, b, nb_feats)
        raw_scores = torch.mm(z_hat, negative_samples)
        raw_scores = raw_scores.reshape(batch_size, batch_size, nb_feat_vectors).float()
        
        # EXTRACT NEGATIVE SCORES
        # (batch_size, batch_size, nb_feat_vectors)
        # HE'S TAKING THE NEGATIVE SAMPLES FROM THE OTHER MINIBATCHES
        # A GIVEN Z_HAT IS ONLY MULTIPLIED BY Z'S FROM OTHER MINIBATCHES
        neg_scores = (mask_neg * raw_scores)
        # ----------------------
        
        # (b, b, nb_feat_vectors) -> (batch_size, batch_size * nb_feat_vectors) 
        neg_scores = neg_scores.reshape(batch_size, -1)
        mask_neg = mask_neg.reshape(batch_size, -1)
        
        # STABLE SOFTMAX
        # (n_batch_gpu, 1)
        neg_maxes = torch.max(neg_scores, dim=1, keepdim=True)[0]
        
        # DENOMINATOR
        # sum over only negative samples (none from the diagonal)
        neg_sumexp = (mask_neg * torch.exp(neg_scores - neg_maxes)).sum(dim=1, keepdim=True)
        all_logsumexp = torch.log(torch.exp(pos_scores - neg_maxes) + neg_sumexp)
        
        # NUMERATOR
        # compute numerators for the NCE log-softmaxes
        pos_shiftexp = pos_scores - neg_maxes
        
        # FULL NCE
#         if epoch_num > 2 and batch_num == 100:
#             pdb.set_trace()
        nce_scores = pos_shiftexp - all_logsumexp
        nce_scores = -nce_scores.mean()
        
#         if np.isnan(nce_scores.cpu().detach().numpy()):
#             pdb.set_trace()
#             print('boom - nceloss')
            
        
        return nce_scores
        
        
        
    def forward(self, Z, C, W_list, device, epoch_num, batch_num):
        '''
        param Z: latent vecs (B, D, H, W)
        param C: context vecs (B, D, H, W)
        param W_list: list of k-1 W projections
        '''
        
        # (b, dim, w, h)
        batch_size, emb_dim, h, w = Z.size()
        
        # (10 x 10 identity matrix)
        diag_mat = torch.eye(batch_size)
        diag_mat = diag_mat.float()
        
        losses = []
        # calculate loss for each k
        
        # Below operations preserve raster order (for B, D, H, W) = (1, 5, 2, 2) check.
        # Z_neg holds all z vecs. 
        Z_neg = Z.permute(1, 0, 2, 3).reshape(emb_dim, -1)
        
        
        for i in range(0, h-1):
            for j in range(0, w):
                cij = C[:, :, i, j]   # B by D
                
                for k in range(i+1, h): # predict on all vectors in the same column, but below current wireframe. 
                    Wk = W_list[str(k)]
                    
                    z_hat_ikj = Wk(cij)
                    zikj = Z[:, :, k, j]
                    
                    # BATCH DOT PRODUCT
                    # (b, d) x (b, d) -> (b, 1)
                    pos_scores = torch.bmm(z_hat_ikj.unsqueeze(1), zikj.unsqueeze(2))
                    pos_scores = pos_scores.squeeze(-1).squeeze(-1)
                    
                    loss = self.nce_loss(z_hat_ikj, pos_scores, Z_neg, diag_mat, device, epoch_num, batch_num)
                    if np.isinf(loss.item()):
                        pdb.set_trace()
                        print("inf -- inside inner for loop")
                    if np.isnan(loss.item()):
                        pdb.set_trace()
                        print("inside inner for loop")
                    losses.append(loss)
                    
                    
        losses = torch.stack(losses)
        loss = losses.mean()
#         if np.isnan(loss.item()):
#             pdb.set_trace()
#             print('boom')
        return loss           
        
        

In [13]:
boom = torch.rand(2, 2)

In [14]:
boom = boom.unsqueeze(0)

In [15]:
np.repeat(boom, 3, axis=0).shape

torch.Size([3, 2, 2])

In [16]:
def train_raster_patchify(img, size = 80, overlap = 32):
    '''
    Left-to-right, top to bottom.
    Assumes img is (3, 240, 240).
    '''
    patches = []
     
    h = -32
    w = -32
    for i in range(6):
        h = h + 32
        for j in range(6):
            w = w + 32
            channel = np.random.randint(3)
            processed_img = np.repeat(np.expand_dims(img[channel, h:h+size, w:w+size], axis=0), 3, axis=0)
            if np.random.randint(2):
                processed_img = np.flip(processed_img, axis=2)
            patches.append(torch.tensor(np.ascontiguousarray(processed_img)))
        w = -32
            
    return patches


def val_raster_patchify(img, size = 80, overlap = 32):
    '''
    Left-to-right, top to bottom.
    Assumes img is (3, 240, 240).
    '''
    patches = []
     
    h = -32
    w = -32
    for i in range(6):
        h = h + 32
        for j in range(6):
            w = w + 32
            patches.append(img[:, h:h+size, w:w+size])
        w = -32
            
    return patches
    

In [17]:
# plt.imshow(patches[35].permute(1, 2, 0))

# plt.imshow(trainset[0][0].permute(1, 2, 0))
# plt.scatter(80,80,color='r')
# plt.scatter(80+32,80,color='r')
# plt.scatter(80+64,80,color='r')
# plt.scatter(80+96,80,color='r')
# plt.scatter(80+128,80,color='r')
# plt.scatter(80+160,80,color='r')

In [18]:
def train_collate_fn(img_list):
    patches = []
    for (img, label) in img_list:
        img_patches = train_raster_patchify(img)
        patches.append(torch.stack(img_patches))
        
    return patches

def val_collate_fn(img_list):
    patches = []
    for (img, label) in img_list:
        img_patches = val_raster_patchify(img)
        patches.append(torch.stack(img_patches))
        
    return patches

In [20]:
data_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(240),
    transforms.ToTensor(),
#     transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])

trainset = datasets.ImageFolder(
    root = '/gpfs/data/geraslab/Vish/imagenette2-320/train/',
    transform = data_transform
)

train_dl = DataLoader(trainset, batch_size=32, shuffle=True, collate_fn=train_collate_fn)

valset = datasets.ImageFolder(
    root = '/gpfs/data/geraslab/Vish/imagenette2-320/val/',
    transform = data_transform
)

val_dl = DataLoader(valset, batch_size=32, shuffle=True, collate_fn=val_collate_fn)

In [21]:
def remove_batchnorm(model):
    model.bn1 = Identity()
    model.layer1[0].bn1 = Identity()
    model.layer1[0].bn2 = Identity()
    model.layer1[1].bn1 = Identity()
    model.layer1[1].bn2 = Identity()
    model.layer1[2].bn1 = Identity()
    model.layer1[2].bn2 = Identity()
    
    model.layer2[0].bn1 = Identity()
    model.layer2[0].bn2 = Identity()
    model.layer2[0].downsample[1] = Identity()
    model.layer2[1].bn1 = Identity()
    model.layer2[1].bn2 = Identity()
    model.layer2[2].bn1 = Identity()
    model.layer2[2].bn2 = Identity()
    model.layer2[3].bn1 = Identity()
    model.layer2[3].bn2 = Identity()
    
    model.layer3[0].bn1 = Identity()
    model.layer3[0].bn2 = Identity()
    model.layer3[0].downsample[1] = Identity()
    model.layer3[1].bn1 = Identity()
    model.layer3[1].bn2 = Identity()
    model.layer3[2].bn1 = Identity()
    model.layer3[2].bn2 = Identity()
    model.layer3[3].bn1 = Identity()
    model.layer3[3].bn2 = Identity()
    model.layer3[4].bn1 = Identity()
    model.layer3[4].bn2 = Identity()
    model.layer3[5].bn1 = Identity()
    model.layer3[5].bn2 = Identity()
    
    model.layer4[0].bn1 = Identity()
    model.layer4[0].bn2 = Identity()
    model.layer4[0].downsample[1] = Identity()
    model.layer4[1].bn1 = Identity()
    model.layer4[1].bn2 = Identity()
    model.layer4[2].bn1 = Identity()
    model.layer4[2].bn2 = Identity()

In [22]:
class CPC(nn.Module):
    def __init__(self):
        super(CPC, self).__init__()
        self.encoder = torchvision.models.resnet34()
        self.encoder.fc = Identity()
        remove_batchnorm(self.encoder)
        self.pixel_cnn = PixelCNN(512)
        self.nce_loss = CPC_loss()
#         self.nce_loss = CPCLossNCE()
        
#         # W transforms (k > 0)
#         self.W_list = {}
#         for k in range(1, 6):
#             w = torch.nn.Linear(512, 512)
#             self.W_list[str(k)] = w

#         self.W_list = nn.ModuleDict(self.W_list)
        

    def forward(self, x, device, epoch_num, batch_num):
        Z = []
        C = []
        for img_patches in x:
            img_patches = img_patches.to(device)
            z = self.encoder(img_patches).squeeze()
            z = z.unsqueeze(0).permute(0, 2, 1).reshape(1, 512, 6, 6)
            Z.append(z)
            c = self.pixel_cnn(z)
            C.append(c)

        Z = torch.stack(Z).squeeze(1)
        C = torch.stack(C).squeeze(1)

        loss = self.nce_loss(Z, device)
        
        
        return loss

In [23]:
def one_epoch(dl, model, optimizer, device, epoch_num, phase = 'train'):
    if phase == 'train':
        model.train()
    else:
        model.eval()
#         for m in model.modules():
#             if isinstance(m, nn.BatchNorm2d):
#                 m.track_running_stats = False
    losses = []
    for i, x in enumerate(dl):
        if phase == 'train':
            optimizer.zero_grad()

        loss = model(x, device, epoch_num, i)
        losses.append(loss.item())

        if phase == 'train': 
            loss.backward()
            optimizer.step()
            
            if i % 50 == 0:
                print("Batch: {}/{}, Loss: {}".format(i, len(dl), loss.item())) 
            
    
    return np.mean(losses)

In [24]:
def run_epochs(epoch_num):
    torch.cuda.set_device(6)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model = CPC().to(device)
    optimizer = optim.Adam(model.parameters(), lr = 2e-4, weight_decay=1e-5, eps=1e-8)
    
#     pretrained_dict = torch.load('self_supervised_rc_15.pt')
#     model_dict = model.state_dict()

#     # 1. filter out unnecessary keys
#     pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
#     # 2. overwrite entries in the existing state dict
#     model_dict.update(pretrained_dict) 
#     # 3. load the new state dict
#     model.load_state_dict(model_dict)
    
    best_val_loss = 1000000
    for i in range(epoch_num):
        print("Started epoch {}\n".format(i))
        avg_train_loss = one_epoch(train_dl, model, optimizer, device, i, phase = 'train')
        print("Average Epoch {} Loss: {}\n".format(i, avg_train_loss))
        avg_val_loss = one_epoch(val_dl, model, optimizer, device, i, phase = 'val')
        print("Validation Loss: {}\n".format(avg_val_loss))
        
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save(model.state_dict(), "paper_self_supervised_rc_best_val.pt")
            print("\nSaved model with best validation loss: {}".format(best_val_loss))
        
        if i in [1, 10, 20, 25]:
            torch.save(model.state_dict(), "paper_self_supervised_rc_{}.pt".format(i))
    

In [25]:
run_epochs(30)

Started epoch 0

Batch: 0/296, Loss: 17646.0859375
Batch: 50/296, Loss: 10.218378067016602
Batch: 100/296, Loss: 8.624032020568848
Batch: 150/296, Loss: 3.7366786003112793
Batch: 200/296, Loss: 4.514866828918457
Batch: 250/296, Loss: 0.9217681884765625
Average Epoch 0 Loss: 66.88577281633341

Validation Loss: 0.9694324612496344


Saved model with best validation loss: 0.9694324612496344
Started epoch 1

Batch: 0/296, Loss: 1.1164277791976929
Batch: 50/296, Loss: 0.5281332731246948
Batch: 100/296, Loss: 2.533245801925659
Batch: 150/296, Loss: 0.3566368520259857
Batch: 200/296, Loss: 0.11818763613700867
Batch: 250/296, Loss: 0.01738981530070305
Average Epoch 1 Loss: 0.756940320571507

Validation Loss: 0.3884073811757371


Saved model with best validation loss: 0.3884073811757371
Started epoch 2

Batch: 0/296, Loss: 0.4992157220840454
Batch: 50/296, Loss: 0.019082389771938324
Batch: 100/296, Loss: 3.908580780029297
Batch: 150/296, Loss: 0.15137018263339996
Batch: 200/296, Loss: 0.15695790

Batch: 50/296, Loss: 2.471792140568141e-05
Batch: 100/296, Loss: 8.232084655901417e-05
Batch: 150/296, Loss: 0.023032132536172867
Batch: 200/296, Loss: 0.08631636947393417
Batch: 250/296, Loss: 0.159549281001091
Average Epoch 22 Loss: 0.1330854403047686

Validation Loss: 0.11809169286265742

Started epoch 23

Batch: 0/296, Loss: 0.29722481966018677
Batch: 50/296, Loss: 1.4901161193847656e-08
Batch: 100/296, Loss: 0.00805171113461256
Batch: 150/296, Loss: 0.0018004218582063913
Batch: 200/296, Loss: 0.0004298919520806521
Batch: 250/296, Loss: 0.0004460994969122112
Average Epoch 23 Loss: 0.012502065027002896

Validation Loss: 0.00909582914934446

Started epoch 24

Batch: 0/296, Loss: 0.00017516003572382033
Batch: 50/296, Loss: 5.002009493182413e-05
Batch: 100/296, Loss: 7.654617547814269e-06
Batch: 150/296, Loss: 6.163492798805237e-06
Batch: 200/296, Loss: 1.9868215517249155e-08
Batch: 250/296, Loss: 3.223461317247711e-05
Average Epoch 24 Loss: 0.007457826524026214

Validation Loss: 0.009

In [96]:
a = torch.tensor(6)
a.repeat(5)

tensor([6, 6, 6, 6, 6])

In [None]:
len(trainset)

In [19]:
trainset[50][0].shape

torch.Size([3, 240, 240])

In [4]:
class ResNet(nn.Module):
    """
    Adapted from torchvision ResNet, converted to v2
    """

    def __init__(self,
                 input_channels, num_filters,
                 first_layer_kernel_size, first_layer_conv_stride,
                 blocks_per_layer_list, block_strides_list, block_fn,
                 first_layer_padding=0,
                 first_pool_size=None, first_pool_stride=None, first_pool_padding=0,
                 growth_factor=2, norm_class="batch", num_groups=1):
        super(ResNet, self).__init__()
        self.first_conv = nn.Conv2d(
            in_channels=input_channels, out_channels=num_filters,
            kernel_size=first_layer_kernel_size,
            stride=first_layer_conv_stride,
            padding=first_layer_padding,
            bias=False,
        )
        # Diff: padding=SAME vs. padding=0
        self.first_pool = nn.MaxPool2d(
            kernel_size=first_pool_size,
            stride=first_pool_stride,
            padding=first_pool_padding,
        )
        self.norm_class = norm_class
        self.num_groups = num_groups

        block = self._resolve_block(block_fn)
        self.layer_list = nn.ModuleList()
        current_num_filters = num_filters
        self.inplanes = num_filters
        for i, (num_blocks, stride) in enumerate(zip(
                blocks_per_layer_list, block_strides_list)):
            self.layer_list.append(self._make_layer(
                block=block,
                planes=current_num_filters,
                blocks=num_blocks,
                stride=stride,
            ))
            current_num_filters *= growth_factor

        self.final_bn = layers.resolve_norm_layer(
            # current_num_filters // growth_factor
            current_num_filters // growth_factor * block.expansion,
            norm_class,
            num_groups
        )
        self.relu = nn.ReLU()
        self.initialize()

        # Expose attributes for downstream dimension computation
        self.num_filters = num_filters
        self.growth_factor = growth_factor
        self.block = block
        self.num_filter_last_seq = current_num_filters // growth_factor * block.expansion

    def forward(self, x, return_intermediate=False):
        intermediate = []
        h = self.first_conv(x)
        h = self.first_pool(h)

        if return_intermediate:
            intermediate.append(h)
        for i, layer in enumerate(self.layer_list):
            h = layer(h)
            if return_intermediate:
                intermediate.append(h)

        h = self.final_bn(h)
        h = self.relu(h)

        if return_intermediate:
            return h, intermediate
        else:
            return h

    @classmethod
    def _resolve_block(cls, block_fn):
        if block_fn == "normal":
            return layers.BasicBlockV2_dbt
        elif block_fn == "bottleneck":
            return layers.BottleneckV2_dbt
        else:
            raise KeyError(block_fn)

    def _make_layer(self, block, planes, blocks, stride=1):
        # downsample = None
        # if stride != 1 or self.inplanes != planes * block.expansion:
        downsample = nn.Sequential(
            nn.Conv2d(self.inplanes, planes * block.expansion,
                      kernel_size=1, stride=stride, bias=False),
            nn.BatchNorm2d(planes * block.expansion),
        )

        layers_ = [
            block(self.inplanes, planes, stride, downsample, self.norm_class, self.num_groups)
        ]
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers_.append(block(self.inplanes, planes, norm_class=self.norm_class, num_groups=self.num_groups))

        return nn.Sequential(*layers_)

    def initialize(self):
        for m in self.modules():
            self._layer_init(m)

    @classmethod
    def _layer_init(cls, m):
        if isinstance(m, nn.Conv2d):
            # From original
            nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
        #             nn.init.xavier_normal_(m.weight)
        elif isinstance(m, nn.BatchNorm2d):
            nn.init.constant_(m.weight, 1)
            nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.GroupNorm):
            nn.init.constant_(m.weight, 1)
            nn.init.constant_(m.bias, 0)

    @classmethod
    def from_parameters(cls, parameters):
        return cls(
            input_channels=parameters["input_channels"],
            num_filters=parameters["num_filters"],
            first_layer_kernel_size=parameters["first_layer_kernel_size"],
            first_layer_conv_stride=parameters["first_layer_conv_stride"],
            first_layer_padding=parameters.get("first_layer_padding", 0),
            blocks_per_layer_list=parameters["blocks_per_layer_list"],
            block_strides_list=parameters["block_strides_list"],
            block_fn=parameters["block_fn"],
            first_pool_size=parameters["first_pool_size"],
            first_pool_stride=parameters["first_pool_stride"],
            first_pool_padding=parameters.get("first_pool_padding", 0),
            growth_factor=parameters.get("growth_factor", 2),
            norm_class=parameters.get("norm_class", "batch"),
            num_groups=parameters.get("num_groups", 1)
        )
    


In [5]:
class ResNet_22(nn.Module):

    def __init__(
            self,
            attention=False,
            dropout=0.0,
            hidden_size=256,

            # resnet hyperparameters
            #         input_channels=1,
            first_layer_kernel_size=7,
            first_layer_conv_stride=2,
            first_pool_size=3,
            first_pool_stride=2,
            first_layer_padding=0,
            first_pool_padding=0,
            growth_factor=2,

            # resnet22 settings
            num_filters=16,
            blocks_per_layer_list=[2, 2, 2, 2, 2],
            block_strides_list=[1, 2, 2, 2, 2],
            block_fn="normal",
            norm_class="group",
            num_groups=8,

            num_image_slices_per_net=1,
    ):
        super(ResNet_22, self).__init__()

        self.num_image_slices_per_net = num_image_slices_per_net

        self.dropout = nn.Dropout(p=dropout)
        self.relu = nn.ReLU()

        self.resnet = ResNet(
            input_channels=3,
            first_layer_kernel_size=first_layer_kernel_size,
            first_layer_conv_stride=first_layer_conv_stride,
            first_pool_size=first_pool_size,
            first_pool_stride=first_pool_stride,
            num_filters=num_filters,
            blocks_per_layer_list=blocks_per_layer_list,
            block_strides_list=block_strides_list,
            block_fn=block_fn,
            first_layer_padding=first_layer_padding,
            first_pool_padding=first_pool_padding,
            growth_factor=growth_factor,
            norm_class=norm_class,
            num_groups=num_groups,
        )

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))  # use avgpool rather than torch.mean
        
    def forward(self, x):
        batch_size = x.shape[0]

        h = self.resnet(x)
        # Shape of pooled_h is [4, 256, 1, 1]
        pooled_h = self.avgpool(h)
        return pooled_h

In [16]:
class Identity(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return x
    

model = ResNet_22()
model.eval()

boom = model(trainset[0][0].unsqueeze(0))

NameError: name 'ResNet_22' is not defined

In [57]:
model = resnet34()

In [59]:
model.fc = Identity()

In [163]:
boom.shape

torch.Size([1, 256, 1, 1])