 # **Challenge 1**

In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, datasets
import numpy as np

# define datasets and loaders

In [2]:
# Install necessary libraries
!pip install requests tqdm pillow

import pickle
import os
import numpy as np
from tqdm import tqdm
import requests
import tarfile
from PIL import Image
import shutil  # Importing shutil here

import os.path as osp
import os
from torch.utils.data import Dataset
from torchvision import transforms
import torch
import numpy as np




In [3]:
def download_file(url, filename):
    """
    Helper method handling downloading large files from `url` to `filename`. Returns a pointer to `filename`.
    """
    chunkSize = 1024
    r = requests.get(url, stream=True)
    with open(filename, 'wb') as f:
        pbar = tqdm(unit="B", total=int(r.headers['Content-Length']))
        for chunk in r.iter_content(chunk_size=chunkSize):
            if chunk:  # filter out keep-alive new chunks
                pbar.update(len(chunk))
                f.write(chunk)
    return filename

if not os.path.exists("cifar-10-python.tar.gz"):
    print("Downloading cifar-10-python.tar.gz\n")
    download_file('https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz', 'cifar-10-python.tar.gz')
    print("Downloading done.\n")
else:
    print("Dataset already downloaded. Did not download twice.\n")

# Unpack the tar file
tarname = "cifar-10-python.tar.gz"
print("Untarring: {}".format(tarname))
tar = tarfile.open(tarname)
tar.extractall()
tar.close()

datapath = "cifar-10-batches-py"

print("Extracting jpg images and classes from pickle files")
# in CIFAR 10, the files are given in multiple batch files for training
batches = ['data_batch_1', 'data_batch_2', 'data_batch_3', 'data_batch_4', 'data_batch_5', 'test_batch']
labels = pickle.load(open(os.path.join(datapath, 'batches.meta'), 'rb'), encoding="ASCII")

# Create directories for train, val, and test
os.makedirs(os.path.join('cifar-fs', 'train'), exist_ok=True)
os.makedirs(os.path.join('cifar-fs', 'val'), exist_ok=True)
os.makedirs(os.path.join('cifar-fs', 'test'), exist_ok=True)

for i, batch in enumerate(batches):
    print("Handling pickle file: {}".format(batch))
    fpath = os.path.join(datapath, batch)
    with open(fpath, 'rb') as f:
        d = pickle.load(f, encoding='bytes')
    for j, (img, label) in enumerate(zip(d[b'data'], d[b'labels'])):
        img = img.reshape(3, 32, 32).transpose(1, 2, 0)  # Convert from CHW to HWC
        img_folder = labels['label_names'][label]
        dataset_type = 'test' if i == len(batches) - 1 else 'train'  # Last batch is test set
        if i < len(batches) - 1 and j < 1000:  # First 1000 images of train set used as val
            dataset_type = 'val'
        img_path = os.path.join('cifar-fs', dataset_type, img_folder)
        os.makedirs(img_path, exist_ok=True)
        img_filename = os.path.join(img_path, f'{j}.jpg')
        Image.fromarray(img).save(img_filename)

print("Cleaning up downloaded files")
os.remove('cifar-10-python.tar.gz')
shutil.rmtree(datapath, ignore_errors=True)
print("Setup complete.")


Downloading cifar-10-python.tar.gz



100%|██████████| 170498071/170498071 [00:05<00:00, 31868428.22B/s]


Downloading done.

Untarring: cifar-10-python.tar.gz
Extracting jpg images and classes from pickle files
Handling pickle file: data_batch_1
Handling pickle file: data_batch_2
Handling pickle file: data_batch_3
Handling pickle file: data_batch_4
Handling pickle file: data_batch_5
Handling pickle file: test_batch
Cleaning up downloaded files
Setup complete.


In [4]:
class CategoriesSampler:

    def __init__(self, set_name, labels, num_episodes,
                 num_way, num_shot, num_query, const_loader, replace=True):

        self.set_name = set_name
        self.num_way = num_way
        self.num_shot = num_shot
        self.num_query = num_query
        self.num_episodes = num_episodes
        self.const_loader = const_loader   # same tasks in different epochs. good for validation
        self.replace = replace             # sample few-shot tasks with replacement (same class can appear twice or more

        self.m_ind = []
        self.batches = []

        labels = np.array(labels)
        for i in range(max(labels) + 1):
            ind = np.argwhere(labels == i).reshape(-1)
            ind = torch.from_numpy(ind)
            self.m_ind.append(ind)

        self.classes = np.arange(len(self.m_ind))

        if self.const_loader:
            for i_batch in range(self.num_episodes):
                batch = []
                # -- faster loading with np.choice -- #
                # classes = torch.randperm(len(self.m_ind))[:self.num_way]
                classes = np.random.choice(self.classes, size=self.num_way, replace=self.replace)
                for c in classes:
                    l = self.m_ind[c]
                    pos = np.random.choice(np.arange(l.shape[0]),
                                           size=self.num_shot + self.num_query,
                                           replace=False)
                    batch.append(l[pos])

                batch = torch.from_numpy(np.stack(batch)).t().reshape(-1)
                self.batches.append(batch)

    def __len__(self):
        return self.num_episodes

    def __iter__(self):
        if not self.const_loader:
            for batch_idx in range(self.num_episodes):
                batch = []
                # classes = torch.randperm(len(self.m_ind))[:self.num_way]
                classes = np.random.choice(self.classes, size=self.num_way, replace=self.replace)
                for c in classes:
                    l = self.m_ind[c]
                    pos = np.random.choice(np.arange(l.shape[0]),
                                           size=self.num_shot + self.num_query,
                                           replace=False)
                    batch.append(l[pos])

                batch = torch.from_numpy(np.stack(batch)).t().reshape(-1)
                yield batch
        else:
            for batch_idx in range(self.num_episodes):
                yield self.batches[batch_idx]



class CIFAR(Dataset):

    def __init__(self, data_path: str, setname: str, backbone: str, augment: bool):
        d = osp.join(data_path, setname)
        dirs = [os.path.join(d, o) for o in os.listdir(d) if os.path.isdir(os.path.join(d, o))]

        data = []
        label = []
        lb = -1

        for d in dirs:
            lb += 1
            for image_name in os.listdir(d):
                path = osp.join(d, image_name)
                data.append(path)
                label.append(lb)

        self.data = data
        self.label = label

        mean = [x / 255.0 for x in [129.37731888, 124.10583864, 112.47758569]]
        std = [x / 255.0 for x in [68.20947949, 65.43124043, 70.45866994]]
        normalize = transforms.Normalize(mean=mean, std=std)

        self.image_size = 32
        if augment and setname == 'train':
            transforms_list = [
                transforms.RandomResizedCrop(self.image_size),
                transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
            ]
        else:
            transforms_list = [
                transforms.Resize((self.image_size, self.image_size)),
                transforms.ToTensor(),
            ]

        self.transform = transforms.Compose(
            transforms_list + [normalize]
        )

    def __len__(self):
        return len(self.data)

    def __getitem__(self, i):
        path, label = self.data[i], self.label[i]
        image = self.transform(Image.open(path).convert('RGB'))
        return image, label, path


def get_transform(img_size: int, split_name: str):
    mean = [x / 255.0 for x in [129.37731888, 124.10583864, 112.47758569]]
    std = [x / 255.0 for x in [68.20947949, 65.43124043, 70.45866994]]
    normalize = transforms.Normalize(mean=mean, std=std)

    if split_name == 'train':
        return transforms.Compose([
            # transforms.RandomResizedCrop((img_size, img_size), scale=(0.05, 1.0)),
            transforms.RandomCrop(32, padding=4),
            transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize
        ])

    else:
        return transforms.Compose([
            transforms.Resize((img_size, img_size)),
            transforms.ToTensor(),
            normalize
        ])

In [5]:
datasets = dict(cifar=CIFAR)

def get_dataloader(set_name: str
                   , num_episodes
                   , train_way
                   , val_way
                   , num_shot
                   , num_query
                   , data_path
                   ,dataset: str = "cifar"
                   , backbone = 'resnet12'
                  , constant: bool = False
                   ,augment: bool = False):
    """
    Get dataloader with categorical sampler for few-shot classification.
    """
    # num_episodes = args.set_episodes[set_name]
    num_episodes = num_episodes
    num_way = train_way if set_name == 'train' else val_way

    # define dataset sampler and data loader
    data_set = datasets[dataset.lower()](
        data_path, set_name, backbone, augment=set_name == 'train' and augment
    )
    # args.img_size = data_set.image_size

    data_sampler = CategoriesSampler(
        set_name, data_set.label, num_episodes, const_loader=constant,
        num_way=num_way, num_shot=num_shot, num_query=num_query
    )
    return DataLoader(
        data_set, batch_sampler=data_sampler, pin_memory=not constant
    )

In [6]:
dataset_path = '/content/cifar-fs'

train_loader = get_dataloader(set_name='train'
                              , num_episodes=1
                              , train_way = 2
                              , val_way = 2
                              , num_shot = 5
                              , num_query =20
                              , data_path = dataset_path
                              ,dataset = "cifar"
                              , backbone = 'resnet12'
                              , constant = False
                              ,augment = False)

val_loader = get_dataloader(set_name='val'
                              , num_episodes=8
                              , train_way = 2
                              , val_way = 2
                              , num_shot = 5
                              , num_query =20
                              , data_path = dataset_path
                              , dataset = "cifar"
                              , backbone = 'resnet12'
                              , constant = True
                              ,augment = False)


# define model

In [7]:
import torch.nn as nn
import torch
import torch.nn.functional as F
from torch.distributions import Bernoulli

In [8]:
class DropBlock(nn.Module):
    def __init__(self, block_size):
        super(DropBlock, self).__init__()

        self.block_size = block_size

    def forward(self, x, gamma):
        # shape: (bsize, channels, height, width)

        if self.training:
            batch_size, channels, height, width = x.shape
            bernoulli = Bernoulli(gamma)
            mask = bernoulli.sample((batch_size, channels, height - (self.block_size - 1), width - (self.block_size - 1)))
            if torch.cuda.is_available():
                mask = mask.cuda()
            block_mask = self._compute_block_mask(mask)
            countM = block_mask.size()[0] * block_mask.size()[1] * block_mask.size()[2] * block_mask.size()[3]
            count_ones = block_mask.sum()

            return block_mask * x * (countM / count_ones)
        else:
            return x

    def _compute_block_mask(self, mask):
        left_padding = int((self.block_size-1) / 2)
        right_padding = int(self.block_size / 2)

        batch_size, channels, height, width = mask.shape
        non_zero_idxs = mask.nonzero()
        nr_blocks = non_zero_idxs.shape[0]

        offsets = torch.stack(
            [
                torch.arange(self.block_size).view(-1, 1).expand(self.block_size, self.block_size).reshape(-1), # - left_padding,
                torch.arange(self.block_size).repeat(self.block_size), #- left_padding
            ]
        ).t()
        offsets = torch.cat((torch.zeros(self.block_size**2, 2).long(), offsets.long()), 1)
        if torch.cuda.is_available():
            offsets = offsets.cuda()

        if nr_blocks > 0:
            non_zero_idxs = non_zero_idxs.repeat(self.block_size ** 2, 1)
            offsets = offsets.repeat(nr_blocks, 1).view(-1, 4)
            offsets = offsets.long()

            block_idxs = non_zero_idxs + offsets
            #block_idxs += left_padding
            padded_mask = F.pad(mask, (left_padding, right_padding, left_padding, right_padding))
            padded_mask[block_idxs[:, 0], block_idxs[:, 1], block_idxs[:, 2], block_idxs[:, 3]] = 1.
        else:
            padded_mask = F.pad(mask, (left_padding, right_padding, left_padding, right_padding))

        block_mask = 1 - padded_mask#[:height, :width]
        return block_mask

def conv3x3(in_planes, out_planes, stride=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)


class ResBasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None, drop_rate=0.0, drop_block=False, block_size=1):
        super(ResBasicBlock, self).__init__()
        self.conv1 = conv3x3(inplanes, planes)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.LeakyReLU(0.1)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = conv3x3(planes, planes)
        self.bn3 = nn.BatchNorm2d(planes)
        self.maxpool = nn.MaxPool2d(stride)
        self.downsample = downsample
        self.stride = stride
        self.drop_rate = drop_rate
        self.num_batches_tracked = 0
        self.drop_block = drop_block
        self.block_size = block_size
        self.DropBlock = DropBlock(block_size=self.block_size)

    def forward(self, x):
        self.num_batches_tracked += 1

        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            residual = self.downsample(x)
        out += residual
        out = self.relu(out)
        out = self.maxpool(out)

        if self.drop_rate > 0:
            if self.drop_block == True:
                feat_size = out.size()[2]
                keep_rate = max(1.0 - self.drop_rate / (20*2000) * (self.num_batches_tracked), 1.0 - self.drop_rate)
                # gamma = (1 - keep_rate) / self.block_size**2 * feat_size**2 / (feat_size - self.block_size + 1)**2
                # out = self.DropBlock(out, gamma=gamma)
                if (feat_size - self.block_size + 1) > 0:
                  gamma = (1 - keep_rate) / self.block_size**2 * feat_size**2 / (feat_size - self.block_size + 1)**2
                  out = self.DropBlock(out, gamma=gamma)
                else:
                    # Skip DropBlock or handle differently
                    gamma = 0
            else:
                out = F.dropout(out, p=self.drop_rate, training=self.training, inplace=True)

        return out


class ResNet(nn.Module):

    def __init__(self, block=ResBasicBlock, keep_prob=1.0, avg_pool=True, dropout=0.1, dropblock_size=5):
        self.inplanes = 3
        drop_rate = dropout
        super(ResNet, self).__init__()

        self.layer1 = self._make_layer(block, 64, stride=2, drop_rate=drop_rate)
        self.layer2 = self._make_layer(block, 160, stride=2, drop_rate=drop_rate)
        self.layer3 = self._make_layer(block, 320, stride=2, drop_rate=drop_rate, drop_block=True, block_size=dropblock_size)
        self.layer4 = self._make_layer(block, 640, stride=2, drop_rate=drop_rate, drop_block=True, block_size=dropblock_size)
        if avg_pool:
            self.avgpool = nn.AvgPool2d(5, stride=1)
        self.keep_prob = keep_prob
        self.keep_avg_pool = avg_pool
        self.dropout = nn.Dropout(p=1 - self.keep_prob, inplace=False)
        self.drop_rate = drop_rate

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def _make_layer(self, block, planes, stride=1, drop_rate=0.0, drop_block=False, block_size=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=1, bias=False),
                nn.BatchNorm2d(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample, drop_rate, drop_block, block_size))
        self.inplanes = planes * block.expansion

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        if self.keep_avg_pool:
            x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        return x


def Res12(keep_prob=1.0, avg_pool=False, **kwargs):
    """Constructs a ResNet-12 model.
    """
    model = ResNet(ResBasicBlock, keep_prob=keep_prob, avg_pool=avg_pool, **kwargs)
    return model

In [9]:
### dropout has been removed in this code. original code had dropout#####
import torch
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F
from torch.autograd import Variable

import sys, os
import numpy as np
import random

act = torch.nn.ReLU()

import math
from torch.nn.utils.weight_norm import WeightNorm


class WRNBasicBlock(nn.Module):
    def __init__(self, in_planes, out_planes, stride, dropRate=0.0):
        super(WRNBasicBlock, self).__init__()
        self.bn1 = nn.BatchNorm2d(in_planes)
        self.relu1 = nn.ReLU(inplace=True)
        self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                               padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_planes)
        self.relu2 = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1,
                               padding=1, bias=False)
        self.droprate = dropRate
        self.equalInOut = (in_planes == out_planes)
        self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride,
                                                                padding=0, bias=False) or None

    def forward(self, x):
        if not self.equalInOut:
            x = self.relu1(self.bn1(x))
        else:
            out = self.relu1(self.bn1(x))
        out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x)))
        if self.droprate > 0:
            out = F.dropout(out, p=self.droprate, training=self.training)
        out = self.conv2(out)
        return torch.add(x if self.equalInOut else self.convShortcut(x), out)


class distLinear(nn.Module):
    def __init__(self, indim, outdim):
        super(distLinear, self).__init__()
        self.L = nn.Linear(indim, outdim, bias=False)
        self.class_wise_learnable_norm = True  # See the issue#4&8 in the github
        if self.class_wise_learnable_norm:
            WeightNorm.apply(self.L, 'weight', dim=0)  # split the weight update component to direction and norm

        if outdim <= 200:
            self.scale_factor = 2  # a fixed scale factor to scale the output of cos value into a reasonably large input for softmax
        else:
            self.scale_factor = 10  # in omniglot, a larger scale factor is required to handle >1000 output classes.

    def forward(self, x):
        x_norm = torch.norm(x, p=2, dim=1).unsqueeze(1).expand_as(x)
        x_normalized = x.div(x_norm + 0.00001)
        if not self.class_wise_learnable_norm:
            L_norm = torch.norm(self.L.weight.data, p=2, dim=1).unsqueeze(1).expand_as(self.L.weight.data)
            self.L.weight.data = self.L.weight.data.div(L_norm + 0.00001)
        cos_dist = self.L(
            x_normalized)  # matrix product by forward function, but when using WeightNorm, this also multiply the cosine distance by a class-wise learnable norm, see the issue#4&8 in the github
        scores = self.scale_factor * (cos_dist)

        return scores


class NetworkBlock(nn.Module):
    def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0):
        super(NetworkBlock, self).__init__()
        self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate)

    def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropRate):
        layers = []
        for i in range(int(nb_layers)):
            layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropRate))
        return nn.Sequential(*layers)

    def forward(self, x):
        return self.layer(x)


def to_one_hot(inp, num_classes):
    y_onehot = torch.FloatTensor(inp.size(0), num_classes)
    if torch.cuda.is_available():
        y_onehot = y_onehot.cuda()

    y_onehot.zero_()
    x = inp.type(torch.LongTensor)
    if torch.cuda.is_available():
        x = x.cuda()

    x = torch.unsqueeze(x, 1)
    y_onehot.scatter_(1, x, 1)

    return Variable(y_onehot, requires_grad=False)
    # return y_onehot


def mixup_data(x, y, lam):
    '''Compute the mixup data. Return mixed inputs, pairs of targets, and lambda'''

    batch_size = x.size()[0]
    index = torch.randperm(batch_size)
    if torch.cuda.is_available():
        index = index.cuda()
    mixed_x = lam * x + (1 - lam) * x[index, :]
    y_a, y_b = y, y[index]

    return mixed_x, y_a, y_b, lam


class WideResNet(nn.Module):
    def __init__(self, depth=28, widen_factor=10, num_classes=200, loss_type='dist', per_img_std=False, stride=1,
                 dropRate=0.5):
        flatten = True
        super(WideResNet, self).__init__()
        nChannels = [16, 16 * widen_factor, 32 * widen_factor, 64 * widen_factor]
        assert ((depth - 4) % 6 == 0)
        n = (depth - 4) / 6
        block = WRNBasicBlock
        # 1st conv before any network block
        self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1,
                               padding=1, bias=False)
        # 1st block
        self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, stride, dropRate)
        # 2nd block
        self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate)
        # 3rd block
        self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate)
        # global average pooling and linear
        self.bn1 = nn.BatchNorm2d(nChannels[3])
        self.relu = nn.ReLU(inplace=True)
        self.nChannels = nChannels[3]

        self.num_classes = num_classes
        if flatten:
            self.final_feat_dim = 640
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def forward(self, x, target=None, mixup=False, mixup_hidden=True, mixup_alpha=None, lam=0.4, return_logits=True):
        if target is not None:
            if mixup_hidden:
                layer_mix = random.randint(0, 3)
            elif mixup:
                layer_mix = 0
            else:
                layer_mix = None

            out = x

            target_a = target_b = target

            if layer_mix == 0:
                out, target_a, target_b, lam = mixup_data(out, target, lam=lam)

            out = self.conv1(out)
            out = self.block1(out)

            if layer_mix == 1:
                out, target_a, target_b, lam = mixup_data(out, target, lam=lam)

            out = self.block2(out)

            if layer_mix == 2:
                out, target_a, target_b, lam = mixup_data(out, target, lam=lam)

            out = self.block3(out)
            if layer_mix == 3:
                out, target_a, target_b, lam = mixup_data(out, target, lam=lam)

            out = self.relu(self.bn1(out))
            out = F.avg_pool2d(out, out.size()[2:])
            out = out.view(out.size(0), -1)
            if not return_logits:
                return out, target_a, target_b

            out1 = self.linear(out)
            return out, out1, target_a, target_b
        else:
            out = x
            out = self.conv1(out)
            out = self.block1(out)
            out = self.block2(out)
            out = self.block3(out)
            out = self.relu(self.bn1(out))
            out = F.avg_pool2d(out, out.size()[2:])
            out = out.view(out.size(0), -1)
            # if not return_logits:
            return out

            # out1 = self.linear(out)
            # return out, out1


def wrn28_10(num_classes=200, loss_type='dist', dropout=0):
    model = WideResNet(depth=28, widen_factor=10, num_classes=num_classes, loss_type=loss_type, per_img_std=False,
                       stride=1, dropRate=dropout)
    return model

In [10]:
models = dict(wrn=wrn28_10, resnet12=Res12)

def get_model(model_name: str, img_size, temperature=0.1, dropout = 0, ):
    """
    Get the backbone model.
    """
    arch = model_name.lower()
    if arch in models.keys():
        if 'vit' in arch:
            model = models[arch](img_size=img_size, patch_size=16)
        else:
            model = models[arch](dropout=dropout)

        if torch.cuda.is_available():
            torch.backends.cudnn.benchmark = True

        return model
    else:
        raise ValueError(f'Model {model_name} not implemented. available models are: {list(models.keys())}')


 # Self Optimal Transport

In [12]:
def log_sum_exp(u: torch.Tensor, dim: int):
    # Reduce log sum exp along axis
    u_max, __ = u.max(dim=dim, keepdim=True)
    log_sum_exp_u = torch.log(torch.exp(u - u_max).sum(dim)) + u_max.sum(dim)
    return log_sum_exp_u


def log_sinkhorn(M: torch.Tensor, reg: float, num_iters: int):
    """
    Log-space-sinkhorn algorithm for better stability.
    """
    if M.dim() > 2:
        return batched_log_sinkhorn(M=M, reg=reg, num_iters=num_iters)

    # Initialize dual variable v (u is implicitly defined in the loop)
    log_v = torch.zeros(M.size()[1]).to(M.device)  # ==torch.log(torch.ones(m.size()[1]))

    # Exponentiate the pairwise distance matrix
    log_K = -M / reg

    # Main loop
    for i in range(num_iters):
        # Match r marginals
        log_u = - log_sum_exp(log_K + log_v[None, :], dim=1)

        # Match c marginals
        log_v = - log_sum_exp(log_u[:, None] + log_K, dim=0)

    # Compute optimal plan, cost, return everything
    log_P = log_u[:, None] + log_K + log_v[None, :]
    return log_P


def batched_log_sinkhorn(M, reg: float, num_iters: int):
    """
    Batched version of log-space-sinkhorn.
    """
    batch_size, x_points, _ = M.shape
    # both marginals are fixed with equal weights
    mu = torch.empty(batch_size, x_points, dtype=torch.float,
                     requires_grad=False).fill_(1.0 / x_points).squeeze().to(M.device)
    nu = torch.empty(batch_size, x_points, dtype=torch.float,
                     requires_grad=False).fill_(1.0 / x_points).squeeze().to(M.device)

    u = torch.zeros_like(mu)
    v = torch.zeros_like(nu)
    # To check if algorithm terminates because of threshold
    # or max iterations reached
    actual_nits = 0
    # Stopping criterion
    thresh = 1e-1

    def C(M, u, v, reg):
        """Modified cost for logarithmic updates"""
        return (-M + u.unsqueeze(-1) + v.unsqueeze(-2)) / reg

    # Sinkhorn iterations
    for i in range(num_iters):
        u1 = u  # useful to check the update
        u = reg * (torch.log(mu + 1e-8) - torch.logsumexp(C(M, u, v, reg), dim=-1)) + u
        v = reg * (torch.log(nu + 1e-8) - torch.logsumexp(C(M, u, v, reg).transpose(-2, -1), dim=-1)) + v
        err = (u - u1).abs().sum(-1).mean()

        actual_nits += 1
        if err.item() < thresh:
            break

    U, V = u, v
    # Transport plan pi = diag(a)*K*diag(b)
    log_p = C(M, U, V, reg)
    return log_p


class SOT(object):
    supported_distances = ['cosine', 'euclidean']

    def __init__(self, distance_metric: str = 'cosine', ot_reg: float = 0.1, sinkhorn_iterations: int = 10,
                 sigmoid: bool = False, mask_diag: bool = True, max_scale: bool = True):
        """
        :param distance_metric - Compute the cost matrix.
        :param ot_reg - Sinkhorn entropy regularization (lambda). For few-shot classification, 0.1-0.2 works best.
        :param sinkhorn_iterations - Maximum number of sinkhorn iterations.
        :param sigmoid - If to apply sigmoid(log_p) instead of the usual exp(log_p).
        :param mask_diag - Set to true to apply diagonal masking before and after the OT.
        :param max_scale - Re-scale the SOT values to range [0,1].
        """
        super().__init__()

        assert distance_metric.lower() in SOT.supported_distances and sinkhorn_iterations > 0

        self.sinkhorn_iterations = sinkhorn_iterations
        self.distance_metric = distance_metric.lower()
        self.mask_diag = mask_diag
        self.sigmoid = sigmoid
        self.ot_reg = ot_reg
        self.max_scale = max_scale
        self.diagonal_val = 1e3                         # value to mask self-values with

    def compute_cost(self, X: torch.Tensor) -> torch.Tensor:
        """
        Compute cost matrix.
        """
        if self.distance_metric == 'euclidean':
            M = torch.cdist(X, X, p=2)
            # scale euclidean distances to [0, 1]
            return M / M.max()

        elif self.distance_metric == 'cosine':
            # cosine distance
            return 1 - SOT.cosine_similarity(X)

    def mask_diagonal(self, M: torch.Tensor, value: float):
        """
        Set new value at a diagonal matrix.
        """
        if self.mask_diag:
            if M.dim() > 2:
                M[torch.eye(M.shape[1]).repeat(M.shape[0], 1, 1).bool()] = value
            else:
                M.fill_diagonal_(value)
        return M

    def __call__(self, X: torch.Tensor) -> torch.Tensor:
        """
        Compute the SOT features for X
        """
        # get masked cost matrix
        C = self.compute_cost(X=X)
        M = self.mask_diagonal(C, value=self.diagonal_val)

        # compute self-OT
        z_log = log_sinkhorn(M=M, reg=self.ot_reg, num_iters=self.sinkhorn_iterations)

        if self.sigmoid:
            z = torch.sigmoid(z_log)
        else:
            z = torch.exp(z_log)

        # divide the SOT matrix by its max to scale it up
        if self.max_scale:
            z_max = z.max().item() if z.dim() <= 2 else z.amax(dim=(1, 2), keepdim=True)
            z = z / z_max

        # set self-values to 1
        return self.mask_diagonal(z, value=1)

    @staticmethod
    def cosine_similarity(a: torch.Tensor, eps: float = 1e-8):
        """
        Compute the pairwise cosine similarity between a matrix to itself.
        """
        d_n = a / a.norm(dim=-1, keepdim=True)
        if len(a.shape) > 2:
            C = torch.bmm(d_n, d_n.transpose(1, 2))
        else:
            C = torch.mm(d_n, d_n.transpose(0, 1))
        return C

# Method

In [13]:
class ProtoLoss(nn.Module):
    def __init__(self
               , train_way
               , val_way
               , num_shot
               , num_query
               , temperature = 0.1
               , sot: SOT = None):

        super().__init__()
        self.way_dict = dict(train=train_way, val=val_way)
        self.num_shot = num_shot
        self.num_query = num_query
        self.temperature = temperature
        self.SOT = sot  # if sot is None no sot will be applied
        self.num_labeled = None

    @staticmethod
    def get_accuracy(probas: torch.Tensor, labels: torch.Tensor):
        y_hat = probas.argmin(dim=-1)
        matches = labels.eq(y_hat).float()
        m = matches.mean().item()
        # pm = matches.std(unbiased=False).item() * 1.96
        return m

    def forward(self, X: torch.Tensor, labels: torch.Tensor, mode: str):
        num_way = self.way_dict[mode]
        self.num_labeled = num_way * self.num_shot
        if self.SOT is not None:
            # compute the SOT matrix
            X = self.SOT(X)

        X_s, X_q = X[:self.num_labeled], X[self.num_labeled:]
        # data is sorted as [0, 1, 2, 3, 4, 0, 1, 2, 3, 4, ...]
        # compute centroids
        X_c = X_s.reshape(self.num_shot, num_way, -1).transpose(0, 1).mean(dim=1)

        # compute distances between queries and the centroids
        D = torch.cdist(X_q, X_c) / self.temperature

        return -D, ProtoLoss.get_accuracy(D, labels)


In [14]:
import torch
from torch import nn
import torch.nn.functional as F

"""
Implementation of PT-MAP as a differential module.
Original code in https://github.com/yhu01/PT-MAP
"""


def centerDatas(X: torch.Tensor, n_lsamples: int):
    """
    Center labeled and un-labeled data separately.
    """
    X[:n_lsamples, :] = X[:n_lsamples, :] - X[:n_lsamples, :].mean(0, keepdim=True)
    X[n_lsamples:, :] = X[n_lsamples:, :] - X[n_lsamples:, :].mean(0, keepdim=True)
    return X


# ---------  GaussianModel

class GaussianModel:
    def __init__(self, num_way: int, num_shot: int, num_query: int, lam: float):
        self.num_way = num_way
        self.num_shot = num_shot
        self.num_query = num_query
        self.n_lsamples = num_way * num_shot
        self.n_usamples = num_way * num_query
        self.lam = lam
        self.mus = None  # shape [n_ways][feat_dim]

    def cuda(self):
        self.mus = self.mus.cuda()

    def init_from_labelled(self, X: torch.Tensor):
        self.mus = X.reshape(self.num_shot + self.num_query, self.num_way, -1)[:self.num_shot, ].mean(0)

    def update_from_estimate(self, estimate, alpha):
        Dmus = estimate - self.mus
        self.mus = self.mus + alpha * Dmus

    def compute_optimal_transport(self, M: torch.Tensor, r: torch.Tensor, c: torch.Tensor, epsilon: float = 1e-6):
        n_runs, n, m = M.shape
        P = torch.exp(-self.lam * M)
        P = P / P.view((n_runs, -1)).sum(1).unsqueeze(1).unsqueeze(1)
        u = torch.zeros((n_runs, n), device='cuda')
        maxiters = 1000
        iters = 1
        # normalize this matrix
        while torch.max(torch.abs(u - P.sum(-1))) > epsilon:
            u = P.sum(dim=-1)
            P *= (r / u).view((n_runs, -1, 1))
            P *= (c / P.sum(1)).view((n_runs, 1, -1))
            if iters == maxiters:
                break
            iters += 1

        if n_runs == 1:
            return P.squeeze(0)
        return P

    def get_probas(self, X: torch.Tensor, labels: torch.Tensor):
        dist = torch.cdist(X, self.mus)
        p_xj = torch.zeros_like(dist)
        r = torch.ones(1, self.num_way * self.num_query, device='cuda')
        c = torch.ones(1, self.num_way, device='cuda') * self.num_query
        p_xj_test = self.compute_optimal_transport(dist.unsqueeze(0)[:, self.n_lsamples:], r, c, epsilon=1e-6)
        p_xj[self.n_lsamples:] = p_xj_test

        p_xj[:self.n_lsamples].fill_(0)
        p_xj[:self.n_lsamples].scatter_(1, labels[:self.n_lsamples].unsqueeze(1), 1)
        return p_xj

    def estimate_from_mask(self, X: torch.Tensor, mask: torch.Tensor):
        emus = mask.T.matmul(X).div(mask.sum(dim=0).unsqueeze(1))
        return emus


# =========================================
#    MAP
# =========================================

class MAP:
    def __init__(self, labels, alpha: float, num_labeled: int, n_runs=1):
        self.alpha = alpha
        self.num_labeled = num_labeled
        self.s_labels = labels[:self.num_labeled]
        self.q_labels = labels[self.num_labeled:]
        self.n_runs = n_runs

    def get_accuracy(self, probas: torch.Tensor):
        y_hat = probas[self.num_labeled:].argmax(dim=-1)
        matches = self.q_labels.eq(y_hat).float()
        m = matches.mean().item()
        pm = matches.std(unbiased=False).item() * 1.96
        return m, pm

    def perform_epoch(self, model: GaussianModel, X: torch.Tensor):
        p_xj = model.get_probas(X=X, labels=self.s_labels)
        m_estimates = model.estimate_from_mask(X=X, mask=p_xj)
        # update centroids
        model.update_from_estimate(m_estimates, self.alpha)

    def loop(self, X: torch.Tensor, model: GaussianModel, n_epochs: int = 20):
        for epoch in range(1, n_epochs + 1):
            self.perform_epoch(model=model, X=X)
        # get final accuracy and return it
        P = model.get_probas(X=X, labels=self.s_labels)
        return P


class PTMAPLoss(nn.Module):
    def __init__(self
                , train_way
                , val_way
                , num_shot
                , num_query
                , temperature = 0.1
                , lam: float = 10, alpha: float = 0.2, n_epochs: int = 20, sot=None):

        super().__init__()
        self.way_dict = dict(train=train_way, val=val_way)
        self.num_shot = num_shot
        self.num_query = num_query
        self.lam = lam
        self.alpha = alpha
        self.n_epochs = n_epochs
        self.num_labeled = None
        self.SOT = sot  # if sot is None no sot will be applied

    def scale(self, X: torch.Tensor, mode: str):
        # normalize, center and normalize again
        if mode != 'train':
            X = F.normalize(X, p=2, dim=-1)
            X = centerDatas(X, self.num_labeled)

        X = F.normalize(X, p=2, dim=-1)
        return X

    def forward(self, X: torch.Tensor, labels: torch.Tensor, mode: str):
        num_way = self.way_dict[mode]
        self.num_labeled = num_way * self.num_shot

        # power transform (PT part) and scaling
        assert X.min() >= 0, "Error: To use PT-MAP you need to apply another ReLU on the output features (or use WRN)."
        X = torch.pow(X + 1e-6, 0.5)
        Z = self.scale(X=X, mode=mode)

        # applying SOT or continue with regular pt-map
        if self.SOT is not None:
            Z = self.SOT(X=Z)

        # MAP
        gaussian_model = GaussianModel(num_way=num_way, num_shot=self.num_shot, num_query=self.num_query, lam=self.lam)
        gaussian_model.init_from_labelled(X=Z)

        optim = MAP(labels=labels, alpha=self.alpha, num_labeled=self.num_labeled)
        P = optim.loop(X=Z, model=gaussian_model, n_epochs=self.n_epochs)
        accuracy, std = optim.get_accuracy(probas=P)

        return torch.log(P[self.num_labeled:] + 1e-5), accuracy

In [15]:
methods = dict(pt_map=PTMAPLoss, pt_map_sot=PTMAPLoss, proto=ProtoLoss, proto_sot=ProtoLoss, )

def get_method(method:str
               , sot: SOT
               , train_way
               , val_way
               , num_shot
               , num_query
               , temperature = 0.1
               ):
    """
    Get the few-shot classification method (e.g. pt_map).
    """

    if method.lower() in methods.keys():
        return methods[method.lower()](
                train_way = train_way
               , val_way = val_way
               , num_shot = num_shot
                , num_query = num_query
               , temperature = temperature
                , sot=sot)
    else:
        raise ValueError(f'Not implemented method. available methods are: {methods.keys()}')


# Train

In [16]:
from time import time
try:
    import wandb
    HAS_WANDB = True
except:
    HAS_WANDB = False

In [17]:
def get_fs_labels(method: str, num_way: int, num_query: int, num_shot: int):
    """
    Prepare few-shot labels. For example for 5-way, 1-shot, 2-query: [0, 1, 2, 3, 4, 0, 1, 2, 3, 4, ...]
    """
    n_samples = num_shot + num_query if 'map' in method else num_query
    labels = torch.arange(num_way, dtype=torch.int16).repeat(n_samples).type(torch.LongTensor)

    if torch.cuda.is_available():
        return labels.cuda()
    else:
        return labels

def log_stepper(results: dict, logger= None):
    """
    Log step to the logger without print.
    """
    if logger is not None:
        logger.log(results)

    for key, value in results.items():
        if 'acc' in key:
            print(f"{key}: {100 * value:.2f}%")
        else:
            print(f"{key}: {value:.4f}")

def print_and_log(results: dict, n: int = 0, logger = None):
    """
    Print and log current results.
    """
    for key in results.keys():
        # average by n if needed (n > 0)
        if n > 0 and 'time' not in key and '/epoch' not in key:
            results[key] = results[key] / n

        # print and log
        print(f'{key}: {results[key]:.4f}')

    if logger is not None:
        logger.log(results)


def train_one_epoch(model, loader, optimizer, method, criterion, labels, log_step, epoch, device = 'cuda', logger=None):
    model.train()
    results = {'train/accuracy': 0, 'train/loss': 0}
    # start = time.time()
    for batch_idx, batch in enumerate(loader):
        images = batch[0].to(device)
        features = model(images)
        # apply few_shot method
        probas, accuracy = method(features, labels=labels, mode='train')
        q_labels = labels if len(labels) == len(probas) else labels[-len(probas):]

        loss = criterion(probas, q_labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        results["train/loss"] += loss.item()
        results["train/accuracy"] += accuracy

        if log_step and (batch_idx + 1) % 50 == 0:
            step = batch_idx+((epoch-1) * len(loader))
            log_stepper(
                results={'train/loss_step': loss.item(), 'train/accuracy_step': accuracy, 'train/train_step': step},
                logger=logger
            )


    # results["train/time"] = time() - start
    results["train/epoch"] = epoch
    print_and_log(results=results, n=len(loader), logger=logger)
    return results


@torch.no_grad()
def eval_one_epoch(model, loader, method, criterion, labels, epoch, logger=None, set_name='val', device = 'cuda'):
    model.eval()
    results = {f'{set_name}/accuracy': 0, f'{set_name}/loss': 0}

    for batch_idx, batch in enumerate(loader):
        images = batch[0].to(device)

        features = model(images)

        # apply few_shot method
        probas, accuracy = method(X=features, labels=labels, mode='val')
        q_labels = labels if len(labels) == len(probas) else labels[-len(probas):]

        loss = criterion(probas, q_labels)

        results[f"{set_name}/loss"] += loss.item()
        results[f"{set_name}/accuracy"] += accuracy


        if batch_idx % 50 == 0:
            step = batch_idx+((epoch-1) * len(loader))
            print(f"Batch {batch_idx + 1}/{len(loader)}: ")
            log_stepper(
                results={f'{set_name}/loss_step': loss.item(), f'{set_name}/accuracy_step': accuracy,
                         f'{set_name}/{set_name}_step': step},
                logger=logger
            )


    results["val/epoch"] = epoch
    print_and_log(results=results, n=len(loader), logger=logger)
    return results


In [18]:
learning_rate = 0.0002
gamma =  0.5
num_way = 2
num_query_train = 20
num_query_test = 20
num_shot = 5
device = "cuda"

model = get_model('resnet12', 32)
model.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=40)
sot = SOT()
method = get_method("proto_sot", sot, num_way, num_way, num_shot, num_query_train)

 # few-shot labels
train_labels = get_fs_labels(method="proto_sot" ,num_way = num_way , num_query=num_query_train, num_shot=num_shot)
val_labels = get_fs_labels(method="proto_sot" ,num_way = num_way , num_query=num_query_test, num_shot=num_shot)
criterion = torch.nn.CrossEntropyLoss()

main loop


In [None]:
print("Start training...")
best_loss = 1000
best_acc = 0
max_epochs = 20
eval_freq = 5

for epoch in range(1, max_epochs + 1):
    print(f"Epoch {epoch}/{max_epochs}: ")
    # train
    train_one_epoch(model, train_loader, optimizer, method, criterion, train_labels, log_step=False , epoch=epoch, device = device)
    print('------------------------------------------------------------------------------------')
    if scheduler is not None:
        scheduler.step()


print('***********************************************************************************')
result = eval_one_epoch(model, val_loader, method, criterion, val_labels, epoch = epoch, device = device)
print('***********************************************************************************')

Start training...
Epoch 1/20: 
train/accuracy: 0.5000
train/loss: 0.7970
train/epoch: 1.0000
------------------------------------------------------------------------------------
Epoch 2/20: 
train/accuracy: 0.8000
train/loss: 0.5512
train/epoch: 2.0000
------------------------------------------------------------------------------------
Epoch 3/20: 
train/accuracy: 0.7500
train/loss: 0.7373
train/epoch: 3.0000
------------------------------------------------------------------------------------
Epoch 4/20: 
train/accuracy: 0.7000
train/loss: 0.4940
train/epoch: 4.0000
------------------------------------------------------------------------------------
Epoch 5/20: 
train/accuracy: 0.5500
train/loss: 0.7236
train/epoch: 5.0000
------------------------------------------------------------------------------------
Epoch 6/20: 
train/accuracy: 0.8000
train/loss: 0.4489
train/epoch: 6.0000
------------------------------------------------------------------------------------
Epoch 7/20: 
train/acc

# Testbed

ResNet + ProtoNet + SOT

In [None]:
import time

times = []
accs = []
max_epochs = 100


for seed in range(1, 5):
  train_loader = get_dataloader(set_name='train'
                              , num_episodes=1
                              , train_way = 2
                              , val_way = 2
                              , num_shot = 5
                              , num_query =20
                              , data_path = dataset_path
                              ,dataset = "cifar"
                              , backbone = 'resnet12'
                              , constant = False
                              ,augment = False)

  val_loader = get_dataloader(set_name='val'
                                , num_episodes=8
                                , train_way = 2
                                , val_way = 2
                                , num_shot = 5
                                , num_query =20
                                , data_path = dataset_path
                                , dataset = "cifar"
                                , backbone = 'resnet12'
                                , constant = True
                                ,augment = False)

  model = get_model('resnet12', 32)
  model.to(device)
  optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
  scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=40)
  sot = SOT()
  method = get_method("proto_sot", sot, num_way, num_way, num_shot, num_query_train)

  # few-shot labels
  train_labels = get_fs_labels(method="proto_sot" ,num_way = num_way , num_query=num_query_train, num_shot=num_shot)
  val_labels = get_fs_labels(method="proto_sot" ,num_way = num_way , num_query=num_query_test, num_shot=num_shot)
  criterion = torch.nn.CrossEntropyLoss()

  start_time = time.time()
  for epoch in range(max_epochs):
    print(f"Epoch {epoch}/{max_epochs}: ")
    # train
    train_one_epoch(model, train_loader, optimizer, method, criterion, train_labels, log_step=False , epoch=epoch, device = device)
    print('------------------------------------------------------------------------------------')
    if scheduler is not None:
        scheduler.step()
  end_time = time.time()
  times.append(end_time - start_time)

  print('***********************************************************************************')
  result = eval_one_epoch(model, val_loader, method, criterion, val_labels, epoch = epoch, device = device)
  print('***********************************************************************************')
  accs.append(result['val/accuracy'])

times = np.array(times)
accs = np.array(accs)
print('Acc over 5 instances: %.2f +- %.2f'%(accs.mean(),accs.std()))
print(f"Average Time over 5 instances: {times.mean()} +-{times.std()}")


Epoch 0/100: 
train/accuracy: 0.7000
train/loss: 0.7659
train/epoch: 0.0000
------------------------------------------------------------------------------------
Epoch 1/100: 
train/accuracy: 0.7750
train/loss: 0.3665
train/epoch: 1.0000
------------------------------------------------------------------------------------
Epoch 2/100: 
train/accuracy: 0.5250
train/loss: 0.9087
train/epoch: 2.0000
------------------------------------------------------------------------------------
Epoch 3/100: 
train/accuracy: 0.5250
train/loss: 0.8493
train/epoch: 3.0000
------------------------------------------------------------------------------------
Epoch 4/100: 
train/accuracy: 0.5750
train/loss: 0.6838
train/epoch: 4.0000
------------------------------------------------------------------------------------
Epoch 5/100: 
train/accuracy: 0.7750
train/loss: 0.4782
train/epoch: 5.0000
------------------------------------------------------------------------------------
Epoch 6/100: 
train/accuracy: 0.90

WRN + PTMap + SOT

In [None]:
accs = []
times = []
max_epochs = 100


for seed in range(1, 5):
  train_loader = get_dataloader(set_name='train'
                              , num_episodes=1
                              , train_way = 2
                              , val_way = 2
                              , num_shot = 5
                              , num_query =20
                              , data_path = dataset_path
                              ,dataset = "cifar"
                              , backbone = 'wrn'
                              , constant = False
                              ,augment = False)

  val_loader = get_dataloader(set_name='val'
                                , num_episodes=8
                                , train_way = 2
                                , val_way = 2
                                , num_shot = 5
                                , num_query =20
                                , data_path = dataset_path
                                , dataset = "cifar"
                                , backbone = 'wrn'
                                , constant = True
                                ,augment = False)

  model = get_model('wrn', 32)
  model.to(device)
  optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
  scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=40)
  sot = SOT()
  method = get_method("pt_map", sot, num_way, num_way, num_shot, num_query_train)

  # few-shot labels
  train_labels = get_fs_labels(method="pt_map" ,num_way = num_way , num_query=num_query_train, num_shot=num_shot)
  val_labels = get_fs_labels(method="pt_map" ,num_way = num_way , num_query=num_query_test, num_shot=num_shot)
  criterion = torch.nn.CrossEntropyLoss()

  start_time = time.time()
  for epoch in range(max_epochs):
    print(f"Epoch {epoch}/{max_epochs}: ")
    # train
    train_one_epoch(model, train_loader, optimizer, method, criterion, train_labels, log_step=False , epoch=epoch, device = device)
    print('------------------------------------------------------------------------------------')
    if scheduler is not None:
        scheduler.step()

  end_time = time.time()
  times.append(end_time - start_time)

  print('***********************************************************************************')
  result = eval_one_epoch(model, val_loader, method, criterion, val_labels, epoch = epoch, device = device)
  print('***********************************************************************************')
  accs.append(result['val/accuracy'])

times = np.array(times)
accs = np.array(accs)
print('Acc over 5 instances: %.2f +- %.2f'%(accs.mean(),accs.std()))
print(f"Average Time over 5 instances: {times.mean()} +-{times.std()}")


Epoch 0/100: 
train/accuracy: 0.5250
train/loss: 2.4627
train/epoch: 0.0000
------------------------------------------------------------------------------------
Epoch 1/100: 
train/accuracy: 0.5000
train/loss: 0.7139
train/epoch: 1.0000
------------------------------------------------------------------------------------
Epoch 2/100: 
train/accuracy: 0.4750
train/loss: 0.7076
train/epoch: 2.0000
------------------------------------------------------------------------------------
Epoch 3/100: 
train/accuracy: 0.3500
train/loss: 0.7573
train/epoch: 3.0000
------------------------------------------------------------------------------------
Epoch 4/100: 
train/accuracy: 0.8000
train/loss: 0.6360
train/epoch: 4.0000
------------------------------------------------------------------------------------
Epoch 5/100: 
train/accuracy: 0.5750
train/loss: 0.6763
train/epoch: 5.0000
------------------------------------------------------------------------------------
Epoch 6/100: 
train/accuracy: 0.72

WRN + Proto + SOT

In [None]:
accs = []
times = []
max_epochs = 100


for seed in range(1, 5):
  train_loader = get_dataloader(set_name='train'
                              , num_episodes=1
                              , train_way = 2
                              , val_way = 2
                              , num_shot = 5
                              , num_query =20
                              , data_path = dataset_path
                              ,dataset = "cifar"
                              , backbone = 'wrn'
                              , constant = False
                              ,augment = False)

  val_loader = get_dataloader(set_name='val'
                                , num_episodes=8
                                , train_way = 2
                                , val_way = 2
                                , num_shot = 5
                                , num_query =20
                                , data_path = dataset_path
                                , dataset = "cifar"
                                , backbone = 'wrn'
                                , constant = True
                                ,augment = False)

  model = get_model('wrn', 32)
  model.to(device)
  optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
  scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=40)
  sot = SOT()
  method = get_method("proto_sot", sot, num_way, num_way, num_shot, num_query_train)

  # few-shot labels
  train_labels = get_fs_labels(method="proto_sot" ,num_way = num_way , num_query=num_query_train, num_shot=num_shot)
  val_labels = get_fs_labels(method="proto_sot" ,num_way = num_way , num_query=num_query_test, num_shot=num_shot)
  criterion = torch.nn.CrossEntropyLoss()

  start_time = time.time()
  for epoch in range(max_epochs):
    print(f"Epoch {epoch}/{max_epochs}: ")
    # train
    train_one_epoch(model, train_loader, optimizer, method, criterion, train_labels, log_step=False , epoch=epoch, device = device)
    print('------------------------------------------------------------------------------------')
    if scheduler is not None:
        scheduler.step()

  end_time = time.time()
  times.append(end_time - start_time)

  print('***********************************************************************************')
  result = eval_one_epoch(model, val_loader, method, criterion, val_labels, epoch = epoch, device = device)
  print('***********************************************************************************')
  accs.append(result['val/accuracy'])

times = np.array(times)
accs = np.array(accs)
print('Acc over 5 instances: %.2f +- %.2f'%(accs.mean(),accs.std()))
print(f"Average Time over 5 instances: {times.mean()} +-{times.std()}")


Epoch 0/100: 
train/accuracy: 0.9000
train/loss: 1.3140
train/epoch: 0.0000
------------------------------------------------------------------------------------
Epoch 1/100: 
train/accuracy: 0.4500
train/loss: 1.5551
train/epoch: 1.0000
------------------------------------------------------------------------------------
Epoch 2/100: 
train/accuracy: 0.5750
train/loss: 0.7137
train/epoch: 2.0000
------------------------------------------------------------------------------------
Epoch 3/100: 
train/accuracy: 0.5250
train/loss: 0.7062
train/epoch: 3.0000
------------------------------------------------------------------------------------
Epoch 4/100: 
train/accuracy: 0.7250
train/loss: 0.6101
train/epoch: 4.0000
------------------------------------------------------------------------------------
Epoch 5/100: 
train/accuracy: 0.5000
train/loss: 1.0500
train/epoch: 5.0000
------------------------------------------------------------------------------------
Epoch 6/100: 
train/accuracy: 0.55

# Challenge 2

In [23]:
def load_weights(model: torch.nn.Module, pretrained_path: str):
    """
    Load pretrained weights from given path.
    """
    if not pretrained_path:
        return model

    print(f'Loading weights from {pretrained_path}')
    state_dict = torch.load(pretrained_path)
    sd_keys = list(state_dict.keys())
    if 'state' in sd_keys:
        state_dict = state_dict['state']
        for k in list(state_dict.keys()):
            if k.startswith('module.'):
                state_dict["{}".format(k[len('module.'):])] = state_dict[k]
                del state_dict[k]

        model.load_state_dict(state_dict, strict=False)

    elif 'params' in sd_keys:
        state_dict = state_dict['params']
        for k in list(state_dict.keys()):
            if k.startswith('encoder.'):
                state_dict["{}".format(k[len('encoder.'):])] = state_dict[k]

            del state_dict[k]

        model.load_state_dict(state_dict, strict=True)
    else:
        model.load_state_dict(state_dict)

    print("Weights loaded successfully ")
    return model

In [24]:
weights_path = './feat-5-shot.pth'
model = get_model('resnet12', 32)
model.to(device)
load_weights(model=model, pretrained_path=weights_path)

Loading weights from ./feat-5-shot.pth
Weights loaded successfully 


ResNet(
  (layer1): Sequential(
    (0): ResBasicBlock(
      (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): LeakyReLU(negative_slope=0.1)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn3): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (maxpool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (downsample): Sequential(
        (0): Conv2d(3, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (DropBlock): DropBlock()
    )
  )
  (

In [31]:
accs = []
max_epochs = 10


for seed in range(1, 5):
  train_loader = get_dataloader(set_name='train'
                              , num_episodes=1
                              , train_way = 2
                              , val_way = 2
                              , num_shot = 5
                              , num_query =20
                              , data_path = dataset_path
                              ,dataset = "cifar"
                              , backbone = 'resnet12'
                              , constant = False
                              ,augment = False)

  val_loader = get_dataloader(set_name='val'
                                , num_episodes=8
                                , train_way = 2
                                , val_way = 2
                                , num_shot = 5
                                , num_query =20
                                , data_path = dataset_path
                                , dataset = "cifar"
                                , backbone = 'resnet12'
                                , constant = True
                                ,augment = False)

  # model = get_model('resnet12', 32)
  model.to(device)
  optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
  scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=40)
  sot = SOT()
  method = get_method("proto_sot", sot, num_way, num_way, num_shot, num_query_train)

  # few-shot labels
  train_labels = get_fs_labels(method="proto_sot" ,num_way = num_way , num_query=num_query_train, num_shot=num_shot)
  val_labels = get_fs_labels(method="proto_sot" ,num_way = num_way , num_query=num_query_test, num_shot=num_shot)
  criterion = torch.nn.CrossEntropyLoss()

  for epoch in range(max_epochs):
    print(f"Epoch {epoch}/{max_epochs}: ")
    # train
    train_one_epoch(model, train_loader, optimizer, method, criterion, train_labels, log_step=False , epoch=epoch, device = device)
    print('------------------------------------------------------------------------------------')
    if scheduler is not None:
        scheduler.step()

  print('***********************************************************************************')
  result = eval_one_epoch(model, val_loader, method, criterion, val_labels, epoch = epoch, device = device)
  print('***********************************************************************************')
  accs.append(result['val/accuracy'])

accs = np.array(accs)
print('Acc over 5 instances: %.2f +- %.2f'%(accs.mean(),accs.std()))


Epoch 0/10: 
train/accuracy: 0.8000
train/loss: 0.4792
train/epoch: 0.0000
------------------------------------------------------------------------------------
Epoch 1/10: 
train/accuracy: 0.8000
train/loss: 0.4844
train/epoch: 1.0000
------------------------------------------------------------------------------------
Epoch 2/10: 
train/accuracy: 0.4750
train/loss: 0.7628
train/epoch: 2.0000
------------------------------------------------------------------------------------
Epoch 3/10: 
train/accuracy: 0.8000
train/loss: 0.3110
train/epoch: 3.0000
------------------------------------------------------------------------------------
Epoch 4/10: 
train/accuracy: 0.5750
train/loss: 0.6592
train/epoch: 4.0000
------------------------------------------------------------------------------------
Epoch 5/10: 
train/accuracy: 0.8750
train/loss: 0.2793
train/epoch: 5.0000
------------------------------------------------------------------------------------
Epoch 6/10: 
train/accuracy: 0.4250
trai