In [1]:
import pickle
import collections as col
import numpy as np
import random
import cv2
import pandas as pd
import matplotlib.pyplot as plt
import math
import pdb
from sklearn import metrics
import time
import os
import PIL

import torch
import torch.nn as nn
from sklearn.metrics import roc_auc_score
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils, datasets
import torch.optim as optim
from torchvision.models import resnet34, resnet50
import torch.nn.functional as F
import models
import layers
import utilities.reading_images as reading_images
from utilities.loading import get_single_image
from torchvision.utils import save_image

%matplotlib inline

In [2]:
# Okay so I've figured out that it's the Batchnorm layer that's causing this. 
# I basically trained a model on Imagewoof and saved several checkpoints. 
# At epoch 30, the model gave me 0.06 loss. That should correlate with close to 100% train accuracy, 
# especially since the classes are balanced. 
# It did not. Train and val accuracies are all ~15%. 
# I then used this saved model and printed its loss in eval mode. It was 4.095. Obviously way off. 
# Several people online encounter this problem. Drastic difference in behavior in BatchNorm between .train()
# and .eval() time. 
# Soumith suggests increasing the momentum parameter. 

In [3]:
class Identity(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return x

In [4]:
class PixelCNN(nn.Module):
    def __init__(self, latent_dim):
        super(PixelCNN, self).__init__()
        
        # Conv2d: (input_channels, output_channels, kernel_size, padding)
        
        self.relu = nn.ReLU()
        
        self.model = nn.Sequential(
            nn.Conv2d(256, 256, (1, 1)),
            nn.ReLU(),
            nn.ConstantPad2d((1, 1, 0, 0), 0),
            nn.Conv2d(256, 256, (1, 3)),
            nn.ConstantPad2d((0, 0, 0, 1), 0),
            nn.Conv2d(256, 256, (2, 1)),
            nn.ReLU(),
            nn.Conv2d(256, 256, (1, 1))
        )

    def forward(self, latents):
        
        # latents: [B, C, H, W]
        cres = latents
        
        for _ in range(5):
            c = self.model(cres)
            cres = cres + c
        cres = self.relu(cres)
        return cres      

In [5]:
def train_raster_patchify(img, size = 80, overlap = 32):
    '''
    Left-to-right, top to bottom.
    Assumes img is (3, 240, 240).
    '''
    patches = []
     
    h = -32
    w = -32
    for i in range(6):
        h = h + 32
        for j in range(6):
            w = w + 32
            channel = np.random.randint(3)
            processed_img = np.repeat(np.expand_dims(img[channel, h:h+size, w:w+size], axis=0), 3, axis=0)
            if np.random.randint(2):
                processed_img = np.flip(processed_img, axis=2)
            patches.append(torch.tensor(np.ascontiguousarray(processed_img)))
        w = -32
            
    return patches

In [6]:
def val_raster_patchify(img, size = 80, overlap = 32):
    '''
    Left-to-right, top to bottom.
    Assumes img is (3, 240, 240).
    '''
    patches = []
     
    h = -32
    w = -32
    for i in range(6):
        h = h + 32
        for j in range(6):
            w = w + 32
            patches.append(img[:, h:h+size, w:w+size])
        w = -32
            
    return patches

In [7]:
def val_collate_fn(img_list):
    patches = []
    labels = []
    for (img, label) in img_list:
        img_patches = val_raster_patchify(img)
        patches.append(torch.stack(img_patches))
        labels.append(label)
        
    return patches, labels

def train_collate_fn(img_list):
    patches = []
    labels = []
    for (img, label) in img_list:
        img_patches = train_raster_patchify(img)
        patches.append(torch.stack(img_patches))
        labels.append(label)
        
    return patches, labels

# train_transform = transforms.Compose([
#     transforms.Resize(256),
#     transforms.RandomCrop(240),
#     transforms.ColorJitter(brightness=(0.55, 1), contrast=(0.5, 1), saturation=(0.5, 1), hue=0.1),
#     transforms.ToTensor(),
# #     transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
# ])

data_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(240),
    transforms.ToTensor(),
])



trainset = datasets.ImageFolder(
    root = '/gpfs/data/geraslab/Vish/imagenette2-320/train/',
    transform = data_transform
)

train_dl = DataLoader(trainset, batch_size=32, shuffle=True, collate_fn=train_collate_fn)

valset = datasets.ImageFolder(
    root = '/gpfs/data/geraslab/Vish/imagenette2-320/val/',
    transform = data_transform
)

val_dl = DataLoader(valset, batch_size=32, shuffle=True, collate_fn=val_collate_fn)

In [8]:
def remove_batchnorm(model):
    model.bn1 = Identity()
    model.layer1[0].bn1 = Identity()
    model.layer1[0].bn2 = Identity()
    model.layer1[0].bn3 = Identity()
    model.layer1[0].downsample[1] = Identity()
    
    model.layer1[1].bn1 = Identity()
    model.layer1[1].bn2 = Identity()
    model.layer1[1].bn3 = Identity()
    
    model.layer1[2].bn1 = Identity()
    model.layer1[2].bn2 = Identity()
    model.layer1[2].bn3 = Identity()
    
    model.layer2[0].bn1 = Identity()
    model.layer2[0].bn2 = Identity()
    model.layer2[0].bn3 = Identity()
    model.layer2[0].downsample[1] = Identity()
    
    model.layer2[1].bn1 = Identity()
    model.layer2[1].bn2 = Identity()
    model.layer2[1].bn3 = Identity()
    
    model.layer2[2].bn1 = Identity()
    model.layer2[2].bn2 = Identity()
    model.layer2[2].bn3 = Identity()
    
    model.layer2[3].bn1 = Identity()
    model.layer2[3].bn2 = Identity()
    model.layer2[3].bn3 = Identity()
    
#     model.layer3[0].bn1 = Identity()
#     model.layer3[0].bn2 = Identity()
#     model.layer3[0].bn3 = Identity()
#     model.layer3[0].downsample[1] = Identity()
    
#     model.layer3[1].bn1 = Identity()
#     model.layer3[1].bn2 = Identity()
#     model.layer3[1].bn3 = Identity()
    
#     model.layer3[2].bn1 = Identity()
#     model.layer3[2].bn2 = Identity()
#     model.layer3[2].bn3 = Identity()
    
    
#     model.layer3[3].bn1 = Identity()
#     model.layer3[3].bn2 = Identity()
#     model.layer3[3].bn3 = Identity()
    
    
#     model.layer3[4].bn1 = Identity()
#     model.layer3[4].bn2 = Identity()
#     model.layer3[4].bn3 = Identity()
    
#     model.layer3[5].bn1 = Identity()
#     model.layer3[5].bn2 = Identity()
#     model.layer3[5].bn3 = Identity()

    model.layer3 = Identity()
    
    model.layer4 = Identity()
    
#     model.layer4[0].bn1 = Identity()
#     model.layer4[0].bn2 = Identity()
#     model.layer4[0].bn3 = Identity()
#     model.layer4[0].downsample[1] = Identity()
    
#     model.layer4[1].bn1 = Identity()
#     model.layer4[1].bn2 = Identity()
#     model.layer4[1].bn3 = Identity()
    
#     model.layer4[2].bn1 = Identity()
#     model.layer4[2].bn2 = Identity()
#     model.layer4[2].bn3 = Identity()




    
#     model.layer4[0].bn1 = Identity()
#     model.layer4[0].bn2 = Identity()
#     model.layer4[0].downsample[1] = Identity()
#     model.layer4[1].bn1 = Identity()
#     model.layer4[1].bn2 = Identity()
#     model.layer4[2].bn1 = Identity()
#     model.layer4[2].bn2 = Identity()

In [24]:
class CPC_Linear(nn.Module):
    def __init__(self):
        super(CPC_Linear, self).__init__()
        self.encoder = resnet50()
        self.encoder.fc = Identity()
        remove_batchnorm(self.encoder)
#         self.bn = nn.GroupNorm(32, 512)
        self.bn = nn.BatchNorm2d(512)
        self.conv_1 = nn.Conv2d(512, 25, (1, 1))
        self.avg_pool = nn.AvgPool2d(6, 6)
        
        self.bn_50_1 = nn.BatchNorm1d(50)
        self.bn_50_2 = nn.BatchNorm1d(50)
        
        self.lin_1 = nn.Linear(25*6*6, 50)
        self.relu = nn.ReLU()
        self.lin_2 = nn.Linear(50, 50)
        self.lin_3 = nn.Linear(50, 10)
#         self.dropout = nn.Dropout(p=0.2)

    def forward(self, x, device):
        Z = []
        for img_patches in x:
            img_patches = img_patches.to(device)
            z = self.encoder(img_patches).squeeze()
            z = z.unsqueeze(0).permute(0, 2, 1).reshape(1, 512, 6, 6)
            Z.append(z)

        Z = torch.stack(Z).squeeze(1)
        
        x = self.conv_1(self.bn(Z))
        x = x.view(-1, 25*6*6)
        
        x = self.relu(self.bn_50_1(self.lin_1(x)))
        x = self.relu(self.bn_50_2(self.lin_2(x)))
                
        output = self.lin_3(x)
        
        

        # output = self.avg_pool(self.conv_1(self.bn(Z))).squeeze(2).squeeze(2)
        
        return output

In [25]:
def one_epoch(dl, model, loss_func, optimizer, device, phase = 'train'):
    if phase == 'train':
        model.train()
    else:
        model.eval()
#         for m in model.modules():
#             if isinstance(m, nn.BatchNorm2d):
#                 m.track_running_stats = False
    losses = []
    correct = 0
    for i, (x, labels) in enumerate(dl):
        if phase == 'train':
            optimizer.zero_grad()

        labels = torch.from_numpy(np.stack(labels)).to(device)
        preds_logit = model(x, device)
        loss = loss_func(preds_logit, labels)
        losses.append(loss.item())

        if phase == 'train': 
            loss.backward()
            optimizer.step()
            
            if i % 50 == 0:
                print("Batch: {}/{}, Loss: {}".format(i, len(dl), loss.item())) 
        else:
            preds_label = torch.argmax(preds_logit, dim=1)
            correct += sum(preds_label == labels)
            
    
    if phase == 'val':
        return correct, np.mean(losses)
    elif phase == 'train':
        return np.mean(losses)

In [26]:
# torch.cuda.set_device(7)
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# model = CPC_Linear()
# pretrained_dict = torch.load('pretrained_imagewoof_bn_0.5_20.pt')
# model_dict = model.state_dict()

# # 1. filter out unnecessary keys
# pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
# # 2. overwrite entries in the existing state dict
# model_dict.update(pretrained_dict) 
# # 3. load the new state dict
# model.load_state_dict(model_dict)

# model = model.to(device)
# loss_func = nn.CrossEntropyLoss()
# optimizer = optim.Adam(model.parameters(), lr = 1e-3, eps=1e-8)

# correct, average_loss = one_epoch(train_dl, model, loss_func, optimizer, device, phase = 'val')

# print(correct)
# print(average_loss)

In [27]:
# loss

In [28]:
def run_epochs(epoch_num):
    torch.cuda.set_device(4)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    model = CPC_Linear()
    pretrained_dict = torch.load('paper_self_supervised_rc_best_val.pt')
    model_dict = model.state_dict()

    # 1. filter out unnecessary keys
    pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
    # 2. overwrite entries in the existing state dict
    model_dict.update(pretrained_dict) 
    # 3. load the new state dict
    model.load_state_dict(model_dict)

    model = model.to(device)
    loss_func = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr = 1e-3, eps=1e-8)
    
    # ----------------------------------------------
    # FREEZE ENCODER
    for param in model.encoder.parameters():
        param.requires_grad = False
    
    for i in range(epoch_num):
        epoch_loss = one_epoch(train_dl, model, loss_func, optimizer, device, phase = 'train')
        print("Average Epoch {} Loss: {}".format(i, epoch_loss))
        correct, _ = one_epoch(train_dl, model, loss_func, optimizer, device, phase = 'val')
        print("Train Accuracy: {}".format(1. * correct / len(trainset)))
        correct, _ = one_epoch(val_dl, model, loss_func, optimizer, device, phase = 'val')
        print("Validation Accuracy: {}".format(1. * correct / len(valset)))
        
        if i in [1, 10, 20, 30]:
            torch.save(model.state_dict(), "pretrained_imagewoof_batch_norm_frozen_{}.pt".format(i))
        
    

In [29]:
# Using CPC-pretrained weights, fine-tuning all. Validation Accuracy (best):  38%, within 23 epochs. 
# Using CPC-pretrained weights, frozen encoder. Validation Accuracy: 
# They use resnet-50 as the backbone. Maybe switch to that? 
# But I'm doing resnet-22 for each patch of the image. 

# I changed to resnet34. Frozen Encoder, validation accuracy: 16% through 9 epochs. 
# Random Color dropping (RC), Frozen Encoder, validation accuracy: 16% through 9 epochs. 
# NO pre-training, random initialized encoder + FROZEN. validation accuracy: 13% through 25 epochs. 
# NO pre-training, random initialized encoder + FINE_TUNE ALL. validation accuracy: 30% through 33 epochs.
# Pre-trained rc_15_epochs, Frozen Encoder. Validation accuracy: 
# Pre-trained rc_15_epochs, Fine-tune all. Validation accuracy.

# self-supervised. Frozen Encoder. Train/Val Accuracy:              36%/31%. 
# self-supervised 1024 (64), color jitter. Frozen Encoder. Train/Val:     43%/39% (Through 40 epochs). 
# self-supervised 2048 (64), color jitter. Frozen Encoder. Train/Val:     44%/41% (Through 40 epochs). 
# self-supervised 2048 (32), color jitter. Frozen Encoder. Train/Val:     46%/41% (Through 40 epochs). 

# self-supervised 2048 but train 1024? (24), Frozen Encoder. Train/Val:   50%/45% (Through 25 epochs)
# self-supervised 2048 but train 1024? (16), Frozen Encoder. Train/Val:   52%/48% (Through 40 epochs)

# self-supervised 1024 but train 1024 (16), Frozen Encoder. Train/Val:    44%/40% (Thru 40 epochs)

# self-supervised 1024 but train 512 (7), Frozen Encoder. Train/Val:      54%/50% (Thru 40 epochs)
# self-supervised 2048 but train 1024 (7), Frozen Encoder. Train/Val:     53%/48% (Thru 40 epochs)
# self-supervised 2048 but train 512 (7), Frozen Encoder. Train/Val:      54%/50% (Thru 40 epochs)

# self-supervised 2048 but train 1024 (12, bs 48), Frozen enc. Train/Val: 56%/47% (Thru 40 epochs)
# self-supervised 2048 but train 512 (12, bs 48), Frozen enc. Train/Val:  53%/47% (Thru 40 epochs)

# self-supervised 2048 but train 1024 (8, bs 48), Frozen Enc. Train/Val:  55%/50% (Thru 40 epochs)
# self-supervised 2048 but train 512 (8, bs 48), Frozen Enc. Train/Val:   54%/50% (Thru 40 epochs)

# self-supervised 2048 but train 512 (7, bs 48), Frozen Enc. Train/Val:   53%/50% (Thru 31 epochs)

# self-supervised 2048 --> 512 --> 50 (7, bs 48), Frozen Enc. Train/Val:  68%/55% (Thru 10 epochs)
# self-supervised 2048 --> 512 --> 50 (7, bs 48) (dropout 0.3),Train/Val: 72%/56% (Thru 40 epochs)
# self-supervised 2048 -> 512 -> 50 -> 50 (7, bs 48) (dropout 0.1),Train/Val: 70%/56% (Thru 18 epochs)
# self-supervised 2048 ->512-> -> 25 -> 50 -> 50 (7, bs 48) (dropout 0.1),Train/Val: 82%/56% (Thru 40 epochs)





In [None]:
run_epochs(40)

Batch: 0/296, Loss: 2.3252148628234863
Batch: 50/296, Loss: 2.200488805770874
Batch: 100/296, Loss: 1.8522738218307495
Batch: 150/296, Loss: 1.7442010641098022
Batch: 200/296, Loss: 1.855542778968811
Batch: 250/296, Loss: 1.538062572479248
Average Epoch 0 Loss: 1.874844025518443
Train Accuracy: 0.4758686423301697
Validation Accuracy: 0.4366878867149353
Batch: 0/296, Loss: 1.7733978033065796
Batch: 50/296, Loss: 1.4563891887664795
Batch: 100/296, Loss: 1.4562599658966064
Batch: 150/296, Loss: 1.563275933265686
Batch: 200/296, Loss: 1.4059970378875732
Batch: 250/296, Loss: 1.3049472570419312
Average Epoch 1 Loss: 1.59215742188531
Train Accuracy: 0.5221248269081116
Validation Accuracy: 0.4807642996311188
Batch: 0/296, Loss: 1.6640311479568481
Batch: 50/296, Loss: 1.2082107067108154
Batch: 100/296, Loss: 1.5915204286575317
Batch: 150/296, Loss: 1.7755053043365479
Batch: 200/296, Loss: 1.3111125230789185
Batch: 250/296, Loss: 1.436311960220337
Average Epoch 2 Loss: 1.4693560253929447
Train 

In [5]:
class ResNet(nn.Module):
    """
    Adapted from torchvision ResNet, converted to v2
    """

    def __init__(self,
                 input_channels, num_filters,
                 first_layer_kernel_size, first_layer_conv_stride,
                 blocks_per_layer_list, block_strides_list, block_fn,
                 first_layer_padding=0,
                 first_pool_size=None, first_pool_stride=None, first_pool_padding=0,
                 growth_factor=2, norm_class="batch", num_groups=1):
        super(ResNet, self).__init__()
        self.first_conv = nn.Conv2d(
            in_channels=input_channels, out_channels=num_filters,
            kernel_size=first_layer_kernel_size,
            stride=first_layer_conv_stride,
            padding=first_layer_padding,
            bias=False,
        )
        # Diff: padding=SAME vs. padding=0
        self.first_pool = nn.MaxPool2d(
            kernel_size=first_pool_size,
            stride=first_pool_stride,
            padding=first_pool_padding,
        )
        self.norm_class = norm_class
        self.num_groups = num_groups

        block = self._resolve_block(block_fn)
        self.layer_list = nn.ModuleList()
        current_num_filters = num_filters
        self.inplanes = num_filters
        for i, (num_blocks, stride) in enumerate(zip(
                blocks_per_layer_list, block_strides_list)):
            self.layer_list.append(self._make_layer(
                block=block,
                planes=current_num_filters,
                blocks=num_blocks,
                stride=stride,
            ))
            current_num_filters *= growth_factor

        self.final_bn = layers.resolve_norm_layer(
            # current_num_filters // growth_factor
            current_num_filters // growth_factor * block.expansion,
            norm_class,
            num_groups
        )
        self.relu = nn.ReLU()
        self.initialize()

        # Expose attributes for downstream dimension computation
        self.num_filters = num_filters
        self.growth_factor = growth_factor
        self.block = block
        self.num_filter_last_seq = current_num_filters // growth_factor * block.expansion

    def forward(self, x, return_intermediate=False):
        intermediate = []
        h = self.first_conv(x)
        h = self.first_pool(h)

        if return_intermediate:
            intermediate.append(h)
        for i, layer in enumerate(self.layer_list):
            h = layer(h)
            if return_intermediate:
                intermediate.append(h)

        h = self.final_bn(h)
        h = self.relu(h)

        if return_intermediate:
            return h, intermediate
        else:
            return h

    @classmethod
    def _resolve_block(cls, block_fn):
        if block_fn == "normal":
            return layers.BasicBlockV2_dbt
        elif block_fn == "bottleneck":
            return layers.BottleneckV2_dbt
        else:
            raise KeyError(block_fn)

    def _make_layer(self, block, planes, blocks, stride=1):
        # downsample = None
        # if stride != 1 or self.inplanes != planes * block.expansion:
        downsample = nn.Sequential(
            nn.Conv2d(self.inplanes, planes * block.expansion,
                      kernel_size=1, stride=stride, bias=False),
            nn.BatchNorm2d(planes * block.expansion),
        )

        layers_ = [
            block(self.inplanes, planes, stride, downsample, self.norm_class, self.num_groups)
        ]
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers_.append(block(self.inplanes, planes, norm_class=self.norm_class, num_groups=self.num_groups))

        return nn.Sequential(*layers_)

    def initialize(self):
        for m in self.modules():
            self._layer_init(m)

    @classmethod
    def _layer_init(cls, m):
        if isinstance(m, nn.Conv2d):
            # From original
            nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
        #             nn.init.xavier_normal_(m.weight)
        elif isinstance(m, nn.BatchNorm2d):
            nn.init.constant_(m.weight, 1)
            nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.GroupNorm):
            nn.init.constant_(m.weight, 1)
            nn.init.constant_(m.bias, 0)

    @classmethod
    def from_parameters(cls, parameters):
        return cls(
            input_channels=parameters["input_channels"],
            num_filters=parameters["num_filters"],
            first_layer_kernel_size=parameters["first_layer_kernel_size"],
            first_layer_conv_stride=parameters["first_layer_conv_stride"],
            first_layer_padding=parameters.get("first_layer_padding", 0),
            blocks_per_layer_list=parameters["blocks_per_layer_list"],
            block_strides_list=parameters["block_strides_list"],
            block_fn=parameters["block_fn"],
            first_pool_size=parameters["first_pool_size"],
            first_pool_stride=parameters["first_pool_stride"],
            first_pool_padding=parameters.get("first_pool_padding", 0),
            growth_factor=parameters.get("growth_factor", 2),
            norm_class=parameters.get("norm_class", "batch"),
            num_groups=parameters.get("num_groups", 1)
        )
    


In [6]:
class ResNet_22(nn.Module):

    def __init__(
            self,
            attention=False,
            dropout=0.0,
            hidden_size=256,

            # resnet hyperparameters
            #         input_channels=1,
            first_layer_kernel_size=7,
            first_layer_conv_stride=2,
            first_pool_size=3,
            first_pool_stride=2,
            first_layer_padding=0,
            first_pool_padding=0,
            growth_factor=2,

            # resnet22 settings
            num_filters=16,
            blocks_per_layer_list=[2, 2, 2, 2, 2],
            block_strides_list=[1, 2, 2, 2, 2],
            block_fn="normal",
            norm_class="group",
            num_groups=8,

            num_image_slices_per_net=1,
    ):
        super(ResNet_22, self).__init__()

        self.num_image_slices_per_net = num_image_slices_per_net

        self.dropout = nn.Dropout(p=dropout)
        self.relu = nn.ReLU()

        self.resnet = ResNet(
            input_channels=3,
            first_layer_kernel_size=first_layer_kernel_size,
            first_layer_conv_stride=first_layer_conv_stride,
            first_pool_size=first_pool_size,
            first_pool_stride=first_pool_stride,
            num_filters=num_filters,
            blocks_per_layer_list=blocks_per_layer_list,
            block_strides_list=block_strides_list,
            block_fn=block_fn,
            first_layer_padding=first_layer_padding,
            first_pool_padding=first_pool_padding,
            growth_factor=growth_factor,
            norm_class=norm_class,
            num_groups=num_groups,
        )

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))  # use avgpool rather than torch.mean
        
    def forward(self, x):
        batch_size = x.shape[0]

        h = self.resnet(x)
        # Shape of pooled_h is [4, 256, 1, 1]
        pooled_h = self.avgpool(h)
        return pooled_h