In [1]:
from pathlib import Path
import argparse
import json
import math
import os
import random
import signal
import subprocess
import sys
import time

from PIL import Image, ImageOps, ImageFilter
from torch import nn, optim
import torch
import torchvision
import torchvision.transforms as transforms

import lightly
import lightly.models as models
import lightly.loss as loss
import lightly.data as data
from lightly.models.barlowtwins import BarlowTwins
from lightly.models.simclr import SimCLR

from simclr.modules.transformations import TransformsSimCLR
from PIL import Image, ImageOps, ImageFilter

import pytorch_lightning as pl
from pytorch_lightning import Trainer, LightningModule

import resnet



In [2]:
checkpointDir = 'barlow-custom34-1000'

In [3]:
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, root, split, transform, limit=0):
        r"""
        Args:
            root: Location of the dataset folder, usually it is /dataset
            split: The split you want to used, it should be one of train, val or unlabeled.
            transform: the transform you want to applied to the images.
        """
        self.split = split
        self.transform = transform

        self.image_dir = os.path.join(root, split)
        
        label_path = os.path.join(root, f"{split}_label_tensor.pt")
        if limit == 0:
            self.num_images = len(os.listdir(self.image_dir))
        else:
            self.num_images = limit

        if os.path.exists(label_path):
            self.labels = torch.load(label_path)
        else:
            self.labels = -1 * torch.ones(self.num_images, dtype=torch.long)
            
            
        if self.split == "unlabeled":
            label_path = os.path.join("label_15.pt")
            if os.path.exists(label_path):
                labels = torch.load(label_path)

            images = []
            f = open("requests.txt", "r")
            s = str(f.read()).split("\n")
            for img in s:
                images.append(int(img.replace(".png,","")))
                
            self.imageLabelDict = { images[i]: labels[i]  for i in range(len(images))} 

    def __len__(self):
        return self.num_images

    def __getitem__(self, idx):
        with open(os.path.join(self.image_dir, f"{idx}.png"), 'rb') as f:
            img = Image.open(f).convert('RGB')

        if self.split == "unlabeled" and idx in self.imageLabelDict:
            return self.transform(img), self.imageLabelDict[idx]            
        else:
            return self.transform(img), self.labels[idx]

In [4]:
class GaussianBlur(object):
    def __init__(self, p):
        self.p = p

    def __call__(self, img):
        if random.random() < self.p:
            sigma = random.random() * 1.9 + 0.1
            return img.filter(ImageFilter.GaussianBlur(sigma))
        else:
            return img


class Solarization(object):
    def __init__(self, p):
        self.p = p

    def __call__(self, img):
        if random.random() < self.p:
            return ImageOps.solarize(img)
        else:
            return img

In [5]:
class Transform:
    def __init__(self):
        self.transform = transforms.Compose([
            transforms.RandomResizedCrop(96, interpolation=Image.BICUBIC),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomApply(
                [transforms.ColorJitter(brightness=0.4, contrast=0.4,
                                        saturation=0.2, hue=0.1)],
                p=0.8
            ),
            transforms.RandomGrayscale(p=0.2),
            GaussianBlur(p=1.0),
            Solarization(p=0.0),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])
        self.transform_prime = transforms.Compose([
            transforms.RandomResizedCrop(96, interpolation=Image.BICUBIC),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomApply(
                [transforms.ColorJitter(brightness=0.4, contrast=0.4,
                                        saturation=0.2, hue=0.1)],
                p=0.8
            ),
            transforms.RandomGrayscale(p=0.2),
            GaussianBlur(p=0.1),
            Solarization(p=0.2),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])

    def __call__(self, x):
        y1 = self.transform(x)
        y2 = self.transform_prime(x)
        return y1, y2

In [6]:
class LARS(optim.Optimizer):
    def __init__(self, params, lr, weight_decay=0, momentum=0.9, eta=0.001,
                 weight_decay_filter=None, lars_adaptation_filter=None):
        defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum,
                        eta=eta, weight_decay_filter=weight_decay_filter,
                        lars_adaptation_filter=lars_adaptation_filter)
        super().__init__(params, defaults)

    @torch.no_grad()
    def step(self):
        for g in self.param_groups:
            for p in g['params']:
                dp = p.grad

                if dp is None:
                    continue

                if g['weight_decay_filter'] is None or not g['weight_decay_filter'](p):
                    dp = dp.add(p, alpha=g['weight_decay'])

                if g['lars_adaptation_filter'] is None or not g['lars_adaptation_filter'](p):
                    param_norm = torch.norm(p)
                    update_norm = torch.norm(dp)
                    one = torch.ones_like(param_norm)
                    q = torch.where(param_norm > 0.,
                                    torch.where(update_norm > 0,
                                                (g['eta'] * param_norm / update_norm), one), one)
                    dp = dp.mul(q)

                param_state = self.state[p]
                if 'mu' not in param_state:
                    param_state['mu'] = torch.zeros_like(p)
                mu = param_state['mu']
                mu.mul_(g['momentum']).add_(dp)

                p.add_(mu, alpha=-g['lr'])


def exclude_bias_and_norm(p):
    return p.ndim == 1

In [7]:
# create a dataset from your image folder
dataset = CustomDataset(root='/dataset', split='unlabeled', transform=Transform())

# build a PyTorch dataloader
loader = torch.utils.data.DataLoader(dataset, batch_size=1024, shuffle=True, pin_memory=True, num_workers=4)

In [8]:
class BarlowTwins(nn.Module):
    def __init__(self):
        super().__init__()
#         self.backbone = torchvision.models.resnet34(zero_init_residual=True)
        self.backbone = resnet.get_custom_resnet34()
        self.backbone.fc = nn.Identity()

        # projector
        sizes = [512] + list(map(int, '1024-1024-1024'.split('-')))
        layers = []
        for i in range(len(sizes) - 2):
            layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=False))
            layers.append(nn.BatchNorm1d(sizes[i + 1]))
            layers.append(nn.ReLU(inplace=True))
        layers.append(nn.Linear(sizes[-2], sizes[-1], bias=False))
        self.projector = nn.Sequential(*layers)

        # normalization layer for the representations z1 and z2
        self.bn = nn.BatchNorm1d(sizes[-1], affine=False)

    def forward(self, y1, y2):
        z1 = self.projector(self.backbone(y1))
        z2 = self.projector(self.backbone(y2))

        # empirical cross-correlation matrix
        c = self.bn(z1).T @ self.bn(z2)

        # sum the cross-correlation matrix between all gpus
        c.div_(1024)
#         torch.distributed.all_reduce(c)

        # use --scale-loss to multiply the loss by a constant factor
        # see the Issues section of the readme
        on_diag = torch.diagonal(c).add_(-1).pow_(2).sum().mul(1/32)
        off_diag = off_diagonal(c).pow_(2).sum().mul(1/32)
        loss = on_diag + 3.9e-3 * off_diag
        return loss

In [9]:
def off_diagonal(x):
    # return a flattened view of the off-diagonal elements of a square matrix
    n, m = x.shape
    assert n == m
    return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten()

In [10]:
def adjust_learning_rate(optimizer, loader, step):
    max_steps = 300 * len(loader)
    warmup_steps = 10 * len(loader)
    base_lr = 0.2 * 1024 / 256
    if step < warmup_steps:
        lr = base_lr * step / warmup_steps
    else:
        step -= warmup_steps
        max_steps -= warmup_steps
        q = 0.5 * (1 + math.cos(math.pi * step / max_steps))
        end_lr = base_lr * 0.001
        lr = base_lr * q + end_lr * (1 - q)
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
    return lr

In [11]:
torch.backends.cudnn.benchmark = True

model = BarlowTwins().cuda()
# model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
optimizer = LARS(model.parameters(), lr=0, weight_decay=1e-6,
                 weight_decay_filter=exclude_bias_and_norm,
                 lars_adaptation_filter=exclude_bias_and_norm)

In [12]:
# # automatically resume from checkpoint if it exists
# if os.path.isfile('/scratch/vvb238/' + checkpointDir + '/checkpoint.pth'):
#     ckpt = torch.load('/scratch/vvb238/' + checkpointDir + '/checkpoint.pth',
#                       map_location='cpu')
#     start_epoch = ckpt['epoch']
#     model.load_state_dict(ckpt['model'])
#     optimizer.load_state_dict(ckpt['optimizer'])
# else:
#     start_epoch = 0

In [13]:
# start_time = time.time()
# scaler = torch.cuda.amp.GradScaler()
# for epoch in range(start_epoch, 300):
# #     sampler.set_epoch(epoch)
#     for step, ((y1, y2), _) in enumerate(loader, start=epoch * len(loader)):
#         y1 = y1.cuda()
#         y2 = y2.cuda()
#         lr = adjust_learning_rate(optimizer, loader, step)
#         optimizer.zero_grad()
#         with torch.cuda.amp.autocast():
#             loss = model.forward(y1, y2)
#         scaler.scale(loss).backward()
#         scaler.step(optimizer)
#         scaler.update()
#         if step % 10 == 0:
# #             torch.distributed.reduce(loss.div_(args.world_size), 0)
# #             if args.rank == 0:
#             stats = dict(epoch=epoch, step=step, learning_rate=lr,
#                          loss=loss.item(),
#                          time=int(time.time() - start_time))
#             print(json.dumps(stats))
# #                 print(json.dumps(stats), file=stats_file)
#         # save checkpoint
#     state = dict(epoch=epoch + 1, model=model.state_dict(),
#                  optimizer=optimizer.state_dict())
#     torch.save(state, '/scratch/vvb238/' + checkpointDir + '/checkpoint.pth')
    
# torch.save(model.backbone.state_dict(),
#            '/scratch/vvb238/' + checkpointDir + '/resnet50.pth')

In [14]:
# THIS PART TAKES THE MODEL IN THE MIDDLE AND USES IT

In [15]:
class GaussianBlur(object):
    def __init__(self, p):
        self.p = p

    def __call__(self, img):
        if random.random() < self.p:
            sigma = random.random() * 1.9 + 0.1
            return img.filter(ImageFilter.GaussianBlur(sigma))
        else:
            return img


class Solarization(object):
    def __init__(self, p):
        self.p = p

    def __call__(self, img):
        if random.random() < self.p:
            return ImageOps.solarize(img)
        else:
            return img

In [16]:
class NYUImageNetDataModule(pl.LightningDataModule):
  
    def train_dataloader(self):
        train_transform = transforms.Compose([
            transforms.RandomResizedCrop(96, interpolation=Image.BICUBIC),
            transforms.RandomHorizontalFlip(p=0.5),
#             transforms.RandomVerticalFlip(p=0.5),
            transforms.RandomApply(
                [transforms.ColorJitter(brightness=0.4, contrast=0.4,
                                        saturation=0.2, hue=0.1)],
                p=0.1
            ),
            transforms.RandomGrayscale(p=0.1),
            GaussianBlur(p=0.2),
            Solarization(p=0.0),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])

        trainset = CustomDataset(root='/dataset', split="train", transform=train_transform)
        train_loader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=4, pin_memory=True)
        return train_loader
    
    def added_train_loader(self):
        train_transform = transforms.Compose([
            transforms.RandomResizedCrop(96, interpolation=Image.BICUBIC),
            transforms.RandomHorizontalFlip(p=0.5),
#             transforms.RandomVerticalFlip(p=0.5),
            transforms.RandomApply(
                [transforms.ColorJitter(brightness=0.4, contrast=0.4,
                                        saturation=0.2, hue=0.1)],
                p=0.1
            ),
            transforms.RandomGrayscale(p=0.1),
            GaussianBlur(p=0.2),
            Solarization(p=0.0),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])

        unlabeledset = CustomDataset(root='/dataset', split="unlabeled", transform=train_transform)
        unlabeledGivenData = torch.utils.data.Subset(unlabeledset, list(unlabeledset.imageLabelDict.keys()))
        trainset = CustomDataset(root='/dataset', split="train", transform=train_transform)
        trainExtraDataset = torch.utils.data.ConcatDataset((unlabeledGivenData, trainset))
        train_loader = torch.utils.data.DataLoader(trainExtraDataset, batch_size=128, shuffle=True, num_workers=4, pin_memory=True)
        return train_loader
        
    
    def val_dataloader(self):
        eval_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])
        evalset = CustomDataset(root='/dataset', split="val", transform=eval_transform)
        eval_loader = torch.utils.data.DataLoader(evalset, batch_size=128, shuffle=False, num_workers=4, pin_memory=True)
        return eval_loader
    
nyudata = NYUImageNetDataModule()

In [17]:
if os.path.isfile('/scratch/nr2229/best-checkpoint.pth'):
    ckpt = torch.load('/scratch/nr2229/best-checkpoint.pth',
                      map_location='cpu')
    model.load_state_dict(ckpt['model'])

In [18]:
ckpt['epoch']

620

In [19]:
from simclr.modules.identity import Identity
import torch.nn.functional as F
from torchmetrics.functional import accuracy

class ResNetClassifier(LightningModule):
    def __init__(self):
        super().__init__()
#         self.backbone = torchvision.models.resnet34(zero_init_residual=True)
        self.backbone = resnet.get_custom_resnet34()
        self.backbone.fc = nn.Identity()
        self.backbone.load_state_dict(model.backbone.state_dict())
        
        self.lastLayer = torch.nn.Sequential(
            torch.nn.Linear(512, 1024),
            torch.nn.ReLU(),
            nn.Dropout(p=0.3),
            torch.nn.Linear(1024, 800),
        )
#         self.lastLayer = torch.nn.Linear(512, 800)
        for layer in self.lastLayer.modules():
           if isinstance(layer, nn.Linear):
                layer.weight.data.normal_(mean=0.0, std=0.01)
                layer.bias.data.zero_()
        
        self.param_groups = [dict(params=self.lastLayer.parameters(), lr=0.01)]
        self.param_groups.append(dict(params=model.parameters(), lr=0.0005))
        
        self.criterion=torch.nn.CrossEntropyLoss()
        
    def forward(self, x):
        x = self.backbone(x)
#         x = self.relu(self.projector(x))
        x = self.lastLayer(x)
        return x
    
    def training_step(self, batch, batch_idx):
        data, label = batch
        classProbs = self.forward(data)
        loss = self.criterion(classProbs, label)
        self.log('train_loss', loss)
        return loss
    
    def _evaluate(self, batch, batch_idx, stage=None):
        x, y = batch
        out = self.forward(x)
        logits = F.log_softmax(out, dim=-1)
        loss = F.nll_loss(logits, y)
        preds = torch.argmax(logits, dim=-1)
        acc = accuracy(preds, y)

        if stage:
            self.log(f'{stage}_loss', loss, prog_bar=True)
            self.log(f'{stage}_acc', acc, prog_bar=True)

        return loss, acc
    
    def validation_step(self,batch,batch_idx):
        self._evaluate(batch, batch_idx, 'val')[0]
    
    def configure_optimizers(self):
        optimizer = optim.SGD(self.param_groups, 0, momentum=0.9, weight_decay=1e-5)
        scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, EPOCHS, verbose=True)
        return {'optimizer': optimizer, 'lr_scheduler': scheduler, 'monitor': 'val_loss'}

In [20]:
EPOCHS = 100
classifier = ResNetClassifier()

In [23]:
from pytorch_lightning.callbacks import ModelCheckpoint
checkpoint_callback = ModelCheckpoint(monitor='val_loss', save_last=True)
classifier_trainer = Trainer(gpus=1,deterministic=True, max_epochs=EPOCHS, default_root_dir='/scratch/nr2229/classifier-' + checkpointDir, profiler="simple",
                     limit_val_batches= 0.3, benchmark=True, callbacks=[checkpoint_callback], fast_dev_run=False)

GPU available: True, used: True
TPU available: None, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


In [None]:
classifier_trainer.fit(classifier, train_dataloader=nyudata.added_train_loader(), val_dataloaders=nyudata.val_dataloader())

Adjusting learning rate of group 0 to 1.0000e-02.
Adjusting learning rate of group 1 to 5.0000e-04.



  | Name      | Type             | Params
-----------------------------------------------
0 | backbone  | ResNet           | 21.3 M
1 | lastLayer | Sequential       | 1.3 M 
2 | criterion | CrossEntropyLoss | 0     
-----------------------------------------------
22.6 M    Trainable params
0         Non-trainable params
22.6 M    Total params


Epoch 0:  84%|████████▎ | 301/360 [01:32<00:18,  3.24it/s, loss=6.47, v_num=1, val_loss=6.68, val_acc=0]
Validating: 0it [00:00, ?it/s][A
Validating:   0%|          | 0/60 [00:00<?, ?it/s][A
Epoch 0:  84%|████████▍ | 302/360 [01:33<00:17,  3.24it/s, loss=6.47, v_num=1, val_loss=6.68, val_acc=0]
Epoch 0:  84%|████████▍ | 304/360 [01:33<00:17,  3.25it/s, loss=6.47, v_num=1, val_loss=6.68, val_acc=0]
Epoch 0:  85%|████████▌ | 306/360 [01:33<00:16,  3.27it/s, loss=6.47, v_num=1, val_loss=6.68, val_acc=0]
Epoch 0:  86%|████████▌ | 308/360 [01:33<00:15,  3.29it/s, loss=6.47, v_num=1, val_loss=6.68, val_acc=0]
Epoch 0:  86%|████████▌ | 310/360 [01:33<00:15,  3.31it/s, loss=6.47, v_num=1, val_loss=6.68, val_acc=0]
Epoch 0:  87%|████████▋ | 312/360 [01:33<00:14,  3.32it/s, loss=6.47, v_num=1, val_loss=6.68, val_acc=0]
Epoch 0:  87%|████████▋ | 314/360 [01:34<00:13,  3.33it/s, loss=6.47, v_num=1, val_loss=6.68, val_acc=0]
Epoch 0:  88%|████████▊ | 316/360 [01:34<00:13,  3.34it/s, loss=6.47, v_

In [None]:
net = classifier.cuda()

net.eval()
correct = 0
total = 0
with torch.no_grad():
    for batch in nyudata.val_dataloader():
        images, labels = batch

        images = images.cuda()
        labels = labels.cuda()

        outputs = net(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()


print(f"Accuracy: {(100 * correct / total):.2f}%")

In [None]:
accuracy = (100 * correct / total)
print('/scratch/vvb238/' + checkpointDir + '/' + str(accuracy).replace('.', '') + '-classifier.pth')
torch.save(classifier.state_dict(),
           '/scratch/vvb238/' + checkpointDir + '/' + str(accuracy).replace('.', '') + '-classifier.pth')