In [None]:
import os
import logging
import time
from datetime import timedelta
import pandas as pd


class LogFormatter:
    def __init__(self):
        self.start_time = time.time()

    def format(self, record):
        elapsed_seconds = round(record.created - self.start_time)

        prefix = "%s - %s - %s" % (
            record.levelname,
            time.strftime("%x %X"),
            timedelta(seconds=elapsed_seconds),
        )
        message = record.getMessage()
        message = message.replace("\n", "\n" + " " * (len(prefix) + 3))
        return "%s - %s" % (prefix, message) if message else ""


def create_logger(filepath, rank):
    """
    Create a logger.
    Use a different log file for each process.
    """
    # create log formatter
    log_formatter = LogFormatter()

    # create file handler and set level to debug
    if filepath is not None:
        if rank > 0:
            filepath = "%s-%i" % (filepath, rank)
        file_handler = logging.FileHandler(filepath, "a")
        file_handler.setLevel(logging.DEBUG)
        file_handler.setFormatter(log_formatter)

    # create console handler and set level to info
    console_handler = logging.StreamHandler()
    console_handler.setLevel(logging.INFO)
    console_handler.setFormatter(log_formatter)

    # create logger and set level to debug
    logger = logging.getLogger()
    logger.handlers = []
    logger.setLevel(logging.DEBUG)
    logger.propagate = False
    if filepath is not None:
        logger.addHandler(file_handler)
    logger.addHandler(console_handler)

    # reset logger elapsed time
    def reset_time():
        log_formatter.start_time = time.time()

    logger.reset_time = reset_time

    return logger


class PD_Stats(object):
    """
    Log stuff with pandas library
    """

    def __init__(self, path, columns):
        self.path = path

        # reload path stats
        if os.path.isfile(self.path):
            self.stats = pd.read_pickle(self.path)

            # check that columns are the same
            assert list(self.stats.columns) == list(columns)

        else:
            self.stats = pd.DataFrame(columns=columns)

    def update(self, row, save=True):
        self.stats.loc[len(self.stats.index)] = row

        # save the statistics
        if save:
            self.stats.to_pickle(self.path)

In [None]:
import random
from logging import getLogger

from PIL import ImageFilter
import numpy as np
import torchvision.datasets as datasets
import torchvision.transforms as transforms

logger = getLogger()


class MultiCropDataset(datasets.ImageFolder):
    def __init__(
        self,
        data_path,
        size_crops,
        nmb_crops,
        min_scale_crops,
        max_scale_crops,
        split=None,
        return_index=False,
    ):
        super(MultiCropDataset, self).__init__(data_path)
        assert len(size_crops) == len(nmb_crops)
        assert len(min_scale_crops) == len(nmb_crops)
        assert len(max_scale_crops) == len(nmb_crops)

        self.split = split

        # True to ILSVRC2012_img_val
        num_classes = 1000
        samples_per_class = 50
        p_train = int(0.7*samples_per_class)
        p_val = int(0.2*samples_per_class)
        p_test = int(0.1*samples_per_class)

        if split == 'train':
            indices = []
            for class_idx in range(num_classes):
                class_start_idx = class_idx * samples_per_class
                class_indices = list(range(class_start_idx, class_start_idx + p_train))
                indices.extend(class_indices)

            self.samples = [self.samples[i] for i in indices]
        elif split == 'val':
            indices = []
            for class_idx in range(num_classes):
                class_start_idx = class_idx * samples_per_class
                class_indices = list(range(class_start_idx + p_train, class_start_idx + p_train + p_val))
                indices.extend(class_indices)

            self.samples = [self.samples[i] for i in indices]
        elif split == 'test':
            indices = []
            for class_idx in range(num_classes):
                class_start_idx = class_idx * samples_per_class
                class_indices = list(range(class_start_idx + p_train + p_val, class_start_idx + p_train + p_val + p_test))
                indices.extend(class_indices)

            self.samples = [self.samples[i] for i in indices]
        self.return_index = return_index


        color_transform = [get_color_distortion(), PILRandomGaussianBlur()]
        mean = [0.485, 0.456, 0.406]
        std = [0.228, 0.224, 0.225]
        trans = []
        if split == 'train':
            for i in range(len(size_crops)):
                randomresizedcrop = transforms.RandomResizedCrop(
                    size_crops[i],
                    scale=(min_scale_crops[i], max_scale_crops[i]),
                )
                trans.extend([transforms.Compose([
                    randomresizedcrop,
                    transforms.RandomHorizontalFlip(p=0.5),
                    transforms.Compose(color_transform),
                    transforms.ToTensor(),
                    transforms.Normalize(mean=mean, std=std)])
                ] * nmb_crops[i])
            self.trans = trans
        elif split == 'val': # for later supervised training
            trans.extend([transforms.Compose([
                transforms.RandomResizedCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize(mean=mean, std=std)])
            ])
            self.trans = trans
        elif split == 'test': # for evaluation
            trans.extend([transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                transforms.Normalize(mean=mean, std=std)])
            ])
            self.trans = trans

    def __getitem__(self, index):
        if self.split == 'train':
            path, _ = self.samples[index]
            image = self.loader(path)
            multi_crops = list(map(lambda trans: trans(image), self.trans))
            if self.return_index:
                return index, multi_crops
            return multi_crops
        else:
            path, target = self.samples[index]
            image = self.loader(path)
            multi_crops = list(map(lambda trans: trans(image), self.trans))
            if self.return_index:
                return index, multi_crops, target
            return multi_crops, target



class PILRandomGaussianBlur(object):
    """
    Apply Gaussian Blur to the PIL image. Take the radius and probability of
    application as the parameter.
    This transform was used in SimCLR - https://arxiv.org/abs/2002.05709
    """

    def __init__(self, p=0.5, radius_min=0.1, radius_max=2.):
        self.prob = p
        self.radius_min = radius_min
        self.radius_max = radius_max

    def __call__(self, img):
        do_it = np.random.rand() <= self.prob
        if not do_it:
            return img

        return img.filter(
            ImageFilter.GaussianBlur(
                radius=random.uniform(self.radius_min, self.radius_max)
            )
        )


def get_color_distortion(s=1.0):
    # s is the strength of color distortion.
    color_jitter = transforms.ColorJitter(0.8*s, 0.8*s, 0.8*s, 0.2*s)
    rnd_color_jitter = transforms.RandomApply([color_jitter], p=0.8)
    rnd_gray = transforms.RandomGrayscale(p=0.2)
    color_distort = transforms.Compose([rnd_color_jitter, rnd_gray])
    return color_distort

In [None]:
import argparse
from logging import getLogger
import pickle
import os

import numpy as np
import torch

import torch.distributed as dist

FALSY_STRINGS = {"off", "false", "0"}
TRUTHY_STRINGS = {"on", "true", "1"}


logger = getLogger()


def bool_flag(s):
    """
    Parse boolean arguments from the command line.
    """
    if s.lower() in FALSY_STRINGS:
        return False
    elif s.lower() in TRUTHY_STRINGS:
        return True
    else:
        raise argparse.ArgumentTypeError("invalid value for a boolean flag")


def init_distributed_mode(args): # Modified
    """
    Initialize the following variables:
        - world_size
        - rank
    """

    args.rank = 0  # Since it's single-node training, set rank to 0
    args.world_size = torch.cuda.device_count()  # Set world_size to the number of available GPUs

    # Try to initialize distributed process group
    try:
        dist.init_process_group(
            backend="nccl",
            init_method="tcp://localhost:12345",
            world_size=args.world_size,
            rank=args.rank,
        )
    except RuntimeError as e:
        # Handle the case when the default process group is already initialized
        if "trying to initialize the default process group twice" in str(e):
            print('default process group is already initialized')
            pass  # Ignore the error and continue

    # Set cuda device
    args.gpu_to_work_on = 0  # Set the GPU device to use (e.g., 0 for the first GPU)
    torch.cuda.set_device(args.gpu_to_work_on)

    return args

# # With TPU
# def init_distributed_mode(args):
#     """
#     Initialize the following variables:
#         - world_size
#         - rank
#     """

#     args.rank = 0  # Since it's single-node training, set rank to 0

#     if hasattr(args, 'tpu') and args.tpu:  # Check if TPU is enabled # Find another way
#         import torch_xla.core.xla_model as xm

#         args.world_size = xm.xrt_world_size()  # Set world_size to the number of available TPUs
#         args.rank = xm.get_ordinal()  # Set rank to the current TPU ordinal

#         # Initialize the TPU device
#         device = xm.xla_device()
#         torch.set_default_tensor_type("torch.FloatTensor")
#     else:
#         args.world_size = torch.cuda.device_count()  # Set world_size to the number of available GPUs

#         # Try to initialize the distributed process group
#         try:
#             dist.init_process_group(
#                 backend="nccl",
#                 init_method="tcp://localhost:12345",
#                 world_size=args.world_size,
#                 rank=args.rank,
#             )
#         except RuntimeError as e:
#             # Handle the case when the default process group is already initialized
#             if "trying to initialize the default process group twice" in str(e):
#                 print('default process group is already initialized')
#                 pass  # Ignore the error and continue

#         # Set the CUDA device
#         args.gpu_to_work_on = 0  # Set the GPU device to use (e.g., 0 for the first GPU)
#         torch.cuda.set_device(args.gpu_to_work_on)
#         device = torch.device("cuda")

#     return args, device


def initialize_exp(params, *args, dump_params=True):
    """
    Initialize the experience:
    - dump parameters
    - create checkpoint repo
    - create a logger
    - create a panda object to keep track of the training statistics
    """

    # dump parameters
    if dump_params:
        pickle.dump(params, open(os.path.join(params.dump_path, "params.pkl"), "wb"))

    # create repo to store checkpoints
    params.dump_checkpoints = os.path.join(params.dump_path, "checkpoints")
    if not params.rank and not os.path.isdir(params.dump_checkpoints):
        os.mkdir(params.dump_checkpoints)

    # create a panda object to log loss and acc
    training_stats = PD_Stats(
        os.path.join(params.dump_path, "stats" + str(params.rank) + ".pkl"), args
    )

    # create a logger
    logger = create_logger(
        os.path.join(params.dump_path, "train.log"), rank=params.rank
    )
    logger.info("============ Initialized logger ============")
    logger.info(
        "\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(params)).items()))
    )
    logger.info("The experiment will be stored in %s\n" % params.dump_path)
    logger.info("")
    return logger, training_stats


def restart_from_checkpoint(ckp_paths, run_variables=None, **kwargs):
    """
    Re-start from checkpoint
    """
    # look for a checkpoint in exp repository
    if isinstance(ckp_paths, list):
        for ckp_path in ckp_paths:
            if os.path.isfile(ckp_path):
                break
    else:
        ckp_path = ckp_paths

    if not os.path.isfile(ckp_path):
        return

    logger.info("Found checkpoint at {}".format(ckp_path))

    # open checkpoint file
    checkpoint = torch.load(
        ckp_path, map_location="cuda:" + str(torch.distributed.get_rank() % torch.cuda.device_count())
    )

    # key is what to look for in the checkpoint file
    # value is the object to load
    # example: {'state_dict': model}
    for key, value in kwargs.items():
        if key in checkpoint and value is not None:
            try:
                msg = value.load_state_dict(checkpoint[key], strict=False)
                print(msg)
            except TypeError:
                msg = value.load_state_dict(checkpoint[key])
            logger.info("=> loaded {} from checkpoint '{}'".format(key, ckp_path))
        else:
            logger.warning(
                "=> failed to load {} from checkpoint '{}'".format(key, ckp_path)
            )

    # re load variable important for the run
    if run_variables is not None:
        for var_name in run_variables:
            if var_name in checkpoint:
                run_variables[var_name] = checkpoint[var_name]


def fix_random_seeds(seed=31):
    """
    Fix random seeds.
    """
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)


class AverageMeter(object):
    """computes and stores the average and current value"""

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) # correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res

In [None]:
import torch
import torch.nn as nn


def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(
        in_planes,
        out_planes,
        kernel_size=3,
        stride=stride,
        padding=dilation,
        groups=groups,
        bias=False,
        dilation=dilation,
    )


def conv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)


class BasicBlock(nn.Module):
    expansion = 1
    __constants__ = ["downsample"]

    def __init__(
        self,
        inplanes,
        planes,
        stride=1,
        downsample=None,
        groups=1,
        base_width=64,
        dilation=1,
        norm_layer=None,
    ):
        super(BasicBlock, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        if groups != 1 or base_width != 64:
            raise ValueError("BasicBlock only supports groups=1 and base_width=64")
        if dilation > 1:
            raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
        # Both self.conv1 and self.downsample layers downsample the input when stride != 1
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = norm_layer(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = norm_layer(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out


class Bottleneck(nn.Module):
    expansion = 4
    __constants__ = ["downsample"]

    def __init__(
        self,
        inplanes,
        planes,
        stride=1,
        downsample=None,
        groups=1,
        base_width=64,
        dilation=1,
        norm_layer=None,
    ):
        super(Bottleneck, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        width = int(planes * (base_width / 64.0)) * groups
        # Both self.conv2 and self.downsample layers downsample the input when stride != 1
        self.conv1 = conv1x1(inplanes, width)
        self.bn1 = norm_layer(width)
        self.conv2 = conv3x3(width, width, stride, groups, dilation)
        self.bn2 = norm_layer(width)
        self.conv3 = conv1x1(width, planes * self.expansion)
        self.bn3 = norm_layer(planes * self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out


class ResNet(nn.Module):
    def __init__(
            self,
            block,
            layers,
            zero_init_residual=False,
            groups=1,
            widen=1,
            width_per_group=64,
            replace_stride_with_dilation=None,
            norm_layer=None,
            normalize=False,
            output_dim=0,
            hidden_mlp=0,
            nmb_prototypes=0,
            eval_mode=False,
    ):
        super(ResNet, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        self._norm_layer = norm_layer

        self.eval_mode = eval_mode
        self.padding = nn.ConstantPad2d(1, 0.0)

        self.inplanes = width_per_group * widen
        self.dilation = 1
        if replace_stride_with_dilation is None:
            # each element in the tuple indicates if we should replace
            # the 2x2 stride with a dilated convolution instead
            replace_stride_with_dilation = [False, False, False]
        if len(replace_stride_with_dilation) != 3:
            raise ValueError(
                "replace_stride_with_dilation should be None "
                "or a 3-element tuple, got {}".format(replace_stride_with_dilation)
            )
        self.groups = groups
        self.base_width = width_per_group

        # change padding 3 -> 2 compared to original torchvision code because added a padding layer
        num_out_filters = width_per_group * widen
        self.conv1 = nn.Conv2d(
            3, num_out_filters, kernel_size=7, stride=2, padding=2, bias=False
        )
        self.bn1 = norm_layer(num_out_filters)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, num_out_filters, layers[0])
        num_out_filters *= 2
        self.layer2 = self._make_layer(
            block, num_out_filters, layers[1], stride=2, dilate=replace_stride_with_dilation[0]
        )
        num_out_filters *= 2
        self.layer3 = self._make_layer(
            block, num_out_filters, layers[2], stride=2, dilate=replace_stride_with_dilation[1]
        )
        num_out_filters *= 2
        self.layer4 = self._make_layer(
            block, num_out_filters, layers[3], stride=2, dilate=replace_stride_with_dilation[2]
        )
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))

        # normalize output features
        self.l2norm = normalize

        # projection head
        if output_dim == 0:
            self.projection_head = None
        elif hidden_mlp == 0:
            self.projection_head = nn.Linear(num_out_filters * block.expansion, output_dim)
        else:
            self.projection_head = nn.Sequential(
                nn.Linear(num_out_filters * block.expansion, hidden_mlp),
                nn.BatchNorm1d(hidden_mlp),
                nn.ReLU(inplace=True),
                nn.Linear(hidden_mlp, output_dim),
            )

        # prototype layer
        self.prototypes = None
        if isinstance(nmb_prototypes, list):
            self.prototypes = MultiPrototypes(output_dim, nmb_prototypes)
        elif nmb_prototypes > 0:
            self.prototypes = nn.Linear(output_dim, nmb_prototypes, bias=False)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        # Zero-initialize the last BN in each residual branch,
        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, Bottleneck):
                    nn.init.constant_(m.bn3.weight, 0)
                elif isinstance(m, BasicBlock):
                    nn.init.constant_(m.bn2.weight, 0)

    def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
        norm_layer = self._norm_layer
        downsample = None
        previous_dilation = self.dilation
        if dilate:
            self.dilation *= stride
            stride = 1
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                conv1x1(self.inplanes, planes * block.expansion, stride),
                norm_layer(planes * block.expansion),
            )

        layers = []
        layers.append(
            block(
                self.inplanes,
                planes,
                stride,
                downsample,
                self.groups,
                self.base_width,
                previous_dilation,
                norm_layer,
            )
        )
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(
                block(
                    self.inplanes,
                    planes,
                    groups=self.groups,
                    base_width=self.base_width,
                    dilation=self.dilation,
                    norm_layer=norm_layer,
                )
            )

        return nn.Sequential(*layers)

    def forward_backbone(self, x):
        x = self.padding(x)

        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        if self.eval_mode:
            return x

        x = self.avgpool(x)
        x = torch.flatten(x, 1)

        return x

    def forward_head(self, x):
        if self.projection_head is not None:
            x = self.projection_head(x)

        if self.l2norm:
            x = nn.functional.normalize(x, dim=1, p=2)

        if self.prototypes is not None:
            return x, self.prototypes(x)
        return x

    def forward(self, inputs):
        if not isinstance(inputs, list):
            inputs = [inputs]
        idx_crops = torch.cumsum(torch.unique_consecutive(
            torch.tensor([inp.shape[-1] for inp in inputs]),
            return_counts=True,
        )[1], 0)
        start_idx = 0
        for end_idx in idx_crops:
            _out = self.forward_backbone(torch.cat(inputs[start_idx: end_idx]).cuda(non_blocking=True))
            if start_idx == 0:
                output = _out
            else:
                output = torch.cat((output, _out))
            start_idx = end_idx
        return self.forward_head(output)


class MultiPrototypes(nn.Module):
    def __init__(self, output_dim, nmb_prototypes):
        super(MultiPrototypes, self).__init__()
        self.nmb_heads = len(nmb_prototypes)
        for i, k in enumerate(nmb_prototypes):
            self.add_module("prototypes" + str(i), nn.Linear(output_dim, k, bias=False))

    def forward(self, x):
        out = []
        for i in range(self.nmb_heads):
            out.append(getattr(self, "prototypes" + str(i))(x))
        return out


def resnet50(**kwargs):
    return ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)


def resnet50w2(**kwargs):
    return ResNet(Bottleneck, [3, 4, 6, 3], widen=2, **kwargs)


def resnet50w4(**kwargs):
    return ResNet(Bottleneck, [3, 4, 6, 3], widen=4, **kwargs)


def resnet50w5(**kwargs):
    return ResNet(Bottleneck, [3, 4, 6, 3], widen=5, **kwargs)

In [None]:
import os
import time
from logging import getLogger
import urllib

import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data as data
import torchvision.transforms as transforms
import torchvision.datasets as datasets

logger = getLogger()

In [None]:
class Args:
    def __init__(self):
        #########################
        #### main parameters ####
        #########################
        self.labels_perc = "10"  # fine-tune on either 1% or 10% of labels
        # change to the path to save results (swav, swav_multicrop, deepcluster, simCLR)
        self.dump_path = "/content/drive/simCLR"  # experiment dump path for checkpoints and log
        self.seed = 31  # seed
        self.data_path = '/content/drive/ILSVRC2012_img_val'  # path to imagenet
        self.workers = 10  # number of data loading workers

        #########################
        #### model parameters ###
        #########################
        self.arch = "resnet50"  # convnet architecture
        # change to the path of .tar file for evaluation (swav, swav_multicrop, deepcluster, simCLR)
        self.pretrained = "/content/drive/simCLR/checkpoint.pth.tar"  # path to pretrained weights # might need to remove /checkpoint.pth.tar

        #########################
        #### optim parameters ###
        #########################
        self.epochs = 50  # number of total epochs to run
        self.batch_size = 32  # batch size per gpu, i.e. how many unique instances per gpu
        self.lr = 0.01  # initial learning rate - trunk
        self.lr_last_layer = 0.2  # initial learning rate - head
        self.decay_epochs = [12, 16]  # Epochs at which to decay learning rate.
        self.gamma = 0.2  # lr decay factor

        #########################
        #### dist parameters ###
        #########################
        self.dist_url = "env://"  # url used to set up distributed training
        self.world_size = -1  # number of processes: it is set automatically and should not be passed as an argument
        self.rank = 0  # rank of this process: it is set automatically and should not be passed as an argument
        self.local_rank = 0  # this argument is not used and should be ignored


global eval_args, best_acc
eval_args = Args()
init_distributed_mode(eval_args)
fix_random_seeds(eval_args.seed)
logger, training_stats = initialize_exp(
    eval_args, "epoch", "loss", "prec1", "prec5", "loss_val", "prec1_val", "prec5_val"
)

In [None]:
# Build val set
sup_train_set = MultiCropDataset(
    data_path=eval_args.data_path,
    # ---filler inputs----
    size_crops=[1],
    nmb_crops=[1],
    min_scale_crops=[1],
    max_scale_crops=[1],
    # --------------------
    split='val'
)
sampler = torch.utils.data.distributed.DistributedSampler(sup_train_set)
sup_train_loader = torch.utils.data.DataLoader(
    sup_train_set,
    sampler=sampler,
    batch_size=eval_args.batch_size,
    num_workers=eval_args.workers,
    pin_memory=True,
)

# Build test set
val_set = MultiCropDataset(
    data_path=eval_args.data_path,
    # ---filler inputs----
    size_crops=[1],
    nmb_crops=[1],
    min_scale_crops=[1],
    max_scale_crops=[1],
    # --------------------
    split='test'
)

val_loader = torch.utils.data.DataLoader(
    val_set,
    batch_size=eval_args.batch_size,
    num_workers=eval_args.workers,
    pin_memory=True,
)

logger.info("Building data done with {} images loaded.".format(len(sup_train_set)))
logger.info("Building data done with {} images loaded.".format(len(val_set)))

In [None]:
# build model
if eval_args.arch == 'resnet50': model = resnet50(output_dim=1000)
elif eval_args.arch == 'resnet50w2': model = resnet50w2(output_dim=1000)
elif eval_args.arch == 'resnet50w4': model = resnet50w4(output_dim=1000)
elif eval_args.arch == 'resnet50w5': model = resnet50w5(output_dim=1000)

# convert batch norm layers
model = nn.SyncBatchNorm.convert_sync_batchnorm(model)

# load weights
if os.path.isfile(eval_args.pretrained):
    state_dict = torch.load(eval_args.pretrained, map_location="cuda:" + str(eval_args.gpu_to_work_on))
    if "state_dict" in state_dict:
        state_dict = state_dict["state_dict"]
    # remove prefixe "module."
    state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
    for k, v in model.state_dict().items():
        if k not in list(state_dict):
            logger.info('key "{}" could not be found in provided state dict'.format(k))
        elif state_dict[k].shape != v.shape:
            logger.info('key "{}" is of different shape in model and provided state dict'.format(k))
            state_dict[k] = v
    msg = model.load_state_dict(state_dict, strict=False)
    logger.info("Load pretrained model with msg: {}".format(msg))
else:
    logger.info("No pretrained weights found => training from random weights")

# model to gpu
model = model.cuda()
model = nn.parallel.DistributedDataParallel(
    model,
    device_ids=[eval_args.gpu_to_work_on],
    find_unused_parameters=True,
)

In [None]:
# set optimizer
trunk_parameters = []
head_parameters = []
for name, param in model.named_parameters():
    if 'head' in name:
        head_parameters.append(param)
    else:
        trunk_parameters.append(param)
optimizer = torch.optim.SGD(
    [{'params': trunk_parameters},
     {'params': head_parameters, 'lr': eval_args.lr_last_layer}],
    lr=eval_args.lr,
    momentum=0.9,
    weight_decay=0,
)
# set scheduler
scheduler = torch.optim.lr_scheduler.MultiStepLR(
    optimizer, eval_args.decay_epochs, gamma=eval_args.gamma
)

In [None]:
# Optionally resume from a checkpoint
to_restore = {"epoch": 0, "best_acc": (0., 0.)}
restart_from_checkpoint(
    os.path.join(eval_args.dump_path, "checkpoint.pth.tar"),
    run_variables=to_restore,
    state_dict=model,
    optimizer=optimizer,
    scheduler=scheduler,
)
start_epoch = to_restore["epoch"]
best_acc = to_restore["best_acc"]
cudnn.benchmark = True

In [None]:
def train(model, optimizer, loader, epoch):
    """
    Train the models on the dataset.
    """
    # running statistics
    batch_time = AverageMeter()
    data_time = AverageMeter()

    # training statistics
    top1 = AverageMeter()
    top5 = AverageMeter()
    losses = AverageMeter()
    end = time.perf_counter()

    model.train()
    criterion = nn.CrossEntropyLoss().cuda()

    for iter_epoch, (inp, target) in enumerate(loader):
        # measure data loading time
        data_time.update(time.perf_counter() - end)

        # move to gpu
        inp = inp[0].cuda(non_blocking=True)
        target = target.cuda(non_blocking=True)

        # forward
        output = model(inp)

        # compute cross entropy loss
        loss = criterion(output, target)

        # compute the gradients
        optimizer.zero_grad()
        loss.backward()

        # step
        optimizer.step()

        # update stats
        acc1, acc5 = accuracy(output, target, topk=(1, 5))
        losses.update(loss.item(), inp.size(0))
        top1.update(acc1[0], inp.size(0))
        top5.update(acc5[0], inp.size(0))

        batch_time.update(time.perf_counter() - end)
        end = time.perf_counter()

        # verbose
        if eval_args.rank == 0 and iter_epoch % 50 == 0:
            logger.info(
                "Epoch[{0}] - Iter: [{1}/{2}]\t"
                "Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t"
                "Data {data_time.val:.3f} ({data_time.avg:.3f})\t"
                "Loss {loss.val:.4f} ({loss.avg:.4f})\t"
                "Prec {top1.val:.3f} ({top1.avg:.3f})\t"
                "LR trunk {lr}\t"
                "LR head {lr_W}".format(
                    epoch,
                    iter_epoch,
                    len(loader),
                    batch_time=batch_time,
                    data_time=data_time,
                    loss=losses,
                    top1=top1,
                    lr=optimizer.param_groups[0]["lr"],
                    lr_W=optimizer.param_groups[1]["lr"],
                )
            )
    return epoch, losses.avg, top1.avg.item(), top5.avg.item()


def validate_network(val_loader, model):
    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()
    global best_acc

    # switch to evaluate mode
    model.eval()

    criterion = nn.CrossEntropyLoss().cuda()

    with torch.no_grad():
        end = time.perf_counter()
        for i, (inp, target) in enumerate(val_loader):

            # move to gpu
            inp = inp[0].cuda(non_blocking=True)
            target = target.cuda(non_blocking=True)

            # compute output
            output = model(inp)
            loss = criterion(output, target)

            acc1, acc5 = accuracy(output, target, topk=(1, 5))
            losses.update(loss.item(), inp.size(0))
            top1.update(acc1[0], inp.size(0))
            top5.update(acc5[0], inp.size(0))

            # measure elapsed time
            batch_time.update(time.perf_counter() - end)
            end = time.perf_counter()

    if top1.avg.item() > best_acc[0]:
        best_acc = (top1.avg.item(), top5.avg.item())

    if eval_args.rank == 0:
        logger.info(
            "Test:\t"
            "Time {batch_time.avg:.3f}\t"
            "Loss {loss.avg:.4f}\t"
            "Acc@1 {top1.avg:.3f}\t"
            "Best Acc@1 so far {acc:.1f}".format(
                batch_time=batch_time, loss=losses, top1=top1, acc=best_acc[0]))

    return losses.avg, top1.avg.item(), top5.avg.item()

In [None]:
for epoch in range(start_epoch, eval_args.epochs):

    # train the network for one epoch
    logger.info("============ Starting epoch %i ... ============" % epoch)

    # set samplers
    sup_train_loader.sampler.set_epoch(epoch)

    scores = train(model, optimizer, sup_train_loader, epoch)
    scores_val = validate_network(val_loader, model)
    training_stats.update(scores + scores_val)

    scheduler.step()

    # save checkpoint
    if eval_args.rank == 0:
        save_dict = {
            "epoch": epoch + 1,
            "state_dict": model.state_dict(),
            "optimizer": optimizer.state_dict(),
            "scheduler": scheduler.state_dict(),
            "best_acc": best_acc,
        }
        torch.save(save_dict, os.path.join(eval_args.dump_path, "checkpoint.pth.tar"))

logger.info("Fine-tuning with {}% of labels completed.\n"
            "Test accuracies: top-1 {acc1:.1f}, top-5 {acc5:.1f}".format(
            eval_args.labels_perc, acc1=best_acc[0], acc5=best_acc[1]))