In [30]:
import pickle
import collections as col
import numpy as np
import random
import cv2
import pandas as pd
import matplotlib.pyplot as plt
import math
import pdb
from sklearn import metrics
import time
import os
import PIL

import torch
import torch.nn as nn
from sklearn.metrics import roc_auc_score
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils, datasets
import torch.optim as optim
from torchvision.models import resnet18, resnet34
import torch.nn.functional as F
from pytorch_datasets import DiagnosticInpainted
import models
import layers
import utilities.reading_images as reading_images
from utilities.loading import get_single_image
from torchvision.utils import save_image

%matplotlib inline

In [6]:
class PixelCNN(nn.Module):
    def __init__(self, latent_dim):
        super(PixelCNN, self).__init__()
        
        # Conv2d: (input_channels, output_channels, kernel_size, padding)
        
        self.relu = nn.ReLU()
        
        self.model = nn.Sequential(
            nn.Conv2d(256, 256, (1, 1)),
            nn.ReLU(),
            nn.ConstantPad2d((1, 1, 0, 0), 0),
            nn.Conv2d(256, 256, (1, 3)),
            nn.ConstantPad2d((0, 0, 0, 1), 0),
            nn.Conv2d(256, 256, (2, 1)),
            nn.ReLU(),
            nn.Conv2d(256, 256, (1, 1))
        )

    def forward(self, latents):
        
        # latents: [B, C, H, W]
        cres = latents
        
        for _ in range(5):
            c = self.model(cres)
            cres = cres + c
        cres = self.relu(cres)
        return cres      

In [7]:
boom = torch.tensor(np.arange(20).reshape(2, 10))
boom

tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9],
        [10, 11, 12, 13, 14, 15, 16, 17, 18, 19]])

In [8]:
boom.reshape(2, 2, 5)

tensor([[[ 0,  1,  2,  3,  4],
         [ 5,  6,  7,  8,  9]],

        [[10, 11, 12, 13, 14],
         [15, 16, 17, 18, 19]]])

In [52]:
# Implementation in the paper is unclear.
# I'm going to go with WF. 

# NCE Loss
# Questions: Is the dimension of Z (B*patches) or (B). 
#            I think it's (B, 6, 6, 4096)

class CPCLossNCE(nn.Module):
    
    def nce_loss(self, z_hat, pos_scores, negative_samples, mask_mat):
        
        z_hat = z_hat.to(device)
        pos_scores = pos_scores.to(device)
        negative_samples = negative_samples.to(device)
        mask_mat = mask_mat.to(device)
                
        # (b, 1)
        pos_scores = pos_scores.float()
        batch_size, emb_dim = z_hat.size()
        nb_feat_vectors = negative_samples.size(1) // batch_size # 36 of them, if 6 by 6 wireframes. 
        
        # (b, b) -> (b, b, nb_feat_vectors)
        # all zeros with ones in diagonal tensor... (ie: b1 b1 are all 1s, b1 b2 are all zeros)
        mask_pos = mask_mat.unsqueeze(dim=2).expand(-1, -1, nb_feat_vectors).float()
        
        # negative mask
        mask_neg = 1. - mask_pos
        
        # ----------------------
        # ALL SCORES computation
        # (visualize in your mind a batch size of 2, 36-length segments) 
        # (b, dim) x (dim, nb_feats*b) -> (b, b, nb_feats)
        raw_scores = torch.mm(z_hat, negative_samples)
        raw_scores = raw_scores.reshape(batch_size, batch_size, nb_feat_vectors).float()
        
        # EXTRACT NEGATIVE SCORES
        # (batch_size, batch_size, nb_feat_vectors)
        # HE'S TAKING THE NEGATIVE SAMPLES FROM THE OTHER MINIBATCHES
        # A GIVEN Z_HAT IS ONLY MULTIPLIED BY Z'S FROM OTHER MINIBATCHES
        neg_scores = (mask_neg * raw_scores)
        # ----------------------
        
        # (b, b, nb_feat_vectors) -> (batch_size, batch_size * nb_feat_vectors) 
        neg_scores = neg_scores.reshape(batch_size, -1)
        mask_neg = mask_neg.reshape(batch_size, -1)
        
        # STABLE SOFTMAX
        # (n_batch_gpu, 1)
        neg_maxes = torch.max(neg_scores, dim=1, keepdim=True)[0]
        
        # DENOMINATOR
        # sum over only negative samples (none from the diagonal)
        neg_sumexp = (mask_neg * torch.exp(neg_scores - neg_maxes)).sum(dim=1, keepdim=True)
        all_logsumexp = torch.log(torch.exp(pos_scores - neg_maxes) + neg_sumexp)
        
        # NUMERATOR
        # compute numerators for the NCE log-softmaxes
        pos_shiftexp = pos_scores - neg_maxes
        
        # FULL NCE
        nce_scores = pos_shiftexp - all_logsumexp
        nce_scores = -nce_scores.mean()
        
        return nce_scores
        
        
        
    def forward(self, Z, C, W_list):
        '''
        param Z: latent vecs (B, D, H, W)
        param C: context vecs (B, D, H, W)
        param W_list: list of k-1 W projections
        '''
        
        # (b, dim, w, h)
        batch_size, emb_dim, h, w = Z.size()
        
        # (10 x 10 identity matrix)
        diag_mat = torch.eye(batch_size)
        diag_mat = diag_mat.float()
        
        losses = []
        # calculate loss for each k
        
        # Below operations preserve raster order (for B, D, H, W) = (1, 5, 2, 2) check.
        # Z_neg holds all z vecs. 
        Z_neg = Z.permute(1, 0, 2, 3).reshape(emb_dim, -1)
        
        
        for i in range(0, h-1):
            for j in range(0, w):
                cij = C[:, :, i, j]   # B by D
                
                for k in range(i+1, h): # predict on all vectors in the same column, but below current wireframe. 
                    Wk = W_list[str(k)]
                    
                    z_hat_ikj = Wk(cij)
                    zikj = Z[:, :, k, j]
                    
                    # BATCH DOT PRODUCT
                    # (b, d) x (b, d) -> (b, 1)
                    pos_scores = torch.bmm(z_hat_ikj.unsqueeze(1), zikj.unsqueeze(2))
                    pos_scores = pos_scores.squeeze(-1).squeeze(-1)
                    
                    loss = self.nce_loss(z_hat_ikj, pos_scores, Z_neg, diag_mat)
                    losses.append(loss)
                    
                    
        losses = torch.stack(losses)
        loss = losses.mean()
        if np.isnan(loss.item()):
            pdb.set_trace()
            print('boom')
        return loss           
        
        

In [10]:
def raster_patchify(img, size = 80, overlap = 32):
    '''
    Left-to-right, top to bottom.
    Assumes img is (3, 240, 240).
    '''
    patches = []
     
    h = -32
    w = -32
    for i in range(6):
        h = h + 32
        for j in range(6):
            w = w + 32
            patches.append(img[:, h:h+size, w:w+size])
        w = -32
            
    return patches
    

In [14]:
patches = raster_patchify(trainset[0][0])
len(patches)

36

In [315]:
# plt.imshow(patches[35].permute(1, 2, 0))

# plt.imshow(trainset[0][0].permute(1, 2, 0))
# plt.scatter(80,80,color='r')
# plt.scatter(80+32,80,color='r')
# plt.scatter(80+64,80,color='r')
# plt.scatter(80+96,80,color='r')
# plt.scatter(80+128,80,color='r')
# plt.scatter(80+160,80,color='r')

In [12]:
def collate_fn(img_list):
    patches = []
    for (img, label) in img_list:
        img_patches = raster_patchify(img)
        patches.append(torch.stack(img_patches))
        
    return patches

In [60]:
data_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(240),
    transforms.ToTensor(),
#     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

trainset = datasets.ImageFolder(
    root = '/gpfs/data/geraslab/Vish/imagewoof320/imagewoof2-320/train/',
    transform = data_transform
)

train_dl = DataLoader(trainset, batch_size=32, shuffle=True, collate_fn=collate_fn)

valset = datasets.ImageFolder(
    root = '/gpfs/data/geraslab/Vish/imagewoof320/imagewoof2-320/val/',
    transform = data_transform
)

val_dl = DataLoader(valset, batch_size=10, shuffle=False)

In [61]:
class CPC(nn.Module):
    def __init__(self):
        super(CPC, self).__init__()
        self.encoder = ResNet_22()
        self.pixel_cnn = PixelCNN(256)
        self.nce_loss = CPCLossNCE()
        
        # W transforms (k > 0)
        self.W_list = {}
        for k in range(1, 6):
            w = torch.nn.Linear(256, 256)
            self.W_list[str(k)] = w

        self.W_list = nn.ModuleDict(self.W_list).to(device)
        

    def forward(self, x):
        Z = []
        C = []
        for img_patches in x:
            img_patches = img_patches.to(device)
            z = self.encoder(img_patches).squeeze()
            z = z.unsqueeze(0).permute(0, 2, 1).reshape(1, 256, 6, 6)
            Z.append(z)
            c = self.pixel_cnn(z)
            C.append(c)

        Z = torch.stack(Z).squeeze(1)
        C = torch.stack(C).squeeze(1)

        loss = self.nce_loss(Z, C, W_list)
        
        
        return loss

In [62]:
torch.cuda.set_device(2)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = CPC().to(device)
optimizer = optim.Adam(model.parameters(), lr = 2e-4, weight_decay=1e-5, eps=1e-8)

for i, x in enumerate(train_dl):
    loss = model(x)
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
    optimizer.step()
    if i % 10 == 0:
        print("Batch: {}/{}, Loss: {}".format(i, len(train_dl), loss.item()))
    

Batch: 0/283, Loss: 7.642627239227295
Batch: 10/283, Loss: 7.0149383544921875
Batch: 20/283, Loss: 6.997243881225586
Batch: 30/283, Loss: 6.954706192016602
Batch: 40/283, Loss: 6.826085090637207
Batch: 50/283, Loss: 6.640190601348877
Batch: 60/283, Loss: 6.352201461791992
Batch: 70/283, Loss: 6.296461582183838
Batch: 80/283, Loss: 6.019741535186768
Batch: 90/283, Loss: 6.025851249694824
Batch: 100/283, Loss: 5.869300842285156
Batch: 110/283, Loss: 5.870182514190674
Batch: 120/283, Loss: 5.592728614807129
Batch: 130/283, Loss: 5.567205905914307
Batch: 140/283, Loss: 5.3170061111450195
Batch: 150/283, Loss: 5.212189674377441


KeyboardInterrupt: 

In [157]:
len(trainset)

9025

In [160]:
trainset[50][0].shape

torch.Size([3, 240, 240])

In [4]:
class ResNet(nn.Module):
    """
    Adapted from torchvision ResNet, converted to v2
    """

    def __init__(self,
                 input_channels, num_filters,
                 first_layer_kernel_size, first_layer_conv_stride,
                 blocks_per_layer_list, block_strides_list, block_fn,
                 first_layer_padding=0,
                 first_pool_size=None, first_pool_stride=None, first_pool_padding=0,
                 growth_factor=2, norm_class="batch", num_groups=1):
        super(ResNet, self).__init__()
        self.first_conv = nn.Conv2d(
            in_channels=input_channels, out_channels=num_filters,
            kernel_size=first_layer_kernel_size,
            stride=first_layer_conv_stride,
            padding=first_layer_padding,
            bias=False,
        )
        # Diff: padding=SAME vs. padding=0
        self.first_pool = nn.MaxPool2d(
            kernel_size=first_pool_size,
            stride=first_pool_stride,
            padding=first_pool_padding,
        )
        self.norm_class = norm_class
        self.num_groups = num_groups

        block = self._resolve_block(block_fn)
        self.layer_list = nn.ModuleList()
        current_num_filters = num_filters
        self.inplanes = num_filters
        for i, (num_blocks, stride) in enumerate(zip(
                blocks_per_layer_list, block_strides_list)):
            self.layer_list.append(self._make_layer(
                block=block,
                planes=current_num_filters,
                blocks=num_blocks,
                stride=stride,
            ))
            current_num_filters *= growth_factor

        self.final_bn = layers.resolve_norm_layer(
            # current_num_filters // growth_factor
            current_num_filters // growth_factor * block.expansion,
            norm_class,
            num_groups
        )
        self.relu = nn.ReLU()
        self.initialize()

        # Expose attributes for downstream dimension computation
        self.num_filters = num_filters
        self.growth_factor = growth_factor
        self.block = block
        self.num_filter_last_seq = current_num_filters // growth_factor * block.expansion

    def forward(self, x, return_intermediate=False):
        intermediate = []
        h = self.first_conv(x)
        h = self.first_pool(h)

        if return_intermediate:
            intermediate.append(h)
        for i, layer in enumerate(self.layer_list):
            h = layer(h)
            if return_intermediate:
                intermediate.append(h)

        h = self.final_bn(h)
        h = self.relu(h)

        if return_intermediate:
            return h, intermediate
        else:
            return h

    @classmethod
    def _resolve_block(cls, block_fn):
        if block_fn == "normal":
            return layers.BasicBlockV2_dbt
        elif block_fn == "bottleneck":
            return layers.BottleneckV2_dbt
        else:
            raise KeyError(block_fn)

    def _make_layer(self, block, planes, blocks, stride=1):
        # downsample = None
        # if stride != 1 or self.inplanes != planes * block.expansion:
        downsample = nn.Sequential(
            nn.Conv2d(self.inplanes, planes * block.expansion,
                      kernel_size=1, stride=stride, bias=False),
            nn.BatchNorm2d(planes * block.expansion),
        )

        layers_ = [
            block(self.inplanes, planes, stride, downsample, self.norm_class, self.num_groups)
        ]
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers_.append(block(self.inplanes, planes, norm_class=self.norm_class, num_groups=self.num_groups))

        return nn.Sequential(*layers_)

    def initialize(self):
        for m in self.modules():
            self._layer_init(m)

    @classmethod
    def _layer_init(cls, m):
        if isinstance(m, nn.Conv2d):
            # From original
            nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
        #             nn.init.xavier_normal_(m.weight)
        elif isinstance(m, nn.BatchNorm2d):
            nn.init.constant_(m.weight, 1)
            nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.GroupNorm):
            nn.init.constant_(m.weight, 1)
            nn.init.constant_(m.bias, 0)

    @classmethod
    def from_parameters(cls, parameters):
        return cls(
            input_channels=parameters["input_channels"],
            num_filters=parameters["num_filters"],
            first_layer_kernel_size=parameters["first_layer_kernel_size"],
            first_layer_conv_stride=parameters["first_layer_conv_stride"],
            first_layer_padding=parameters.get("first_layer_padding", 0),
            blocks_per_layer_list=parameters["blocks_per_layer_list"],
            block_strides_list=parameters["block_strides_list"],
            block_fn=parameters["block_fn"],
            first_pool_size=parameters["first_pool_size"],
            first_pool_stride=parameters["first_pool_stride"],
            first_pool_padding=parameters.get("first_pool_padding", 0),
            growth_factor=parameters.get("growth_factor", 2),
            norm_class=parameters.get("norm_class", "batch"),
            num_groups=parameters.get("num_groups", 1)
        )
    


In [5]:
class ResNet_22(nn.Module):

    def __init__(
            self,
            attention=False,
            dropout=0.0,
            hidden_size=256,

            # resnet hyperparameters
            #         input_channels=1,
            first_layer_kernel_size=7,
            first_layer_conv_stride=2,
            first_pool_size=3,
            first_pool_stride=2,
            first_layer_padding=0,
            first_pool_padding=0,
            growth_factor=2,

            # resnet22 settings
            num_filters=16,
            blocks_per_layer_list=[2, 2, 2, 2, 2],
            block_strides_list=[1, 2, 2, 2, 2],
            block_fn="normal",
            norm_class="group",
            num_groups=8,

            num_image_slices_per_net=1,
    ):
        super(ResNet_22, self).__init__()

        self.num_image_slices_per_net = num_image_slices_per_net

        self.dropout = nn.Dropout(p=dropout)
        self.relu = nn.ReLU()

        self.resnet = ResNet(
            input_channels=3,
            first_layer_kernel_size=first_layer_kernel_size,
            first_layer_conv_stride=first_layer_conv_stride,
            first_pool_size=first_pool_size,
            first_pool_stride=first_pool_stride,
            num_filters=num_filters,
            blocks_per_layer_list=blocks_per_layer_list,
            block_strides_list=block_strides_list,
            block_fn=block_fn,
            first_layer_padding=first_layer_padding,
            first_pool_padding=first_pool_padding,
            growth_factor=growth_factor,
            norm_class=norm_class,
            num_groups=num_groups,
        )

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))  # use avgpool rather than torch.mean
        
    def forward(self, x):
        batch_size = x.shape[0]

        h = self.resnet(x)
        # Shape of pooled_h is [4, 256, 1, 1]
        pooled_h = self.avgpool(h)
        return pooled_h

In [162]:
class Identity(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return x
    

model = ResNet_22()
model.eval()

boom = model(trainset[0][0].unsqueeze(0))

In [163]:
boom.shape

torch.Size([1, 256, 1, 1])