In [54]:
from torch import nn
import torch.nn.functional as F
import torchvision
import torch

from tqdm import tqdm

import pytorch_lightning as pl
from pytorch_lightning import Trainer, LightningModule

from PIL import Image
from simclr import SimCLR
from simclr.modules import NT_Xent, get_resnet
from simclr.modules.transformations import TransformsSimCLR
from simclr.modules.sync_batchnorm import convert_model


import os
import argparse
import sys

In [55]:
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, root, split, transform, limit=0):
        r"""
        Args:
            root: Location of the dataset folder, usually it is /dataset
            split: The split you want to used, it should be one of train, val or unlabeled.
            transform: the transform you want to applied to the images.
        """

        self.split = split
        self.transform = transform

        self.image_dir = os.path.join(root, split)
        label_path = os.path.join(root, f"{split}_label_tensor.pt")

        if limit == 0:
            self.num_images = len(os.listdir(self.image_dir))
        else:
            self.num_images = limit

        if os.path.exists(label_path):
            self.labels = torch.load(label_path)
        else:
            self.labels = -1 * torch.ones(self.num_images, dtype=torch.long)

    def __len__(self):
        return self.num_images

    def __getitem__(self, idx):
        with open(os.path.join(self.image_dir, f"{idx}.png"), 'rb') as f:
            img = Image.open(f).convert('RGB')

        return self.transform(img), self.labels[idx]

In [71]:
class ContrastiveLearning(LightningModule):
    def __init__(self):
        super().__init__()

#         self.hparams = args

        # initialize ResNet
        self.encoder = get_resnet("resnet18", pretrained=False)
        self.n_features = self.encoder.fc.in_features  # get dimensions of fc layer
        self.model = SimCLR(self.encoder, 512, self.n_features)
        self.criterion = NT_Xent(
            BATCH_SIZE, 0.5, world_size=1
        )

    def forward(self, x_i, x_j):
        h_i, h_j, z_i, z_j = self.model(x_i, x_j)
        loss = self.criterion(z_i, z_j)
        return loss

    def training_step(self, batch, batch_idx):
        # training_step defined the train loop. It is independent of forward
        (x_i, x_j), _ = batch
        loss = self.forward(x_i, x_j)
        return loss

    def configure_criterion(self):
        criterion = NT_Xent(BATCH_SIZE, 0.5)
        return criterion

    def configure_optimizers(self):
        scheduler = None
#       "Adam":
        optimizer = torch.optim.Adam(self.model.parameters(), lr=3e-4)
    
#       "LARS"
        # optimized using LARS with linear learning rate scaling
        # (i.e. LearningRate = 0.3 × BatchSize/256) and weight decay of 10−6.
        learning_rate = 0.3 * BATCH_SIZE / 256
        optimizer = LARS(
            self.model.parameters(),
            lr=learning_rate,
            weight_decay=0.000001,
            exclude_from_weight_decay=["batch_normalization", "bias"],
        )

        # "decay the learning rate with the cosine decay schedule without restarts"
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, EPOCHS, eta_min=0, last_epoch=-1
        )

        if scheduler:
            return {"optimizer": optimizer, "lr_scheduler": scheduler}
        else:
            return {"optimizer": optimizer}

In [72]:
EPOCHS = 10
BATCH_SIZE = 256

In [73]:
train_dataset = CustomDataset(root='/dataset', split='unlabeled', transform=TransformsSimCLR(96))
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)

In [74]:
simclr = ContrastiveLearning()

In [78]:
from pytorch_lightning.callbacks import ModelCheckpoint
# checkpoint_callback = ModelCheckpoint(monitor='val_loss')

trainer = Trainer(gpus=1,deterministic=True, max_epochs=EPOCHS, default_root_dir='/scratch/vvb238/vae', profiler="simple", fast_dev_run=False)
#                      callbacks=[checkpoint_callback], fast_dev_run=False)
trainer.sync_batchnorm=True

GPU available: True, used: True
TPU available: False, using: 0 TPU cores


In [None]:
trainer.fit(simclr, train_dataloader=train_dataloader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type    | Params
--------------------------------------
0 | encoder   | ResNet  | 11.2 M
1 | model     | SimCLR  | 11.7 M
2 | criterion | NT_Xent | 0     
--------------------------------------
11.7 M    Trainable params
0         Non-trainable params
11.7 M    Total params
46.803    Total estimated model params size (MB)


Epoch 0:   0%|          | 5/2000 [00:03<23:21,  1.42it/s, loss=6.22, v_num=0]  

In [36]:
from simclr.modules import LARS

encoder = torchvision.models.resnet18(pretrained=False)
criterion = NT_Xent(BATCH_SIZE, 0.5, 1)
model = SimCLR(encoder, 1024, 512)

# optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
# optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)

learning_rate = 0.3 * BATCH_SIZE / 256
optimizer = LARS(
    model.parameters(),
    lr=learning_rate,
    weight_decay=0.000001,
    exclude_from_weight_decay=["batch_normalization", "bias"],
)

# "decay the learning rate with the cosine decay schedule without restarts"
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer, 100, eta_min=0, last_epoch=-1
)

checkpoint_dir = "/scratch/vvb238/simclr"

p = checkpoint_dir+"/simclr_encoder_18.pth"
print(p)
check = os.path.exists(checkpoint_dir+"/simclr_encoder_18.pth")
print(check)

if os.path.exists(checkpoint_dir+"/simclr_encoder_18.pth"):
    print('Loading previous model')
    model.encoder.load_state_dict(torch.load(checkpoint_dir +'/simclr_encoder_18.pth'))
    model.projector.load_state_dict(torch.load(checkpoint_dir +'/simclr_projector_18.pth'))

model = model.cuda()

numOfBatches = len(train_dataset)/train_dataloader.batch_size

EPOCHS=10
for i in range(EPOCHS):
    model.train()
    print('Current Epoch: {}'.format(i))
    
    loss_epoch = 0
    for step, ((x_i, x_j), _) in tqdm(enumerate(train_dataloader), total=int(numOfBatches)):
        optimizer.zero_grad()
        x_i = x_i.cuda(non_blocking=True)
        x_j = x_j.cuda(non_blocking=True)

        # positive pair, with encoding
        h_i, h_j, z_i, z_j = model(x_i, x_j)

        loss = criterion(z_i, z_j)
        loss.backward()

        optimizer.step()
        if step%100==0:
            print('Step: {}, Train Loss: {}'.format(step, loss.item()))
#             os.makedirs(args.checkpoint_dir, exist_ok=True)
        loss_epoch += loss.item()
    
    
    
    
    
    
#     loss_epoch = train(train_dataloader, model, criterion, optimizer, args)
    torch.save(model.encoder.state_dict(), os.path.join(checkpoint_dir, 'simclr_encoder.pth'))
    torch.save(model.projector.state_dict(), os.path.join(checkpoint_dir, 'simclr_projector.pth'))
    avg_loss = loss_epoch/len(train_dataloader)
    
    print('Epoch: {}, Train Loss: {}'.format(i+1, avg_loss))