<a href="https://colab.research.google.com/github/yoonkim97/pytorch-resnet-mnist/blob/master/MNISTResNet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
from torchvision.models.resnet import ResNet, BasicBlock
from torchvision.datasets import MNIST
from tqdm.autonotebook import tqdm
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score, confusion_matrix, roc_curve, auc
import inspect
import time
from torch import nn, optim
import torch
from torchvision.transforms import Compose, ToTensor, Normalize, Resize
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import fetch_openml
import urllib
from torch.nn import functional as F

In [17]:

num_workers = 0
batch_size = 20
basepath = '.'

def set_header_for(url, filename):
    opener = urllib.request.URLopener()
    opener.addheader('User-Agent', 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_9_3) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/35.0.1916.47 Safari/537.36')
    opener.retrieve(
    url, f'{basepath}/{filename}')

set_header_for('http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz', 'train-images-idx3-ubyte.gz')
set_header_for('http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz', 'train-labels-idx1-ubyte.gz')
set_header_for('http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz', 't10k-images-idx3-ubyte.gz')
set_header_for('http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz', 't10k-labels-idx1-ubyte.gz')


  if sys.path[0] == '':
  del sys.path[0]
  
  from ipykernel import kernelapp as app


In [0]:
class MnistResNet(ResNet):
  def __init__(self):
    super(MnistResNet, self).__init__(BasicBlock, [2, 2, 2, 2], num_classes=10)
    self.conv1 = torch.nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  def forward(self, x):
    return torch.softmax(super(MnistResNet, self).forward(x), dim=-1)

In [0]:
train_label_classes = [0, 1, 2, 3, 4, 5, 6, 7, 8]
test_label_classes = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
def get_same_indices(target, labels):
  label_indices = []
  for i in range (len(target)):
    for j in range (len(labels)):
      if target[i] == labels[j]:
        label_indices.append(i)
  return label_indices

In [0]:
def get_data_loaders(train_batch_size, val_batch_size):
  mnist = MNIST(download=True, train=True, root=".").train_data.float()
    
  data_transform = Compose([Resize((224, 224)),ToTensor(), Normalize((mnist.mean()/255,), (mnist.std()/255,))])

  train_dataset = MNIST(download=True, root=".", transform=data_transform, train=True)
  train_indices = get_same_indices(train_dataset.targets, train_label_classes)
  train_loader = DataLoader(dataset = train_dataset, batch_size=train_batch_size, shuffle=False, sampler=torch.utils.data.sampler.SubsetRandomSampler(train_indices))

  val_dataset = MNIST(download=False, root=".", transform=data_transform, train=False)
  val_indices = get_same_indices(val_dataset.targets, test_label_classes)
  val_loader = DataLoader(dataset = val_dataset, batch_size=val_batch_size, shuffle=False, sampler=torch.utils.data.sampler.SubsetRandomSampler(val_indices))
  return train_loader, val_loader

In [0]:
def calculate_metric(metric_fn, true_y, pred_y):
  if "average" in inspect.getfullargspec(metric_fn).args:
    return metric_fn(true_y, pred_y, average="macro")
  else:
    return metric_fn(true_y, pred_y)
    
def print_scores(p, r, f1, a, batch_size):
  for name, scores in zip(("precision", "recall", "F1", "accuracy"), (p, r, f1, a)):
    print(f"\t{name.rjust(14, ' ')}: {sum(scores)/batch_size:.4f}")

In [36]:
start_ts = time.time()

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
epochs = 5

model = MnistResNet().to(device)
train_loader, val_loader = get_data_loaders(256, 256)


losses = []
loss_function = nn.CrossEntropyLoss()
optimizer = optim.Adadelta(model.parameters())

batches = len(train_loader)
val_batches = len(val_loader)

# training loop + eval loop
for epoch in range(epochs):
    total_loss = 0
    progress = tqdm(enumerate(train_loader), desc="Loss: ", total=batches)
    model.train()
    
    for i, data in progress:
        X, y = data[0].to(device), data[1].to(device)
        
        model.zero_grad()
        outputs = model(X)
        loss = loss_function(outputs, y)

        loss.backward()
        optimizer.step()
        current_loss = loss.item()
        total_loss += current_loss
        progress.set_description("Loss: {:.4f}".format(total_loss/(i+1)))
        
    torch.cuda.empty_cache()
    
    val_losses = 0
    precision, recall, f1, accuracy = [], [], [], []
    
    confusion_actuals = []
    probabilities = []
    predictions = []
    roc_actuals = []

    which_class = 9

    model.eval()
    with torch.no_grad():
        for i, data in enumerate(val_loader):
            X, y = data[0].to(device), data[1].to(device)
            outputs = model(X)
            prediction = outputs.argmax(dim=1, keepdim=True)
            confusion_actuals.extend(y.view_as(prediction))
            roc_actuals.extend(y.view_as(prediction) == which_class)
            probabilities.extend(np.exp(outputs.cpu()[:, which_class]))  
            predictions.extend(prediction)

            val_losses += loss_function(outputs, y)

            predicted_classes = torch.max(outputs, 1)[1]
            
            for acc, metric in zip((precision, recall, f1, accuracy), 
                                   (precision_score, recall_score, f1_score, accuracy_score)):
                acc.append(
                    calculate_metric(metric, y.cpu(), predicted_classes.cpu())
                )
    confusion_actuals = [i.item() for i in confusion_actuals]
    predictions = [i.item() for i in predictions]
    roc_actuals = [i.item() for i in roc_actuals]
    class_probabilities = [i.item() for i in probabilities]


    # fpr, tpr, _ = roc_curve(roc_actuals, class_probabilities)
    # roc_auc = auc(fpr, tpr)
    # plt.figure()
    # lw = 2 
    # plt.plot(fpr, tpr, color='darkorange', lw=lw, label='ROC curve (area = %0.2f)' % roc_auc)
    # plt.plot([0, 1], [0, 1], color = 'navy', lw=lw, linestyle='--')
    # plt.xlim([0.0, 1.0])
    # plt.ylim([0.0, 1.05])
    # plt.xlabel('False Positive Rate')
    # plt.ylabel('True Positive Rate')
    # plt.title('ROC for digit=%d class' % which_class)
    # plt.legend(loc="lower right")
    # plt.show()
        
    print(f"Epoch {epoch+1}/{epochs}, training loss: {total_loss/batches}, validation loss: {val_losses/val_batches}")
    print(confusion_matrix(confusion_actuals, predictions))    
    print_scores(precision, recall, f1, accuracy, val_batches)
    losses.append(total_loss/batches)
print(losses)
print(f"Training time: {time.time()-start_ts}s")




HBox(children=(IntProgress(value=0, description='Loss: ', max=212, style=ProgressStyle(description_width='init…




  _warn_prf(average, modifier, msg_start, len(result))


Epoch 1/5, training loss: 1.6082462094864756, validation loss: 2.0825254917144775
[[  20    0  231  680   32    0    0    3   14    0]
 [   0 1008    1   18   85    0   10    0   13    0]
 [   0    0   55  976    0    0    0    1    0    0]
 [   0    1    0 1008    0    0    0    1    0    0]
 [   0    6   11  169  780    0    0   16    0    0]
 [   0    1    2  848    0   41    0    0    0    0]
 [   0    2  299  486   45    3  101    0   22    0]
 [   0   13   10  368    3    0    0  634    0    0]
 [   0    0   53  853    1    1    0    1   65    0]
 [   0    8   81  572   51    1    0  291    5    0]]
	     precision: 0.4947
	        recall: 0.3611
	            F1: 0.3133
	      accuracy: 0.3757


HBox(children=(IntProgress(value=0, description='Loss: ', max=212, style=ProgressStyle(description_width='init…


Epoch 2/5, training loss: 1.4751488684483294, validation loss: 1.8221243619918823
[[ 609    0  230    0    1    0   14    0  126    0]
 [   0  891   13    0   16    0    3    0  212    0]
 [   0    0 1023    0    0    0    0    0    9    0]
 [   0    0  722  107    0    1    0    0  180    0]
 [   0    0    3    0  939    0    7    0   33    0]
 [   0    0  167    0    0  629    8    0   88    0]
 [   2    1    9    0    0    0  926    0   20    0]
 [   0    2  565    1   38    0    0  266  156    0]
 [   0    0    3    0    0    1    0    0  970    0]
 [   4    0   27    1  229   19    1    4  724    0]]
	     precision: 0.7367
	        recall: 0.6395
	            F1: 0.5948
	      accuracy: 0.6387


HBox(children=(IntProgress(value=0, description='Loss: ', max=212, style=ProgressStyle(description_width='init…


Epoch 3/5, training loss: 1.470278873758496, validation loss: 1.5689407587051392
[[ 973    0    2    0    0    0    2    1    2    0]
 [   0 1130    1    3    0    0    0    1    0    0]
 [   0    0 1029    0    0    0    0    2    1    0]
 [   0    0    0 1008    0    1    0    1    0    0]
 [   1    1    5    0  969    0    1    1    4    0]
 [   0    0    0    9    0  881    1    0    1    0]
 [   5    5    2    1    0    1  943    0    1    0]
 [   0    2    7    0    0    0    0 1017    2    0]
 [   0    0    0    3    0    0    0    0  971    0]
 [  15    1   44   17  134   13    0  260  525    0]]
	     precision: 0.8185
	        recall: 0.8930
	            F1: 0.8498
	      accuracy: 0.8917


HBox(children=(IntProgress(value=0, description='Loss: ', max=212, style=ProgressStyle(description_width='init…


Epoch 4/5, training loss: 1.4670822052460797, validation loss: 1.56833016872406
[[ 971    0    2    0    0    2    0    2    3    0]
 [   0 1125    2    3    1    0    0    4    0    0]
 [   0    0 1031    0    0    0    0    1    0    0]
 [   0    0    0 1003    0    5    0    1    1    0]
 [   0    0    1    0  973    0    1    2    5    0]
 [   0    0    0    2    0  889    1    0    0    0]
 [   7    5    2    1    0   10  930    0    3    0]
 [   0    2    5    0    0    1    0 1019    1    0]
 [   0    0    0    1    1    2    0    0  970    0]
 [  15    1    4    3  320   79    0  324  263    0]]
	     precision: 0.8119
	        recall: 0.8914
	            F1: 0.8458
	      accuracy: 0.8922


HBox(children=(IntProgress(value=0, description='Loss: ', max=212, style=ProgressStyle(description_width='init…


Epoch 5/5, training loss: 1.4658914618896988, validation loss: 1.577235221862793
[[ 969    0    2    0    1    2    3    3    0    0]
 [   0 1134    0    1    0    0    0    0    0    0]
 [   1    3  988    7    1    0    1   31    0    0]
 [   0    0    0 1002    0    5    0    3    0    0]
 [   0    1    0    0  979    0    2    0    0    0]
 [   0    0    0    4    0  887    1    0    0    0]
 [   0    6    0    0    1    4  947    0    0    0]
 [   0    5    2    0    0    1    0 1020    0    0]
 [   1    0    1   23    4   20    6    4  915    0]
 [  28    6    6   12  354  275    3  317    8    0]]
	     precision: 0.8087
	        recall: 0.8833
	            F1: 0.8388
	      accuracy: 0.8839
[1.6082462094864756, 1.4751488684483294, 1.470278873758496, 1.4670822052460797, 1.4658914618896988]
Training time: 591.5420858860016s


In [0]:
class ModelWithTemperature(nn.Module):
    def __init__(self, model):
        super(ModelWithTemperature, self).__init__()
        self.model = model
        self.temperature = nn.Parameter(torch.ones(1) * 1.5)

    def forward(self, input):
        logits = self.model(input)
        return self.temperature_scale(logits)

    def temperature_scale(self, logits):
        """
        Perform temperature scaling on logits
        """
        # Expand temperature to match the size of logits
        temperature = self.temperature.unsqueeze(1).expand(logits.size(0), logits.size(1))
        return logits / temperature

    def set_temperature(self, valid_loader):
        self.cuda()
        nll_criterion = nn.CrossEntropyLoss().cuda()
        ece_criterion = _ECELoss().cuda()
        anomaly_criterion = _AnomalyDetection().cuda()

        # First: collect all the logits and labels for the validation set
        logits_list = []
        labels_list = []
        with torch.no_grad():
            for input, label in valid_loader:
                input = input.cuda()
                logits = self.model(input)
                logits_list.append(logits)
                labels_list.append(label)
            logits = torch.cat(logits_list).cuda()
            labels = torch.cat(labels_list).cuda()

        # Calculate NLL and ECE before temperature scaling
        # before_temperature_nll = nll_criterion(logits, labels).item()
        # before_temperature_ece = ece_criterion(logits, labels).item()
        # print('Before temperature - NLL: %.3f, ECE: %.3f' % (before_temperature_nll, before_temperature_ece))

        # Next: optimize the temperature w.r.t. NLL
        optimizer = optim.LBFGS([self.temperature], lr=0.01, max_iter=50)

        def eval():
            loss = nll_criterion(self.temperature_scale(logits), labels)
            loss.backward()
            return loss
        optimizer.step(eval)

        # Calculate NLL and ECE after temperature scaling
        after_temperature_nll = nll_criterion(self.temperature_scale(logits), labels).item()
        after_temperature_ece = ece_criterion(self.temperature_scale(logits), labels).item()
        after_temperature_anomaly = anomaly_criterion(self.temperature_scale(logits), labels)
        print('Optimal temperature: %.3f' % self.temperature.item())
        print('After temperature - NLL: %.3f, ECE: %.3f' % (after_temperature_nll, after_temperature_ece))
        print('Anomaly Count: %d' %after_temperature_anomaly)
        return self

In [0]:
class _ECELoss(nn.Module):
    def __init__(self, n_bins=15):
        """
        n_bins (int): number of confidence interval bins
        """
        super(_ECELoss, self).__init__()
        bin_boundaries = torch.linspace(0, 1, n_bins + 1)
        self.bin_lowers = bin_boundaries[:-1]
        self.bin_uppers = bin_boundaries[1:]

    def forward(self, logits, labels):
      with torch.no_grad():
        softmaxes = F.softmax(logits, dim=1)
        confidences, predictions = torch.max(softmaxes, 1)
        accuracies = predictions.eq(labels)
    
      ece = torch.zeros(1, device=logits.device)
      for bin_lower, bin_upper in zip(self.bin_lowers, self.bin_uppers):
          # Calculated |confidence - accuracy| in each bin
          in_bin = confidences.gt(bin_lower.item()) * confidences.le(bin_upper.item())
          prop_in_bin = in_bin.float().mean()
          if prop_in_bin.item() > 0:
              accuracy_in_bin = accuracies[in_bin].float().mean()
              avg_confidence_in_bin = confidences[in_bin].mean()
              ece += torch.abs(avg_confidence_in_bin - accuracy_in_bin) * prop_in_bin

      return ece

In [0]:
class _AnomalyDetection(nn.Module):
  def __init__(self):
    super(_AnomalyDetection, self).__init__()
  def forward(self, logits, labels):
    with torch.no_grad():
      softmaxes = F.softmax(logits, dim=1)
      confidences, predictions = torch.max(softmaxes, 1)
      accuracies = predictions.eq(labels)

    predictions_list = predictions.cpu().numpy()
    confidences_list = confidences.cpu().numpy()
    actuals_list = labels.cpu().numpy()

    all_triples = []
    anomaly_triples = []
    modified_anomaly_triples = []

    anomaly_count = 0
    for i in range(len(actuals_list)):
      all_triples.append((actuals_list[i], predictions_list[i], confidences_list[i]))
      if (confidences_list[i] < 0.80):
        anomaly_triples.append((i, actuals_list[i], predictions_list[i], confidences_list[i]))
        if (actuals_list[i] == 9):
          predictions_list[i] = 9
          modified_anomaly_triples.append((i, actuals_list[i], predictions_list[i], confidences_list[i]))
        anomaly_count += 1
    print(anomaly_triples)
    print(modified_anomaly_triples)
    print(len(anomaly_triples))
    print(len(modified_anomaly_triples))

    return anomaly_count

In [42]:
scaled_model = ModelWithTemperature(model)
scaled_model.set_temperature(val_loader)


model_filename = 'model_with_temperature.pth'
torch.save(scaled_model.state_dict(), model_filename)
print('Temperature scaled model saved to %s', model_filename)
print('Done!')

[(11, 9, 5, 0.685994), (13, 9, 3, 0.45989692), (45, 8, 8, 0.37976483), (130, 9, 5, 0.6563065), (145, 9, 7, 0.69590145), (178, 9, 4, 0.5665457), (179, 9, 4, 0.60991746), (241, 9, 4, 0.62459296), (242, 9, 7, 0.62323684), (243, 9, 0, 0.71782786), (263, 9, 5, 0.60171026), (303, 9, 7, 0.74335974), (400, 9, 4, 0.33204296), (445, 9, 7, 0.5391221), (457, 8, 4, 0.58640194), (460, 9, 5, 0.5841426), (467, 2, 2, 0.45893326), (504, 9, 5, 0.6720578), (530, 9, 5, 0.45750248), (557, 9, 5, 0.59489113), (576, 9, 5, 0.792452), (588, 9, 4, 0.39732212), (607, 9, 5, 0.3166903), (621, 9, 5, 0.71744514), (628, 9, 4, 0.766451), (640, 9, 5, 0.3493439), (677, 9, 5, 0.36712888), (690, 9, 5, 0.6383828), (712, 9, 5, 0.44902757), (715, 9, 7, 0.7010521), (720, 8, 5, 0.5184946), (748, 9, 7, 0.6609909), (764, 9, 5, 0.2799803), (821, 6, 6, 0.38658223), (823, 9, 4, 0.4966704), (841, 9, 5, 0.75184554), (845, 9, 4, 0.58853316), (862, 2, 2, 0.5479177), (870, 1, 1, 0.7335397), (875, 9, 4, 0.50488394), (880, 9, 8, 0.69092685)