This notebook contains a PyTorch implementation of the paper [A Simple Framework for Contrastive Learning of Visual Representations](https://arxiv.org/abs/2002.05709) by Chen et al.

**Try to modify the CLR to train a 100x100 pic data-set.**
One way is to generate a CIFAR-10 compiliable data-sets, but this way needs to customize CIFAR-10 class.
Another way, instead of using CIFAR-10, directly load data. I will try the latter way.  

In [0]:
import numpy as np
from tqdm import tqdm_notebook as tqdm
from PIL import Image

import torch
import torch.nn as nn
from torch.optim import Adam
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as tfs
from torchvision.datasets import *
from torchvision.models import *

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

In [0]:
tf_tr = tfs.Compose([
    tfs.RandomResizedCrop(32),
    tfs.RandomHorizontalFlip(),
    tfs.ColorJitter(0.5, 0.5, 0.5, 0.5),
    tfs.ToTensor(),
    tfs.Normalize(mean=[0.485, 0.456, 0.406], 
                  std=[0.229, 0.224, 0.225])
])

tf_de = tfs.Compose([
    tfs.Resize(32),
    tfs.ToTensor(),
    tfs.Normalize(mean=[0.485, 0.456, 0.406], 
                  std=[0.229, 0.224, 0.225])
])

tf_te = tfs.Compose([
    tfs.Resize(32),
    tfs.ToTensor(),
    tfs.Normalize(mean=[0.485, 0.456, 0.406], 
                  std=[0.229, 0.224, 0.225])
])

In [0]:
class CustomCIFAR10(CIFAR10):
    def __init__(self, **kwds):
        super().__init__(**kwds)
            
    def __getitem__(self, idx):
        if not self.train:
            return super().__getitem__(idx)
    
        img = self.data[idx]
        img = Image.fromarray(img).convert('RGB')
        imgs = [self.transform(img), self.transform(img)]
        return torch.stack(imgs)

In [0]:
class CustomCervicalCIFAR10(CIFAR10):
    def __init__(self, root, train=True, transform=None, target_transform=None,
                 download=False):
        super(CIFAR10, self).__init__(root, transform=transform,
                                      target_transform=target_transform)
        
        self.train = train  # training set or test set

        if self.train:
            downloaded_list = self.train_list
        else:
            downloaded_list = self.test_list

        self.data = []
        self.targets = []

        # now load the picked numpy arrays
        for file_name, checksum in downloaded_list:
            file_path = os.path.join(self.root, self.base_folder, file_name)
            with open(file_path, 'rb') as f:
                if sys.version_info[0] == 2:
                    entry = pickle.load(f)
                else:
                    entry = pickle.load(f, encoding='latin1')
                self.data.append(entry['data'])
                if 'labels' in entry:
                    self.targets.extend(entry['labels'])
                else:
                    self.targets.extend(entry['fine_labels'])

        self.data = np.vstack(self.data).reshape(-1,3,32,32)
        self.data = self.data.tranpose((0,2,3,1))  #covert to HWC

        self._load_meta()
    
    def _load_meta(self):
        path = os.path.join(self.root, self.base_folder, self.meta['filename'])
        if not check_integrity(path, self.meta['md5']):
            raise RuntimeError('Dataset metadata file not found or corrupted.' +
                               ' You can use download=True to download it')
        with open(path, 'rb') as infile:
            if sys.version_info[0] == 2:
                data = pickle.load(infile)
            else:
                data = pickle.load(infile, encoding='latin1')
            self.classes = data[self.meta['key']]
        self.class_to_idx = {_class: i for i, _class in enumerate(self.classes)}





In [0]:
ds_tr = CustomCIFAR10(root='data', train=True, transform=tf_tr, download=True)
ds_de = CIFAR10(root='data', train=True, transform=tf_de, download=True)
ds_te = CIFAR10(root='data', train=False, transform=tf_te, download=True)

In [0]:
dl_tr = DataLoader(ds_tr, batch_size=256, shuffle=True)
dl_de = DataLoader(ds_de, batch_size=256, shuffle=True)
dl_te = DataLoader(ds_te, batch_size=256, shuffle=False)

In [0]:
model = resnet50(pretrained=False)
model.conv1 = nn.Conv2d(3, 64, 3, 1, 1, bias=False)
model.maxpool = nn.Identity()

In [0]:
ch = model.fc.in_features
model.fc = nn.Sequential(nn.Linear(ch, ch),
                           nn.ReLU(),
                           nn.Linear(ch, ch))
model.to(device)
model.train()

In [0]:
def pair_cosine_similarity(x, eps=1e-8):
    n = x.norm(p=2, dim=1, keepdim=True)
    return (x @ x.t()) / (n * n.t()).clamp(min=eps)

def nt_xent(x, t=0.5):
    x = pair_cosine_similarity(x)
    x = torch.exp(x / t)
    idx = torch.arange(x.size()[0])
    # Put positive pairs on the diagonal
    idx[::2] += 1
    idx[1::2] -= 1
    x = x[idx]
    # subtract the similarity of 1 from the numerator
    x = x.diag() / (x.sum(0) - torch.exp(torch.tensor(1 / t)))
    return -torch.log(x.mean())

In [0]:
optimizer = Adam(model.parameters(), lr=0.001)

In [0]:
model.train()
for i in range(100):
    c, s = 0, 0
    pBar = tqdm(dl_tr)
    for data in pBar:
        d = data.size()
        x = data.view(d[0]*2, d[2], d[3], d[4]).to(device)
        optimizer.zero_grad()
        p = model(x)
        loss = nt_xent(p)
        s = ((s*c)+(float(loss)*len(p)))/(c+len(p))
        c += len(p)
        pBar.set_description('Train: '+str(round(float(s),3)))
        loss.backward()
        optimizer.step()
    if (i+1) % 10 == 0:
        torch.save(model.state_dict(), path+'cifar10-rn50-mlp-b256-t0.5-e'+str(i+1)+'.pt')

In [0]:
for param in model.parameters():
    param.requires_grad = False

In [0]:
model.fc = nn.Linear(ch, len(ds_de.classes))
model.to(device)

In [0]:
optimizer = Adam(model.parameters(), lr=0.003)
criterion = nn.CrossEntropyLoss()

In [0]:
model.train()
for i in range(5):
    c, s = 0, 0
    pBar = tqdm(dl_de)
    for data in pBar:
        x, y = data[0].to(device), data[1].to(device)
        optimizer.zero_grad()
        p = model(x)
        loss = criterion(p, y)
        s = ((s*c)+(float(loss)*len(p)))/(c+len(p))
        c += len(p)
        pBar.set_description('Train: '+str(round(float(s),3)))
        loss.backward()
        optimizer.step()

In [0]:
optimizer = Adam(model.parameters(), lr=0.0001)
criterion = nn.CrossEntropyLoss()

In [0]:
model.train()
for i in range(5):
    c, s = 0, 0
    pBar = tqdm(dl_de)
    for data in pBar:
        x, y = data[0].to(device), data[1].to(device)
        optimizer.zero_grad()
        p = model(x)
        loss = criterion(p, y)
        s = ((s*c)+(float(loss)*len(p)))/(c+len(p))
        c += len(p)
        pBar.set_description('Train: '+str(round(float(s),3)))
        loss.backward()
        optimizer.step()

In [0]:
model.eval()
c, s = 0, 0
pBar = tqdm(dl_te)
for data in pBar:
    x, y, = data[0].to(device), data[1].to(device)
    p = model(x)
    loss = criterion(p, y)
    s = ((s*c)+(float(loss)*len(p)))/(c+len(p))
    c += len(p)
    pBar.set_description('Test: '+str(round(float(s),3)))

In [0]:
model.eval()
y_pred, y_true = [], []
pBar = tqdm(dl_te)
for data in pBar:
    x, y = data[0].to(device), data[1].to(device)
    p = model(x)
    y_pred.append(p.cpu().detach().numpy())
    y_true.append(y.cpu().detach().numpy())
y_pred = np.concatenate(y_pred)
y_true = np.concatenate(y_true)

In [0]:
(y_true == y_pred.argmax(axis=1)).mean()