In [38]:
from torch import nn
import torch.nn.functional as F
import torchvision
import torch

from tqdm import tqdm

import pytorch_lightning as pl
from pytorch_lightning import Trainer, LightningModule
import torchvision.transforms as transforms

from PIL import Image
from simclr import SimCLR
from simclr.modules import NT_Xent, get_resnet
from simclr.modules.transformations import TransformsSimCLR
from simclr.modules.sync_batchnorm import convert_model
from simclr.modules import LARS

import resnet

import os
import argparse
import sys

In [7]:
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, root, split, transform, limit=0):
        r"""
        Args:
            root: Location of the dataset folder, usually it is /dataset
            split: The split you want to used, it should be one of train, val or unlabeled.
            transform: the transform you want to applied to the images.
        """

        self.split = split
        self.transform = transform

        self.image_dir = os.path.join(root, split)
        label_path = os.path.join(root, f"{split}_label_tensor.pt")

        if limit == 0:
            self.num_images = len(os.listdir(self.image_dir))
        else:
            self.num_images = limit

        if os.path.exists(label_path):
            self.labels = torch.load(label_path)
        else:
            self.labels = -1 * torch.ones(self.num_images, dtype=torch.long)

    def __len__(self):
        return self.num_images

    def __getitem__(self, idx):
        with open(os.path.join(self.image_dir, f"{idx}.png"), 'rb') as f:
            img = Image.open(f).convert('RGB')

        return self.transform(img), self.labels[idx]

In [18]:
class NYUImageNetDataModule(pl.LightningDataModule):
  
    def train_dataloader(self):
        train_transform = transforms.Compose([
            transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomVerticalFlip(p=0.5),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])

        trainset = CustomDataset(root='/dataset', split="train", transform=train_transform)
        train_loader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True, num_workers=4, pin_memory=True)
        return train_loader
    
    def val_dataloader(self):
        eval_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])
        evalset = CustomDataset(root='/dataset', split="val", transform=eval_transform)
        eval_loader = torch.utils.data.DataLoader(evalset, batch_size=64, shuffle=False, num_workers=4, pin_memory=True)
        return eval_loader
    
    def ssl_train_dataloader(self, batch_size):
        unlabeled_dataset = CustomDataset(root='/dataset', split='unlabeled', transform=TransformsSimCLR(96))
        unlabeled_dataloader = torch.utils.data.DataLoader(unlabeled_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
        return unlabeled_dataloader
        
    def ssl_val_dataloader(self, batch_size):
        val_dataset = CustomDataset(root='/dataset', split='val', transform=TransformsSimCLR(96))
        val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
        return val_dataloader

In [9]:
class ContrastiveLearning(LightningModule):
    def __init__(self):
        super().__init__()
        # initialize ResNet
        self.encoder = resnet.get_custom_resnet18()
#         get_resnet("resnet18", pretrained=False)
        self.n_features = self.encoder.fc.in_features  # get dimensions of fc layer
        self.model = SimCLR(self.encoder, 512, self.n_features)
        self.criterion = NT_Xent(
            BATCH_SIZE, 0.5, world_size=1
        )

    def forward(self, x_i, x_j):
        h_i, h_j, z_i, z_j = self.model(x_i, x_j)
        return z_i, z_j

    def training_step(self, batch, batch_idx):
        # training_step defined the train loop. It is independent of forward
        (x_i, x_j), _ = batch
        z_i, z_j = self.forward(x_i, x_j)
        loss = self.criterion(z_i, z_j)
        self.log('train_loss', loss)
        return loss
    
    def validation_step(self, batch, batch_idx):
        # training_step defined the train loop. It is independent of forward
        (x_i, x_j), _ = batch
        z_i, z_j = self.forward(x_i, x_j)
        loss = self.criterion(z_i, z_j)
        self.log('val_loss', loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        return { 'val_loss' : loss }

    def configure_criterion(self):
        criterion = NT_Xent(BATCH_SIZE, 0.5)
        return criterion

    def configure_optimizers(self):
        scheduler = None
#       "Adam":
        optimizer = torch.optim.Adam(self.model.parameters(), lr=3e-4)
    
#       "LARS"
        # optimized using LARS with linear learning rate scaling
        # (i.e. LearningRate = 0.3 × BatchSize/256) and weight decay of 10−6.
        learning_rate = 0.3 * BATCH_SIZE / 256
        optimizer = LARS(
            self.model.parameters(),
            lr=learning_rate,
            weight_decay=0.000001,
            exclude_from_weight_decay=["batch_normalization", "bias"],
        )

        # "decay the learning rate with the cosine decay schedule without restarts"
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, EPOCHS, eta_min=0, last_epoch=-1
        )

        if scheduler:
            return {"optimizer": optimizer, "lr_scheduler": scheduler}
        else:
            return {"optimizer": optimizer}

In [24]:
EPOCHS = 1
BATCH_SIZE = 256

In [12]:
# unlabeled_dataset = CustomDataset(root='/dataset', split='unlabeled', transform=TransformsSimCLR(96))
# unlabeled_dataloader = torch.utils.data.DataLoader(unlabeled_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True)

In [13]:
# val_dataset = CustomDataset(root='/dataset', split='val', transform=TransformsSimCLR(96))
# val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4, pin_memory=True)

In [19]:
data = NYUImageNetDataModule()

In [23]:
# simclr = ContrastiveLearning()
simclr = ContrastiveLearning.load_from_checkpoint('/scratch/vvb238/simclr/simclr.ckpt')

In [25]:
from pytorch_lightning.callbacks import ModelCheckpoint
checkpoint_callback = ModelCheckpoint(monitor='val_loss', save_last=True)

trainer = Trainer(gpus=1,deterministic=True, max_epochs=EPOCHS, default_root_dir='/scratch/vvb238/simclr', profiler="simple",
                     limit_val_batches= 5, precision=16, benchmark=True, callbacks=[checkpoint_callback], fast_dev_run=False)
trainer.sync_batchnorm=True

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
Using native 16bit precision.


In [26]:
trainer.fit(simclr, train_dataloader=data.ssl_train_dataloader(BATCH_SIZE), val_dataloaders=data.ssl_val_dataloader(BATCH_SIZE))

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type    | Params
--------------------------------------
0 | encoder   | ResNet  | 11.2 M
1 | model     | SimCLR  | 11.7 M
2 | criterion | NT_Xent | 0     
--------------------------------------
11.7 M    Trainable params
0         Non-trainable params
11.7 M    Total params
46.772    Total estimated model params size (MB)


Epoch 0:   0%|          | 1/2005 [00:32<18:17:51, 32.87s/it, loss=4.72, v_num=3, val_loss=4.740]

	add_(Number alpha, Tensor other)
Consider using one of the following signatures instead:
	add_(Tensor other, *, Number alpha) (Triggered internally at  /opt/conda/conda-bld/pytorch_1616554793803/work/torch/csrc/utils/python_arg_parser.cpp:1005.)
  next_v.mul_(momentum).add_(scaled_lr, grad)


Epoch 0: 100%|█████████▉| 2000/2005 [13:43<00:02,  2.43it/s, loss=4.76, v_num=3, val_loss=4.740]
Validating: 0it [00:00, ?it/s][A
Validating:   0%|          | 0/5 [00:00<?, ?it/s][A
Epoch 0: 100%|█████████▉| 2002/2005 [13:45<00:01,  2.43it/s, loss=4.76, v_num=3, val_loss=4.740]
Epoch 0: 100%|█████████▉| 2004/2005 [13:45<00:00,  2.43it/s, loss=4.76, v_num=3, val_loss=4.740]
Epoch 0: 100%|██████████| 2005/2005 [13:46<00:00,  2.42it/s, loss=4.76, v_num=3, val_loss=4.860]
                                                         [A

Saving latest checkpoint...


Epoch 0: 100%|██████████| 2005/2005 [13:47<00:00,  2.42it/s, loss=4.76, v_num=3, val_loss=4.860]



Profiler Report

Action                             	|  Mean duration (s)	|Num calls      	|  Total time (s) 	|  Percentage %   	|
------------------------------------------------------------------------------------------------------------------------------------
Total                              	|  -              	|_              	|  838.48         	|  100 %          	|
------------------------------------------------------------------------------------------------------------------------------------
run_training_epoch                 	|  827.2          	|1              	|  827.2          	|  98.655         	|
run_training_batch                 	|  0.29712        	|2000           	|  594.24         	|  70.871         	|
optimizer_step_and_closure_0       	|  0.29681        	|2000           	|  593.63         	|  70.798         	|
get_train_batch                    	|  0.10879        	|2000           	|  217.58         	|  25.95          	|
training_step_and_backward         	|  0.




1

In [27]:
trainer.save_checkpoint("/scratch/vvb238/simclr/simclr.ckpt")

In [31]:
checkpoint_dir = "/scratch/vvb238/simclr"
torch.save(simclr.model.encoder.state_dict(), os.path.join(checkpoint_dir, 'simclr_encoder.pth'))
torch.save(simclr.model.projector.state_dict(), os.path.join(checkpoint_dir, 'simclr_projector.pth'))

In [None]:
# Fine-tuning on labeled data

In [48]:
class ResNetClassifier(LightningModule):
    def __init__(self):
        super().__init__()
#         self.encoder = resnet.get_custom_resnet18()
#         self.encoder.load_state_dict(torch.load(os.path.join(checkpoint_dir, 'simclr_encoder.pth')))
        self.encoder = simclr.model.encoder
        self.lastLayer = torch.nn.Linear(512, 800)
        self.criterion=torch.nn.CrossEntropyLoss()
        
    def forward(self, x):
        x = self.encoder(x)
        x = self.lastLayer(x)
        return x
    
    def training_step(self, batch, batch_idx):
        data, label = batch
        classProbs = self.forward(data)
        loss = self.criterion(classProbs, label)
        self.log('train_loss', loss)
        return loss
    
    def validation_step(self,batch,batch_idx):
        data, label = batch
        classProbs = self.forward(data)
        loss = self.criterion(classProbs, label)
        self.log('val_loss', loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        return { 'val_loss' : loss, 'prediction' : classProbs, 'target' : label }
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters())
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=5)
        return ({'optimizer': optimizer, 'lr_scheduler': scheduler, 'monitor': 'val_loss'})

In [49]:
classifier = ResNetClassifier()

In [57]:
EPOCHS = 60
trainer = Trainer(gpus=1,deterministic=True, max_epochs=EPOCHS, default_root_dir='/scratch/vvb238/classifier', profiler="simple",
                     limit_val_batches= 0.75, precision=16, benchmark=True, callbacks=[checkpoint_callback], fast_dev_run=False)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
Using native 16bit precision.


In [58]:
trainer.fit(classifier, train_dataloader=data.train_dataloader(), val_dataloaders=data.val_dataloader())

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type             | Params
-----------------------------------------------
0 | encoder   | ResNet           | 11.2 M
1 | lastLayer | Linear           | 410 K 
2 | criterion | CrossEntropyLoss | 0     
-----------------------------------------------
11.6 M    Trainable params
0         Non-trainable params
11.6 M    Total params
46.317    Total estimated model params size (MB)


Epoch 0:  57%|█████▋    | 400/700 [00:17<00:13, 22.67it/s, loss=5.42, v_num=3, val_loss=5.570]
Validating: 0it [00:00, ?it/s][A
Validating:   0%|          | 0/300 [00:00<?, ?it/s][A
Epoch 0:  58%|█████▊    | 403/700 [00:18<00:13, 22.37it/s, loss=5.42, v_num=3, val_loss=5.570]
Epoch 0:  59%|█████▊    | 411/700 [00:18<00:12, 22.67it/s, loss=5.42, v_num=3, val_loss=5.570]
Epoch 0:  60%|█████▉    | 419/700 [00:18<00:12, 22.98it/s, loss=5.42, v_num=3, val_loss=5.570]
Epoch 0:  61%|██████    | 427/700 [00:18<00:11, 23.25it/s, loss=5.42, v_num=3, val_loss=5.570]
Epoch 0:  62%|██████▏   | 436/700 [00:18<00:11, 23.61it/s, loss=5.42, v_num=3, val_loss=5.570]
Epoch 0:  64%|██████▎   | 445/700 [00:18<00:10, 23.92it/s, loss=5.42, v_num=3, val_loss=5.570]
Epoch 0:  65%|██████▍   | 454/700 [00:18<00:10, 24.24it/s, loss=5.42, v_num=3, val_loss=5.570]
Epoch 0:  66%|██████▌   | 463/700 [00:18<00:09, 24.55it/s, loss=5.42, v_num=3, val_loss=5.570]
Epoch 0:  67%|██████▋   | 472/700 [00:18<00:09, 24.85it/

Saving latest checkpoint...


Epoch 59: 100%|██████████| 700/700 [00:22<00:00, 30.85it/s, loss=1.53, v_num=3, val_loss=4.900]



Profiler Report

Action                             	|  Mean duration (s)	|Num calls      	|  Total time (s) 	|  Percentage %   	|
------------------------------------------------------------------------------------------------------------------------------------
Total                              	|  -              	|_              	|  1375.7         	|  100 %          	|
------------------------------------------------------------------------------------------------------------------------------------
run_training_epoch                 	|  22.907         	|60             	|  1374.4         	|  99.91          	|
run_training_batch                 	|  0.040919       	|24000          	|  982.05         	|  71.388         	|
optimizer_step_and_closure_0       	|  0.040065       	|24000          	|  961.57         	|  69.899         	|
training_step_and_backward         	|  0.012763       	|24000          	|  306.3          	|  22.266         	|
model_backward                     	|  0.




1

In [59]:
torch.save(classifier.state_dict(), os.path.join(checkpoint_dir, 'classifier.pth'))

In [60]:
net = ResNetClassifier()
net.load_state_dict(torch.load(os.path.join(checkpoint_dir, 'classifier.pth')))

<All keys matched successfully>

In [61]:
net = net.cuda()

net.eval()
correct = 0
total = 0
with torch.no_grad():
    for batch in data.val_dataloader():
        images, labels = batch

        images = images.cuda()
        labels = labels.cuda()

        outputs = net(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()


print(f"Accuracy: {(100 * correct / total):.2f}%")

Accuracy: 17.45%
