In [14]:
from pathlib import Path
import argparse
import json
import math
import os
import random
import signal
import subprocess
import sys
import time

from PIL import Image, ImageOps, ImageFilter
from torch import nn, optim
import torch
import torchvision
import torchvision.transforms as transforms

import torch.nn.functional as F
from torchmetrics.functional import accuracy

import pytorch_lightning as pl
from pytorch_lightning import Trainer, LightningModule

import resnet

In [15]:
checkpointDir = 'barlow-custom34-1000'

In [16]:
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, root, split, transform, limit=0):
        r"""
        Args:
            root: Location of the dataset folder, usually it is /dataset
            split: The split you want to used, it should be one of train, val or unlabeled.
            transform: the transform you want to applied to the images.
        """
        self.split = split
        self.transform = transform

        self.image_dir = os.path.join(root, split)
        
        label_path = os.path.join(root, f"{split}_label_tensor.pt")
        if limit == 0:
            self.num_images = len(os.listdir(self.image_dir))
        else:
            self.num_images = limit

        if os.path.exists(label_path):
            self.labels = torch.load(label_path)
        else:
            self.labels = -1 * torch.ones(self.num_images, dtype=torch.long)
            
            
        if self.split == "unlabeled":
            label_path = os.path.join("label_15.pt")
            if os.path.exists(label_path):
                labels = torch.load(label_path)

            images = []
            f = open("requests.txt", "r")
            s = str(f.read()).split("\n")
            for img in s:
                images.append(int(img.replace(".png,","")))
                
            self.imageLabelDict = { images[i]: labels[i]  for i in range(len(images))} 

    def __len__(self):
        return self.num_images

    def __getitem__(self, idx):
        with open(os.path.join(self.image_dir, f"{idx}.png"), 'rb') as f:
            img = Image.open(f).convert('RGB')

        if self.split == "unlabeled" and idx in self.imageLabelDict:
            return self.transform(img), self.imageLabelDict[idx]            
        else:
            return self.transform(img), self.labels[idx]

In [17]:
class GaussianBlur(object):
    def __init__(self, p):
        self.p = p

    def __call__(self, img):
        if random.random() < self.p:
            sigma = random.random() * 1.9 + 0.1
            return img.filter(ImageFilter.GaussianBlur(sigma))
        else:
            return img

In [27]:
class NYUImageNetDataModule(pl.LightningDataModule):
    def __init__(self):
        self.train_transform = transforms.Compose([
            transforms.RandomResizedCrop(96, interpolation=Image.BICUBIC),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomApply(
                [transforms.ColorJitter(brightness=0.4, contrast=0.4,
                                        saturation=0.2, hue=0.1)],
                p=0.8
            ),
            transforms.RandomGrayscale(p=0.2),
            GaussianBlur(p=0.5),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])
    
    
    def train_dataloader(self):
        trainset = CustomDataset(root='/dataset', split="train", transform=self.train_transform)
        train_loader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=4, pin_memory=True)
        return train_loader
    
    def extra_train_loader(self):
        unlabeledset = CustomDataset(root='/dataset', split="unlabeled", transform=self.train_transform)
        unlabeledGivenData = torch.utils.data.Subset(unlabeledset, list(unlabeledset.imageLabelDict.keys()))
        trainset = CustomDataset(root='/dataset', split="train", transform=self.train_transform)
        trainExtraDataset = torch.utils.data.ConcatDataset((unlabeledGivenData, trainset))
        train_loader = torch.utils.data.DataLoader(trainExtraDataset, batch_size=128, shuffle=True, num_workers=4, pin_memory=True)
        return train_loader
        
    
    def val_dataloader(self):
        eval_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])
        evalset = CustomDataset(root='/dataset', split="val", transform=eval_transform)
        eval_loader = torch.utils.data.DataLoader(evalset, batch_size=128, shuffle=False, num_workers=4, pin_memory=True)
        return eval_loader

In [28]:
nyudata = NYUImageNetDataModule()

In [29]:
class BarlowTwins(nn.Module):
    def __init__(self):
        super().__init__()
        self.backbone = resnet.get_custom_resnet34()
        self.backbone.fc = nn.Identity()

        # projector
        sizes = [512] + list(map(int, '1024-1024-1024'.split('-')))
        layers = []
        for i in range(len(sizes) - 2):
            layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=False))
            layers.append(nn.BatchNorm1d(sizes[i + 1]))
            layers.append(nn.ReLU(inplace=True))
        layers.append(nn.Linear(sizes[-2], sizes[-1], bias=False))
        self.projector = nn.Sequential(*layers)

        # normalization layer for the representations z1 and z2
        self.bn = nn.BatchNorm1d(sizes[-1], affine=False)
        
def exclude_bias_and_norm(p):
    return p.ndim == 1

In [30]:
torch.backends.cudnn.benchmark = True

model = BarlowTwins().cuda()
if os.path.isfile('/scratch/vvb238/' + checkpointDir + '/best-checkpoint.pth'):
    ckpt = torch.load('/scratch/vvb238/' + checkpointDir + '/best-checkpoint.pth',
                      map_location='cpu')
    model.load_state_dict(ckpt['model'])
print(ckpt['epoch'])

762


In [31]:
class ResNetClassifier(LightningModule):
    def __init__(self):
        super().__init__()
#         self.backbone = torchvision.models.resnet34(zero_init_residual=True)
        self.backbone = resnet.get_custom_resnet34()
        self.backbone.fc = nn.Identity()
        self.backbone.load_state_dict(model.backbone.state_dict())
        
        self.lastLayer = torch.nn.Sequential(
            torch.nn.Linear(512, 1024),
            torch.nn.ReLU(),
            nn.Dropout(p=0.3),
#             torch.nn.Linear(1024, 1024),
#             torch.nn.ReLU(),
#             nn.Dropout(p=0.3),
            torch.nn.Linear(1024, 800),
        )
#         self.lastLayer = torch.nn.Linear(512, 800)
        for layer in self.lastLayer.modules():
           if isinstance(layer, nn.Linear):
                layer.weight.data.normal_(mean=0.0, std=0.01)
                layer.bias.data.zero_()
        
        self.param_groups = [dict(params=self.lastLayer.parameters(), lr=0.01)]
        self.param_groups.append(dict(params=model.parameters(), lr=0.0008))
        
        self.criterion=torch.nn.CrossEntropyLoss()
        
    def forward(self, x):
        x = self.backbone(x)
        x = self.lastLayer(x)
        return x
    
    def training_step(self, batch, batch_idx):
        data, label = batch
        classProbs = self.forward(data)
        loss = self.criterion(classProbs, label)
        self.log('train_loss', loss)
        return loss
    
    def _evaluate(self, batch, batch_idx, stage=None):
        x, y = batch
        out = self.forward(x)
        logits = F.log_softmax(out, dim=-1)
        loss = F.nll_loss(logits, y)
        preds = torch.argmax(logits, dim=-1)
        acc = accuracy(preds, y)

        if stage:
            self.log(f'{stage}_loss', loss, prog_bar=True)
            self.log(f'{stage}_acc', acc, prog_bar=True)

        return loss, acc
    
    def validation_step(self,batch,batch_idx):
        self._evaluate(batch, batch_idx, 'val')[0]
    
    def configure_optimizers(self):
        optimizer = optim.SGD(self.param_groups, 0, momentum=0.9, weight_decay=1e-5)
        scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, EPOCHS, verbose=True)
        return {'optimizer': optimizer, 'lr_scheduler': scheduler, 'monitor': 'val_loss'}

In [32]:
EPOCHS = 100
classifier = ResNetClassifier()

In [33]:
torch.save(classifier.state_dict(),
           '/scratch/vvb238/' + checkpointDir + '/base-classifier.pth')

In [34]:
from pytorch_lightning.callbacks import ModelCheckpoint
checkpoint_callback = ModelCheckpoint(monitor='val_acc', save_last=True)
classifier_trainer = Trainer(gpus=1,deterministic=True, max_epochs=EPOCHS, default_root_dir='/scratch/vvb238/classifier-' + checkpointDir, profiler="simple",
                     limit_val_batches= 0.6, benchmark=True, callbacks=[checkpoint_callback], fast_dev_run=False, )
#                             resume_from_checkpoint='/scratch/vvb238/classifier-barlow-custom34-1000/lightning_logs/version_7/checkpoints/last.ckpt')

GPU available: True, used: True
TPU available: False, using: 0 TPU cores


In [35]:
classifier_trainer.fit(classifier, train_dataloader=nyudata.extra_train_loader(), val_dataloaders=nyudata.val_dataloader())

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type             | Params
-----------------------------------------------
0 | backbone  | ResNet           | 21.3 M
1 | lastLayer | Sequential       | 1.3 M 
2 | criterion | CrossEntropyLoss | 0     
-----------------------------------------------
22.6 M    Trainable params
0         Non-trainable params
22.6 M    Total params
90.489    Total estimated model params size (MB)


Adjusting learning rate of group 0 to 1.0000e-02.
Adjusting learning rate of group 1 to 8.0000e-04.
Epoch 0:  71%|███████▏  | 300/420 [01:50<00:44,  2.71it/s, loss=6.47, v_num=15, val_loss=6.700, val_acc=0.000]
Validating: 0it [00:00, ?it/s][A
Validating:   0%|          | 0/120 [00:00<?, ?it/s][A
Epoch 0:  72%|███████▏  | 302/420 [01:51<00:43,  2.71it/s, loss=6.47, v_num=15, val_loss=6.700, val_acc=0.000]
Epoch 0:  72%|███████▏  | 304/420 [01:51<00:42,  2.73it/s, loss=6.47, v_num=15, val_loss=6.700, val_acc=0.000]
Epoch 0:  73%|███████▎  | 306/420 [01:51<00:41,  2.74it/s, loss=6.47, v_num=15, val_loss=6.700, val_acc=0.000]
Epoch 0:  73%|███████▎  | 308/420 [01:51<00:40,  2.76it/s, loss=6.47, v_num=15, val_loss=6.700, val_acc=0.000]
Epoch 0:  74%|███████▍  | 310/420 [01:51<00:39,  2.77it/s, loss=6.47, v_num=15, val_loss=6.700, val_acc=0.000]
Epoch 0:  74%|███████▍  | 312/420 [01:51<00:38,  2.79it/s, loss=6.47, v_num=15, val_loss=6.700, val_acc=0.000]
Epoch 0:  75%|███████▍  | 314/420 

Saving latest checkpoint...


Epoch 99: 100%|██████████| 420/420 [01:13<00:00,  5.74it/s, loss=1.96, v_num=15, val_loss=2.520, val_acc=0.438]




Profiler Report

Action                             	|  Mean duration (s)	|Num calls      	|  Total time (s) 	|  Percentage %   	|
------------------------------------------------------------------------------------------------------------------------------------
Total                              	|  -              	|_              	|  7373.0         	|  100 %          	|
------------------------------------------------------------------------------------------------------------------------------------
run_training_epoch                 	|  73.647         	|100            	|  7364.7         	|  99.888         	|
run_training_batch                 	|  0.20857        	|30000          	|  6257.2         	|  84.867         	|
evaluation_step_and_end            	|  0.063657       	|12002          	|  764.01         	|  10.362         	|
validation_step                    	|  0.063486       	|12002          	|  761.96         	|  10.335         	|
optimizer_step_and_closure_0       	|  0.

1

In [36]:
net = classifier.cuda()

net.eval()
correct = 0
total = 0
with torch.no_grad():
    for batch in nyudata.val_dataloader():
        images, labels = batch

        images = images.cuda()
        labels = labels.cuda()

        outputs = net(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()


print(f"Accuracy: {(100 * correct / total):.2f}%")

Accuracy: 43.73%


In [39]:
accuracy = (100 * correct / total)
print('/scratch/vvb238/extra-' + checkpointDir + '/' + str(accuracy).replace('.', '') + '-extra-classifier.pth')
torch.save(classifier.state_dict(),
           '/scratch/vvb238/' + checkpointDir + '/' + str(accuracy).replace('.', '') + '-extra-classifier.pth')

/scratch/vvb238/extra-barlow-custom34-1000/437265625-extra-classifier.pth
