<a href="https://colab.research.google.com/github/yoonkim97/pytorch-resnet-mnist/blob/master/MNISTResNet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
from torchvision.models.resnet import ResNet, BasicBlock
from torchvision.datasets import MNIST
from tqdm.autonotebook import tqdm
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score, confusion_matrix, roc_curve, auc
import inspect
import time
from torch import nn, optim
import torch
from torchvision.transforms import Compose, ToTensor, Normalize, Resize
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import fetch_openml
import urllib
from torch.nn import functional as F

In [48]:

num_workers = 0
batch_size = 20
basepath = '.'

def set_header_for(url, filename):
    opener = urllib.request.URLopener()
    opener.addheader('User-Agent', 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_9_3) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/35.0.1916.47 Safari/537.36')
    opener.retrieve(
    url, f'{basepath}/{filename}')

set_header_for('http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz', 'train-images-idx3-ubyte.gz')
set_header_for('http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz', 'train-labels-idx1-ubyte.gz')
set_header_for('http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz', 't10k-images-idx3-ubyte.gz')
set_header_for('http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz', 't10k-labels-idx1-ubyte.gz')


  if sys.path[0] == '':
  del sys.path[0]
  
  from ipykernel import kernelapp as app


In [0]:
class MnistResNet(ResNet):
  def __init__(self):
    super(MnistResNet, self).__init__(BasicBlock, [2, 2, 2, 2], num_classes=10)
    self.conv1 = torch.nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  def forward(self, x):
    return torch.softmax(super(MnistResNet, self).forward(x), dim=-1)

In [0]:
train_label_classes = [0, 1, 2, 3, 4, 5, 6, 7, 8]
test_label_classes = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
def get_same_indices(target, labels):
  label_indices = []
  for i in range (len(target)):
    for j in range (len(labels)):
      if target[i] == labels[j]:
        label_indices.append(i)
  return label_indices

In [0]:
def get_data_loaders(train_batch_size, val_batch_size):
  mnist = MNIST(download=True, train=True, root=".").train_data.float()
    
  data_transform = Compose([Resize((224, 224)),ToTensor(), Normalize((mnist.mean()/255,), (mnist.std()/255,))])

  train_dataset = MNIST(download=True, root=".", transform=data_transform, train=True)
  train_indices = get_same_indices(train_dataset.targets, train_label_classes)
  train_loader = DataLoader(dataset = train_dataset, batch_size=train_batch_size, shuffle=False, sampler=torch.utils.data.sampler.SubsetRandomSampler(train_indices))

  val_dataset = MNIST(download=False, root=".", transform=data_transform, train=False)
  val_indices = get_same_indices(val_dataset.targets, test_label_classes)
  val_loader = DataLoader(dataset = val_dataset, batch_size=val_batch_size, shuffle=False, sampler=torch.utils.data.sampler.SubsetRandomSampler(val_indices))
  return train_loader, val_loader

In [0]:
def calculate_metric(metric_fn, true_y, pred_y):
  if "average" in inspect.getfullargspec(metric_fn).args:
    return metric_fn(true_y, pred_y, average="macro")
  else:
    return metric_fn(true_y, pred_y)
    
def print_scores(p, r, f1, a, batch_size):
  for name, scores in zip(("precision", "recall", "F1", "accuracy"), (p, r, f1, a)):
    print(f"\t{name.rjust(14, ' ')}: {sum(scores)/batch_size:.4f}")

In [0]:
start_ts = time.time()

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
epochs = 5

model = MnistResNet().to(device)
train_loader, val_loader = get_data_loaders(256, 256)


losses = []
loss_function = nn.CrossEntropyLoss()
optimizer = optim.Adadelta(model.parameters())

batches = len(train_loader)
val_batches = len(val_loader)

# training loop + eval loop
for epoch in range(epochs):
    total_loss = 0
    progress = tqdm(enumerate(train_loader), desc="Loss: ", total=batches)
    model.train()
    
    for i, data in progress:
        X, y = data[0].to(device), data[1].to(device)
        
        model.zero_grad()
        outputs = model(X)
        loss = loss_function(outputs, y)

        loss.backward()
        optimizer.step()
        current_loss = loss.item()
        total_loss += current_loss
        progress.set_description("Loss: {:.4f}".format(total_loss/(i+1)))
        
    torch.cuda.empty_cache()
    
    val_losses = 0
    precision, recall, f1, accuracy = [], [], [], []
    
    confusion_actuals = []
    probabilities = []
    predictions = []
    roc_actuals = []

    which_class = 9

    model.eval()
    with torch.no_grad():
        for i, data in enumerate(val_loader):
            X, y = data[0].to(device), data[1].to(device)
            outputs = model(X)
            prediction = outputs.argmax(dim=1, keepdim=True)
            confusion_actuals.extend(y.view_as(prediction))
            roc_actuals.extend(y.view_as(prediction) == which_class)
            probabilities.extend(np.exp(outputs.cpu()[:, which_class]))  
            predictions.extend(prediction)

            val_losses += loss_function(outputs, y)

            predicted_classes = torch.max(outputs, 1)[1]
            
            for acc, metric in zip((precision, recall, f1, accuracy), 
                                   (precision_score, recall_score, f1_score, accuracy_score)):
                acc.append(
                    calculate_metric(metric, y.cpu(), predicted_classes.cpu())
                )
    confusion_actuals = [i.item() for i in confusion_actuals]
    predictions = [i.item() for i in predictions]
    roc_actuals = [i.item() for i in roc_actuals]
    class_probabilities = [i.item() for i in probabilities]

    print(roc_actuals)
    print(class_probabilities)

    # fpr, tpr, _ = roc_curve(roc_actuals, class_probabilities)
    # roc_auc = auc(fpr, tpr)
    # plt.figure()
    # lw = 2 
    # plt.plot(fpr, tpr, color='darkorange', lw=lw, label='ROC curve (area = %0.2f)' % roc_auc)
    # plt.plot([0, 1], [0, 1], color = 'navy', lw=lw, linestyle='--')
    # plt.xlim([0.0, 1.0])
    # plt.ylim([0.0, 1.05])
    # plt.xlabel('False Positive Rate')
    # plt.ylabel('True Positive Rate')
    # plt.title('ROC for digit=%d class' % which_class)
    # plt.legend(loc="lower right")
    # plt.show()
        
    print(f"Epoch {epoch+1}/{epochs}, training loss: {total_loss/batches}, validation loss: {val_losses/val_batches}")
    print(confusion_matrix(confusion_actuals, predictions))    
    print_scores(precision, recall, f1, accuracy, val_batches)
    losses.append(total_loss/batches)
print(losses)
print(f"Training time: {time.time()-start_ts}s")




HBox(children=(IntProgress(value=0, description='Loss: ', max=212, style=ProgressStyle(description_width='init…

In [0]:
class ModelWithTemperature(nn.Module):
    def __init__(self, model):
        super(ModelWithTemperature, self).__init__()
        self.model = model
        self.temperature = nn.Parameter(torch.ones(1) * 1.5)

    def forward(self, input):
        logits = self.model(input)
        return self.temperature_scale(logits)

    def temperature_scale(self, logits):
        """
        Perform temperature scaling on logits
        """
        # Expand temperature to match the size of logits
        temperature = self.temperature.unsqueeze(1).expand(logits.size(0), logits.size(1))
        return logits / temperature

    def set_temperature(self, valid_loader):
        self.cuda()
        nll_criterion = nn.CrossEntropyLoss().cuda()
        ece_criterion = _ECELoss().cuda()
        anomaly_criterion = _AnomalyDetection().cuda()

        # First: collect all the logits and labels for the validation set
        logits_list = []
        labels_list = []
        with torch.no_grad():
            for input, label in valid_loader:
                input = input.cuda()
                logits = self.model(input)
                logits_list.append(logits)
                labels_list.append(label)
            logits = torch.cat(logits_list).cuda()
            labels = torch.cat(labels_list).cuda()

        # Calculate NLL and ECE before temperature scaling
        # before_temperature_nll = nll_criterion(logits, labels).item()
        # before_temperature_ece = ece_criterion(logits, labels).item()
        # print('Before temperature - NLL: %.3f, ECE: %.3f' % (before_temperature_nll, before_temperature_ece))

        # Next: optimize the temperature w.r.t. NLL
        optimizer = optim.LBFGS([self.temperature], lr=0.01, max_iter=50)

        def eval():
            loss = nll_criterion(self.temperature_scale(logits), labels)
            loss.backward()
            return loss
        optimizer.step(eval)

        # Calculate NLL and ECE after temperature scaling
        after_temperature_nll = nll_criterion(self.temperature_scale(logits), labels).item()
        after_temperature_ece = ece_criterion(self.temperature_scale(logits), labels).item()
        after_temperature_anomaly = anomaly_criterion(self.temperature_scale(logits), labels)
        print('Optimal temperature: %.3f' % self.temperature.item())
        print('After temperature - NLL: %.3f, ECE: %.3f' % (after_temperature_nll, after_temperature_ece))
        print('Anomaly Count: %d' %after_temperature_anomaly)
        return self

In [0]:
class _ECELoss(nn.Module):
    def __init__(self, n_bins=15):
        """
        n_bins (int): number of confidence interval bins
        """
        super(_ECELoss, self).__init__()
        bin_boundaries = torch.linspace(0, 1, n_bins + 1)
        self.bin_lowers = bin_boundaries[:-1]
        self.bin_uppers = bin_boundaries[1:]

    def forward(self, logits, labels):
      with torch.no_grad():
        softmaxes = F.softmax(logits, dim=1)
        confidences, predictions = torch.max(softmaxes, 1)
        accuracies = predictions.eq(labels)
    
      ece = torch.zeros(1, device=logits.device)
      for bin_lower, bin_upper in zip(self.bin_lowers, self.bin_uppers):
          # Calculated |confidence - accuracy| in each bin
          in_bin = confidences.gt(bin_lower.item()) * confidences.le(bin_upper.item())
          prop_in_bin = in_bin.float().mean()
          if prop_in_bin.item() > 0:
              accuracy_in_bin = accuracies[in_bin].float().mean()
              avg_confidence_in_bin = confidences[in_bin].mean()
              ece += torch.abs(avg_confidence_in_bin - accuracy_in_bin) * prop_in_bin

      return ece

In [0]:
class _AnomalyDetection(nn.Module):
  def __init__(self):
    super(_AnomalyDetection, self).__init__()
  def forward(self, logits, labels):
    with torch.no_grad():
      softmaxes = F.softmax(logits, dim=1)
      confidences, predictions = torch.max(softmaxes, 1)
      accuracies = predictions.eq(labels)

    predictions_list = predictions.cpu().numpy()
    confidences_list = confidences.cpu().numpy()
    actuals_list = labels.cpu().numpy()

    all_triples = []
    anomaly_triples = []
    modified_anomaly_triples = []

    anomaly_count = 0
    for i in range(len(actuals_list)):
      all_triples.append((actuals_list[i], predictions_list[i], confidences_list[i]))
      if (confidences_list[i] < 0.80):
        anomaly_triples.append((i, actuals_list[i], predictions_list[i], confidences_list[i]))
        if (actuals_list[i] == 9):
          predictions_list[i] = 9
          modified_anomaly_triples.append((i, actuals_list[i], predictions_list[i], confidences_list[i]))
        anomaly_count += 1
    print(anomaly_triples)
    print(modified_anomaly_triples)
    print(len(anomaly_triples))
    print(len(modified_anomaly_triples))

    conf_matrix = confusion_matrix(actuals_list, predictions_list)
    precision = precision_score(actuals_list, predictions_list, average='macro')
    recall = recall_score(actuals_list, predictions_list, average='macro')
    f1 = f1_score(actuals_list, predictions_list, average='macro')
    print(conf_matrix) 
    print("Precision: %d", precision)
    print("Recall: %d", recall)
    print("F1 Score: %d", f1)

    anomaly_class = 9 
    actuals = []
    confidences = []
    for i in range(len(actuals_list)):
      if (actuals_list[i] == anomaly_class):
        actuals.append(actuals_list[i])
        confidences.append(confidences_list[i])

    print(actuals)
    print(confidences)
    fpr, tpr, _ = roc_curve(actuals, confidences)
    roc_auc = auc(fpr, tpr)
    print(roc_auc)
    return anomaly_count

In [46]:
scaled_model = ModelWithTemperature(model)
scaled_model.set_temperature(val_loader)


model_filename = 'model_with_temperature.pth'
torch.save(scaled_model.state_dict(), model_filename)
print('Temperature scaled model saved to %s', model_filename)
print('Done!')

[(57, 8, 8, 0.6025449), (68, 9, 7, 0.36073866), (88, 9, 4, 0.54358953), (125, 9, 7, 0.46659192), (149, 9, 4, 0.5654679), (176, 9, 7, 0.65573764), (236, 9, 4, 0.7288132), (242, 8, 5, 0.50411546), (243, 8, 6, 0.72702235), (319, 9, 7, 0.70804584), (320, 9, 5, 0.6563065), (322, 9, 4, 0.56559986), (327, 9, 4, 0.62139213), (347, 9, 7, 0.6404826), (378, 9, 4, 0.44418088), (429, 9, 4, 0.39046547), (444, 9, 7, 0.42731133), (509, 9, 8, 0.58918613), (542, 9, 4, 0.7811711), (550, 9, 4, 0.48983827), (606, 9, 4, 0.4486785), (622, 8, 5, 0.44585204), (626, 9, 0, 0.41016862), (637, 9, 5, 0.7727562), (648, 8, 8, 0.54579353), (651, 9, 4, 0.50510496), (669, 9, 7, 0.6597887), (670, 9, 5, 0.59488916), (673, 9, 7, 0.7962353), (719, 9, 4, 0.69263923), (755, 9, 5, 0.68667513), (783, 9, 7, 0.76774544), (789, 9, 5, 0.5064388), (791, 9, 7, 0.6975717), (815, 8, 5, 0.43734246), (838, 9, 5, 0.5443313), (856, 9, 5, 0.6272443), (932, 9, 7, 0.69174707), (965, 9, 7, 0.7010521), (993, 2, 2, 0.72791207), (1000, 9, 7, 0.77