<a href="https://colab.research.google.com/github/siyusa/Contrastive-Learning/blob/main/docs/tutorial_notebooks/tutorial17/SimCLR.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Supervised Contrastive Learning with SimCLR


In [None]:
#os 模块提供了与操作系统交互的功能，如读取环境变量、管理文件和目录等
import os
# deepcopy 用于创建对象的深拷贝，即完全复制一个对象及其所有嵌套的内容，修改新对象不会影响原对象。
from copy import deepcopy

## Imports for plotting
import matplotlib.pyplot as plt
# 将 matplotlib 的默认色彩映射设置为 ‘cividis’。这是一种现代、感知均匀、对色盲友好的颜色方案，常用于科学出版物，能更好地呈现数据细节
plt.set_cmap('cividis')
# 这是一个 IPython 魔法命令（以 % 开头）。使 matplotlib 的图形能够直接嵌入在 Jupyter Notebook 中显示，而不是作为单独的窗口弹出
%matplotlib inline
# 设置 matplotlib 的输出格式为 SVG 和 PDF，这两种格式适合高质量的图形导出和打印
from IPython.display import set_matplotlib_formats
set_matplotlib_formats('svg', 'pdf') # For export
import matplotlib
# 将绘图中所有线条的默认宽度设置为 2.0。这使得图表中的线条比默认值更粗、更清晰
matplotlib.rcParams['lines.linewidth'] = 2.0
# seaborn 是一个基于 matplotlib 的高级统计图形库和数据可视化库，提供更高级的接口和美观的默认样式。下面的代码导入 seaborn 并应用其默认样式设置
import seaborn as sns
# 激活 seaborn 的默认主题和样式设置，自动覆盖 matplotlib 原有的样式，使图表立即变得更具现代感和美观性（如更改背景、网格、字体、调色板等）
sns.set()

## tqdm for loading bars
# tqdm 是一个用于显示循环进度条的 Python 库，特别适合在长时间运行的任务中提供视觉反馈。下面的代码导入 tqdm 的 notebook 版本，以便在 Jupyter Notebook 环境中使用
from tqdm.notebook import tqdm

# PyTorch
import torch
# Torch 的神经网络模块，包含各种神经网络层、损失函数等
import torch.nn as nn
# Torch 的函数式接口，提供各种操作的无状态版本，如激活函数、卷积操作等
import torch.nn.functional as F
# Torch 的数据处理模块，包含数据集和数据加载器等
import torch.utils.data as data
# Torch 的优化器模块，包含各种优化算法，如 SGD、Adam 等
import torch.optim as optim

## Torchvision
# Torchvision 是 PyTorch 的一个子库，专注于计算机视觉任务，提供了常用的数据集、模型和图像变换工具
import torchvision
# torchvision 的图像变换模块，提供各种图像预处理和数据增强操作
from torchvision import transforms

# PyTorch Lightning
# PyTorch Lightning 是一个用于简化 PyTorch 代码结构和训练流程的高层框架，旨在提高代码的可读性和可维护性，同时支持分布式训练等高级功能
#try:
#    import pytorch_lightning as pl
#except ModuleNotFoundError: # Google Colab does not have PyTorch Lightning installed by default. Hence, we do it here if necessary
#    !pip install --quiet pytorch-lightning>=1.4
#    import pytorch_lightning as pl
# PyTorch Lightning 的回调模块，包含各种训练过程中的回调函数，如学习率监控、模型检查点等
#from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint

# Import tensorboard
# TensorBoard 是一个用于可视化机器学习实验的工具，常与 TensorFlow 和 PyTorch 一起使用。下面的代码导入 TensorBoard 的扩展，以便在 Jupyter Notebook 中使用
%load_ext tensorboard

# Path to the folder where the datasets are/should be downloaded (e.g. CIFAR10)
#DATASET_PATH = "../data"
# Path to the folder where the pretrained models are saved
#CHECKPOINT_PATH = "../saved_models/tutorial17"
# In this notebook, we use data loaders with heavier computational processing. It is recommended to use as many
# workers as possible in a data loader, which corresponds to the number of CPU cores
#NUM_WORKERS = os.cpu_count()

# Setting the seed
#pl.seed_everything(42)

# Ensure that all operations are deterministic on GPU (if used) for reproducibility



Global seed set to 42


Device: cuda:0
Number of workers: 16


In [None]:
contrast_transforms = transforms.Compose([transforms.RandomResizedCrop(size=256),
                                          transforms.RandomApply([
                                              transforms.ColorJitter(brightness=0.5,
                                                                     contrast=0.5,
                                                                     saturation=0.5,
                                                                     hue=0.1)
                                          ], p=0.8),
                                          transforms.RandomGrayscale(p=0.2),
                                          transforms.RandomHorizontalFlip(),
                                          transforms.GaussianBlur(kernel_size=9),
                                         ])

In [None]:
class SimCLRUSDataset(Dataset):
    """
    Loads images under directory structure:
    data/<plane>/*.png
    Returns:
      x1, x2 : two augmented views for SimCLR (tensor [3,H,W], values in [0,1])
      enc   : weakly augmented masked image for encoder input ([1,H,W], values in [0,1])
      path  : original path
    """
    def __init__(self, root_dir='./data', size=512, data_transforms=None, n_views=2):
        self.root = Path(root_dir)
        exts = ('.png','.jpg','.jpeg','.bmp','.tiff','.tif')
        # 递归查找self.root目录及其所有子目录中，扩展名在exts集合中的所有文件，并按字母顺序返回它们的路径列表。
        self.samples = [p for p in sorted(self.root.rglob('*')) if p.suffix.lower() in exts] 
        self.size = size

        # Strong augmentations for SimCLR views
        self.base_transforms = data_transforms
        self.n_views = n_views
        self.augment2 = T.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,))
        ])
                                                  

    def __len__(self):
        return len(self.samples)
    


    def __getitem__(self, idx):
        p = self.samples[idx]
        img = cv2.imread(str(p), cv2.IMREAD_GRAYSCALE)
        if img is None:
            raise RuntimeError(f"Cannot read image: {p}")

        """
        x1 = self.base_transforms(img)  # [3,H,W], 
        x2 = self.base_transforms(img)

        plt.figure(figsize=(10,15))
        plt.title('Augmented image examples of the dataset')
        plt.subplot(1,3,1)
        plt.imshow(img)
        plt.subplot(1,3,2)
        plt.imshow(x1.transpose(0,1).transpose(1,2))  # (c,h,w) to (h,w,c)
        plt.subplot(1,3,3)
        plt.imshow(x2.transpose(0,1).transpose(1,2))  
        plt.axis('off')
        plt.show()
        plt.close()
        """  

        return [self.base_transforms(img) for i in range(self.n_views)]


### SimCLR implementation

Using the data loader pipeline above, we can now implement SimCLR. At each iteration, we get for every image $x$ two differently augmented versions, which we refer to as $\tilde{x}_i$ and $\tilde{x}_j$. Both of these images are encoded into a one-dimensional feature vector, between which we want to maximize similarity which minimizes it to all other images in the batch. The encoder network is split into two parts: a base encoder network $f(\cdot)$, and a projection head $g(\cdot)$. The base network is usually a deep CNN as we have seen in e.g. [Tutorial 5](https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/tutorial5/Inception_ResNet_DenseNet.html) before, and is responsible for extracting a representation vector from the augmented data examples. In our experiments, we will use the common ResNet-18 architecture as $f(\cdot)$, and refer to the output as $f(\tilde{x}_i)=h_i$. The projection head $g(\cdot)$ maps the representation $h$ into a space where we apply the contrastive loss, i.e., compare similarities between vectors. It is often chosen to be a small MLP with non-linearities, and for simplicity, we follow the original SimCLR paper setup by defining it as a two-layer MLP with ReLU activation in the hidden layer. Note that in the follow-up paper, [SimCLRv2](https://arxiv.org/abs/2006.10029), the authors mention that larger/wider MLPs can boost the performance considerably. This is why we apply an MLP with four times larger hidden dimensions, but deeper MLPs showed to overfit on the given dataset. The general setup is visualized below (figure credit - [Ting Chen et al.](https://arxiv.org/abs/2006.10029)):

<center width="100%"><img src="https://github.com/phlippe/uvadlc_notebooks/blob/master/docs/tutorial_notebooks/tutorial17/simclr_network_setup.svg?raw=1" width="350px"></center>

After finishing the training with contrastive learning, we will remove the projection head $g(\cdot)$, and use $f(\cdot)$ as a pretrained feature extractor. The representations $z$ that come out of the projection head $g(\cdot)$ have been shown to perform worse than those of the base network $f(\cdot)$ when finetuning the network for a new task. This is likely because the representations $z$ are trained to become invariant to many features like the color that can be important for downstream tasks. Thus, $g(\cdot)$ is only needed for the contrastive learning stage.

Now that the architecture is described, let's take a closer look at how we train the model. As mentioned before, we want to maximize the similarity between the representations of the two augmented versions of the same image, i.e., $z_i$ and $z_j$ in the figure above, while minimizing it to all other examples in the batch. SimCLR thereby applies the InfoNCE loss, originally proposed by [Aaron van den Oord et al.](https://arxiv.org/abs/1807.03748) for contrastive learning. In short, the InfoNCE loss compares the similarity of $z_i$ and $z_j$ to the similarity of $z_i$ to any other representation in the batch by performing a softmax over the similarity values. The loss can be formally written as:

$$
\ell_{i,j}=-\log \frac{\exp(\text{sim}(z_i,z_j)/\tau)}{\sum_{k=1}^{2N}\mathbb{1}_{[k\neq i]}\exp(\text{sim}(z_i,z_k)/\tau)}=-\text{sim}(z_i,z_j)/\tau+\log\left[\sum_{k=1}^{2N}\mathbb{1}_{[k\neq i]}\exp(\text{sim}(z_i,z_k)/\tau)\right]
$$

The function $\text{sim}$ is a similarity metric, and the hyperparameter $\tau$ is called temperature determining how peaked the distribution is. Since many similarity metrics are bounded, the temperature parameter allows us to balance the influence of many dissimilar image patches versus one similar patch. The similarity metric that is used in SimCLR is cosine similarity, as defined below:

$$
\text{sim}(z_i,z_j) = \frac{z_i^\top \cdot z_j}{||z_i||\cdot||z_j||}
$$

The maximum cosine similarity possible is $1$, while the minimum is $-1$. In general, we will see that the features of two different images will converge to a cosine similarity around zero since the minimum, $-1$, would require $z_i$ and $z_j$ to be in the exact opposite direction in all feature dimensions, which does not allow for great flexibility.

Finally, now that we have discussed all details, let's implement SimCLR below as a PyTorch Lightning module:

In [None]:
class SimCLR(nn.Module):

    def __init__(self, hidden_dim, lr, temperature, weight_decay, max_epochs=500):
        super().__init__()
        # 保存所有传入的超参数到 self.hparams 字典中，便于日志记录和模型检查点保存
        #self.save_hyperparameters()
        # 断言检查，确保温度参数为正数（负温度无意义）
        #assert self.hparams.temperature > 0.0, 'The temperature must be a positive float!'
        # 加载预训练的 ResNet18 模型，将最后全连接层输出改为 4*hidden_dim，作为Base model f(.) 
        self.convnet = torchvision.models.resnet18(num_classes=4*hidden_dim)  # Output of last linear layer
        # The MLP for g(.) consists of Linear->ReLU->Linear
        self.convnet.fc = nn.Sequential(
            self.convnet.fc,  # Linear(ResNet output, 4*hidden_dim)
            nn.ReLU(inplace=True),
            nn.Linear(4*hidden_dim, hidden_dim)
        )


    def info_nce_loss(self, batch, mode='train'):
        imgs, _ = batch   #imgs 的原始shape：类型：元组，包含2个张量, 每个张量shape：[batch_size, channels, height, width]
        imgs = torch.cat(imgs, dim=0) # [batch_size, C, H, W] × 2 → [2*batch_size, C, H, W]

        # Encode all images
        feats = self.convnet(imgs)   # feats: [2*batch_size, hidden_dim]
        # Calculate cosine similarity 计算所有特征对之间的余弦相似度矩阵
        # feats[:,None,:]: 形状 [2N, 1, D]; feats[None,:,:]: 形状 [1, 2N, D]; 结果: [2N, 2N] 相似度矩阵
        cos_sim = F.cosine_similarity(feats[:,None,:], feats[None,:,:], dim=-1)
        # Mask out cosine similarity to itself 创建对角线掩码并填充极大负值. 排除每个样本与自身的相似度（softmax 中 exp(-9e15) ≈ 0）
        self_mask = torch.eye(cos_sim.shape[0], dtype=torch.bool, device=cos_sim.device)
        cos_sim.masked_fill_(self_mask, -9e15)
        # Find positive example -> batch_size//2 away from the original example
        pos_mask = self_mask.roll(shifts=cos_sim.shape[0]//2, dims=0)
        # InfoNCE loss
        cos_sim = cos_sim / self.hparams.temperature
        nll = -cos_sim[pos_mask] + torch.logsumexp(cos_sim, dim=-1) # InfoNCE 损失,l(i,j) = -log[exp(sim(zᵢ,zⱼ)/τ) / ∑ exp(sim(zᵢ,zₖ)/τ)]
        nll = nll.mean()

        # Logging loss
        self.log(mode+'_loss', nll)
        # 构建用于排名的相似度矩阵 Get ranking position of positive example
        # 每行第一列是该样本的正样本相似度，后面是其他所有样本（负样本）的相似度
        comb_sim = torch.cat([cos_sim[pos_mask][:,None],  # First position positive example
                              cos_sim.masked_fill(pos_mask, -9e15)],
                             dim=-1)
        # 计算正样本的排名:每行降序排序，获取索引。找到值0（第一列）在排序中的位置
        sim_argsort = comb_sim.argsort(dim=-1, descending=True).argmin(dim=-1)
        # Logging ranking metrics
        writer.add_scalar(mode+'_acc_top1', (sim_argsort == 0).float().mean())
        writer.add_scalar(mode+'_acc_top5', (sim_argsort < 5).float().mean())
        writer.add_scalar(mode+'_acc_mean_pos', 1+sim_argsort.float().mean())

        return nll

    def training_step(self, batch, batch_idx):
        return self.info_nce_loss(batch, mode='train')

    def validation_step(self, batch, batch_idx):
        self.info_nce_loss(batch, mode='val')

In [None]:
import argparse
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

def train(args):
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    print("Device:", device)
    writer = SummaryWriter(log_dir=args.log_dir)
    ds = SimCLRUSDataset(root_dir='./data', size=args.img_size, data_transforms = contrast_transforms,n_views = 2)
    train_loader = data.DataLoader(ds, batch_size=args.batch_size, shuffle=True, drop_last=True, pin_memory=True, num_workers=args.num_workers)
    model = SimCLR(max_epochs=args.max_epochs, hidden_dim=args.hidden_dim, lr=args.lr, temperature=args.temperature, weight_decay=1e-4)
    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    os.makedirs(args.save_dir, exist_ok=True)
    
    global_step = 0
    for epoch in range(args.max_epochs):
        model.train()
        epoch_loss = 0.0
        for batch_idx, batch in enumerate(tqdm(train_loader, desc=f"Epoch {epoch+1}/{args.max_epochs}")):
            optimizer.zero_grad()
            loss = model.training_step(batch, batch_idx)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            writer.add_scalar('Train/Loss', loss.item(), global_step)
            global_step += 1
            
        avg_epoch_loss = epoch_loss / len(train_loader)
        print(f"Epoch {epoch+1}/{args.max_epochs}, Loss: {avg_epoch_loss:.4f}")
        if epoch % args.save_interval == 0 or epoch == args.max_epochs - 1:
            torch.save(model.state_dict(), os.path.join(args.save_dir, f"simclr_epoch{epoch+1}.pth"))
    
    writer.close()
    print("Training complete. Saved to:", args.save_dir)

In [None]:
if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--log_dir', type=str, default='runs/simclr_experiment')
    parser.add_argument('--img_size', type=int, default=512)
    parser.add_argument('--batch_size', type=int, default=128)
    parser.add_argument('--num_workers', type=int, default=1)

    parser.add_argument('--max_epochs', type=int, default=500)
    parser.add_argument('--hidden_dim', type=int, default=128)
    parser.add_argument('--lr', type=float, default=5e-5)
    parser.add_argument('--temperature', type=float, default=0.1)
    parser.add_argument('--save_dir', type=str, default='./checkpoints')
    parser.add_argument('--save_interval', type=int, default=50)
    parser.add_argument('--data_dir', type=str, default='./data')

    parser.add_argument('--proj_dim', type=int, default=128)
    parser.add_argument('--use_cpu', action='store_true')
    args = parser.parse_args()
    train(args)
