In [1]:
from torch import nn
import torch.nn.functional as F
import torchvision
import torch

from tqdm import tqdm

import pytorch_lightning as pl
from pytorch_lightning import Trainer, LightningModule
import torchvision.transforms as transforms
from torchmetrics.functional import accuracy

from PIL import Image
from simclr import SimCLR
from simclr.modules import NT_Xent, get_resnet
# from simclr.modules.transformations import TransformsSimCLR
from simclr.modules.sync_batchnorm import convert_model
from simclr.modules import LARS

import resnet

import os
import argparse
import sys

In [2]:
checkpoint_dir = "/scratch/vvb238/simclr"

### This class reads the images and their labels from the root folder

In [3]:
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, root, split, transform, limit=0):
        r"""
        Args:
            root: Location of the dataset folder, usually it is /dataset
            split: The split you want to used, it should be one of train, val or unlabeled.
            transform: the transform you want to applied to the images.
        """

        self.split = split
        self.transform = transform

        self.image_dir = os.path.join(root, split)
        label_path = os.path.join(root, f"{split}_label_tensor.pt")

        if limit == 0:
            self.num_images = len(os.listdir(self.image_dir))
        else:
            self.num_images = limit

        if os.path.exists(label_path):
            self.labels = torch.load(label_path)
        else:
            self.labels = -1 * torch.ones(self.num_images, dtype=torch.long)

    def __len__(self):
        return self.num_images

    def __getitem__(self, idx):
        with open(os.path.join(self.image_dir, f"{idx}.png"), 'rb') as f:
            img = Image.open(f).convert('RGB')

        return self.transform(img), self.labels[idx]

In [4]:
class TransformsSimCLR:
    """
    A stochastic data augmentation module that transforms any given data example randomly
    resulting in two correlated views of the same example,
    denoted x ̃i and x ̃j, which we consider as a positive pair.
    """

    def __init__(self, size):
        s = 1
        color_jitter = torchvision.transforms.ColorJitter(
            0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s
        )
        self.train_transform = torchvision.transforms.Compose(
            [
                torchvision.transforms.RandomResizedCrop(size=size),
                torchvision.transforms.RandomHorizontalFlip(),  # with 0.5 probability
                torchvision.transforms.RandomApply([color_jitter], p=0.8),
                torchvision.transforms.RandomGrayscale(p=0.2),
                torchvision.transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),   
            ]
        )

        self.test_transform = torchvision.transforms.Compose(
            [
                torchvision.transforms.Resize(size=size),
                torchvision.transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 
            ]
        )

    def __call__(self, x):
        return self.train_transform(x), self.train_transform(x)

#### Dataset class

In [5]:
class NYUImageNetDataModule(pl.LightningDataModule):
  
    def train_dataloader(self):
        train_transform = transforms.Compose([
            transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomVerticalFlip(p=0.5),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])

        trainset = CustomDataset(root='/dataset', split="train", transform=train_transform)
        train_loader = torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=True, num_workers=4, pin_memory=True)
        return train_loader
    
    def val_dataloader(self):
        eval_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])
        evalset = CustomDataset(root='/dataset', split="val", transform=eval_transform)
        eval_loader = torch.utils.data.DataLoader(evalset, batch_size=32, shuffle=False, num_workers=4, pin_memory=True)
        return eval_loader
    
    def ssl_train_dataloader(self, batch_size):
        unlabeled_dataset = CustomDataset(root='/dataset', split='unlabeled', transform=TransformsSimCLR(96))
        unlabeled_dataloader = torch.utils.data.DataLoader(unlabeled_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
        return unlabeled_dataloader
        
    def ssl_val_dataloader(self, batch_size):
        val_dataset = CustomDataset(root='/dataset', split='val', transform=TransformsSimCLR(96))
        val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
        return val_dataloader

#### Self-supervised Learning (uses the model defined in resnet.py)

In [6]:
class ContrastiveLearning(LightningModule):
    def __init__(self):
        super().__init__()
        # initialize ResNet
        self.encoder = resnet.get_custom_resnet18()
#         get_resnet("resnet18", pretrained=False)
        self.n_features = self.encoder.fc.in_features  # get dimensions of fc layer
        self.model = SimCLR(self.encoder, 1024, self.n_features)
        self.criterion = NT_Xent(
            BATCH_SIZE, 0.5, world_size=1
        )

    def forward(self, x_i, x_j):
        h_i, h_j, z_i, z_j = self.model(x_i, x_j)
        return z_i, z_j

    def training_step(self, batch, batch_idx):
        # training_step defined the train loop. It is independent of forward
        (x_i, x_j), _ = batch
        z_i, z_j = self.forward(x_i, x_j)
        loss = self.criterion(z_i, z_j)
        self.log('train_loss', loss)
        return loss
    
    def validation_step(self, batch, batch_idx):
        # training_step defined the train loop. It is independent of forward
        (x_i, x_j), _ = batch
        z_i, z_j = self.forward(x_i, x_j)
        loss = self.criterion(z_i, z_j)
        self.log('val_loss', loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        return { 'val_loss' : loss }

    def configure_criterion(self):
        criterion = NT_Xent(BATCH_SIZE, 0.5)
        return criterion

    def configure_optimizers(self):
        scheduler = None
#       "Adam":
#         optimizer = torch.optim.Adam(self.model.parameters(), lr=3e-4)
    
#       "LARS"
        # optimized using LARS with linear learning rate scaling
        # (i.e. LearningRate = 0.3 × BatchSize/256) and weight decay of 10−6.
        learning_rate = 0.3 * BATCH_SIZE / 256
        optimizer = LARS(
            self.model.parameters(),
            lr=learning_rate,
            weight_decay=0.000001,
            exclude_from_weight_decay=["batch_normalization", "bias"],
        )

        # "decay the learning rate with the cosine decay schedule without restarts"
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, EPOCHS, eta_min=0, last_epoch=-1, verbose=True
        )

        if scheduler:
            return {"optimizer": optimizer, "lr_scheduler": scheduler}
        else:
            return {"optimizer": optimizer}

In [7]:
EPOCHS = 300
BATCH_SIZE = 256

In [8]:
data = NYUImageNetDataModule()

In [9]:
simclr = ContrastiveLearning()
# simclr = ContrastiveLearning.load_from_checkpoint(os.path.join(checkpoint_dir, 'simclr.ckpt'))

In [10]:
from pytorch_lightning.callbacks import ModelCheckpoint
checkpoint_callback = ModelCheckpoint(monitor='val_loss', save_last=True)

In [11]:
trainer = Trainer(gpus=1,deterministic=True, max_epochs=EPOCHS, default_root_dir=checkpoint_dir, profiler="simple",
                limit_val_batches= 5, precision=16, benchmark=True, callbacks=[checkpoint_callback], fast_dev_run=False,
                resume_from_checkpoint=os.path.join(checkpoint_dir, 'lightning_logs/version_0/checkpoints/last.ckpt'))
trainer.sync_batchnorm=True

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
Using native 16bit precision.


In [12]:
trainer.fit(simclr, train_dataloader=data.ssl_train_dataloader(BATCH_SIZE), val_dataloaders=data.ssl_val_dataloader(BATCH_SIZE))

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type    | Params
--------------------------------------
0 | encoder   | ResNet  | 11.2 M
1 | model     | SimCLR  | 12.0 M
2 | criterion | NT_Xent | 0     
--------------------------------------
12.0 M    Trainable params
0         Non-trainable params
12.0 M    Total params
47.821    Total estimated model params size (MB)


Adjusting learning rate of group 0 to 3.0000e-01.
Epoch 0:   0%|          | 1/2005 [00:32<18:06:28, 32.53s/it, loss=6.22, v_num=0, val_loss=6.230]

	add_(Number alpha, Tensor other)
Consider using one of the following signatures instead:
	add_(Tensor other, *, Number alpha) (Triggered internally at  /opt/conda/conda-bld/pytorch_1616554793803/work/torch/csrc/utils/python_arg_parser.cpp:1005.)
  next_v.mul_(momentum).add_(scaled_lr, grad)


Epoch 0: 100%|█████████▉| 2001/2005 [13:21<00:01,  2.50it/s, loss=4.96, v_num=0, val_loss=6.230]
Validating: 0it [00:00, ?it/s][A
Validating:   0%|          | 0/5 [00:00<?, ?it/s][A
Epoch 0: 100%|█████████▉| 2002/2005 [13:22<00:01,  2.50it/s, loss=4.96, v_num=0, val_loss=6.230]
Epoch 0: 100%|█████████▉| 2004/2005 [13:22<00:00,  2.50it/s, loss=4.96, v_num=0, val_loss=6.230]
Validating: 100%|██████████| 5/5 [00:02<00:00,  2.86it/s][AAdjusting learning rate of group 0 to 2.9999e-01.
Epoch 0: 100%|██████████| 2005/2005 [13:23<00:00,  2.49it/s, loss=4.96, v_num=0, val_loss=4.990]
Epoch 1: 100%|█████████▉| 2000/2005 [10:15<00:01,  3.25it/s, loss=4.86, v_num=0, val_loss=4.990]
Validating: 0it [00:00, ?it/s][A
Validating:   0%|          | 0/5 [00:00<?, ?it/s][A
Epoch 1: 100%|█████████▉| 2002/2005 [10:16<00:00,  3.25it/s, loss=4.86, v_num=0, val_loss=4.990]
Epoch 1: 100%|█████████▉| 2004/2005 [10:16<00:00,  3.25it/s, loss=4.86, v_num=0, val_loss=4.990]
Validating: 100%|██████████| 5/5 [00:

Saving latest checkpoint...
Traceback (most recent call last):
Traceback (most recent call last):
  File "/ext3/miniconda3/envs/dev/lib/python3.8/multiprocessing/queues.py", line 245, in _feed
    send_bytes(obj)
  File "/ext3/miniconda3/envs/dev/lib/python3.8/multiprocessing/connection.py", line 200, in send_bytes
    self._send_bytes(m[offset:offset + size])
  File "/ext3/miniconda3/envs/dev/lib/python3.8/multiprocessing/connection.py", line 411, in _send_bytes
    self._send(header + buf)
  File "/ext3/miniconda3/envs/dev/lib/python3.8/multiprocessing/connection.py", line 368, in _send
    n = write(self._handle, buf)
BrokenPipeError: [Errno 32] Broken pipe
  File "/ext3/miniconda3/envs/dev/lib/python3.8/multiprocessing/queues.py", line 245, in _feed
    send_bytes(obj)
  File "/ext3/miniconda3/envs/dev/lib/python3.8/multiprocessing/connection.py", line 200, in send_bytes
    self._send_bytes(m[offset:offset + size])
  File "/ext3/miniconda3/envs/dev/lib/python3.8/multiprocessing/co

Epoch 32:   2%|▏         | 32/2005 [00:11<11:58,  2.75it/s, loss=4.66, v_num=0, val_loss=4.670]




Profiler Report

Action                             	|  Mean duration (s)	|Num calls      	|  Total time (s) 	|  Percentage %   	|
------------------------------------------------------------------------------------------------------------------------------------
Total                              	|  -              	|_              	|  2.0069e+04     	|  100 %          	|
------------------------------------------------------------------------------------------------------------------------------------
run_training_epoch                 	|  607.24         	|33             	|  2.0039e+04     	|  99.852         	|
run_training_batch                 	|  0.30228        	|64033          	|  1.9356e+04     	|  96.449         	|
optimizer_step_and_closure_0       	|  0.30186        	|64033          	|  1.9329e+04     	|  96.316         	|
training_step_and_backward         	|  0.1091         	|64033          	|  6985.9         	|  34.81          	|
model_forward                      	|  0.

1

In [13]:
torch.save(simclr.model.encoder.state_dict(), os.path.join(checkpoint_dir, 'simclr_encoder.pth'))
torch.save(simclr.model.projector.state_dict(), os.path.join(checkpoint_dir, 'simclr_projector.pth'))

#### Supervised Learning by fine-tuning the SSL model on labeled data

In [29]:
from simclr.modules.identity import Identity

class ResNetClassifier(LightningModule):
    def __init__(self):
        super().__init__()
        self.encoder = resnet.get_custom_resnet18()
        self.encoder.fc = Identity()
        self.encoder.load_state_dict(torch.load(os.path.join(checkpoint_dir, 'simclr_encoder.pth')))
#         self.projector = nn.Sequential(
#             nn.Linear(512, 512, bias=False),
#             nn.ReLU(),
#             nn.Linear(512, 1024, bias=False),
#         )
#         self.projector.load_state_dict(torch.load(os.path.join(checkpoint_dir, 'simclr_projector.pth')))
#         self.lastLayer = torch.nn.Linear(1024, 800)

#         self.encoder = simclr.encoder
        self.encoder.fc = torch.nn.Linear(512, 800)
#         self.lastLayer = torch.nn.Linear(512, 800)
        
        self.criterion=torch.nn.CrossEntropyLoss()
        
    def forward(self, x):
        x = self.encoder(x)
#         x = self.lastLayer(x)
        return x
    
    def training_step(self, batch, batch_idx):
        data, label = batch
        classProbs = self.forward(data)
        loss = self.criterion(classProbs, label)
        self.log('train_loss', loss)
        return loss
    
    def _evaluate(self, batch, batch_idx, stage=None):
        x, y = batch
        out = self.forward(x)
        logits = F.log_softmax(out, dim=-1)
        loss = F.nll_loss(logits, y)
        preds = torch.argmax(logits, dim=-1)
        acc = accuracy(preds, y)

        if stage:
            self.log(f'{stage}_loss', loss, prog_bar=True)
            self.log(f'{stage}_acc', acc, prog_bar=True)

        return loss, acc
    
    def validation_step(self,batch,batch_idx):
        self._evaluate(batch, batch_idx, 'val')[0]
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=0.01)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=4, verbose=True)
        return {'optimizer': optimizer, 'lr_scheduler': scheduler, 'monitor': 'val_loss'}

In [30]:
classifier = ResNetClassifier()
# classifier.load_state_dict(torch.load(os.path.join(checkpoint_dir, 'classifier.pth')))

<All keys matched successfully>

In [31]:
EPOCHS = 50
classifier_trainer = Trainer(gpus=1,deterministic=True, max_epochs=EPOCHS, default_root_dir='/scratch/vvb238/classifier', profiler="simple",
                     limit_val_batches= 0.5, precision=16, benchmark=True, callbacks=[checkpoint_callback], fast_dev_run=False)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
Using native 16bit precision.


In [32]:
classifier_trainer.fit(classifier, train_dataloader=data.train_dataloader(), val_dataloaders=data.val_dataloader())

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type             | Params
-----------------------------------------------
0 | encoder   | ResNet           | 11.6 M
1 | criterion | CrossEntropyLoss | 0     
-----------------------------------------------
11.6 M    Trainable params
0         Non-trainable params
11.6 M    Total params
46.317    Total estimated model params size (MB)


Epoch 0:  67%|██████▋   | 801/1200 [00:22<00:11, 35.49it/s, loss=5.39, v_num=3, val_loss=4.750, val_acc=0.141]
Validating: 0it [00:00, ?it/s][A
Validating:   0%|          | 0/400 [00:00<?, ?it/s][A
Epoch 0:  67%|██████▋   | 805/1200 [00:22<00:11, 35.18it/s, loss=5.39, v_num=3, val_loss=4.750, val_acc=0.141]
Epoch 0:  68%|██████▊   | 819/1200 [00:22<00:10, 35.63it/s, loss=5.39, v_num=3, val_loss=4.750, val_acc=0.141]
Epoch 0:  70%|██████▉   | 834/1200 [00:23<00:10, 36.13it/s, loss=5.39, v_num=3, val_loss=4.750, val_acc=0.141]
Epoch 0:  71%|███████   | 849/1200 [00:23<00:09, 36.60it/s, loss=5.39, v_num=3, val_loss=4.750, val_acc=0.141]
Epoch 0:  72%|███████▏  | 864/1200 [00:23<00:09, 37.08it/s, loss=5.39, v_num=3, val_loss=4.750, val_acc=0.141]
Epoch 0:  73%|███████▎  | 880/1200 [00:23<00:08, 37.59it/s, loss=5.39, v_num=3, val_loss=4.750, val_acc=0.141]
Epoch 0:  75%|███████▍  | 896/1200 [00:23<00:07, 38.10it/s, loss=5.39, v_num=3, val_loss=4.750, val_acc=0.141]
Epoch 0:  76%|███████▌ 

Saving latest checkpoint...


Epoch 33:  44%|████▍     | 527/1200 [00:15<00:19, 34.92it/s, loss=1.77, v_num=3, val_loss=6.080, val_acc=0.133]



Profiler Report

Action                             	|  Mean duration (s)	|Num calls      	|  Total time (s) 	|  Percentage %   	|
------------------------------------------------------------------------------------------------------------------------------------
Total                              	|  -              	|_              	|  868.06         	|  100 %          	|
------------------------------------------------------------------------------------------------------------------------------------
run_training_epoch                 	|  25.491         	|34             	|  866.68         	|  99.841         	|
run_training_batch                 	|  0.025641       	|26928          	|  690.46         	|  79.541         	|
optimizer_step_and_closure_0       	|  0.024795       	|26928          	|  667.67         	|  76.915         	|
training_step_and_backward         	|  0.015064       	|26928          	|  405.65         	|  46.731         	|
model_backward                     	|  0.




1

In [28]:
torch.save(classifier.state_dict(), os.path.join(checkpoint_dir, 'classifier.pth'))

In [33]:
# net = ResNetClassifier()
# net.load_state_dict(torch.load(os.path.join(checkpoint_dir, 'classifier.pth')))
net = classifier

In [34]:
net = net.cuda()

net.eval()
correct = 0
total = 0
with torch.no_grad():
    for batch in data.val_dataloader():
        images, labels = batch

        images = images.cuda()
        labels = labels.cuda()

        outputs = net(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()


print(f"Accuracy: {(100 * correct / total):.2f}%")

Accuracy: 13.75%
