In [None]:
import logging
import time

# Create logger
logger = logging.getLogger()
logger.setLevel(logging.INFO)

handler = logging.FileHandler(f'training.log')
# Create STDERR handler
# handler = logging.StreamHandler(sys.stderr)
# ch.setLevel(logging.DEBUG)

# Create formatter and add it to the handler
formatter = logging.Formatter('%(name)s - %(levelname)s - %(message)s')
handler.setFormatter(formatter)

# Set STDERR handler as the only handler 
logger.handlers = [handler]

# SWaV: unsupervised training using clustering and swap prediction problem

In [1]:

import torch
import torch.nn as nn


def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(
        in_planes,
        out_planes,
        kernel_size=3,
        stride=stride,
        padding=dilation,
        groups=groups,
        bias=False,
        dilation=dilation,
    )


def conv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)


class BasicBlock(nn.Module):
    expansion = 1
    __constants__ = ["downsample"]

    def __init__(
        self,
        inplanes,
        planes,
        stride=1,
        downsample=None,
        groups=1,
        base_width=64,
        dilation=1,
        norm_layer=None,
    ):
        super(BasicBlock, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        if groups != 1 or base_width != 64:
            raise ValueError("BasicBlock only supports groups=1 and base_width=64")
        if dilation > 1:
            raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
        # Both self.conv1 and self.downsample layers downsample the input when stride != 1
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = norm_layer(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = norm_layer(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out


class Bottleneck(nn.Module):
    expansion = 4
    __constants__ = ["downsample"]

    def __init__(
        self,
        inplanes,
        planes,
        stride=1,
        downsample=None,
        groups=1,
        base_width=64,
        dilation=1,
        norm_layer=None,
    ):
        super(Bottleneck, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        width = int(planes * (base_width / 64.0)) * groups
        # Both self.conv2 and self.downsample layers downsample the input when stride != 1
        self.conv1 = conv1x1(inplanes, width)
        self.bn1 = norm_layer(width)
        self.conv2 = conv3x3(width, width, stride, groups, dilation)
        self.bn2 = norm_layer(width)
        self.conv3 = conv1x1(width, planes * self.expansion)
        self.bn3 = norm_layer(planes * self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out


class ResNet(nn.Module):
    def __init__(
            self,
            block,
            layers,
            zero_init_residual=False,
            groups=1,
            widen=1,
            width_per_group=64,
            replace_stride_with_dilation=None,
            norm_layer=None,
            normalize=False,
            output_dim=0,
            hidden_mlp=0,
            nmb_prototypes=0,
            eval_mode=False,
    ):
        super(ResNet, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        self._norm_layer = norm_layer

        self.eval_mode = eval_mode
        self.padding = nn.ConstantPad2d(1, 0.0)

        self.inplanes = width_per_group * widen
        self.dilation = 1
        if replace_stride_with_dilation is None:
            # each element in the tuple indicates if we should replace
            # the 2x2 stride with a dilated convolution instead
            replace_stride_with_dilation = [False, False, False]
        if len(replace_stride_with_dilation) != 3:
            raise ValueError(
                "replace_stride_with_dilation should be None "
                "or a 3-element tuple, got {}".format(replace_stride_with_dilation)
            )
        self.groups = groups
        self.base_width = width_per_group

        # change padding 3 -> 2 compared to original torchvision code because added a padding layer
        num_out_filters = width_per_group * widen
        self.conv1 = nn.Conv2d(
            3, num_out_filters, kernel_size=7, stride=2, padding=2, bias=False
        )
        self.bn1 = norm_layer(num_out_filters)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, num_out_filters, layers[0])
        num_out_filters *= 2
        self.layer2 = self._make_layer(
            block, num_out_filters, layers[1], stride=2, dilate=replace_stride_with_dilation[0]
        )
        num_out_filters *= 2
        self.layer3 = self._make_layer(
            block, num_out_filters, layers[2], stride=2, dilate=replace_stride_with_dilation[1]
        )
        num_out_filters *= 2
        self.layer4 = self._make_layer(
            block, num_out_filters, layers[3], stride=2, dilate=replace_stride_with_dilation[2]
        )
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))

        # normalize output features
        self.l2norm = normalize

        # projection head
        if output_dim == 0:
            self.projection_head = None
        elif hidden_mlp == 0:
            self.projection_head = nn.Linear(num_out_filters * block.expansion, output_dim)
        else:
            self.projection_head = nn.Sequential(
                nn.Linear(num_out_filters * block.expansion, hidden_mlp),
                nn.BatchNorm1d(hidden_mlp),
                nn.ReLU(inplace=True),
                nn.Linear(hidden_mlp, output_dim),
            )

        # prototype layer
        self.prototypes = None
        if isinstance(nmb_prototypes, list):
            self.prototypes = MultiPrototypes(output_dim, nmb_prototypes)
        elif nmb_prototypes > 0:
            self.prototypes = nn.Linear(output_dim, nmb_prototypes, bias=False)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        # Zero-initialize the last BN in each residual branch,
        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, Bottleneck):
                    nn.init.constant_(m.bn3.weight, 0)
                elif isinstance(m, BasicBlock):
                    nn.init.constant_(m.bn2.weight, 0)

    def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
        norm_layer = self._norm_layer
        downsample = None
        previous_dilation = self.dilation
        if dilate:
            self.dilation *= stride
            stride = 1
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                conv1x1(self.inplanes, planes * block.expansion, stride),
                norm_layer(planes * block.expansion),
            )

        layers = []
        layers.append(
            block(
                self.inplanes,
                planes,
                stride,
                downsample,
                self.groups,
                self.base_width,
                previous_dilation,
                norm_layer,
            )
        )
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(
                block(
                    self.inplanes,
                    planes,
                    groups=self.groups,
                    base_width=self.base_width,
                    dilation=self.dilation,
                    norm_layer=norm_layer,
                )
            )

        return nn.Sequential(*layers)

    def forward_backbone(self, x):
        x = self.padding(x)

        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        if self.eval_mode:
            return x

        x = self.avgpool(x)
        x = torch.flatten(x, 1)

        return x

    def forward_head(self, x):
        if self.projection_head is not None:
            x = self.projection_head(x)

        if self.l2norm:
            x = nn.functional.normalize(x, dim=1, p=2)

        if self.prototypes is not None:
            return x, self.prototypes(x)
        return x

    def forward(self, inputs):
        if not isinstance(inputs, list):
            inputs = [inputs]
        idx_crops = torch.cumsum(torch.unique_consecutive(
            torch.tensor([inp.shape[-1] for inp in inputs]),
            return_counts=True,
        )[1], 0)
        start_idx = 0
        for end_idx in idx_crops:
            _out = self.forward_backbone(torch.cat(inputs[start_idx: end_idx]))
            # .to(device)
                                        #  cuda(non_blocking=True))
            if start_idx == 0:
                output = _out
            else:
                output = torch.cat((output, _out))
            start_idx = end_idx
        return self.forward_head(output)


class MultiPrototypes(nn.Module):
    def __init__(self, output_dim, nmb_prototypes):
        super(MultiPrototypes, self).__init__()
        self.nmb_heads = len(nmb_prototypes)
        for i, k in enumerate(nmb_prototypes):
            self.add_module("prototypes" + str(i), nn.Linear(output_dim, k, bias=False))

    def forward(self, x):
        out = []
        for i in range(self.nmb_heads):
            out.append(getattr(self, "prototypes" + str(i))(x))
        return out


def resnet50(**kwargs):
    return ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)


def resnet50w2(**kwargs):
    return ResNet(Bottleneck, [3, 4, 6, 3], widen=2, **kwargs)


def resnet50w4(**kwargs):
    return ResNet(Bottleneck, [3, 4, 6, 3], widen=4, **kwargs)


def resnet50w5(**kwargs):
    return ResNet(Bottleneck, [3, 4, 6, 3], widen=5, **kwargs)


In [4]:
hidden_mlp = 2048
feat_dim = 128
nmb_prototypes = 3000 # default was 3000

model = resnet50(normalize=True,
        hidden_mlp=hidden_mlp,
        output_dim=feat_dim,
        nmb_prototypes=nmb_prototypes)

In [2]:
import random
from PIL import ImageFilter, Image
import torchvision
import numpy as np
import torchvision.transforms as transforms

class MultiCropCIFAR10(torchvision.datasets.CIFAR10):
    def __init__(
        self,
        root,
        size_crops,
        nmb_crops,
        min_scale_crops,
        max_scale_crops,
        transform=None,
        target_transform=None,
        size_dataset=-1,
        return_index=False,
        download=True
    ):
        super().__init__(root, transform=transform, target_transform=target_transform, download=download)
        assert len(size_crops) == len(nmb_crops)
        assert len(min_scale_crops) == len(nmb_crops)
        assert len(max_scale_crops) == len(nmb_crops)
        if size_dataset >= 0:
            self.data = self.data[:size_dataset]
            self.targets = self.targets[:size_dataset]
        self.return_index = return_index

        color_transform = [get_color_distortion(), PILRandomGaussianBlur()]
        mean = [0.485, 0.456, 0.406]
        std = [0.228, 0.224, 0.225]
        trans = []
        for i in range(len(size_crops)):
            randomresizedcrop = transforms.RandomResizedCrop(
                size_crops[i],
                scale=(min_scale_crops[i], max_scale_crops[i]),
            )
            trans.extend([transforms.Compose([
                randomresizedcrop,
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.Compose(color_transform),
                transforms.ToTensor(),
                transforms.Normalize(mean=mean, std=std)])
            ] * nmb_crops[i])
        self.trans = trans

    def __getitem__(self, index):
        image = Image.fromarray(self.data[index])
        multi_crops = list(map(lambda trans: trans(image), self.trans))
        target = self.targets[index]
        if self.return_index:
            return index, multi_crops
        return multi_crops
    

class MultiCropSVHN(torchvision.datasets.SVHN):
    def __init__(
        self,
        root,
        size_crops,
        nmb_crops,
        min_scale_crops,
        max_scale_crops,
        transform=None,
        target_transform=None,
        size_dataset=-1,
        return_index=False,
        download=True
    ):
        super().__init__(root, transform=transform, target_transform=target_transform, download=download)
        assert len(size_crops) == len(nmb_crops)
        assert len(min_scale_crops) == len(nmb_crops)
        assert len(max_scale_crops) == len(nmb_crops)
        if size_dataset >= 0:
            self.data = self.data[:size_dataset]
            self.targets = self.labels[:size_dataset]
        self.return_index = return_index

        color_transform = [get_color_distortion(), PILRandomGaussianBlur()]
        mean = [0.485, 0.456, 0.406]
        std = [0.228, 0.224, 0.225]
        trans = []
        for i in range(len(size_crops)):
            randomresizedcrop = transforms.RandomResizedCrop(
                size_crops[i],
                scale=(min_scale_crops[i], max_scale_crops[i]),
            )
            trans.extend([transforms.Compose([
                randomresizedcrop,
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.Compose(color_transform),
                transforms.ToTensor(),
                transforms.Normalize(mean=mean, std=std)])
            ] * nmb_crops[i])
        self.trans = trans

    def __getitem__(self, index):
        image = Image.fromarray(self.data[index])
        multi_crops = list(map(lambda trans: trans(image), self.trans))
        target = int(self.targets[index])
        if self.return_index:
            return index, multi_crops
        return multi_crops

class MultiCropSTL10(torchvision.datasets.STL10):
    def __init__(
        self,
        root,
        size_crops,
        nmb_crops,
        min_scale_crops,
        max_scale_crops,
        transform=None,
        target_transform=None,
        size_dataset=-1,
        return_index=False,
        download=True
    ):
        super().__init__(root, transform=transform, target_transform=target_transform, download=download)
        assert len(size_crops) == len(nmb_crops)
        assert len(min_scale_crops) == len(nmb_crops)
        assert len(max_scale_crops) == len(nmb_crops)
        if size_dataset >= 0:
            self.data = self.data[:size_dataset]
            self.targets = self.labels[:size_dataset]
        self.return_index = return_index

        color_transform = [get_color_distortion(), PILRandomGaussianBlur()]
        mean = [0.485, 0.456, 0.406]
        std = [0.228, 0.224, 0.225]
        trans = []
        for i in range(len(size_crops)):
            randomresizedcrop = transforms.RandomResizedCrop(
                size_crops[i],
                scale=(min_scale_crops[i], max_scale_crops[i]),
            )
            trans.extend([transforms.Compose([
                randomresizedcrop,
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.Compose(color_transform),
                transforms.ToTensor(),
                transforms.Normalize(mean=mean, std=std)])
            ] * nmb_crops[i])
        self.trans = trans

    def __getitem__(self, index):
        image = Image.fromarray(np.transpose(self.data[index], (1, 2, 0)))
        multi_crops = list(map(lambda trans: trans(image), self.trans))
        target = self.targets[index]
        if self.return_index:
            return index, multi_crops
        return multi_crops



class PILRandomGaussianBlur(object):
    """
    Apply Gaussian Blur to the PIL image. Take the radius and probability of
    application as the parameter.
    This transform was used in SimCLR - https://arxiv.org/abs/2002.05709
    """

    def __init__(self, p=0.5, radius_min=0.1, radius_max=2.):
        self.prob = p
        self.radius_min = radius_min
        self.radius_max = radius_max

    def __call__(self, img):
        do_it = np.random.rand() <= self.prob
        if not do_it:
            return img

        return img.filter(
            ImageFilter.GaussianBlur(
                radius=random.uniform(self.radius_min, self.radius_max)
            )
        )


def get_color_distortion(s=1.0):
    # s is the strength of color distortion.
    color_jitter = transforms.ColorJitter(0.8*s, 0.8*s, 0.8*s, 0.2*s)
    rnd_color_jitter = transforms.RandomApply([color_jitter], p=0.8)
    rnd_gray = transforms.RandomGrayscale(p=0.2)
    color_distort = transforms.Compose([rnd_color_jitter, rnd_gray])
    return color_distort

In [3]:
import torch
import torchvision
import numpy as np
import matplotlib.pyplot as plt

# Define the parameters for the multi-crop dataset
data_path = 'image_datasets'
size_crops = [224, 96]
nmb_crops = [2, 6]
min_scale_crops = [0.25, 0.05]
max_scale_crops = [1.0, 0.3]
size_dataset = 100  # number of images to use in the dataset
return_index = False

# Create the multi-crop dataset
dataset = MultiCropCIFAR10(
    data_path,
    size_crops,
    nmb_crops,
    min_scale_crops,
    max_scale_crops,
    size_dataset=size_dataset,
    return_index=return_index,
    download=True
)

# Create a data loader for the dataset
batch_size = 4
data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)


dataset = MultiCropSVHN(
    data_path,
    size_crops,
    nmb_crops,
    min_scale_crops,
    max_scale_crops,
    size_dataset=size_dataset,
    return_index=return_index,
    download=True
)
batch_size = 256
data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)



dataset = MultiCropSTL10(
    data_path,
    size_crops,
    nmb_crops,
    min_scale_crops,
    max_scale_crops,
    size_dataset=size_dataset,
    return_index=return_index,
    download=True
)
batch_size = 4
data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)


Files already downloaded and verified
Using downloaded and verified file: image_datasets/train_32x32.mat
Files already downloaded and verified


In [4]:
import os
import time
import math
import logging
import shutil
import torch.nn.functional as F

# logger = logging.getLogger()
# logger.setLevel(logging.INFO)

# directory_path = "./challange2"
# os.makedirs(directory_path, exist_ok=True)

# handler = logging.FileHandler('training-challange-2.log')

# # Create formatter and add it to the handler
# formatter = logging.Formatter('%(name)s - %(levelname)s - %(message)s')
# handler.setFormatter(formatter)

# # Set STDERR handler as the only handler 
# logger.handlers = [handler]

# use_fp16 = False
# dump_path = "./challange2"
# rank = 0
# world_size = 1
# epoch_queue_starts = 15
# crops_for_assign = [0,1]
# checkpoint_freq = 25
# temperature = 0.1
# freeze_prototypes_niters = 313
# epsilon = 0.05
# sinkhorn_iterations = 3
# epochs = 100

def fix_random_seeds(seed=31):
    """
    Fix random seeds.
    """
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)


class AverageMeter(object):
    """computes and stores the average and current value"""

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res

def train_pretrained_util(train_loader, model, optimizer, epoch, lr_schedule, queue, logger):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()

    crops_for_assign = [0,1]
    temperature = 0.1
    freeze_prototypes_niters = 313
    rank = 0

    model.train()
    use_the_queue = False

    end = time.time()
    for it, inputs in enumerate(train_loader):
        # measure data loading time
        data_time.update(time.time() - end)

        # update learning rate
        iteration = epoch * len(train_loader) + it
        for param_group in optimizer.param_groups:
            param_group["lr"] = lr_schedule[iteration]

        # normalize the prototypes
        with torch.no_grad():
            w = model.prototypes.weight.data.clone()
            w = nn.functional.normalize(w, dim=1, p=2)
            model.prototypes.weight.copy_(w)

        # ============ multi-res forward passes ... ============
        # print(inputs[0].shape)
        embedding, output = model(inputs)
        embedding = embedding.detach()
        bs = inputs[0].size(0)

        # ============ swav loss ... ============
        loss = 0
        for i, crop_id in enumerate(crops_for_assign):
            with torch.no_grad():
                out = output[bs * crop_id: bs * (crop_id + 1)].detach()

                # time to use the queue
                if queue is not None:
                    if use_the_queue or not torch.all(queue[i, -1, :] == 0):
                        use_the_queue = True
                        out = torch.cat((torch.mm(
                            queue[i],
                            model.prototypes.weight.t()
                        ), out))
                    # fill the queue
                    queue[i, bs:] = queue[i, :-bs].clone()
                    queue[i, :bs] = embedding[crop_id * bs: (crop_id + 1) * bs]

                # get assignments
                q = distributed_sinkhorn(out)[-bs:]

            # cluster assignment prediction
            subloss = 0
            for v in np.delete(np.arange(np.sum(nmb_crops)), crop_id):
                x = output[bs * v: bs * (v + 1)] / temperature
                subloss -= torch.mean(torch.sum(q * F.log_softmax(x, dim=1), dim=1))
            loss += subloss / (np.sum(nmb_crops) - 1)
        loss /= len(crops_for_assign)

        # ============ backward and optim step ... ============
        optimizer.zero_grad()
        loss.backward()
        # cancel gradients for the prototypes
        if iteration < freeze_prototypes_niters:
            for name, p in model.named_parameters():
                if "prototypes" in name:
                    p.grad = None
        optimizer.step()

        # ============ misc ... ============
        losses.update(loss.item(), inputs[0].size(0))
        batch_time.update(time.time() - end)
        end = time.time()
        if rank ==0 and it % 50 == 0:
            logger.info(
                "Epoch: [{0}][{1}]\t"
                "Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t"
                "Data {data_time.val:.3f} ({data_time.avg:.3f})\t"
                "Loss {loss.val:.4f} ({loss.avg:.4f})\t"
                "Lr: {lr:.4f}".format(
                    epoch,
                    it,
                    batch_time=batch_time,
                    data_time=data_time,
                    loss=losses,
                    lr=optimizer.param_groups[0]["lr"],
                )
            )
    return (epoch, losses.avg), queue


@torch.no_grad()
def distributed_sinkhorn(out, epsilon=0.05, world_size=1, sinkhorn_iterations=3):
    Q = torch.exp(out / epsilon).t() # Q is K-by-B for consistency with notations from our paper
    B = Q.shape[1] * world_size # number of samples to assign
    K = Q.shape[0] # how many prototypes

    # make the matrix sums to 1
    sum_Q = torch.sum(Q)
    # dist.all_reduce(sum_Q)
    Q /= sum_Q

    for it in range(sinkhorn_iterations):
        # normalize each row: total weight per prototype must be 1/K
        sum_of_rows = torch.sum(Q, dim=1, keepdim=True)
        # dist.all_reduce(sum_of_rows)
        Q /= sum_of_rows
        Q /= K

        # normalize each column: total weight per sample must be 1/B
        Q /= torch.sum(Q, dim=0, keepdim=True)
        Q /= B

    Q *= B # the colomns must sum to 1 so that Q is an assignment
    return Q.t()


# main()

# Fine-Tuning for binary classification

In [5]:
def train_finetuned_util(model, device, train_loader, optimizer, epoch, logger, display=True):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.cross_entropy(output, target)
        loss.backward()
        optimizer.step()
    if display:
        logger.info('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
          epoch, batch_idx * len(data), len(train_loader.dataset),
          100. * batch_idx / len(train_loader), loss.item()))
        print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
          epoch, batch_idx * len(data), len(train_loader.dataset),
          100. * batch_idx / len(train_loader), loss.item()))

def test_finetuned_util(model, device, test_loader, logger):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.cross_entropy(output, target, size_average=False).item() # sum up batch loss
            pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    logger.info('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))
    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))
    return 100. * correct / len(test_loader.dataset)

In [6]:
import torchvision.models as models

from numpy.random import RandomState
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Subset
from torchvision import datasets, transforms


class FineTunedSwav(nn.Module):
    def __init__(self, pre_trained, num_classes) -> None:
        super(FineTunedSwav, self).__init__()
        self.num_classes = num_classes
        model_children = list(pre_trained.children())[:-2]
        self.pre_trained = nn.Sequential(*model_children)
        
        # Replace the last two layers with a single fully connected layer
        last_layer_in_features = pre_trained.projection_head[0].in_features
        self.fc = nn.Sequential(
            nn.Linear(last_layer_in_features, self.num_classes),
            nn.Softmax(dim=1)
        )
    
    def forward(self, x):
        x = self.pre_trained(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x


def train_finetuned(backbone_model, config, logger):
    f_config = config["fine_tuned"]

    crop = transforms.RandomResizedCrop(224)
    hflip = transforms.RandomHorizontalFlip()
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                    std=[0.229, 0.224, 0.225])


    transform = transforms.Compose([
        crop,
        transforms.RandomApply([hflip], p=0.5),
        normalize,
    ])


    # We resize images to allow using imagenet pre-trained models, is there a better way?
    resize = transforms.Resize(224) 

    transform_val = transforms.Compose([resize, transforms.ToTensor(), transform]) #careful to keep this one same
    transform_train = transforms.Compose([resize, transforms.ToTensor(), transform]) 

    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")
    print(device) # you will really need gpu's for this part


    ##### Cifar Data
    cifar_data = datasets.CIFAR10(root=f_config['dataset_path'],train=True, transform=transform_train, download=True)
        
    #We need two copies of this due to weird dataset api 
    cifar_data_val = datasets.CIFAR10(root=f_config['dataset_path'],train=True, transform=transform_val, download=True)
        
    accs = []

    for seed in range(1, 5):
        prng = RandomState(seed)
        random_permute = prng.permutation(np.arange(0, 5000))
        classes =  prng.permutation(np.arange(0,10))
        indx_train = np.concatenate([np.where(np.array(cifar_data.targets) == classe)[0][random_permute[0:25]] for classe in classes[0:2]])
        indx_val = np.concatenate([np.where(np.array(cifar_data.targets) == classe)[0][random_permute[25:225]] for classe in classes[0:2]])

        train_data = Subset(cifar_data, indx_train)
        val_data = Subset(cifar_data_val, indx_val)

        print('Num Samples For Training %d Num Samples For Val %d'%(train_data.indices.shape[0],val_data.indices.shape[0]))
        logger.info('Num Samples For Training %d Num Samples For Val %d'%(train_data.indices.shape[0],val_data.indices.shape[0]))

        train_loader = torch.utils.data.DataLoader(train_data,
                                                    batch_size=f_config['batch_size'], 
                                                    shuffle=True)

        val_loader = torch.utils.data.DataLoader(val_data,
                                                batch_size=f_config['batch_size'], 
                                                shuffle=False)

        # ORIGINAL #
        # model = models.alexnet(pretrained=True)
        # model.classifier = nn.Linear(256 * 6 * 6, 10)
        ############
        model = FineTunedSwav(pre_trained=backbone_model, num_classes=f_config['num_classes'])
    
        optimizer = torch.optim.SGD(model.fc.parameters(), 
                                    lr=f_config['lr'], momentum=f_config['momentum'],
                                    weight_decay=f_config['weight_decay'])

        model.to(device)
        for epoch in range(f_config['epochs']):
            train_finetuned_util(model, device, train_loader, optimizer, epoch, logger=logger, display=epoch%(f_config['epochs']//10)==0)

        accs.append(test_finetuned_util(model, device, val_loader, logger=logger))

    accs = np.array(accs)
    logger.info('Acc over 5 instances: %.2f +- %.2f'%(accs.mean(),accs.std()))
    print('Acc over 5 instances: %.2f +- %.2f'%(accs.mean(),accs.std()))

    

In [7]:
import torch
import os
import time


def train_pretrained(data_loader, config, logger):
    device = config['device']
    p_config = config['pre_trained']
    # build model
    hidden_mlp = p_config['hidden_mlp']
    feat_dim = p_config['feat_dim']
    nmb_prototypes = p_config['num_prototypes'] # default was 3000

    model = resnet50(normalize=True,
            hidden_mlp=hidden_mlp,
            output_dim=feat_dim,
            nmb_prototypes=nmb_prototypes)
    
    # synchronize batch norm layers
    fix_random_seeds()

    base_lr = p_config['base_lr']
    wd = p_config['weight_decay']
    mu = p_config['momentum']
    
    # build optimizer
    optimizer = torch.optim.SGD(
        model.parameters(),
        lr=base_lr,
        momentum=mu,
        weight_decay=wd,
    )
    base_lr = p_config['base_lr']
    wd = p_config['weight_decay']
    mu = p_config['momentum']
    start_warmup = p_config['start_warmup']
    warmup_epochs = p_config['warmup_epochs']
    final_lr = p_config['final_lr']
    epochs = p_config['epochs']
    warmup_lr_schedule = np.linspace(start_warmup, base_lr, len(data_loader) * warmup_epochs)
    iters = np.arange(len(data_loader) * (epochs - warmup_epochs))
    cosine_lr_schedule = np.array([final_lr + 0.5 * (base_lr - final_lr) * (1 + \
                         math.cos(math.pi * t / (len(data_loader) * (epochs - warmup_epochs)))) for t in iters])
    lr_schedule = np.concatenate((warmup_lr_schedule, cosine_lr_schedule))
    logger.info("Building optimizer done.")

    dump_path = config['model_dir']
    rank = 0
    world_size = 1
    epoch_queue_starts = p_config['epoch_queue_starts']
    crops_for_assign = p_config['crops_for_assign']
    # build the queue
    queue_length = p_config['queue_length']
    batch_size = p_config['batch_size']
    queue = None
    queue_path = os.path.join(dump_path, "queue" + str(rank) + ".pth")
    if os.path.isfile(queue_path):
        queue = torch.load(queue_path)["queue"]
    # the queue needs to be divisible by the batch size
    queue_length -= queue_length % (batch_size * world_size)

    start_epoch = 0
    for epoch in range(start_epoch, epochs):

        # train the network for one epoch
        logger.info("============ Starting epoch %i ... ============" % epoch)

        # optionally starts a queue
        if queue_length > 0 and epoch >= epoch_queue_starts and queue is None:
            queue = torch.zeros(
                len(crops_for_assign),
                queue_length // world_size,
                feat_dim,
            ).to(device)

        # train the network
        scores, queue = train_pretrained_util(data_loader, model, optimizer, epoch, lr_schedule, queue, logger=logger)

        # save checkpoints
        save_dict = {
            "epoch": epoch + 1,
            "state_dict": model.state_dict(),
            "optimizer": optimizer.state_dict(),
        }
        chk_path = os.path.join(dump_path, "checkpoint.pth.tar")
        torch.save(
            save_dict,
            chk_path,
        )
        if queue is not None:
            torch.save({"queue": queue}, queue_path)
    
    return model, chk_path

def load_pretrained(model, saved_chk_pt_path):
    checkpoint = torch.load(saved_chk_pt_path)
    checkpoint_key_model = 'state_dict'
    checkpoint_key_optimizer = 'optimizer'
    model_weights = checkpoint[checkpoint_key_model]
    optimizer_state = checkpoint[checkpoint_key_optimizer]
    epoch = checkpoint['epoch']

    hidden_mlp = 2048
    feat_dim = 128
    nmb_prototypes = 512 # default was 3000

    model = resnet50(normalize=True,
                hidden_mlp=hidden_mlp,
                output_dim=feat_dim,
                nmb_prototypes=nmb_prototypes)
    model.load_state_dict(model_weights)

    return model


def get_new_logger(f_path):
    
    logger = logging.getLogger(__name__)
    logger.setLevel(logging.INFO)
    fh = logging.FileHandler(filename=f_path)
    fh.setLevel(logging.INFO)
    formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    fh.setFormatter(formatter)
    logger.addHandler(fh)

    return logger


def pre_train_and_finetune(dataset_name:str = 'svhn'):
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")
    
    config = {
        'dataset_name': dataset_name,
        'log_dir': 'challange2-logs',
        'device': device,
        'pre_trained': {
            'num_classes': 10,
            'hidden_mlp': 2048,
            'feat_dim': 128,
            'num_prototypes': 32,
            'batch_size': 256,
            'data_path': 'image_datasets',
            'size_crops': [224, 96],
            'nmb_crops': [2, 6],
            'min_scale_crops': [0.25, 0.05],
            'max_scale_crops': [1.0, 3.0],
            'size_dataset': 10000,
            'base_lr': 4.8,
            'weight_decay': 0.0005,
            'momentum': 0.9,
            'start_warmup': 0.0,
            'warmup_epochs': 10,
            'final_lr': 0.001,
            'epochs': 5,
            'epoch_queue_starts': 1,
            'crops_for_assign': [0,1],
            'queue_length': 1000
        },
        'fine_tuned': {
            'lr': 0.001,
            'dataset_path': 'image_datasets',
            'num_classes': 10,
            'epochs': 10,
            'batch_size': 32,
            'weight_decay': 0.0005,
            'momentum': 0.9,

        }
    }

    p_config = config['pre_trained']
    f_config = config['fine_tuned']
    # lr = f_config['lr']
    
    t_now = time.time()
    experiment_dir = f'{config["dataset_name"]}_{p_config["epochs"]}_{f_config["epochs"]}_{f_config["batch_size"]}'
    config['log_dir'] = os.path.join(experiment_dir, 'logs')
    config['model_dir'] = os.path.join(experiment_dir, 'models')
    f_name = f'{experiment_dir}_{t_now}__.log'
    
    log_dir = config['log_dir']
    model_dir = config['model_dir']
    f_path = os.path.join(log_dir, f_name)
    os.makedirs(log_dir, exist_ok=True)
    os.makedirs(model_dir, exist_ok=True)
    f_logger = get_new_logger(f_path)

    # Define the parameters for the multi-crop dataset
    p_config = config['pre_trained']
    data_path = p_config["data_path"]
    size_crops = p_config['size_crops']
    nmb_crops = p_config['nmb_crops']
    min_scale_crops = p_config['min_scale_crops']
    max_scale_crops = p_config['max_scale_crops']
    batch_size = p_config['batch_size']
    size_dataset = p_config['size_dataset']  # number of images to use in the dataset
    return_index = False

    dataset = None
    data_loader = None
    if dataset_name == 'svhn':
        # get the dataset
        dataset = MultiCropSVHN(
            data_path,
            size_crops,
            nmb_crops,
            min_scale_crops,
            max_scale_crops,
            size_dataset=size_dataset,
            return_index=return_index,
            download=True
        )
    
    if dataset_name == 'stl10':
        print('using stl10')
        dataset = MultiCropSTL10(
            data_path,
            size_crops,
            nmb_crops,
            min_scale_crops,
            max_scale_crops,
            size_dataset=size_dataset,
            return_index=return_index,
            download=True
        )
    
    data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
        
    # train the model
    model, saved_chk_pt_path = train_pretrained(data_loader, config, f_logger)
    # p_config = config['pre_trained']
    # # build model
    # hidden_mlp = p_config['hidden_mlp']
    # feat_dim = p_config['feat_dim']
    # nmb_prototypes = p_config['num_prototypes'] # default was 3000

    # model = resnet50(normalize=True,
    #         hidden_mlp=hidden_mlp,
    #         output_dim=feat_dim,
    #         nmb_prototypes=nmb_prototypes)
    # saved_chk_pt_path = os.path.join(model_dir, 'checkpoint.pth.tar')
    # model = load_pretrained(model, saved_chk_pt_path)
    train_finetuned(model, config, f_logger)
    


In [8]:
pre_train_and_finetune(dataset_name='stl10')

using stl10
Files already downloaded and verified
cpu
Files already downloaded and verified
Files already downloaded and verified
Num Samples For Training 50 Num Samples For Val 400





Test set: Average loss: 2.2571, Accuracy: 200/400 (50.00%)

Num Samples For Training 50 Num Samples For Val 400

Test set: Average loss: 2.2376, Accuracy: 271/400 (67.75%)

Num Samples For Training 50 Num Samples For Val 400
