In [125]:
import numpy as np
from torchvision.datasets import CIFAR100
from collections import defaultdict
import random
import pickle
from tqdm import tqdm
import torch
import torch.nn as nn

In [126]:
# Get Dataset
dataset = CIFAR100(root='./data', train=True, download=True)
print(dataset)

with open('data/cifar-100-python/meta', 'rb') as fo:
    metadata = pickle.load(fo, encoding='bytes')

classLabels = dict(list(enumerate(metadata[b'fine_label_names'])))
print(classLabels)

Files already downloaded and verified
Dataset CIFAR100
    Number of datapoints: 50000
    Root location: ./data
    Split: Train
{0: b'apple', 1: b'aquarium_fish', 2: b'baby', 3: b'bear', 4: b'beaver', 5: b'bed', 6: b'bee', 7: b'beetle', 8: b'bicycle', 9: b'bottle', 10: b'bowl', 11: b'boy', 12: b'bridge', 13: b'bus', 14: b'butterfly', 15: b'camel', 16: b'can', 17: b'castle', 18: b'caterpillar', 19: b'cattle', 20: b'chair', 21: b'chimpanzee', 22: b'clock', 23: b'cloud', 24: b'cockroach', 25: b'couch', 26: b'crab', 27: b'crocodile', 28: b'cup', 29: b'dinosaur', 30: b'dolphin', 31: b'elephant', 32: b'flatfish', 33: b'forest', 34: b'fox', 35: b'girl', 36: b'hamster', 37: b'house', 38: b'kangaroo', 39: b'keyboard', 40: b'lamp', 41: b'lawn_mower', 42: b'leopard', 43: b'lion', 44: b'lizard', 45: b'lobster', 46: b'man', 47: b'maple_tree', 48: b'motorcycle', 49: b'mountain', 50: b'mouse', 51: b'mushroom', 52: b'oak_tree', 53: b'orange', 54: b'orchid', 55: b'otter', 56: b'palm_tree', 57: b'pear

In [127]:
with open('data/cifar-100-python/train', 'rb') as fo:
    trainMeta = pickle.load(fo, encoding='bytes')

with open('data/cifar-100-python/test', 'rb') as fo:
    testMeta = pickle.load(fo, encoding='bytes')


totalTest = len(testMeta[b'data'])
valNum = 4000

# Get training/testing data and labels
trainData = trainMeta[b'data']
trainData = trainData.reshape((len(trainData), 3, 32, 32))
trainLabel = np.array(trainMeta[b'fine_labels'])

mask = list(range(valNum))
valData = testMeta[b'data'][mask]
valData = valData.reshape((len(valData), 3, 32, 32))
valLabel = np.array(trainMeta[b'fine_labels'])[mask]

mask = list(range(valNum, totalTest))
testData = testMeta[b'data'][mask]
testData = testData.reshape((len(testData), 3, 32, 32))
testLabel = np.array(testMeta[b'fine_labels'])[mask]

In [128]:
# Get the number of images per class
imgPerClass = []

for cls in range(100):
    num = 500 * (0.05 ** (cls / 99))
    imgPerClass.append(int(num))

print(imgPerClass)
print(np.sum(imgPerClass))

[500, 485, 470, 456, 442, 429, 416, 404, 392, 380, 369, 358, 347, 337, 327, 317, 308, 298, 290, 281, 272, 264, 256, 249, 241, 234, 227, 220, 214, 207, 201, 195, 189, 184, 178, 173, 168, 163, 158, 153, 149, 144, 140, 136, 132, 128, 124, 120, 116, 113, 110, 106, 103, 100, 97, 94, 91, 89, 86, 83, 81, 78, 76, 74, 72, 69, 67, 65, 63, 61, 60, 58, 56, 54, 53, 51, 50, 48, 47, 45, 44, 43, 41, 40, 39, 38, 37, 35, 34, 33, 32, 31, 30, 29, 29, 28, 27, 26, 25, 25]
15907


In [129]:
# Get LT training data
trainDataLT, trainLabelLT = [], []
random.shuffle(imgPerClass)

for cls, numImg in enumerate(imgPerClass):
    clsIndx = np.where(trainLabel == cls)[0]
    numSampledImages = np.random.choice(clsIndx, numImg, replace=False)

    trainDataLT.append(trainData[numSampledImages])
    trainLabelLT.append(trainLabel[numSampledImages])

trainDataLT, trainLabelLT = np.concatenate(trainDataLT), np.concatenate(trainLabelLT)

In [None]:

def test_model(model, dataSet):

    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for batch_data in dataSet:
            images, labels = batch_data

            output = model(images)
            predicted = torch.argmax(output, 1)

            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    acc = correct // total
    return acc

In [None]:
def train_model(trainData, valData, epochs: int, model=None, optimizer=None, scheduler=None, lossFunc=None, freq=50) -> None:

    for epoch in range(epochs):
        model.train()
        
        avgLoss = 0
        totalCorrect = 0
        totalNum = 0

        for i, batch in tqdm(enumerate(trainData)):
            data, labels = batch

            optimizer.zero_grad()
            output = model(data)
            target = torch.tensor(labels, dtype=torch.long)
            loss = lossFunc(output, target)
            loss.backward()
            optimizer.step()

            avgLoss += loss.item()
            pred = torch.argmax(output, 1)
            totalCorrect += (pred == labels).sum().item()
            totalNum += labels.size(0)

            if i % freq == 0:    # print every certain number of mini-batches
                avgLoss = avgLoss / freq
                accuracy = totalCorrect / totalNum
                last_lr = scheduler.get_last_lr()[0]
                print(f'[{epoch + 1}/{epochs}, {i + 1:5d}/{len(trainData)}] loss: {avgLoss:.3f} acc: {accuracy:.3f} lr: {last_lr:.5f}')
                avgLoss = 0.0
                totalCorrect = 0.0
                totalNum = 0.0

        scheduler.step()

        val_acc = test_model(model, valData)
        print(f'[{epoch + 1}/{epochs}] val acc: {val_acc:.3f}')

            


In [159]:
class baseModel(nn.Module):
    def __init__(self):
        super(baseModel, self).__init__()
        self.relu = nn.ReLU()
        self.pool = nn.MaxPool2d(2, 2)
        
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(32)

        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(128)
        
        self.fc1 = nn.Linear(128 * 4 * 4, 256)
        self.fc2 = nn.Linear(256, 100)

        self.dropout = nn.Dropout(0.5)
        self.softmax = nn.Softmax()

    def forward(self, x):
        x = self.pool(self.relu(self.conv1(x)))
        x = self.pool(self.relu(self.conv2(x)))
        x = self.pool(self.relu(self.conv3(x)))

        x = torch.flatten(x, 1)
        x = self.dropout(self.relu(self.fc1(x)))
        x = self.softmax(self.fc2(x))

        return x

In [160]:
trainSet = torch.utils.data.TensorDataset(torch.Tensor(trainDataLT), torch.Tensor(trainLabelLT))
trainLoader = torch.utils.data.DataLoader(trainSet, batch_size=32, shuffle=True, num_workers=2)
valSet = torch.utils.data.TensorDataset(torch.Tensor(valData), torch.Tensor(valLabel))
valLoader = torch.utils.data.DataLoader(valSet, batch_size=32, shuffle=True, num_workers=2)

model = baseModel()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
loss = nn.CrossEntropyLoss()

train_model(trainLoader, valLoader, 10, model, optimizer, scheduler, loss)

  return self._call_impl(*args, **kwargs)
  target = torch.tensor(labels, dtype=torch.long)
11it [00:02,  7.46it/s]

[1/10,     1/498] loss: 0.092 acc: 0.000 lr: 0.00100


55it [00:03, 44.20it/s]

[1/10,    51/498] loss: 4.610 acc: 1.125 lr: 0.00100


109it [00:04, 52.03it/s]

[1/10,   101/498] loss: 4.609 acc: 1.312 lr: 0.00100


157it [00:05, 53.25it/s]

[1/10,   151/498] loss: 4.608 acc: 1.438 lr: 0.00100


205it [00:06, 53.48it/s]

[1/10,   201/498] loss: 4.612 acc: 1.000 lr: 0.00100


259it [00:07, 52.40it/s]

[1/10,   251/498] loss: 4.610 acc: 1.250 lr: 0.00100


307it [00:07, 53.12it/s]

[1/10,   301/498] loss: 4.605 acc: 1.750 lr: 0.00100


355it [00:08, 52.23it/s]

[1/10,   351/498] loss: 4.612 acc: 1.000 lr: 0.00100


409it [00:09, 52.65it/s]

[1/10,   401/498] loss: 4.608 acc: 1.375 lr: 0.00100


457it [00:10, 52.33it/s]

[1/10,   451/498] loss: 4.612 acc: 1.062 lr: 0.00100


498it [00:11, 41.54it/s]


[1/10] val acc: 1.000


6it [00:02,  3.38it/s]

[2/10,     1/498] loss: 0.092 acc: 0.000 lr: 0.00100


59it [00:03, 44.20it/s]

[2/10,    51/498] loss: 4.615 acc: 0.688 lr: 0.00100


107it [00:04, 51.16it/s]

[2/10,   101/498] loss: 4.603 acc: 1.938 lr: 0.00100


160it [00:05, 49.28it/s]

[2/10,   151/498] loss: 4.609 acc: 1.312 lr: 0.00100


208it [00:06, 51.54it/s]

[2/10,   201/498] loss: 4.614 acc: 0.812 lr: 0.00100


256it [00:07, 50.54it/s]

[2/10,   251/498] loss: 4.607 acc: 1.562 lr: 0.00100


308it [00:08, 53.23it/s]

[2/10,   301/498] loss: 4.612 acc: 1.000 lr: 0.00100


356it [00:09, 51.73it/s]

[2/10,   351/498] loss: 4.611 acc: 1.125 lr: 0.00100


410it [00:10, 53.02it/s]

[2/10,   401/498] loss: 4.605 acc: 1.688 lr: 0.00100


458it [00:11, 53.09it/s]

[2/10,   451/498] loss: 4.608 acc: 1.438 lr: 0.00100


498it [00:12, 40.84it/s]


[2/10] val acc: 1.000


7it [00:02,  4.32it/s]

[3/10,     1/498] loss: 0.092 acc: 0.000 lr: 0.00100


55it [00:03, 43.02it/s]

[3/10,    51/498] loss: 4.610 acc: 1.188 lr: 0.00100


109it [00:04, 52.61it/s]

[3/10,   101/498] loss: 4.609 acc: 1.312 lr: 0.00100


157it [00:05, 52.25it/s]

[3/10,   151/498] loss: 4.607 acc: 1.562 lr: 0.00100


205it [00:05, 52.68it/s]

[3/10,   201/498] loss: 4.609 acc: 1.312 lr: 0.00100


259it [00:06, 53.23it/s]

[3/10,   251/498] loss: 4.610 acc: 1.250 lr: 0.00100


307it [00:07, 53.30it/s]

[3/10,   301/498] loss: 4.608 acc: 1.375 lr: 0.00100


355it [00:08, 53.43it/s]

[3/10,   351/498] loss: 4.613 acc: 0.875 lr: 0.00100


409it [00:09, 54.22it/s]

[3/10,   401/498] loss: 4.606 acc: 1.625 lr: 0.00100


457it [00:10, 53.96it/s]

[3/10,   451/498] loss: 4.610 acc: 1.188 lr: 0.00100


498it [00:11, 42.15it/s]


[3/10] val acc: 1.000


7it [00:02,  4.31it/s]

[4/10,     1/498] loss: 0.092 acc: 3.125 lr: 0.00100


61it [00:03, 47.00it/s]

[4/10,    51/498] loss: 4.607 acc: 1.500 lr: 0.00100


109it [00:04, 53.86it/s]

[4/10,   101/498] loss: 4.613 acc: 0.938 lr: 0.00100


157it [00:04, 54.04it/s]

[4/10,   151/498] loss: 4.610 acc: 1.188 lr: 0.00100


211it [00:05, 54.07it/s]

[4/10,   201/498] loss: 4.607 acc: 1.500 lr: 0.00100


259it [00:06, 53.32it/s]

[4/10,   251/498] loss: 4.610 acc: 1.250 lr: 0.00100


307it [00:07, 53.57it/s]

[4/10,   301/498] loss: 4.610 acc: 1.250 lr: 0.00100


355it [00:08, 53.43it/s]

[4/10,   351/498] loss: 4.607 acc: 1.500 lr: 0.00100


409it [00:09, 53.12it/s]

[4/10,   401/498] loss: 4.608 acc: 1.438 lr: 0.00100


457it [00:10, 53.78it/s]

[4/10,   451/498] loss: 4.613 acc: 0.875 lr: 0.00100


498it [00:11, 42.42it/s]


[4/10] val acc: 1.000


6it [00:02,  3.37it/s]

[5/10,     1/498] loss: 0.092 acc: 0.000 lr: 0.00100


58it [00:03, 44.69it/s]

[5/10,    51/498] loss: 4.609 acc: 1.312 lr: 0.00100


106it [00:04, 52.73it/s]

[5/10,   101/498] loss: 4.612 acc: 1.000 lr: 0.00100


160it [00:05, 52.79it/s]

[5/10,   151/498] loss: 4.604 acc: 1.812 lr: 0.00100


208it [00:06, 53.39it/s]

[5/10,   201/498] loss: 4.608 acc: 1.438 lr: 0.00100


256it [00:07, 53.22it/s]

[5/10,   251/498] loss: 4.610 acc: 1.250 lr: 0.00100


310it [00:08, 53.29it/s]

[5/10,   301/498] loss: 4.606 acc: 1.625 lr: 0.00100


358it [00:09, 53.63it/s]

[5/10,   351/498] loss: 4.606 acc: 1.625 lr: 0.00100


406it [00:09, 53.07it/s]

[5/10,   401/498] loss: 4.612 acc: 1.000 lr: 0.00100


460it [00:10, 53.42it/s]

[5/10,   451/498] loss: 4.612 acc: 1.062 lr: 0.00100


498it [00:12, 41.28it/s]


[5/10] val acc: 1.000


6it [00:02,  3.68it/s]

[6/10,     1/498] loss: 0.092 acc: 0.000 lr: 0.00100


58it [00:03, 44.95it/s]

[6/10,    51/498] loss: 4.611 acc: 1.125 lr: 0.00100


106it [00:04, 51.16it/s]

[6/10,   101/498] loss: 4.603 acc: 1.938 lr: 0.00100


160it [00:05, 51.96it/s]

[6/10,   151/498] loss: 4.612 acc: 1.000 lr: 0.00100


207it [00:06, 52.53it/s]

[6/10,   201/498] loss: 4.611 acc: 1.125 lr: 0.00100


255it [00:07, 52.34it/s]

[6/10,   251/498] loss: 4.607 acc: 1.500 lr: 0.00100


309it [00:08, 53.24it/s]

[6/10,   301/498] loss: 4.611 acc: 1.125 lr: 0.00100


357it [00:08, 53.21it/s]

[6/10,   351/498] loss: 4.616 acc: 0.625 lr: 0.00100


411it [00:09, 52.34it/s]

[6/10,   401/498] loss: 4.609 acc: 1.312 lr: 0.00100


459it [00:10, 54.01it/s]

[6/10,   451/498] loss: 4.605 acc: 1.750 lr: 0.00100


498it [00:12, 40.82it/s]


[6/10] val acc: 1.000


7it [00:02,  4.26it/s]

[7/10,     1/498] loss: 0.092 acc: 0.000 lr: 0.00100


58it [00:03, 45.32it/s]

[7/10,    51/498] loss: 4.604 acc: 1.812 lr: 0.00100


106it [00:04, 51.77it/s]

[7/10,   101/498] loss: 4.609 acc: 1.312 lr: 0.00100


160it [00:05, 53.41it/s]

[7/10,   151/498] loss: 4.607 acc: 1.500 lr: 0.00100


208it [00:06, 53.40it/s]

[7/10,   201/498] loss: 4.612 acc: 1.062 lr: 0.00100


256it [00:06, 52.77it/s]

[7/10,   251/498] loss: 4.609 acc: 1.312 lr: 0.00100


310it [00:07, 53.40it/s]

[7/10,   301/498] loss: 4.608 acc: 1.375 lr: 0.00100


358it [00:08, 53.05it/s]

[7/10,   351/498] loss: 4.610 acc: 1.188 lr: 0.00100


406it [00:09, 52.36it/s]

[7/10,   401/498] loss: 4.610 acc: 1.250 lr: 0.00100


454it [00:10, 50.14it/s]

[7/10,   451/498] loss: 4.608 acc: 1.375 lr: 0.00100


498it [00:11, 41.59it/s]


[7/10] val acc: 1.000


6it [00:02,  3.71it/s]

[8/10,     1/498] loss: 0.092 acc: 0.000 lr: 0.00100


59it [00:03, 45.87it/s]

[8/10,    51/498] loss: 4.605 acc: 1.750 lr: 0.00100


107it [00:04, 52.64it/s]

[8/10,   101/498] loss: 4.615 acc: 0.750 lr: 0.00100


155it [00:04, 53.43it/s]

[8/10,   151/498] loss: 4.610 acc: 1.250 lr: 0.00100


209it [00:06, 52.56it/s]

[8/10,   201/498] loss: 4.604 acc: 1.812 lr: 0.00100


257it [00:06, 53.42it/s]

[8/10,   251/498] loss: 4.608 acc: 1.375 lr: 0.00100


311it [00:07, 53.71it/s]

[8/10,   301/498] loss: 4.610 acc: 1.188 lr: 0.00100


359it [00:08, 53.00it/s]

[8/10,   351/498] loss: 4.610 acc: 1.188 lr: 0.00100


407it [00:09, 52.04it/s]

[8/10,   401/498] loss: 4.611 acc: 1.125 lr: 0.00100


455it [00:10, 53.55it/s]

[8/10,   451/498] loss: 4.614 acc: 0.812 lr: 0.00100


498it [00:11, 41.98it/s]


[8/10] val acc: 1.000


6it [00:02,  3.63it/s]

[9/10,     1/498] loss: 0.092 acc: 0.000 lr: 0.00100


58it [00:03, 45.52it/s]

[9/10,    51/498] loss: 4.607 acc: 1.562 lr: 0.00100


106it [00:04, 52.16it/s]

[9/10,   101/498] loss: 4.608 acc: 1.438 lr: 0.00100


160it [00:05, 53.60it/s]

[9/10,   151/498] loss: 4.610 acc: 1.250 lr: 0.00100


208it [00:06, 52.61it/s]

[9/10,   201/498] loss: 4.612 acc: 1.000 lr: 0.00100


256it [00:06, 53.06it/s]

[9/10,   251/498] loss: 4.608 acc: 1.438 lr: 0.00100


310it [00:08, 53.06it/s]

[9/10,   301/498] loss: 4.607 acc: 1.500 lr: 0.00100


358it [00:08, 53.06it/s]

[9/10,   351/498] loss: 4.615 acc: 0.688 lr: 0.00100


412it [00:09, 53.66it/s]

[9/10,   401/498] loss: 4.608 acc: 1.375 lr: 0.00100


454it [00:10, 52.55it/s]

[9/10,   451/498] loss: 4.612 acc: 1.000 lr: 0.00100


498it [00:11, 41.55it/s]


[9/10] val acc: 1.000


7it [00:02,  4.22it/s]

[10/10,     1/498] loss: 0.092 acc: 3.125 lr: 0.00100


60it [00:03, 46.15it/s]

[10/10,    51/498] loss: 4.613 acc: 0.938 lr: 0.00100


108it [00:04, 53.58it/s]

[10/10,   101/498] loss: 4.605 acc: 1.750 lr: 0.00100


156it [00:05, 51.68it/s]

[10/10,   151/498] loss: 4.610 acc: 1.188 lr: 0.00100


210it [00:06, 52.95it/s]

[10/10,   201/498] loss: 4.608 acc: 1.375 lr: 0.00100


258it [00:07, 52.92it/s]

[10/10,   251/498] loss: 4.608 acc: 1.375 lr: 0.00100


306it [00:07, 52.73it/s]

[10/10,   301/498] loss: 4.606 acc: 1.625 lr: 0.00100


360it [00:08, 53.18it/s]

[10/10,   351/498] loss: 4.612 acc: 1.062 lr: 0.00100


408it [00:09, 51.45it/s]

[10/10,   401/498] loss: 4.610 acc: 1.250 lr: 0.00100


456it [00:10, 53.21it/s]

[10/10,   451/498] loss: 4.610 acc: 1.250 lr: 0.00100


498it [00:12, 41.47it/s]


[10/10] val acc: 1.000
