In [1]:
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split
import torchvision
import torchvision.transforms as transforms
from torchvision.datasets import CIFAR100
%matplotlib inline

In [2]:
hflip = transforms.RandomHorizontalFlip(p=0.5)
train_transform = transforms.Compose([hflip, transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])])
dataset = CIFAR100(root = 'data/', download = True, transform = train_transform)
test_ds = CIFAR100(root = 'data/', train = False)
train_ds, val_ds = random_split(dataset, [40000, 10000])

Downloading https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz to data/cifar-100-python.tar.gz


100%|██████████| 169001437/169001437 [00:05<00:00, 28957589.91it/s]


Extracting data/cifar-100-python.tar.gz to data/


In [3]:
device = torch.device("cuda:0")

In [4]:
class VanillaSoftmax(nn.Module):
  def __init__(self):
    super().__init__()

  def forward(self, x, dim=-1):
    return torch.softmax(x, dim=dim)

class GumbelSoftmax(nn.Module):
  def __init__(self):
      super().__init__()
      self.linear = nn.Linear(100, 100)
      self.temperature = 1

  def forward(self, x):
      logits = self.linear(x)
      return nn.functional.gumbel_softmax(logits, tau=self.temperature)

class AdaptiveSoftmax(nn.Module):
  def __init__(self):
    super().__init__()
    self.adapt_softmax = nn.AdaptiveLogSoftmaxWithLoss(100, 100, cutoffs=[49, 51], device="cuda:0")  #since all the classes have equal probability, all of them belong to the head cluster

  def forward(self, x, y):
    return self.adapt_softmax(x, y)

In [5]:
def accuracy(outputs, labels):
    _, preds = torch.max(outputs, dim=1)
    return torch.tensor(torch.sum(preds == labels).item() / len(preds))

def to_device(data, device):
    if isinstance(data, (list,tuple)):
        return [to_device(x, device) for x in data]
    return data.to(device, non_blocking=True)

class DeviceDataLoader():
    def __init__(self, dl, device):
        self.dl = dl
        self.device = device
        
    def __iter__(self):
        for b in self.dl: 
            yield to_device(b, self.device)

    def __len__(self):
        return len(self.dl)

In [6]:
class CifarCNN(nn.Module):
    def __init__(self, softmax_type):
        super().__init__()
        self.network = nn.Sequential(
          nn.Conv2d(3, 32, kernel_size=3, padding=1),
          nn.ReLU(),
          nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
          nn.ReLU(),
          nn.MaxPool2d(2, 2),

          nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
          nn.ReLU(),
          nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1),
          nn.ReLU(),
          nn.MaxPool2d(2, 2),

          nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
          nn.ReLU(),
          nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
          nn.ReLU(),
          nn.MaxPool2d(2, 2),

          nn.Flatten(), 
          nn.Linear(256*4*4, 1024),
          nn.ReLU(),
          nn.Linear(1024, 512),
          nn.ReLU(),
          nn.Linear(512, 100))
        
        self.norm_fn = softmax_type
        
    def forward(self, x):
        return self.network(x)

    def train_batch(self, batch):

      images, targets = batch
      preds = self(images)
      norm_func = self.norm_fn

      if norm_func == "VanillaSoftmax":
        m = VanillaSoftmax()
      elif norm_func == "GumbelSoftmax":
        m = GumbelSoftmax()
      elif norm_func == "AdaptiveSoftmax":
        m = AdaptiveSoftmax()

      if norm_func != "AdaptiveSoftmax":
        norm_preds = m(preds)
        loss = nn.CrossEntropyLoss()
        loss = loss(preds, targets)

      else:
        _, loss = m(preds, targets)
      
      return loss

    def val_step(self, batch):
      images, targets = batch
      preds = self(images)

      norm_func = self.norm_fn

      if norm_func == "VanillaSoftmax":
        m = VanillaSoftmax()
      elif norm_func == "GumbelSoftmax":
        m = GumbelSoftmax()
      elif norm_func == "AdaptiveSoftmax":
        m = AdaptiveSoftmax()

      if norm_func != "AdaptiveSoftmax":
        norm_preds = m(preds)
        loss = nn.CrossEntropyLoss()
        loss = loss(preds, targets)
      else:
        _, loss = m(preds, targets)
      val_acc = accuracy(preds, targets)

      return {'val_loss': loss.detach(), 'val_acc': val_acc}

    def val_epoch_status(self, outputs):
        batch_losses = [x['val_loss'] for x in outputs]
        epoch_loss = torch.stack(batch_losses).mean()   
        batch_accs = [x['val_acc'] for x in outputs]
        epoch_acc = torch.stack(batch_accs).mean()     
        
        return {'val_loss': epoch_loss.item(), 'val_acc': epoch_acc.item()}
    
    def epoch_status(self, epoch, result):
        print("Epoch [{}], train_loss: {:.4f}, val_loss: {:.4f}, val_acc: {:.4f}".format(
            epoch, result['train_loss'], result['val_loss'], result['val_acc']))

In [7]:
def evaluate(model, val_loader):
  model.eval()
  torch.no_grad()
  outputs = [model.val_step(batch) for batch in val_loader]
  return model.val_epoch_status(outputs)


def fit(model, train_loader, val_loader, epochs=20, lr=0.001, opt_func=torch.optim.Adam):
    history = []
    optimizer = opt_func(model.parameters(), lr)
    for epoch in range(epochs):
        model.train()
        train_losses = []
        for batch in train_loader:
            loss = model.train_batch(batch)
            train_losses.append(loss)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
        result = evaluate(model, val_loader)
        result['train_loss'] = torch.stack(train_losses).mean().item()
        model.epoch_status(epoch, result)
        history.append(result)

In [8]:
train_loader = DataLoader(train_ds, num_workers=2, batch_size=128, pin_memory=True)
val_loader = DataLoader(val_ds, num_workers=2, batch_size=128, pin_memory=True)
test_loader = DataLoader(test_ds, num_workers=2, batch_size=128, pin_memory=True)

In [9]:
model1 = CifarCNN("VanillaSoftmax")
to_device(model1, device)

CifarCNN(
  (network): Sequential(
    (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU()
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU()
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU()
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU()
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU()
    (14): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (15): Flatten(start_dim=1, end_dim=-1)
    (16): Linear(in_features=4096, out_features=1024, bias=True)
    (17): ReLU()
    (18): Linear(in_features=10

In [10]:
train_loader = DeviceDataLoader(train_loader, device)
val_loader = DeviceDataLoader(val_loader, device)
test_loader = DeviceDataLoader(val_loader, device)

In [11]:
fit(model1, train_loader, val_loader)

Epoch [0], train_loss: 4.2820, val_loss: 3.9340, val_acc: 0.0789
Epoch [1], train_loss: 3.6580, val_loss: 3.4377, val_acc: 0.1672
Epoch [2], train_loss: 3.2363, val_loss: 3.1203, val_acc: 0.2274
Epoch [3], train_loss: 2.9260, val_loss: 2.9019, val_acc: 0.2725
Epoch [4], train_loss: 2.6888, val_loss: 2.7645, val_acc: 0.2984
Epoch [5], train_loss: 2.4809, val_loss: 2.6493, val_acc: 0.3279
Epoch [6], train_loss: 2.3038, val_loss: 2.5819, val_acc: 0.3428
Epoch [7], train_loss: 2.1527, val_loss: 2.5595, val_acc: 0.3505
Epoch [8], train_loss: 2.0315, val_loss: 2.5513, val_acc: 0.3571
Epoch [9], train_loss: 1.9071, val_loss: 2.5746, val_acc: 0.3617
Epoch [10], train_loss: 1.8047, val_loss: 2.5210, val_acc: 0.3776
Epoch [11], train_loss: 1.7088, val_loss: 2.6438, val_acc: 0.3641
Epoch [12], train_loss: 1.6293, val_loss: 2.6471, val_acc: 0.3758
Epoch [13], train_loss: 1.5677, val_loss: 2.6753, val_acc: 0.3784
Epoch [14], train_loss: 1.4899, val_loss: 2.7519, val_acc: 0.3776
Epoch [15], train_lo

In [12]:
model2 = CifarCNN("GumbelSoftmax")
to_device(model2, device)

CifarCNN(
  (network): Sequential(
    (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU()
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU()
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU()
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU()
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU()
    (14): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (15): Flatten(start_dim=1, end_dim=-1)
    (16): Linear(in_features=4096, out_features=1024, bias=True)
    (17): ReLU()
    (18): Linear(in_features=10

In [13]:
fit(model2, train_loader, val_loader)

RuntimeError: ignored

In [14]:
model3 = CifarCNN("AdaptiveSoftmax")
to_device(model3, device)

CifarCNN(
  (network): Sequential(
    (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU()
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU()
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU()
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU()
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU()
    (14): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (15): Flatten(start_dim=1, end_dim=-1)
    (16): Linear(in_features=4096, out_features=1024, bias=True)
    (17): ReLU()
    (18): Linear(in_features=10

In [16]:
fit(model3, train_loader, test_loader)

Epoch [0], train_loss: 5.8544, val_loss: 5.8468, val_acc: 0.0088
Epoch [1], train_loss: 5.8543, val_loss: 5.8476, val_acc: 0.0088
Epoch [2], train_loss: 5.8542, val_loss: 5.8485, val_acc: 0.0097
Epoch [3], train_loss: 5.8538, val_loss: 5.8478, val_acc: 0.0088
Epoch [4], train_loss: 5.8523, val_loss: 5.8471, val_acc: 0.0088
Epoch [5], train_loss: 5.8548, val_loss: 5.8492, val_acc: 0.0088
Epoch [6], train_loss: 5.8547, val_loss: 5.8501, val_acc: 0.0103
Epoch [7], train_loss: 5.8558, val_loss: 5.8468, val_acc: 0.0103
Epoch [8], train_loss: 5.8542, val_loss: 5.8499, val_acc: 0.0103
Epoch [9], train_loss: 5.8522, val_loss: 5.8451, val_acc: 0.0106
Epoch [10], train_loss: 5.8539, val_loss: 5.8468, val_acc: 0.0106
Epoch [11], train_loss: 5.8539, val_loss: 5.8443, val_acc: 0.0103
Epoch [12], train_loss: 5.8519, val_loss: 5.8494, val_acc: 0.0096
Epoch [13], train_loss: 5.8551, val_loss: 5.8471, val_acc: 0.0096
Epoch [14], train_loss: 5.8538, val_loss: 5.8480, val_acc: 0.0096
Epoch [15], train_lo