In [3]:
from pathlib import Path
import argparse
import json
import math
import os
import random
import signal
import subprocess
import sys
import time

from PIL import Image, ImageOps, ImageFilter
from torch import nn, optim
import torch
import torchvision
import torchvision.transforms as transforms

import lightly
import lightly.models as models
import lightly.loss as loss
import lightly.data as data
from lightly.models.barlowtwins import BarlowTwins
from lightly.models.simclr import SimCLR

from simclr.modules.identity import Identity
import torch.nn.functional as F
from torchmetrics.functional import accuracy

from simclr.modules.transformations import TransformsSimCLR
from PIL import Image, ImageOps, ImageFilter

from tqdm import tqdm

import pytorch_lightning as pl
from pytorch_lightning import Trainer, LightningModule

import resnet

In [4]:
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, root, split, transform, limit=0):
        r"""
        Args:
            root: Location of the dataset folder, usually it is /dataset
            split: The split you want to used, it should be one of train, val or unlabeled.
            transform: the transform you want to applied to the images.
        """

        self.split = split
        self.transform = transform

        self.image_dir = os.path.join(root, split)
        label_path = os.path.join(root, f"{split}_label_tensor.pt")

        if limit == 0:
            self.num_images = len(os.listdir(self.image_dir))
        else:
            self.num_images = limit

        if os.path.exists(label_path):
            self.labels = torch.load(label_path)
        else:
            self.labels = -1 * torch.ones(self.num_images, dtype=torch.long)

    def __len__(self):
        return self.num_images

    def __getitem__(self, idx):
        idx = int(idx)
        with open(os.path.join(self.image_dir, f"{idx}.png"), 'rb') as f:
            img = Image.open(f).convert('RGB')
            
        if self.transform == None:
            return img, self.labels[idx]            

        return self.transform(img), self.labels[idx], torch.tensor(idx).float()

In [5]:
class ResNetClassifier(LightningModule):
    def __init__(self):
        super().__init__()
        self.backbone = torchvision.models.resnet34(zero_init_residual=True)
#         self.backbone = resnet.get_custom_resnet34()
        self.backbone.fc = nn.Identity()
#         self.backbone.load_state_dict(model.backbone.state_dict())
        
        self.lastLayer = torch.nn.Sequential(
            torch.nn.Linear(512, 1024),
            torch.nn.ReLU(),
            nn.Dropout(p=0.1),
            torch.nn.Linear(1024, 800),
        )
        
        for layer in self.lastLayer.modules():
           if isinstance(layer, nn.Linear):
                layer.weight.data.normal_(mean=0.0, std=0.01)
                layer.bias.data.zero_()
        
        self.param_groups = [dict(params=self.lastLayer.parameters(), lr=0.01)]
        self.param_groups.append(dict(params=self.backbone.parameters(), lr=0.0001))
        
        self.criterion=torch.nn.CrossEntropyLoss()
        
    def forward(self, x):
        x = self.backbone(x)
        x = self.lastLayer(x)
        return x
    
    def training_step(self, batch, batch_idx):
        data, label, _ = batch
        classProbs = self.forward(data)
        loss = self.criterion(classProbs, label)
        self.log('train_loss', loss)
        return loss
    
    def _evaluate(self, batch, batch_idx, stage=None):
        x, y, _ = batch
        out = self.forward(x)
        logits = F.log_softmax(out, dim=-1)
        loss = F.nll_loss(logits, y)
        preds = torch.argmax(logits, dim=-1)
        acc = accuracy(preds, y)

        if stage:
            self.log(f'{stage}_loss', loss, prog_bar=True)
            self.log(f'{stage}_acc', acc, prog_bar=True)

        return loss, acc
    
    def validation_step(self,batch,batch_idx):
        self._evaluate(batch, batch_idx, 'val')[0]
    
    def configure_optimizers(self):
        optimizer = optim.SGD(self.param_groups, 0, momentum=0.9, weight_decay=1e-5)
        scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, EPOCHS, verbose=True)
        return {'optimizer': optimizer, 'lr_scheduler': scheduler, 'monitor': 'val_loss'}

In [7]:
checkpointDir = 'barlow-34'

In [8]:
classifier = ResNetClassifier()

if os.path.isfile('/scratch/vvb238/' + checkpointDir + '/27-classifier.pth'):
    ckpt = torch.load('/scratch/vvb238/' + checkpointDir + '/27-classifier.pth',
                      map_location='cpu')
    classifier.load_state_dict(ckpt)

In [9]:
unlabeled_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

entireUnlabeledDataset = CustomDataset(root='/dataset', split="unlabeled", transform=unlabeled_transform)
# toBeRankedIndices = torch.tensor([i for i in range(len(entireUnlabeledDataset))])

In [10]:
classifier = classifier.cuda()
entireUnlabeledDataLoader = torch.utils.data.DataLoader(entireUnlabeledDataset, batch_size=512, 
                                                        shuffle=True, num_workers=4, pin_memory=True)

In [11]:
allDifferenceInTopTwo, predictedLabels = torch.Tensor(), torch.Tensor()
actualLabels, allIndices, allImageTensors = torch.Tensor(), torch.tensor([]), torch.Tensor()

classifier.eval()
print("\tStarting the evaluation process with unlabeled data")
with torch.no_grad():
    # Going through the left over unlabeled set and collecting the confidence for model predictions
    numOfBatches = len(entireUnlabeledDataset) / entireUnlabeledDataLoader.batch_size
    for idx, batch in tqdm(enumerate(entireUnlabeledDataLoader), total=int(numOfBatches)):
        images, labels, indices = batch

        images = images.cuda()
#             labels = labels.cuda()

        classScores = classifier(images)
        classLogits = F.softmax(classScores, dim=1)

        # Taking the top 2 values in the class prediction for each image
        labelConfidence, predictions = torch.sort(classLogits.data, dim=1, descending=True)
        # And subtracting those values
        differenceInTopTwo = labelConfidence[:, 0] - labelConfidence[:, 1]
        # Sorting based on the subtracted values, this gives the images with most confusion between top two classes
        sortedDifferenceInTopTwo, sortedDifferenceInTopTwoPos = torch.sort(differenceInTopTwo, descending=False)
        # Taking the top 150 of the confusion to avoid memory overload
        topSortedDifferenceInTopTwoPos = sortedDifferenceInTopTwoPos[:150]

        allDifferenceInTopTwo = torch.cat((allDifferenceInTopTwo, differenceInTopTwo[topSortedDifferenceInTopTwoPos].cpu()))
        allIndices = torch.cat((allIndices, indices[topSortedDifferenceInTopTwoPos].cpu()))
        

    print("\tGot the predictions of" , len(entireUnlabeledDataset), " images")

    # Sorting all the predictions based on the confidence scores and the argsort
    allSortedDifferenceInTopTwo, allSortedDifferenceInTopTwoPos = torch.sort(allDifferenceInTopTwo, descending=False)
    print("\tSorted the predictions based on confidence scores")

    # Calculating how many top predictions to retrain the model on
    leastDifferenceInTopTwoPos = allSortedDifferenceInTopTwoPos[:100000]
    print("\tGot the top ", 100000, "confidence indices")


    # Fetching the top confidence's index in original dataset
    topConfidenceIndices = allIndices[leastDifferenceInTopTwoPos]

	Starting the evaluation process with unlabeled data


  0%|          | 0/1000 [00:05<?, ?it/s]


KeyboardInterrupt: 

In [None]:
unlabeledFilteredData = torch.utils.data.Subset(entireUnlabeledDataset, topConfidenceIndices.tolist())
unlabeledFiteredDataLoader = torch.utils.data.DataLoader(unlabeledFilteredData, batch_size=512, shuffle=False, num_workers=4, pin_memory=True)

In [12]:
allIndices, allImageEncoding = torch.Tensor(), torch.tensor([])

classifier.eval()
print("\tStarting the evaluation process with unlabeled data")
with torch.no_grad():
    # Going through the left over unlabeled set and collecting the confidence for model predictions
    numOfBatches = len(unlabeledFilteredData) / unlabeledFiteredDataLoader.batch_size
    for idx, batch in tqdm(enumerate(unlabeledFiteredDataLoader), total=int(numOfBatches)):
        images, labels, indices = batch

        images = images.cuda()
#             labels = labels.cuda()

        classScores = classifier(images)
        classLogits = F.softmax(classScores, dim=1)

        allImageEncoding = torch.cat((allImageEncoding, classLogits.cpu()))
        allIndices = torch.cat((allIndices, indices.cpu()))

	Starting the evaluation process with unlabeled data


196it [02:00,  1.62it/s]                         


In [13]:
from sklearn.cluster import KMeans
import numpy as np

In [14]:
allImageEncoding = allImageEncoding.numpy()
allIndices = allIndices.numpy()

In [15]:
allImageEncoding.shape

(100000, 800)

In [16]:
kmeans = KMeans(n_clusters=800, n_init=5, verbose=1).fit(allImageEncoding)

Initialization complete
Iteration 0, inertia 2656.10791015625
Iteration 1, inertia 2190.6669921875
Iteration 2, inertia 2122.38720703125
Iteration 3, inertia 2092.43408203125
Iteration 4, inertia 2075.581298828125
Iteration 5, inertia 2065.463134765625
Iteration 6, inertia 2058.611083984375
Iteration 7, inertia 2053.73681640625
Iteration 8, inertia 2050.10791015625
Iteration 9, inertia 2047.1480712890625
Iteration 10, inertia 2044.8035888671875
Iteration 11, inertia 2042.8455810546875
Iteration 12, inertia 2041.1810302734375
Iteration 13, inertia 2039.7398681640625
Iteration 14, inertia 2038.52685546875
Iteration 15, inertia 2037.4541015625
Iteration 16, inertia 2036.523681640625
Iteration 17, inertia 2035.7159423828125
Iteration 18, inertia 2034.9560546875
Iteration 19, inertia 2034.3880615234375
Iteration 20, inertia 2033.875
Iteration 21, inertia 2033.430419921875
Iteration 22, inertia 2033.055908203125
Iteration 23, inertia 2032.7059326171875
Iteration 24, inertia 2032.421997070312

In [17]:
from collections import defaultdict
clusterImageIdMap = defaultdict(list)
totalCount = 0
for clusterId, image in zip(kmeans.labels_, allIndices):
    if len(clusterImageIdMap[clusterId]) < 17:
        clusterImageIdMap[clusterId].append(image)
        totalCount += 1
    if totalCount == 12800:
        print("Reached max limit")
        break

Reached max limit


In [59]:
kmeans.cluster_centers_

array([[4.2580566e-04, 5.9748301e-04, 5.5818297e-03, ..., 5.2689691e-05,
        2.0025240e-04, 2.7144561e-03],
       [1.3270357e-05, 2.6593683e-05, 4.3253647e-05, ..., 2.9688352e-05,
        1.8629110e-03, 3.4080073e-04],
       [1.6505823e-03, 4.7148514e-04, 9.9293201e-04, ..., 7.7215867e-04,
        5.0312490e-05, 2.4045401e-03],
       ...,
       [2.5987683e-05, 7.5475611e-03, 2.3264412e-03, ..., 8.7153399e-05,
        6.1658910e-05, 1.2810214e-04],
       [1.2027473e-03, 8.2049519e-07, 1.0367716e-05, ..., 3.2855608e-03,
        6.7136716e-07, 4.3596374e-06],
       [3.1851232e-06, 2.9807896e-03, 1.9166067e-03, ..., 9.4703631e-05,
        7.0744427e-05, 1.8749142e-04]], dtype=float32)

In [18]:
f = open("imageRequest-23.txt", "a")
for key in clusterImageIdMap:
    for image in clusterImageIdMap[key]:
        f.write(str(int(image)) + ".png\n")
f.close()

In [63]:
uniqueCheck = set()
for key in clusterImageIdMap:
    for image in clusterImageIdMap[key]:
        uniqueCheck.add(image)
len(uniqueCheck)

12800