In [46]:
from torchvision import transforms
from torchvision.datasets import FashionMNIST, CIFAR10, CIFAR100, MNIST
from models.cleaner import NoiseCleaner
from models.predictor import Predictor

In [47]:
class TwoCropTransform:
    """Create two crops of the same image"""
    def __init__(self, transform):
        self.transform = transform

    def __call__(self, x):
        return [self.transform(x), self.transform(x)]

In [48]:
transform = transforms.Compose([
                                transforms.ToTensor(),
                                transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])
augmented_transform = transforms.Compose([
                                transforms.RandomCrop(size=32, padding=4),
                                transforms.AutoAugment(policy=transforms.autoaugment.AutoAugmentPolicy.CIFAR10),
                                transforms.ToTensor(),
                                transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])
train_dataset = CIFAR10(root='data', train=True, download=False, transform=TwoCropTransform(augmented_transform))

In [49]:
from models import *
from models.dataset import *

In [50]:
from torch.utils.data import DataLoader

train_loader = DataLoader(train_dataset, batch_size=1024)

In [51]:
import torch
import torch.nn as nn
import torchvision
from torchvision.models import *
from torch.nn import functional as F

class SupConResNet(nn.Module):
    """backbone + projection head"""
    def __init__(self, head='mlp', feat_dim=128):
        super(SupConResNet, self).__init__()
        model_fun = resnet50()
        model_fun.fc = nn.Flatten()
        dim_in = 2048
        self.encoder = model_fun
        if head == 'linear':
            self.head = nn.Linear(dim_in, feat_dim)
        elif head == 'mlp':
            self.head = nn.Sequential(
                nn.Linear(dim_in, dim_in),
                nn.ReLU(inplace=True),
                nn.Linear(dim_in, feat_dim)
            )
        else:
            raise NotImplementedError(
                'head not supported: {}'.format(head))

    def forward(self, x):
        feat = self.encoder(x)
        feat = F.normalize(self.head(feat), dim=1)
        return feat
    
class SupConLoss(nn.Module):
    """Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf.
    It also supports the unsupervised contrastive loss in SimCLR"""
    def __init__(self, temperature=0.07, contrast_mode='all',
                 base_temperature=0.07):
        super(SupConLoss, self).__init__()
        self.temperature = temperature
        self.contrast_mode = contrast_mode
        self.base_temperature = base_temperature

    def forward(self, features, labels=None, mask=None):
        """Compute loss for model. If both `labels` and `mask` are None,
        it degenerates to SimCLR unsupervised loss:
        https://arxiv.org/pdf/2002.05709.pdf

        Args:
            features: hidden vector of shape [bsz, n_views, ...].
            labels: ground truth of shape [bsz].
            mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j
                has the same class as sample i. Can be asymmetric.
        Returns:
            A loss scalar.
        """
        device = (torch.device('cuda')
                  if features.is_cuda
                  else torch.device('cpu'))

        if len(features.shape) < 3:
            raise ValueError('`features` needs to be [bsz, n_views, ...],'
                             'at least 3 dimensions are required')
        if len(features.shape) > 3:
            features = features.view(features.shape[0], features.shape[1], -1)

        batch_size = features.shape[0]
        if labels is not None and mask is not None:
            raise ValueError('Cannot define both `labels` and `mask`')
        elif labels is None and mask is None:
            mask = torch.eye(batch_size, dtype=torch.float32).to(device)
        elif labels is not None:
            labels = labels.contiguous().view(-1, 1)
            if labels.shape[0] != batch_size:
                raise ValueError('Num of labels does not match num of features')
            mask = torch.eq(labels, labels.T).float().to(device)
        else:
            mask = mask.float().to(device)

        contrast_count = features.shape[1]
        contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0)
        if self.contrast_mode == 'one':
            anchor_feature = features[:, 0]
            anchor_count = 1
        elif self.contrast_mode == 'all':
            anchor_feature = contrast_feature
            anchor_count = contrast_count
        else:
            raise ValueError('Unknown mode: {}'.format(self.contrast_mode))

        # compute logits
        anchor_dot_contrast = torch.div(
            torch.matmul(anchor_feature, contrast_feature.T),
            self.temperature)
        # for numerical stability
        logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
        logits = anchor_dot_contrast - logits_max.detach()

        # tile mask
        mask = mask.repeat(anchor_count, contrast_count)
        # mask-out self-contrast cases
        logits_mask = torch.scatter(
            torch.ones_like(mask),
            1,
            torch.arange(batch_size * anchor_count).view(-1, 1).to(device),
            0
        )
        mask = mask * logits_mask

        # compute log_prob
        exp_logits = torch.exp(logits) * logits_mask
        log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))

        # compute mean of log-likelihood over positive
        # modified to handle edge cases when there is no positive pair
        # for an anchor point. 
        # Edge case e.g.:- 
        # features of shape: [4,1,...]
        # labels:            [0,1,1,2]
        # loss before mean:  [nan, ..., ..., nan] 
        mask_pos_pairs = mask.sum(1)
        mask_pos_pairs = torch.where(mask_pos_pairs < 1e-6, 1, mask_pos_pairs)
        mean_log_prob_pos = (mask * log_prob).sum(1) / mask_pos_pairs

        # loss
        loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos
        loss = loss.view(anchor_count, batch_size).mean()

        return loss

In [52]:
model = SupConResNet().cuda()
criterion = SupConLoss(temperature=0.07).cuda()
optimizer = torch.optim.SGD(model.parameters(), lr=0.05, weight_decay=1e-4, momentum=0.9)

In [53]:
from tqdm import tqdm
def train():
    """one epoch training"""
    losses = 0
    model.train()
    
    for (images, labels) in tqdm(train_loader):

        images = torch.cat([images[0], images[1]], dim=0)
        if torch.cuda.is_available():
            images = images.cuda(non_blocking=True)
            labels = labels.cuda(non_blocking=True)
        bsz = labels.shape[0]

        features = model(images)
        f1, f2 = torch.split(features, [bsz, bsz], dim=0)
        features = torch.cat([f1.unsqueeze(1), f2.unsqueeze(1)], dim=1)
        loss = criterion(features, labels)
        losses += loss.item()

        # SGD
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    return losses / len(train_loader)

In [55]:
for epoch in range(1, 500):
    loss = train()
    print(f'epoch {epoch}, loss: {loss}')

100%|██████████| 49/49 [00:50<00:00,  1.03s/it]


epoch 1, loss: 7.488406473276567


100%|██████████| 49/49 [00:50<00:00,  1.04s/it]


epoch 2, loss: 7.466656383203


100%|██████████| 49/49 [00:49<00:00,  1.02s/it]


epoch 3, loss: 7.453608892401871


100%|██████████| 49/49 [00:50<00:00,  1.02s/it]


epoch 4, loss: 7.433524735119878


100%|██████████| 49/49 [00:50<00:00,  1.03s/it]


epoch 5, loss: 7.4247791718463505


100%|██████████| 49/49 [00:50<00:00,  1.04s/it]


epoch 6, loss: 7.4150548467830735


100%|██████████| 49/49 [00:50<00:00,  1.03s/it]


epoch 7, loss: 7.400428032388493


100%|██████████| 49/49 [00:51<00:00,  1.05s/it]


epoch 8, loss: 7.393695101446035


100%|██████████| 49/49 [00:51<00:00,  1.04s/it]


epoch 9, loss: 7.385201405505745


100%|██████████| 49/49 [00:50<00:00,  1.02s/it]


epoch 10, loss: 7.3774237924692585


100%|██████████| 49/49 [00:50<00:00,  1.02s/it]


epoch 11, loss: 7.370231930090457


100%|██████████| 49/49 [00:50<00:00,  1.03s/it]


epoch 12, loss: 7.365513928082525


100%|██████████| 49/49 [00:50<00:00,  1.03s/it]


epoch 13, loss: 7.358673290330536


100%|██████████| 49/49 [00:49<00:00,  1.02s/it]


epoch 14, loss: 7.349098487776153


100%|██████████| 49/49 [00:50<00:00,  1.04s/it]


epoch 15, loss: 7.346024386736811


100%|██████████| 49/49 [00:51<00:00,  1.04s/it]


epoch 16, loss: 7.335449403646041


100%|██████████| 49/49 [00:51<00:00,  1.05s/it]


epoch 17, loss: 7.3312782754703445


100%|██████████| 49/49 [00:51<00:00,  1.05s/it]


epoch 18, loss: 7.328432637818006


100%|██████████| 49/49 [00:50<00:00,  1.04s/it]


epoch 19, loss: 7.31942875531255


100%|██████████| 49/49 [00:50<00:00,  1.04s/it]


epoch 20, loss: 7.301228387015207


100%|██████████| 49/49 [00:50<00:00,  1.04s/it]


epoch 21, loss: 7.266652331060293


100%|██████████| 49/49 [00:50<00:00,  1.04s/it]


epoch 22, loss: 7.2467862440615285


100%|██████████| 49/49 [00:50<00:00,  1.03s/it]


epoch 23, loss: 7.233137656231316


100%|██████████| 49/49 [00:50<00:00,  1.04s/it]


epoch 24, loss: 7.215645799831468


100%|██████████| 49/49 [00:50<00:00,  1.03s/it]


epoch 25, loss: 7.198462768476837


100%|██████████| 49/49 [00:50<00:00,  1.04s/it]


epoch 26, loss: 7.185923790445133


100%|██████████| 49/49 [00:50<00:00,  1.03s/it]


epoch 27, loss: 7.1779457111747895


100%|██████████| 49/49 [00:50<00:00,  1.04s/it]


epoch 28, loss: 7.166823445534219


100%|██████████| 49/49 [00:50<00:00,  1.03s/it]


epoch 29, loss: 7.158848421914237


100%|██████████| 49/49 [00:50<00:00,  1.04s/it]


epoch 30, loss: 7.142386047207579


100%|██████████| 49/49 [00:51<00:00,  1.04s/it]


epoch 31, loss: 7.127965975780876


100%|██████████| 49/49 [00:50<00:00,  1.04s/it]


epoch 32, loss: 7.121603323488819


100%|██████████| 49/49 [00:50<00:00,  1.03s/it]


epoch 33, loss: 7.099560377549152


100%|██████████| 49/49 [00:50<00:00,  1.03s/it]


epoch 34, loss: 7.070475111202318


100%|██████████| 49/49 [00:50<00:00,  1.03s/it]


epoch 35, loss: 7.047530728943494


100%|██████████| 49/49 [00:51<00:00,  1.04s/it]


epoch 36, loss: 7.0163883968275425


100%|██████████| 49/49 [00:50<00:00,  1.04s/it]


epoch 37, loss: 6.986596389692657


100%|██████████| 49/49 [00:51<00:00,  1.04s/it]


epoch 38, loss: 6.957078894790338


100%|██████████| 49/49 [00:50<00:00,  1.03s/it]


epoch 39, loss: 6.923497287594542


100%|██████████| 49/49 [00:50<00:00,  1.04s/it]


epoch 40, loss: 6.901412039386983


100%|██████████| 49/49 [00:50<00:00,  1.03s/it]


epoch 41, loss: 6.875406226333307


100%|██████████| 49/49 [00:50<00:00,  1.03s/it]


epoch 42, loss: 6.856286953906624


100%|██████████| 49/49 [00:50<00:00,  1.03s/it]


epoch 43, loss: 6.8365897451128275


100%|██████████| 49/49 [00:50<00:00,  1.03s/it]


epoch 44, loss: 6.810364470189931


100%|██████████| 49/49 [00:50<00:00,  1.03s/it]


epoch 45, loss: 6.787629536220005


100%|██████████| 49/49 [00:50<00:00,  1.04s/it]


epoch 46, loss: 6.768530874836202


100%|██████████| 49/49 [00:50<00:00,  1.04s/it]


epoch 47, loss: 6.745668498837218


100%|██████████| 49/49 [00:50<00:00,  1.04s/it]


epoch 48, loss: 6.728077878757399


100%|██████████| 49/49 [00:50<00:00,  1.03s/it]


epoch 49, loss: 6.701558842950938


100%|██████████| 49/49 [00:50<00:00,  1.03s/it]


epoch 50, loss: 6.685647497371751


100%|██████████| 49/49 [00:50<00:00,  1.04s/it]


epoch 51, loss: 6.663356547452966


100%|██████████| 49/49 [00:51<00:00,  1.04s/it]


epoch 52, loss: 6.6519860637431245


100%|██████████| 49/49 [00:50<00:00,  1.04s/it]


epoch 53, loss: 6.635970067004768


100%|██████████| 49/49 [00:50<00:00,  1.03s/it]


epoch 54, loss: 6.618380828779571


100%|██████████| 49/49 [00:50<00:00,  1.03s/it]


epoch 55, loss: 6.604814928405139


100%|██████████| 49/49 [00:50<00:00,  1.03s/it]


epoch 56, loss: 6.585687082640979


100%|██████████| 49/49 [00:50<00:00,  1.03s/it]


epoch 57, loss: 6.576900881163928


100%|██████████| 49/49 [00:50<00:00,  1.04s/it]


epoch 58, loss: 6.5554801687902335


100%|██████████| 49/49 [00:50<00:00,  1.04s/it]


epoch 59, loss: 6.540663767834099


100%|██████████| 49/49 [00:50<00:00,  1.03s/it]


epoch 60, loss: 6.533001958107461


100%|██████████| 49/49 [00:50<00:00,  1.03s/it]


epoch 61, loss: 6.516741188205018


100%|██████████| 49/49 [00:50<00:00,  1.02s/it]


epoch 62, loss: 6.496172106995875


100%|██████████| 49/49 [00:50<00:00,  1.03s/it]


epoch 63, loss: 6.484894158888836


100%|██████████| 49/49 [00:50<00:00,  1.03s/it]


epoch 64, loss: 6.4718626956550445


100%|██████████| 49/49 [00:50<00:00,  1.03s/it]


epoch 65, loss: 6.46280238093162


100%|██████████| 49/49 [00:50<00:00,  1.04s/it]


epoch 66, loss: 6.449887353546766


100%|██████████| 49/49 [00:50<00:00,  1.03s/it]


epoch 67, loss: 6.436789259618642


100%|██████████| 49/49 [00:50<00:00,  1.02s/it]


epoch 68, loss: 6.421223017634178


100%|██████████| 49/49 [00:50<00:00,  1.03s/it]


epoch 69, loss: 6.415432462886888


100%|██████████| 49/49 [00:50<00:00,  1.03s/it]


epoch 70, loss: 6.399190892978591


100%|██████████| 49/49 [00:51<00:00,  1.04s/it]


epoch 71, loss: 6.390789421237245


 76%|███████▌  | 37/49 [00:39<00:12,  1.08s/it]


KeyboardInterrupt: 