In [1]:
from torchvision import datasets, transforms, models
from torch import nn
import torch.nn.functional as F
import torchvision
import torch

from tqdm import tqdm

import pytorch_lightning as pl
from pytorch_lightning import Trainer, LightningModule
import torchvision.transforms as transforms

from PIL import Image
from simclr import SimCLR
from simclr.modules import NT_Xent
from simclr.modules.transformations import TransformsSimCLR
from simclr.modules.sync_batchnorm import convert_model
from simclr.modules import LARS
from simclr.modules.identity import Identity

import random
from typing import Type, Any, Callable, Union, List, Optional
from torch import Tensor

import resnet

import os
import argparse
import sys

In [2]:
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, root, split, transform, limit=0):
        r"""
        Args:
            root: Location of the dataset folder, usually it is /dataset
            split: The split you want to used, it should be one of train, val or unlabeled.
            transform: the transform you want to applied to the images.
        """

        self.split = split
        self.transform = transform

        self.image_dir = os.path.join(root, split)
        label_path = os.path.join(root, f"{split}_label_tensor.pt")

        if limit == 0:
            self.num_images = len(os.listdir(self.image_dir))
        else:
            self.num_images = limit

        if os.path.exists(label_path):
            self.labels = torch.load(label_path)
        else:
            self.labels = -1 * torch.ones(self.num_images, dtype=torch.long)

    def __len__(self):
        return self.num_images

    def __getitem__(self, idx):
        with open(os.path.join(self.image_dir, f"{idx}.png"), 'rb') as f:
            img = Image.open(f).convert('RGB')

        return self.transform(img), self.labels[idx]

In [3]:
class NYUImageNetDataModule(pl.LightningDataModule):
  
    def train_dataloader(self):
        train_transform = transforms.Compose([
            transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomVerticalFlip(p=0.5),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])

        trainset = CustomDataset(root='/dataset', split="train", transform=train_transform)
        train_loader = torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=True, num_workers=4, pin_memory=True)
        return train_loader
    
    def val_dataloader(self):
        eval_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])
        evalset = CustomDataset(root='/dataset', split="val", transform=eval_transform)
        eval_loader = torch.utils.data.DataLoader(evalset, batch_size=32, shuffle=False, num_workers=4, pin_memory=True)
        return eval_loader
    
    def ssl_train_dataloader(self, batch_size):
        unlabeled_dataset = CustomDataset(root='/dataset', split='unlabeled', transform=TransformsSimCLR(96))
        unlabeled_dataloader = torch.utils.data.DataLoader(unlabeled_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
        return unlabeled_dataloader
        
    def ssl_val_dataloader(self, batch_size):
        val_dataset = CustomDataset(root='/dataset', split='val', transform=TransformsSimCLR(96))
        val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
        return val_dataloader

In [4]:
data = NYUImageNetDataModule()

In [5]:
class ResNetClassifier(LightningModule):
    def __init__(self):
        super().__init__()
        self.encoder = resnet.get_custom_resnet18()
        # self.encoder.load_state_dict(torch.load(os.path.join(checkpoint_dir, 'simclr_encoder.pth')))
        self.encoder.fc = Identity()
        self.lastLayer = torch.nn.Linear(512, 800)
        self.criterion=torch.nn.CrossEntropyLoss()
        
    def forward(self, x):
        x = self.encoder(x)
        x = self.lastLayer(x)
        return x
    
    def training_step(self, batch, batch_idx):
        data, label = batch
        classProbs = self.forward(data)
        loss = self.criterion(classProbs, label)
        self.log('train_loss', loss)
        return loss
    
    def validation_step(self,batch,batch_idx):
        data, label = batch
        classProbs = self.forward(data)
        loss = self.criterion(classProbs, label)
        self.log('val_loss', loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        return { 'val_loss' : loss, 'prediction' : classProbs, 'target' : label }
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters())
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=5)
        return ({'optimizer': optimizer, 'lr_scheduler': scheduler, 'monitor': 'val_loss'})

In [6]:
classifier = ResNetClassifier()
classifier.load_state_dict(torch.load(os.path.join('/scratch/vvb238/finalSubmission', 'classifier.pth')))

<All keys matched successfully>

In [7]:
classifier = classifier.cuda()

classifier.eval()
correct = 0
total = 0
conf, pred, actual = torch.Tensor().cuda(), torch.Tensor().cuda(), torch.Tensor().cuda()
with torch.no_grad():
    for batch in data.val_dataloader():
        images, labels = batch

        images = images.cuda()
        labels = labels.cuda()

        outputs = classifier(images)
        logits = F.softmax(outputs, dim=1)
        labelConfidence, predictedLabels = torch.max(logits.data, 1)
        correctSamples = (predictedLabels == labels)
        
        conf = torch.cat((conf, labelConfidence))
        pred = torch.cat((pred, predictedLabels))
        actual = torch.cat((actual, labels))
        
#         print(labelConfidence, predictedLabels)
#         print(((predictedLabels == labels) != 0).nonzero().squeeze())
#         print(labelConfidence[((predictedLabels == labels) != 0).nonzero().squeeze()])
#         print(torch.sort(labelConfidence, descending=True))
#         break
        total += labels.size(0)
        correct += correctSamples.sum().item()


print(f"Accuracy: {(100 * correct / total):.2f}%")

Accuracy: 17.45%


In [9]:
sortedConf, sortedLabel = torch.sort(conf.cpu(), descending=True)

In [31]:
with torch.no_grad():
    limit = int(actual.shape[0] * 0.03)
    top20Labels = sortedLabel[:limit]
    equalLabels = pred[top20Labels].cpu() == actual[top20Labels].cpu()
    print("Correct Labels", equalLabels.count_nonzero())
    print("Incorrect Labels", limit - equalLabels.count_nonzero())
    print("Percentage correct", equalLabels.count_nonzero() / limit)

Correct Labels tensor(592)
Incorrect Labels tensor(176)
Percentage correct tensor(0.7708)


In [23]:
actual.shape[0]

25600

In [10]:
torch.save(classifier.state_dict(), os.path.join('/scratch/vvb238/iterativeDataFeeding', 'classifier.pth'))

In [29]:
torch.Tensor()

tensor([])