In [1]:
from torch import nn
import torch.nn.functional as F
import torchvision
import torch

from tqdm import tqdm

import pytorch_lightning as pl
from pytorch_lightning import Trainer, LightningModule
import torchvision.transforms as transforms

from PIL import Image

# from simclr import SimCLR
# from simclr.modules import NT_Xent, get_resnet
# from simclr.modules.transformations import TransformsSimCLR
# from simclr.modules import LARS
from byol_pytorch import BYOL
# from byol.modules.sync_batchnorm import convert_model

from torch import optim
import resnet

import os
import argparse
import sys

In [2]:
# !pip install byol-pytorch

In [3]:
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, root, split, transform, limit=0):
        r"""
        Args:
            root: Location of the dataset folder, usually it is /dataset
            split: The split you want to used, it should be one of train, val or unlabeled.
            transform: the transform you want to applied to the images.
        """

        self.split = split
        self.transform = transform

        self.image_dir = os.path.join(root, split)
        label_path = os.path.join(root, f"{split}_label_tensor.pt")

        if limit == 0:
            self.num_images = len(os.listdir(self.image_dir))
        else:
            self.num_images = limit

        if os.path.exists(label_path):
            self.labels = torch.load(label_path)
        else:
            self.labels = -1 * torch.ones(self.num_images, dtype=torch.long)

    def __len__(self):
        return self.num_images

    def __getitem__(self, idx):
        with open(os.path.join(self.image_dir, f"{idx}.png"), 'rb') as f:
            img = Image.open(f).convert('RGB')

        return self.transform(img), self.labels[idx]

In [4]:
def expand_greyscale(t):
    return t.expand(3, -1, -1)

In [5]:
class NYUImageNetDataModule(pl.LightningDataModule):
  
    def train_dataloader(self):
        train_transform = transforms.Compose([
            transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomVerticalFlip(p=0.5),
            transforms.ToTensor(),
            transforms.Lambda(expand_greyscale),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])

        trainset = CustomDataset(root='/dataset', split="train", transform=train_transform)
        train_loader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True, num_workers=4, pin_memory=True)
        return train_loader
    
    def val_dataloader(self):
        eval_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])
        evalset = CustomDataset(root='/dataset', split="val", transform=eval_transform)
        eval_loader = torch.utils.data.DataLoader(evalset, batch_size=64, shuffle=False, num_workers=4, pin_memory=True)
        return eval_loader
    
    def ssl_train_dataloader(self, batch_size):
        ssl_train_transform = transforms.Compose([
#             transforms.Resize((96,96)),
            transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomVerticalFlip(p=0.5),
            transforms.ToTensor(),
            transforms.Lambda(expand_greyscale),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])
        unlabeled_dataset = CustomDataset(root='/dataset', split='unlabeled', transform=ssl_train_transform)
        unlabeled_dataloader = torch.utils.data.DataLoader(unlabeled_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
        return unlabeled_dataloader
        
    def ssl_val_dataloader(self, batch_size):
        ssl_eval_transform = transforms.Compose([
            transforms.Resize((96,96)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])
        val_dataset = CustomDataset(root='/dataset', split='val', transform=ssl_eval_transform)
        val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
        return val_dataloader

In [6]:
# from torch.utils.data import DataLoader, Dataset
# def expand_greyscale(t):
#     return t.expand(3, -1, -1)

# class ImagesDataset(Dataset):
#     def __init__(self, folder, image_size):
#         super().__init__()
#         self.folder = folder
#         self.paths = []

#         for path in Path(f'{folder}').glob('**/*'):
#             _, ext = os.path.splitext(path)
#             if ext.lower() in IMAGE_EXTS:
#                 self.paths.append(path)

#         print(f'{len(self.paths)} images found')

#         self.transform = transforms.Compose([
#             transforms.Resize(image_size),
#             transforms.CenterCrop(image_size),
#             transforms.ToTensor(),
#             transforms.Lambda(expand_greyscale)
#         ])

#     def __len__(self):
#         return len(self.paths)

#     def __getitem__(self, index):
#         path = self.paths[index]
#         img = Image.open(path)
#         img = img.convert('RGB')
#         return self.transform(img)

In [7]:
class ContrastiveLearning(LightningModule):
#     def __init__(self, net, **kwargs):
#         super().__init__()
#         self.net = net
#         self.save_hyperparameters()
#         self.learner = BYOL(self.net , **kwargs)
    def __init__(self):
        super().__init__()
        self.net = resnet.get_custom_resnet18()
        self.learner = BYOL(self.net, image_size = 96,
                            hidden_layer = 'avgpool',
                            projection_size = 256,
                            projection_hidden_size = 4096,
                            moving_average_decay = 0.99)

    def forward(self, images):
#         print("forward", images)
        return self.learner(images[0])
    
    def training_step(self, images, _):
        print("loss")
        loss = self.forward(images)
        return {'loss': loss}
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=3e-4)

#     def on_before_zero_grad(self, _):
#         if self.learner.use_momentum:
#             self.learner.update_moving_average()
    
#     def validation_step(self, images, _):
#         loss = self.forward(images)
#         return {'loss': loss}

In [8]:
EPOCHS = 5
BATCH_SIZE = 256

In [9]:
# unlabeled_dataset = CustomDataset(root='/dataset', split='unlabeled', transform=TransformsSimCLR(96))
# unlabeled_dataloader = torch.utils.data.DataLoader(unlabeled_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True)

In [10]:
# val_dataset = CustomDataset(root='/dataset', split='val', transform=TransformsSimCLR(96))
# val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4, pin_memory=True)

In [11]:
data = NYUImageNetDataModule()

In [12]:
# encoder = resnet.get_custom_resnet18()
# byol = ContrastiveLearning(
#     net = encoder,
#     image_size = 96,
#     hidden_layer = 'avgpool',
#     projection_size = 256,
#     projection_hidden_size = 4096,
#     moving_average_decay = 0.99
#         )
# byol = ContrastiveLearning.load_from_checkpoint("/scratch/nr2229/byol/byolv1.ckpt")
byol = ContrastiveLearning()

In [13]:
from pytorch_lightning.callbacks import ModelCheckpoint
checkpoint_callback = ModelCheckpoint(monitor='val_loss', save_last=True)

trainer = Trainer(
        gpus = 1,
        deterministic=True,
        max_epochs = EPOCHS,
        accumulate_grad_batches = 1,
        default_root_dir='/scratch/nr2229/byol', 
        profiler="simple",
        limit_val_batches= 5, 
        benchmark=True,
        callbacks=[checkpoint_callback],
        fast_dev_run=False,
        sync_batchnorm = True
    )
# trainer = Trainer(gpus=1,deterministic=True, max_epochs=EPOCHS, default_root_dir='/scratch/nr2229/byol', profiler="simple",
#                      limit_val_batches= 5, precision=16, benchmark=True, callbacks=[checkpoint_callback], fast_dev_run=False)
# trainer.sync_batchnorm=True

GPU available: True, used: True
TPU available: None, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


In [14]:
trainer.fit(byol, train_dataloader=data.ssl_train_dataloader(BATCH_SIZE))


  | Name    | Type   | Params
-----------------------------------
0 | net     | ResNet | 11.6 M
1 | learner | BYOL   | 31.6 M
-----------------------------------
31.6 M    Trainable params
0         Non-trainable params
31.6 M    Total params


Epoch 0:   0%|          | 0/2000 [00:00<?, ?it/s] loss
Epoch 0:   0%|          | 1/2000 [00:05<2:55:14,  5.26s/it, loss=3.96, v_num=35]loss
Epoch 0:   0%|          | 2/2000 [00:05<1:39:14,  2.98s/it, loss=3.08, v_num=35]loss
Epoch 0:   0%|          | 3/2000 [00:06<1:13:59,  2.22s/it, loss=2.65, v_num=35]loss
Epoch 0:   0%|          | 4/2000 [00:07<1:01:08,  1.84s/it, loss=2.37, v_num=35]loss
Epoch 0:   0%|          | 5/2000 [00:08<53:29,  1.61s/it, loss=2.2, v_num=35]   loss
Epoch 0:   0%|          | 6/2000 [00:08<48:23,  1.46s/it, loss=2.09, v_num=35]loss
Epoch 0:   0%|          | 7/2000 [00:09<44:51,  1.35s/it, loss=1.99, v_num=35]loss
Epoch 0:   0%|          | 8/2000 [00:10<42:08,  1.27s/it, loss=1.92, v_num=35]loss
Epoch 0:   0%|          | 9/2000 [00:10<39:55,  1.20s/it, loss=1.86, v_num=35]loss
Epoch 0:   0%|          | 10/2000 [00:11<38:10,  1.15s/it, loss=1.81, v_num=35]loss
Epoch 0:   1%|          | 11/2000 [00:12<36:48,  1.11s/it, loss=1.77, v_num=35]loss
Epoch 0:   1%|      

Saving latest checkpoint...


Epoch 4: 100%|██████████| 2000/2000 [23:34<00:00,  1.41it/s, loss=1.34, v_num=35]



Profiler Report

Action                      	|  Mean duration (s)	|Num calls      	|  Total time (s) 	|  Percentage %   	|
-----------------------------------------------------------------------------------------------------------------------------
Total                       	|  -              	|_              	|  7077.5         	|  100 %          	|
-----------------------------------------------------------------------------------------------------------------------------
run_training_epoch          	|  1414.7         	|5              	|  7073.3         	|  99.939         	|
run_training_batch          	|  0.70454        	|10000          	|  7045.4         	|  99.546         	|
optimizer_step_and_closure_0	|  0.4072         	|10000          	|  4072.0         	|  57.534         	|
training_step_and_backward  	|  0.15382        	|10000          	|  1538.2         	|  21.734         	|
model_backward              	|  0.087565       	|10000          	|  875.65         	|  12.372    




1

In [15]:
trainer.save_checkpoint("/scratch/nr2229/byol/byol.ckpt")

In [16]:
# byol.net

In [17]:
# byol_Load = ContrastiveLearning.load_from_checkpoint("/scratch/nr2229/byol/byolv1.ckpt")

In [19]:
# byol.learner.state_dict().keys()
checkpoint_dir = "/scratch/nr2229/byol"
torch.save(byol.net.state_dict(), os.path.join(checkpoint_dir,'state_dict.pth'))

In [20]:

# torch.save(byol.learner.online_encoder.state_dict(), os.path.join(checkpoint_dir, 'byol_online_encoder.pth'))
# torch.save(byol.learner.target_encoder.state_dict(), os.path.join(checkpoint_dir, 'byol_target_encoder.pth'))

In [21]:
# Fine-tuning on labeled data

In [24]:
class ResNetClassifier(LightningModule):
    def __init__(self):
        super().__init__()
#         self.encoder = resnet.get_custom_resnet18()
#         self.encoder.load_state_dict(torch.load(os.path.join(checkpoint_dir, 'simclr_encoder.pth')))
#         self.online_encoder = byol.learner.online_encoder
#         self.target_encoder = byol.learner.target_encoder
        states = byol.net.state_dict()
        self.encoder = resnet.get_custom_resnet18()
        self.encoder.load_state_dict(states)
        self.lastLayer = torch.nn.Linear(800, 800)
        self.criterion=torch.nn.CrossEntropyLoss()
        
    def forward(self, x):
#         print("x:",x)
        x = self.encoder(x)
#         print("shape1",x.shape)
        x = self.lastLayer(x)
#         print("shape2",x.shape)
        return x
    
    def training_step(self, batch, batch_idx):
        data, label = batch
        classProbs = self.forward(data)
        loss = self.criterion(classProbs, label)
        self.log('train_loss', loss)
        return loss
    
    def validation_step(self,batch,batch_idx):
        data, label = batch
        classProbs = self.forward(data)
        loss = self.criterion(classProbs, label)
        self.log('val_loss', loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        return { 'val_loss' : loss, 'prediction' : classProbs, 'target' : label }
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters())
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=5)
        return ({'optimizer': optimizer, 'lr_scheduler': scheduler, 'monitor': 'val_loss'})

In [25]:
classifier = ResNetClassifier()

In [26]:
EPOCHS = 60
trainer = Trainer(gpus=1,deterministic=True, max_epochs=EPOCHS, default_root_dir='/scratch/nr2229/classifier_byol', profiler="simple",
                     limit_val_batches= 0.75, benchmark=True, callbacks=[checkpoint_callback], fast_dev_run=False)

GPU available: True, used: True
TPU available: None, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


In [27]:
trainer.fit(classifier, train_dataloader=data.train_dataloader(), val_dataloaders=data.val_dataloader())


  | Name      | Type             | Params
-----------------------------------------------
0 | encoder   | ResNet           | 11.6 M
1 | lastLayer | Linear           | 640 K 
2 | criterion | CrossEntropyLoss | 0     
-----------------------------------------------
12.2 M    Trainable params
0         Non-trainable params
12.2 M    Total params


Epoch 0:  57%|█████▋    | 401/700 [00:31<00:23, 12.55it/s, loss=6.6, v_num=15, val_loss=6.8] 
Validating: 0it [00:00, ?it/s][A
Validating:   0%|          | 0/300 [00:00<?, ?it/s][A
Epoch 0:  58%|█████▊    | 403/700 [00:32<00:23, 12.48it/s, loss=6.6, v_num=15, val_loss=6.8]
Epoch 0:  59%|█████▉    | 413/700 [00:32<00:22, 12.66it/s, loss=6.6, v_num=15, val_loss=6.8]
Validating:   5%|▌         | 16/300 [00:01<00:17, 15.87it/s][A
Epoch 0:  60%|██████    | 423/700 [00:33<00:21, 12.74it/s, loss=6.6, v_num=15, val_loss=6.8]
Validating:   8%|▊         | 24/300 [00:01<00:16, 16.53it/s][A
Validating:   9%|▉         | 28/300 [00:01<00:15, 17.17it/s][A
Epoch 0:  62%|██████▏   | 433/700 [00:33<00:20, 12.76it/s, loss=6.6, v_num=15, val_loss=6.8]
Validating:  12%|█▏        | 36/300 [00:02<00:15, 17.06it/s][A
Epoch 0:  63%|██████▎   | 443/700 [00:34<00:19, 12.91it/s, loss=6.6, v_num=15, val_loss=6.8]
Validating:  16%|█▌        | 47/300 [00:02<00:09, 28.04it/s][A
Epoch 0:  65%|██████▍   | 453/70

Saving latest checkpoint...


Epoch 59: 100%|██████████| 700/700 [00:35<00:00, 19.57it/s, loss=1.45, v_num=15, val_loss=5.77]



Profiler Report

Action                      	|  Mean duration (s)	|Num calls      	|  Total time (s) 	|  Percentage %   	|
-----------------------------------------------------------------------------------------------------------------------------
Total                       	|  -              	|_              	|  2151.2         	|  100 %          	|
-----------------------------------------------------------------------------------------------------------------------------
run_training_epoch          	|  35.83          	|60             	|  2149.8         	|  99.938         	|
run_training_batch          	|  0.070247       	|24000          	|  1685.9         	|  78.374         	|
optimizer_step_and_closure_0	|  0.012788       	|24000          	|  306.91         	|  14.267         	|
evaluation_step_and_end     	|  0.016425       	|18002          	|  295.68         	|  13.745         	|
training_step_and_backward  	|  0.0089752      	|24000          	|  215.41         	|  10.013    




1

In [28]:
torch.save(classifier.state_dict(), os.path.join(checkpoint_dir, 'classifier.pth'))

In [29]:
net = ResNetClassifier()
net.load_state_dict(torch.load(os.path.join(checkpoint_dir, 'classifier.pth')))

<All keys matched successfully>

In [30]:
net = net.cuda()

net.eval()
correct = 0
total = 0
with torch.no_grad():
    for batch in data.val_dataloader():
        images, labels = batch

        images = images.cuda()
        labels = labels.cuda()

        outputs = net(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()


print(f"Accuracy: {(100 * correct / total):.2f}%")

Accuracy: 16.73%
