Self-Supervised Relational Reasoning
------------------------------------------------------

Official PyTorch implementation of the paper: 

```
"Sefl-Supervised Relational Reasoning for Representation Learning", Patacchiola, M., and Storkey, A., *Advances in Neural Information Processing Systems (NeurIPS)
```

In this notebook is presented an essential implementation of the method, which is modular and easy to extend. The code has been successfully tested on `Ubuntu 18.04 LTS` with `PyTorch 1.4`, `Torchvision 0.5`, and `PIL 7.0`, so we suggest a similar configuration. First of all we import all the necessary modules and print their versions:

In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
from PIL import Image
import math

print(torch.__version__)
print(torchvision.__version__)
print(Image.__version__)

We override the CIFAR10 and CIFAR100 classes in torchvision, with a version able to return **multiple augmentations** of the input mini-batch. This can be easily done by overriding the `__getitem__()` method. We add the parameter `K` which specifies the number of augmentations we want to apply to the input image. The output of `__getitem__()` is a list of lenght `K` containing different augmentations of the input instance, and the label of that image (this is discarded during the unsupervised training phase).

In [3]:
class MultiCIFAR10(torchvision.datasets.CIFAR10):
  """Override torchvision CIFAR10 for multi-image management.
  Similar class can be defined for other datasets (e.g. CIFAR100).
  Given K total augmentations, it returns a list of lenght K with
  different augmentations of the input mini-batch.
  """
  def __init__(self, K, **kwds):
    super().__init__(**kwds)
    self.K = K # tot number of augmentations
            
  def __getitem__(self, index):
    img, target = self.data[index], self.targets[index]
    pic = Image.fromarray(img)            
    img_list = list()
    if self.transform is not None:
      for _ in range(self.K):
        img_transformed = self.transform(pic.copy())
        img_list.append(img_transformed)
    else:
        img_list = img
    return img_list, target

In [4]:
class MultiCIFAR100(torchvision.datasets.CIFAR100):
  """Override torchvision CIFAR100 for multi-image management.
  Given K total augmentations, it returns a list of lenght K with
  different augmentations of the input mini-batch.
  """
  def __init__(self, K, **kwds):
    super().__init__(**kwds)
    self.K = K # tot number of augmentations
            
  def __getitem__(self, index):
    img, target = self.data[index], self.targets[index]
    pic = Image.fromarray(img)            
    img_list = list()
    if self.transform is not None:
      for _ in range(self.K):
        img_transformed = self.transform(pic.copy())
        img_list.append(img_transformed)
    else:
        img_list = img
    return img_list, target

Now we need a convolutional neural network (CNN) backbone, which is used as a preliminary stage for dimensionality reduction of the input images. Here, we define a simple 4-layers CNN, but any other network with a linear output layer can be used.

In [5]:
class Conv4(torch.nn.Module):
    """A simple 4 layers CNN.
    Used as backbone.    
    """
    def __init__(self):
        super(Conv4, self).__init__()
        self.feature_size = 64
        self.name = "conv4"

        self.layer1 = torch.nn.Sequential(
          torch.nn.Conv2d(3, 8, kernel_size=3, stride=1, padding=1, bias=False),
          torch.nn.BatchNorm2d(8),
          torch.nn.ReLU(),
          torch.nn.AvgPool2d(kernel_size=2, stride=2)
        )

        self.layer2 = torch.nn.Sequential(
          torch.nn.Conv2d(8, 16, kernel_size=3, stride=1, padding=1, bias=False),
          torch.nn.BatchNorm2d(16),
          torch.nn.ReLU(),
          torch.nn.AvgPool2d(kernel_size=2, stride=2)
        )

        self.layer3 = torch.nn.Sequential(
          torch.nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1, bias=False),
          torch.nn.BatchNorm2d(32),
          torch.nn.ReLU(),
          torch.nn.AvgPool2d(kernel_size=2, stride=2)
        )

        self.layer4 = torch.nn.Sequential(
          torch.nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1, bias=False),
          torch.nn.BatchNorm2d(64),
          torch.nn.ReLU(),
          torch.nn.AdaptiveAvgPool2d(1)
        )

        self.flatten = torch.nn.Flatten()

        for m in self.modules():
            if isinstance(m, torch.nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, torch.nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def forward(self, x):
        h = self.layer1(x)
        h = self.layer2(h)
        h = self.layer3(h)
        h = self.layer4(h)
        h = self.flatten(h)
        return h

Now we define the Relational Reasoning class. This consists of an `__inti__()` method, an internal `aggregate()` method, and finally a `train()` routine. 

In the `__init__()` we pass the CNN backbone, and the feature size representing the number of output (linear) units in the CNN. At init time, the relation head is created. This is just a multi-layer perceptron (MLP) with 256 hidden units and leaky-ReLU activation function. The type of relation head is important. If the relation head is too complex then it can easily discriminate the relation pairs; as a result the backbone will not learn useful representations.

The `aggregate()` method takes as input the features produced in the forward pass by the backbone and `K` which is the total number of augmentations we are using. The output of the aggregation phase are the relation pairs joined by a concatenation operator (other commutative operators can be used, but concatenation is the most effective). 

The `train()` routine is just an iterative learning schedule for optimizing the parameters of the backbone and relation head. It takes as input a train loader, and an integer representing the total number of epochs. Here we use the Binary Cross-Entropy loss (BCE) but using a Focal Loss can give some boost.

In [6]:
class RelationalReasoning(torch.nn.Module):
  """Self-Supervised Relational Reasoning.
  Essential implementation of the method, which uses
  the 'cat' aggregation function (the most effective),
  and can be used with any backbone.
  """
  def __init__(self, backbone, feature_size=64):
    super(RelationalReasoning, self).__init__()
    self.backbone = backbone
    self.relation_head = torch.nn.Sequential(
                             torch.nn.Linear(feature_size*2, 256),
                             torch.nn.BatchNorm1d(256),
                             torch.nn.LeakyReLU(),
                             torch.nn.Linear(256, 1))

  def aggregate(self, features, K):
    relation_pairs_list = list()
    targets_list = list()
    size = int(features.shape[0] / K)
    shifts_counter=1
    for index_1 in range(0, size*K, size):
      for index_2 in range(index_1+size, size*K, size):
        # Using the 'cat' aggregation function by default
        pos_pair = torch.cat([features[index_1:index_1+size], 
                              features[index_2:index_2+size]], 1)
        # Shuffle without collisions by rolling the mini-batch (negatives)
        neg_pair = torch.cat([
                     features[index_1:index_1+size], 
                     torch.roll(features[index_2:index_2+size], 
                     shifts=shifts_counter, dims=0)], 1)
        relation_pairs_list.append(pos_pair)
        relation_pairs_list.append(neg_pair)
        targets_list.append(torch.ones(size, dtype=torch.float32))
        targets_list.append(torch.zeros(size, dtype=torch.float32))
        shifts_counter+=1
        if(shifts_counter>=size): 
            shifts_counter=1 # avoid identity pairs
    relation_pairs = torch.cat(relation_pairs_list, 0)
    targets = torch.cat(targets_list, 0)
    return relation_pairs, targets

  def train(self, tot_epochs, train_loader):
    optimizer = torch.optim.Adam([
                  {'params': self.backbone.parameters()},
                  {'params': self.relation_head.parameters()}])                               
    BCE = torch.nn.BCEWithLogitsLoss().cuda()
    self.backbone.train()
    self.relation_head.train()
    for epoch in range(tot_epochs):
      # the real target is discarded (unsupervised)
      for i, (data_augmented, _) in enumerate(train_loader):
        K = len(data_augmented) # tot augmentations
        x = torch.cat(data_augmented, 0)
        optimizer.zero_grad()              
        # forward pass (backbone)
        features = self.backbone(x) 
        # aggregation function
        relation_pairs, targets = self.aggregate(features, K)
        # forward pass (relation head)
        score = self.relation_head(relation_pairs).squeeze()        
        # cross-entropy loss and backward
        loss = BCE(score, targets)
        loss.backward()
        optimizer.step()            
        # estimate the accuracy
        predicted = torch.round(torch.sigmoid(score))
        correct = predicted.eq(targets.view_as(predicted)).sum()
        accuracy = (100.0 * correct / float(len(targets)))
        
        if(i%100==0):
          print('Epoch [{}][{}/{}] loss: {:.5f}; accuracy: {:.2f}%' \
            .format(epoch+1, i+1, len(train_loader)+1, 
                    loss.item(), accuracy.item()))

Unsupervised training
--------------------------------

In this section we use self-supervised relational reasoning for training a Conv-4 backbone on the unsupervised CIFAR-10 dataset.

In the next cell we define some hyper-parameters, such as `K` the number of total augmentations, the mini-batch size, tot_epochs and feature size (related to the backbone used). The time complexity is quadratic in the number of augmentations `K`, therefore here we use a small value just for checking the code. In the paper we used `K=32` for CIFAR-10 and CIFAR-100 experiments with mini-batch 64. In this example we train the relational model for just *10 epochs*, in the paper we trained for 200 epochs. We invite the reader to experiment with the hyper-parameters.

In [7]:
# Hyper-parameters of the simulation
K = 4 # tot augmentations, in the paper K=32 for CIFAR10/100
batch_size = 64 # 64 has been used in the paper
tot_epochs = 10 # 200 has been used in the paper
feature_size = 64 # number of units for the Conv4 backbone

Below we define the augmentation strategy. Note that here we use the CIFAR-10 normalization values, which must be changed if CIFAR-100 is used.

In [8]:
# Those are the transformations used in the paper
normalize = transforms.Normalize(mean=[0.491, 0.482, 0.447], 
                                 std=[0.247, 0.243, 0.262]) # CIFAR10
#normalize = transforms.Normalize(mean=[0.507, 0.487, 0.441], 
#                                 std=[0.267, 0.256, 0.276]) # CIFAR100

color_jitter = transforms.ColorJitter(brightness=0.8, contrast=0.8, 
                                      saturation=0.8, hue=0.2)
rnd_color_jitter = transforms.RandomApply([color_jitter], p=0.8)
rnd_gray = transforms.RandomGrayscale(p=0.2)
rnd_rcrop = transforms.RandomResizedCrop(size=32, scale=(0.08, 1.0), 
                                         interpolation=2)
rnd_hflip = transforms.RandomHorizontalFlip(p=0.5)
train_transform = transforms.Compose([rnd_rcrop, rnd_hflip,
                                      rnd_color_jitter, rnd_gray, 
                                      transforms.ToTensor(), normalize])

Defining the backbone, model, and train loader. This will download the dataset if it is not available.

In [9]:
backbone = Conv4() # simple CNN with 64 linear output units
model = RelationalReasoning(backbone, feature_size)    
train_set = MultiCIFAR10(K=K, root='data', train=True, 
                         transform=train_transform, 
                         download=True)
train_loader = torch.utils.data.DataLoader(train_set, 
                                           batch_size=batch_size, 
                                           shuffle=True)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to data/cifar-10-python.tar.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

Extracting data/cifar-10-python.tar.gz to data


The next cell starts the training for the number of epochs specified and then save the backbone in a local file. This may take a while, based on your hardware configuration, the size of `K`, and the number of epochs. The code can be easily adapted to run on a GPU if you have one. The following cell takes ~30 minutes to complete the 10 training epochs on a medium-level laptop (with no GPU acceleration). 

In [10]:
model.train(tot_epochs=tot_epochs, train_loader=train_loader)
torch.save(model.backbone.state_dict(), './backbone.tar')

Epoch [1][1/783] loss: 0.70595; accuracy: 50.52%

Epoch [1][101/783] loss: 0.64139; accuracy: 63.41%
Epoch [1][201/783] loss: 0.53704; accuracy: 72.27%
Epoch [1][301/783] loss: 0.52131; accuracy: 75.65%
Epoch [1][401/783] loss: 0.51761; accuracy: 75.91%
Epoch [1][501/783] loss: 0.53399; accuracy: 73.31%
Epoch [1][601/783] loss: 0.52216; accuracy: 75.13%
Epoch [1][701/783] loss: 0.48697; accuracy: 76.95%
Epoch [2][1/783] loss: 0.52206; accuracy: 73.57%
Epoch [2][101/783] loss: 0.47944; accuracy: 75.91%
Epoch [2][201/783] loss: 0.48841; accuracy: 76.56%
Epoch [2][301/783] loss: 0.45724; accuracy: 77.99%
Epoch [2][401/783] loss: 0.52364; accuracy: 73.18%
Epoch [2][501/783] loss: 0.45385; accuracy: 78.39%
Epoch [2][601/783] loss: 0.48026; accuracy: 74.74%
Epoch [2][701/783] loss: 0.44490; accuracy: 79.04%
Epoch [3][1/783] loss: 0.43150; accuracy: 79.30%
Epoch [3][101/783] loss: 0.49314; accuracy: 77.99%
Epoch [3][201/783] loss: 0.43615; accuracy: 79.95%
Epoch [3][301/783] loss: 0.50594; ac

KeyboardInterrupt: 

Linear evaluation
------------------------

Once the model has been trained and the backbone saved, we can use the backbone for downstream tasks such as classification or regression. Here, we perform a **linear evaluation** test which takes the backbone, stack a linear layer on top of it, then train just the weights of the linear classifier (no backprop on the backbone). This allow us to check the quality of the representations, and how close they are to the fully supervised upper-bound score. We perform linear evaluation on the same dataset (e.g. CIFAR-10) by accessing the labels.

In [28]:
# no augmentations used for linear evaluation
transform_lineval = transforms.Compose([transforms.ToTensor(), normalize]) 
train_set_lineval = torchvision.datasets.CIFAR10('data', train=True, transform=transform_lineval)
test_set_lineval = torchvision.datasets.CIFAR10('data', train=False, transform=transform_lineval)
train_loader_lineval = torch.utils.data.DataLoader(train_set_lineval, batch_size=128, shuffle=True)
test_loader_lineval = torch.utils.data.DataLoader(test_set_lineval, batch_size=128, shuffle=False)
# 64 are the number of output features in the backbone, and 10 the number of classes
linear_layer = torch.nn.Linear(64, 10)
# loading the saved backbone
backbone_lineval = Conv4() #defining a raw backbone model
checkpoint = torch.load('./backbone.tar')
backbone_lineval.load_state_dict(checkpoint)

<All keys matched successfully>

Here we start the training routine on the supervised CIFAR-10 for 10 epochs. This phase is much faster because we are just backpropagating on the linear layer. This cell takes ~2 minutes to complete the 10 epochs (medium-level laptop with no GPU acceleration).

In [37]:
optimizer = torch.optim.Adam(linear_layer.parameters())                               
CE = torch.nn.CrossEntropyLoss()
linear_layer.train()
backbone_lineval.eval()

print('Linear evaluation')
for epoch in range(10):
    accuracy_list = list()
    for i, (data, target) in enumerate(train_loader_lineval):
        optimizer.zero_grad()
        output = backbone_lineval(data).detach()
        output = linear_layer(output)
        loss = CE(output, target)
        loss.backward()
        optimizer.step()
        # estimate the accuracy
        prediction = output.argmax(-1)
        correct = prediction.eq(target.view_as(prediction)).sum()
        accuracy = (100.0 * correct / len(target))
        accuracy_list.append(accuracy.item())
    print('Epoch [{}] loss: {:.5f}; accuracy: {:.2f}%' \
            .format(epoch+1, loss.item(), sum(accuracy_list)/len(accuracy_list)))          

Linear evaluation
Epoch [1] loss: 1.47552; accuracy: 52.99%
Epoch [2] loss: 1.22791; accuracy: 52.98%
Epoch [3] loss: 1.44275; accuracy: 53.03%
Epoch [4] loss: 1.32339; accuracy: 53.02%
Epoch [5] loss: 1.47076; accuracy: 53.05%
Epoch [6] loss: 1.33488; accuracy: 53.07%
Epoch [7] loss: 1.22281; accuracy: 53.15%
Epoch [8] loss: 1.29908; accuracy: 53.11%
Epoch [9] loss: 1.34019; accuracy: 53.12%
Epoch [10] loss: 1.29688; accuracy: 53.16%


We can now test both backbone and linear layer on the test set of CIFAR-10.

In [48]:
accuracy_list = list()
for i, (data, target) in enumerate(test_loader_lineval):
    output = backbone_lineval(data).detach()
    output = linear_layer(output)
    # estimate the accuracy
    prediction = output.argmax(-1)
    correct = prediction.eq(target.view_as(prediction)).sum()
    accuracy = (100.0 * correct / len(target))
    accuracy_list.append(accuracy.item())

print('Test accuracy: {:.2f}%'.format(sum(accuracy_list)/len(accuracy_list)))

Test accuracy: 52.33%


Note that the test accuracy is significantly above chance level (in CIFAR-10 chance level is 10%), meaning that during the self-supervised traing it has been possible to build useful representations without accessing the labels.