# Supervised Contrastive Learning

**Author:** [Khalid Salama](https://www.linkedin.com/in/khalid-salama-24403144/)<br>
**Date created:** 2020/11/30<br>
**Last modified:** 2020/11/30<br>
**Description:** Using supervised contrastive learning for image classification.

## Introduction

[Supervised Contrastive Learning](https://arxiv.org/abs/2004.11362)
(Prannay Khosla et al.) is a training methodology that outperforms
supervised training with crossentropy on classification tasks.

Essentially, training an image classification model with Supervised Contrastive
Learning is performed in two phases:

1. Training an encoder to learn to produce vector representations of input images such
that representations of images in the same class will be more similar compared to
representations of images in different classes.
2. Training a classifier on top of the frozen encoder.

Note that this example requires [TensorFlow Addons](https://www.tensorflow.org/addons), which you can install using the following command:

```python
pip install tensorflow-addons
```

## Setup

In [40]:
!pip install tensorflow-addons



In [41]:
import tensorflow as tf
import tensorflow_addons as tfa
import numpy as np
from tensorflow import keras
from tensorflow.keras import layers

## Prepare the data

In [42]:
num_classes = 10
input_shape = (32, 32, 3)

# # Load the train and test data splits
# (x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()

# # Display shapes of train and test datasets
# print(f"x_train shape: {x_train.shape} - y_train shape: {y_train.shape}")
# print(f"x_test shape: {x_test.shape} - y_test shape: {y_test.shape}")


In [43]:
# import random
# from torchvision import datasets

# # Load CIFAR-10 training dataset
# cifar10_train = datasets.CIFAR10(root='./data', train=True, download=True)

# # Load CIFAR-10 testing dataset
# cifar10_test = datasets.CIFAR10(root='./data', train=False, download=True)

# # Concatenate the training and testing datasets
# cifar10_combined = cifar10_train + cifar10_test


# # Convert CIFAR-10 dataset into a list of tuples containing pixel values
# cifar10_combined_list = [(image_tensor, label) for image_tensor, label in cifar10_combined]

# # Shuffle the combined dataset
# random.shuffle(cifar10_combined_list)

# # Calculate the number of samples for the first part (k%)
# k_percentage = 10  # You can adjust this value as needed
# k_samples = int(len(cifar10_combined_list) * k_percentage / 100)

# # Split the combined dataset into two parts
# cifar10_first_part = cifar10_combined_list[:k_samples]
# cifar10_second_part = cifar10_combined_list[k_samples:]

# # Optionally, convert the lists back into datasets if needed
# cifar10_first_part_dataset = CustomCIFAR10Dataset(cifar10_first_part)
# cifar10_second_part_dataset = CustomCIFAR10Dataset(cifar10_second_part)

In [44]:
# import torch
# from torchvision.transforms.functional import to_tensor

# # Initialize lists to store images and labels
# x_train = []
# y_train = []

# # Iterate through cifar10_first_part to extract image tensors and labels
# for image, label in cifar10_first_part:
#     # Convert image to tensor
#     image_tensor = to_tensor(image)
#     # Append image tensor to x_train
#     x_train.append(image_tensor)
#     # Append label to y_train
#     y_train.append(label)

# # Convert lists to tensors
# x_train = torch.stack(x_train)  # Convert list of tensors to a tensor
# y_train = torch.tensor(y_train)  # Convert list of labels to a tensor

# import torch
# from torchvision.transforms.functional import to_tensor

# # Initialize lists to store images and labels
# x_test = []
# y_test = []

# # Iterate through cifar10_second_part to extract image tensors and labels
# for image, label in cifar10_second_part:
#     # Convert image to tensor
#     image_tensor = to_tensor(image)
#     # Append image tensor to x_test
#     x_test.append(image_tensor)
#     # Append label to y_test
#     y_test.append(label)

# # Convert lists to tensors
# x_test = torch.stack(x_test)  # Convert list of tensors to a tensor
# y_test = torch.tensor(y_test)  # Convert list of labels to a tensor



In [45]:
import numpy as np
from sklearn.model_selection import train_test_split
import random
from torchvision import datasets

# Load CIFAR-10 training dataset
cifar10_train = datasets.CIFAR10(root='./data', train=True, download=True)

# Load CIFAR-10 testing dataset
cifar10_test = datasets.CIFAR10(root='./data', train=False, download=True)

# Concatenate the training and testing datasets
cifar10_combined = cifar10_train + cifar10_test


# Define the percentage split
k = 40  # Change this to the desired percentage split

# Split the combined CIFAR-10 dataset
x_data = [np.array(data[0]) for data in cifar10_combined]
y_data = [data[1] for data in cifar10_combined]

x_train, x_test, y_train, y_test = train_test_split(x_data, y_data, train_size=k/100, random_state=42)

# Convert to numpy arrays
x_train = np.array(x_train)
x_test = np.array(x_test)
y_train = np.array(y_train)
y_test = np.array(y_test)

# Print shapes to verify
print("x_train shape:", x_train.shape)
print("y_train shape:", y_train.shape)
print("x_test shape:", x_test.shape)
print("y_test shape:", y_test.shape)


Files already downloaded and verified
Files already downloaded and verified
x_train shape: (24000, 32, 32, 3)
y_train shape: (24000,)
x_test shape: (36000, 32, 32, 3)
y_test shape: (36000,)


## Using image data augmentation

In [46]:
data_augmentation = keras.Sequential(
    [
        layers.Normalization(),
        layers.RandomFlip("horizontal"),
        layers.RandomRotation(0.02),
    ]
)

# Setting the state of the normalization layer.
data_augmentation.layers[0].adapt(x_train)

## Build the encoder model

The encoder model takes the image as input and turns it into a 2048-dimensional
feature vector.

In [47]:

def create_encoder():
    resnet = keras.applications.ResNet50V2(
        include_top=False, weights=None, input_shape=input_shape, pooling="avg"
    )

    inputs = keras.Input(shape=input_shape)
    augmented = data_augmentation(inputs)
    outputs = resnet(augmented)
    model = keras.Model(inputs=inputs, outputs=outputs, name="cifar10-encoder")
    return model


encoder = create_encoder()
encoder.summary()

learning_rate = 0.001
batch_size = 265
hidden_units = 512
projection_units = 128
num_epochs = 20
dropout_rate = 0.5
temperature = 0.05

Model: "cifar10-encoder"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_25 (InputLayer)       [(None, 32, 32, 3)]       0         
                                                                 
 sequential_3 (Sequential)   (None, 32, 32, 3)         7         
                                                                 
 resnet50v2 (Functional)     (None, 2048)              23564800  
                                                                 
Total params: 23564807 (89.89 MB)
Trainable params: 23519360 (89.72 MB)
Non-trainable params: 45447 (177.53 KB)
_________________________________________________________________


## Build the classification model

The classification model adds a fully-connected layer on top of the encoder,
plus a softmax layer with the target classes.

In [48]:

def create_classifier(encoder, trainable=True):

    for layer in encoder.layers:
        layer.trainable = trainable

    inputs = keras.Input(shape=input_shape)
    features = encoder(inputs)
    features = layers.Dropout(dropout_rate)(features)
    features = layers.Dense(hidden_units, activation="relu")(features)
    features = layers.Dropout(dropout_rate)(features)
    outputs = layers.Dense(num_classes, activation="softmax")(features)

    model = keras.Model(inputs=inputs, outputs=outputs, name="cifar10-classifier")
    model.compile(
        optimizer=keras.optimizers.Adam(learning_rate),
        loss=keras.losses.SparseCategoricalCrossentropy(),
        metrics=[keras.metrics.SparseCategoricalAccuracy()],
    )
    return model


## Experiment 1: Train the baseline classification model

In this experiment, a baseline classifier is trained as usual, i.e., the
encoder and the classifier parts are trained together as a single model
to minimize the crossentropy loss.

In [49]:
encoder = create_encoder()
classifier = create_classifier(encoder)
classifier.summary()

history = classifier.fit(x=x_train, y=y_train, batch_size=batch_size, epochs=num_epochs)

accuracy = classifier.evaluate(x_test, y_test)[1]
print(f"Test accuracy: {round(accuracy * 100, 2)}%")


Model: "cifar10-classifier"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_28 (InputLayer)       [(None, 32, 32, 3)]       0         
                                                                 
 cifar10-encoder (Functiona  (None, 2048)              23564807  
 l)                                                              
                                                                 
 dropout_10 (Dropout)        (None, 2048)              0         
                                                                 
 dense_12 (Dense)            (None, 512)               1049088   
                                                                 
 dropout_11 (Dropout)        (None, 512)               0         
                                                                 
 dense_13 (Dense)            (None, 10)                5130      
                                                

## Experiment 2: Use supervised contrastive learning

In this experiment, the model is trained in two phases. In the first phase,
the encoder is pretrained to optimize the supervised contrastive loss,
described in [Prannay Khosla et al.](https://arxiv.org/abs/2004.11362).

In the second phase, the classifier is trained using the trained encoder with
its weights freezed; only the weights of fully-connected layers with the
softmax are optimized.

### 1. Supervised contrastive learning loss function

In [50]:

class SupervisedContrastiveLoss(keras.losses.Loss):
    def __init__(self, temperature=1, name=None):
        super().__init__(name=name)
        self.temperature = temperature

    def __call__(self, labels, feature_vectors, sample_weight=None):
        # Normalize feature vectors
        feature_vectors_normalized = tf.math.l2_normalize(feature_vectors, axis=1)
        # Compute logits
        logits = tf.divide(
            tf.matmul(
                feature_vectors_normalized, tf.transpose(feature_vectors_normalized)
            ),
            self.temperature,
        )
        return tfa.losses.npairs_loss(tf.squeeze(labels), logits)


def add_projection_head(encoder):
    inputs = keras.Input(shape=input_shape)
    features = encoder(inputs)
    outputs = layers.Dense(projection_units, activation="relu")(features)
    model = keras.Model(
        inputs=inputs, outputs=outputs, name="cifar-encoder_with_projection-head"
    )
    return model


### 2. Pretrain the encoder

In [51]:
encoder = create_encoder()

encoder_with_projection_head = add_projection_head(encoder)
encoder_with_projection_head.compile(
    optimizer=keras.optimizers.Adam(learning_rate),
    loss=SupervisedContrastiveLoss(temperature),
)

encoder_with_projection_head.summary()

history = encoder_with_projection_head.fit(
    x=x_train, y=y_train, batch_size=batch_size, epochs=num_epochs
)

Model: "cifar-encoder_with_projection-head"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_31 (InputLayer)       [(None, 32, 32, 3)]       0         
                                                                 
 cifar10-encoder (Functiona  (None, 2048)              23564807  
 l)                                                              
                                                                 
 dense_14 (Dense)            (None, 128)               262272    
                                                                 
Total params: 23827079 (90.89 MB)
Trainable params: 23781632 (90.72 MB)
Non-trainable params: 45447 (177.53 KB)
_________________________________________________________________
Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
E

### 3. Train the classifier with the frozen encoder

In [52]:
classifier = create_classifier(encoder, trainable=False)

history = classifier.fit(x=x_train, y=y_train, batch_size=batch_size, epochs=num_epochs)

accuracy = classifier.evaluate(x_test, y_test)[1]
print(f"Test accuracy: {round(accuracy * 100, 2)}%")

Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20
Test accuracy: 68.87%


We get to an improved test accuracy.

## Conclusion

As shown in the experiments, using the supervised contrastive learning technique
outperformed the conventional technique in terms of the test accuracy. Note that
the same training budget (i.e., number of epochs) was given to each technique.
Supervised contrastive learning pays off when the encoder involves a complex
architecture, like ResNet, and multi-class problems with many labels.
In addition, large batch sizes and multi-layer projection heads
improve its effectiveness. See the [Supervised Contrastive Learning](https://arxiv.org/abs/2004.11362)
paper for more details.

You can use the trained model hosted on [Hugging Face Hub](https://huggingface.co/keras-io/supervised-contrastive-learning-cifar10)
and try the demo on [Hugging Face Spaces](https://huggingface.co/spaces/keras-io/supervised-contrastive-learning).