In [1]:
import argparse
import os
import subprocess

import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.optim as optim
from torchlars import LARS
from tqdm import tqdm

from configs import get_datasets
from critic import LinearCritic
from evaluate import save_checkpoint,save_checkpoint2, encode_train_set, train_clf, test
from models import *
from scheduler import CosineAnnealingWithLinearRampLR
from augmentation import ManualNormalise, DifferentiableColourDistortionByTorch3
from torchvision import transforms




import torch.autograd as autograd

In [2]:
import torch
import numbers
import random
from torch import Tensor
import matplotlib.pyplot as plt
import numpy as np

In [3]:
# functions to show an image
def imshow(img):
    img = img.to('cpu')
    #img = img / 2 + 0.5     
    npimg = img.detach().numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()
    

In [4]:
# config
### config ####
num_workers = 16
batch_size = 512
img_size = 32
temperature = 0.5
dataset = 'cifar10'
CACHED_MEAN_STD = {
        'cifar10': ((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        'cifar100': ((0.5071, 0.4865, 0.4409), (0.2009, 0.1984, 0.2023)),
        'stl10': ((0.4409, 0.4279, 0.3868), (0.2309, 0.2262, 0.2237)),
        'imagenet': ((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    }

device = 'cuda' if torch.cuda.is_available() else 'cpu'
base_lr = 0.25
arch = 'resnet18'
momentum = 0.9
cosine_anneal = True
num_epochs = 1
test_freq = 1


In [5]:
lr = base_lr * (batch_size / 256)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
best_acc = 0  # best test accuracy
start_epoch = 0  # start from epoch 0 or last checkpoint epoch
clf = None

print('==> Preparing data..')
trainset, testset, clftrainset, num_classes, stem = get_datasets(dataset)

trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True,
                                          num_workers=num_workers, pin_memory=True)
testloader = torch.utils.data.DataLoader(testset, batch_size=1000, shuffle=False, num_workers=num_workers,
                                         pin_memory=True)
clftrainloader = torch.utils.data.DataLoader(clftrainset, batch_size=1000, shuffle=False, num_workers=num_workers,
                                             pin_memory=True)

# Model
print('==> Building model..')
##############################################################
# Encoder
##############################################################
if arch == 'resnet18':
    net = ResNet18(stem=stem)
elif arch == 'resnet34':
    net = ResNet34(stem=stem)
elif arch == 'resnet50':
    net = ResNet50(stem=stem)
else:
    raise ValueError("Bad architecture specification")
net = net.to(device)

##############################################################
# Critic
##############################################################
critic = LinearCritic(net.representation_dim, temperature=temperature).to(device)
#differentiable augmentation
s = 1.0
aug_by_torch_batch = DifferentiableColourDistortionByTorch3(0.8*s, 0.8*s, 0.8*s, 0.2*s)

if device == 'cuda':
    repr_dim = net.representation_dim
    net = torch.nn.DataParallel(net)
    net.representation_dim = repr_dim
    cudnn.benchmark = True

criterion = nn.CrossEntropyLoss()
base_optimizer = optim.SGD(list(net.parameters()) + list(critic.parameters()), lr=lr, weight_decay=1e-6,
                           momentum=momentum)
if cosine_anneal == True:
    scheduler = CosineAnnealingWithLinearRampLR(base_optimizer, num_epochs)
    encoder_optimizer = LARS(base_optimizer, trust_coef=1e-3)

##############################################################
# Training
##############################################################
def train(epoch):
    print('\nEpoch: %d' % epoch)
    net.train()
    critic.train()
    train_loss = 0
    t = tqdm(enumerate(trainloader), desc='Loss: **** ', total=len(trainloader), bar_format='{desc}{bar}{r_bar}')
    for batch_idx, (inputs, _, _) in t:
        x1, x2 = inputs
        x1, x2 = x1.to(device), x2.to(device)
        
        x1, x2 = aug_by_torch_batch(x1), aug_by_torch_batch(x2)
        x1, x2 = ManualNormalise(x1, dataset), ManualNormalise(x2, dataset) 

        encoder_optimizer.zero_grad()
        representation1, representation2 = net(x1), net(x2)
        raw_scores, pseudotargets = critic(representation1, representation2)
        loss = criterion(raw_scores, pseudotargets)
        loss.backward()
        encoder_optimizer.step()

        train_loss += loss.item()

        t.set_description('Loss: %.3f ' % (train_loss / (batch_idx + 1)))
        
        

==> Preparing data..
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
==> Building model..


# Gradient penalty 101

In [9]:
class DifferentiableColourDistortionByTorch_manual(nn.Module):
    
    '''
    need to input the parameter of color augmentation 
    '''
    
    def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
        super().__init__()
        self.brightness = brightness
        self.contrast = contrast
        self.saturation = saturation
        self.hue = hue

    
    #### batch color augmentation forward #####
    
    def forward(self, x):
        '''
        Args:
            x: Input tensor batchsize x 32 x 32
        
        Returns:
            x_aug : color jittered image
            
        apply color jitter with prob 0.8
        apply random grayscale with prob 0.2
        
        '''
        batch_size = x.size()[0]
        
        p_jitter = torch.ones(batch_size) * 0.8
        jitter = torch.bernoulli(p_jitter)
        jitter = jitter.reshape(batch_size,1,1,1)

        p_gray = torch.ones(batch_size)* 0.2
        gray = torch.bernoulli(p_gray)
        gray = gray.reshape(batch_size,1,1,1)

        jitter = jitter.to(x.device)
        gray = gray.to(x.device)
        
        #random color jitter
        x_jitter = self.batch_colourjitter(x)
        x = x_jitter * jitter + x *(1-jitter)
        
        #random gray scale
        x_gray = self.batch_rgb_to_grayscale(x).unsqueeze(1)
        x = x_gray * gray + x* (1-gray)
       
        return x
    
    
    def batch_colourjitter(self, img: Tensor) -> Tensor:
        
        fn_idx = torch.randperm(4)
        for fn_id in fn_idx:
            if fn_id == 0 and self.brightness is not None:
                brightness_list = self.brightness
                img = self.batch_adjust_brightness(img, brightness_list)

            if fn_id == 1 and self.contrast is not None:
                contrast_list = self.contrast
                img = self.batch_adjust_contrast(img, contrast_list)

            if fn_id == 2 and self.saturation is not None:
                saturation_list = self.saturation
                img = self.batch_adjust_saturation(img, saturation_list)

            if fn_id == 3 and self.hue is not None:
                hue_list = self.hue
                img = self.batch_adjust_hue(img, hue_list)

        return img
        

    def __repr__(self):
        format_string = self.__class__.__name__ + '('
        format_string += 'brightness={0}'.format(self.brightness)
        format_string += ', contrast={0}'.format(self.contrast)
        format_string += ', saturation={0}'.format(self.saturation)
        format_string += ', hue={0})'.format(self.hue)
        return format_string
                
    ##### function from pytorch source code #####
    def _is_tensor_a_torch_image(self, x: Tensor) -> bool:
        return x.ndim >= 2
    
    def _blend(self, img1: Tensor, img2: Tensor, ratio: float) -> Tensor:
        bound = 1 if img1.dtype in [torch.half, torch.float32, torch.float64] else 255
        return (ratio * img1 + (1 - ratio) * img2).clamp(0, bound).to(img1.dtype)
    def batch_adjust_brightness(self,img: Tensor, brightness_list: list) -> Tensor:
        '''
        Batch x C x H x  W -> Batch x C x H x W
        '''
        B = img.size()[0]
        brightness_factor = brightness_list.reshape(B,1,1,1).to(img.device)
        return self._blend(img, torch.zeros_like(img), brightness_factor)
    
    def batch_rgb_to_grayscale(self,img: Tensor) -> Tensor:
        '''
        Batch x C x H x  W -> Batch x C x H x W
        '''
        if img.shape[1] != 3:
            raise TypeError('Input Image does not contain 3 Channels')

        img_tp = img.transpose(0,1)
        img_gray = (0.2989 * img_tp[0] + 0.5870 * img_tp[1] + 0.1140 * img_tp[2])
        return img_gray
    
    def batch_adjust_saturation(self,img: Tensor, saturation_list: list) -> Tensor:
        '''
        Batch x C x H x  W -> Batch x C x H x W
        '''

        B = img.size()[0]
        saturation_factor = saturation_list.reshape(B,1,1,1).to(img.device)

        return self._blend(img, self.batch_rgb_to_grayscale(img).unsqueeze(1), saturation_factor)
    
    def batch_adjust_contrast(self,img: Tensor, contrast_list: float) -> Tensor:
        '''
        Batch x C x H x  W -> Batch x C x H x W
        '''
        B = img.size()[0]
        contrast_list = contrast_list.reshape(B,1,1,1).to(img.device)

        #mean for each pic (over HxW points)
        img_gray = self.batch_rgb_to_grayscale(img)
        mean = torch.mean(img_gray.reshape(img_gray.shape[0],-1), dim = 1)


        return self._blend(img, mean.reshape([mean.size()[0],1,1,1]),  contrast_list)
    
    def batch_adjust_hue(self, img: Tensor, hue_list: list) -> Tensor:
        '''
        Batch x C x H x  W -> Batch x C x H x W
        '''

        B = img.size()[0]

        # generate tensor 
        one_tensor = torch.ones(B)
        zero_tensor = torch.zeros(B)
        cos_tensor = torch.cos(hue_list)
        sin_tensor = torch.sin(hue_list)

        #stack
        T_hue = torch.stack([one_tensor, zero_tensor, zero_tensor,
                     zero_tensor, cos_tensor, -sin_tensor,
                     zero_tensor, sin_tensor, cos_tensor]).transpose(0,1).reshape(B,3,3)

        T_yiq = torch.tensor([[0.299, 0.587, 0.114], [0.596, -0.274, -0.321], [0.211, -0.523, 0.311]])
        T_rgb = torch.tensor([[1, 0.956, 0.621], [1, -0.272, -0.647], [1, -1.107, 1.705]])
        T_final = torch.matmul(torch.matmul(T_rgb, T_hue), T_yiq)
        T_final = T_final.to(img.device)

        #return T_rgb x T_hue x T_yiq x img

        return torch.matmul(T_final.unsqueeze(1).unsqueeze(1), img.transpose(1,-1).unsqueeze(-1)).squeeze(-1).transpose(1,-1)

    

In [10]:
def gen_lambda(B, brightness_bound, contrast_bound, saturation_bound, hue_bound):
    brightness_list = (torch.torch.rand(B)*(brightness_bound[1] - brightness_bound[0]) 
                   + brightness_bound[0]).requires_grad_(True)
    saturation_list = (torch.torch.rand(B)*(saturation_bound[1] - saturation_bound[0]) 
                       + saturation_bound[0]).requires_grad_(True)
    contrast_list   = (torch.torch.rand(B)*(contrast_bound[1] - contrast_bound[0]) 
                       + contrast_bound[0]).requires_grad_(True)
    hue_list        = ((torch.torch.rand(B)*(hue_bound[1]-hue_bound[0]) 
                       + hue_bound[0])* 3.1415* 2).requires_grad_(True)
    return brightness_list, saturation_list, contrast_list, hue_list

In [11]:
brightness_bound = [0.2, 1.8]
contrast_bound = [0.2, 1.8]
saturation_bound = [0.2, 1.8]
hue_bound = [-0.2, 0.2]
    
epoch = 0
print('\nEpoch: %d' % epoch)
net.train()
critic.train()
train_loss = 0
t = tqdm(enumerate(trainloader), desc='Loss: **** ', total=len(trainloader), bar_format='{desc}{bar}{r_bar}')
for batch_idx, (inputs, _, _) in t:
    x1, x2 = inputs
    x1, x2 = x1.to(device), x2.to(device)

    ##### colour augmentation #####
    B = x1.size()[0]
    brightness_list1, saturation_list1, contrast_list1, hue_list1 = gen_lambda(B, brightness_bound, 
                                                                               contrast_bound,
                                                                               saturation_bound,
                                                                               hue_bound)
    aug_manual1 = DifferentiableColourDistortionByTorch_manual(brightness = brightness_list1,
                                                              contrast = contrast_list1,
                                                              saturation = saturation_list1,
                                                              hue = hue_list1)

    brightness_list2, saturation_list2, contrast_list2, hue_list2 = gen_lambda(B, brightness_bound, 
                                                                               contrast_bound,
                                                                               saturation_bound,
                                                                               hue_bound)

    aug_manual2 = DifferentiableColourDistortionByTorch_manual(brightness = brightness_list2,
                                                              contrast = contrast_list2,
                                                              saturation = saturation_list2,
                                                              hue = hue_list2)
  
    x1, x2 = aug_manual1(x1), aug_manual2(x2)
    
    #####
    x1, x2 = ManualNormalise(x1, dataset), ManualNormalise(x2, dataset)

    encoder_optimizer.zero_grad()
    representation1, representation2 = net(x1), net(x2)
    raw_scores, pseudotargets = critic(representation1, representation2)
    loss = criterion(raw_scores, pseudotargets)

    ##### gradient penalty #####
    gradient_lambda = []
    lambda_list = [brightness_list1, saturation_list1, contrast_list1, hue_list1, 
                   brightness_list2, saturation_list2, contrast_list2, hue_list2 ]

    for aug_variable in lambda_list:
        gradients_augvar =  autograd.grad(outputs = loss,
                             inputs = aug_variable,
                             retain_graph = True,
                             grad_outputs = torch.ones_like(aug_variable).to(device))[0]
        gradient_lambda.append(gradients_augvar)

    gradient_lambda = torch.stack(gradient_lambda, dim = 0)

    # currently gradient_lamba.sum([0]).sqrt() has nan element? so I use norm instead
    gradient_penalty = gradient_lambda.sum([0]).norm(2)

    loss_gp = loss + 0.001*gradient_penalty
    loss_gp.backward()
    encoder_optimizer.step()
    train_loss += loss_gp.item()


Epoch: 0


Loss: **** ██████████| 98/98 [03:24<00:00,  2.09s/it]
