In [None]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)
%cd /content/drive/MyDrive/robust_DL
%mkdir results

In [None]:
from __future__ import print_function
import os
import argparse
import matplotlib.pyplot as plt
import torch.nn as nn
import torchvision
import numpy as np
import torch.nn.functional as F
from losses.trades import trades_loss
import copy
import torch
import torch.optim as optim
from torch.autograd import Variable
from torchvision import datasets, transforms
from models.wideresnet import *
from models.resnet import *
from models.small_cnn import *
from torch.utils.data import Dataset, DataLoader
from models.AlexNet import AlexNet
from torch.utils.data.sampler import SubsetRandomSampler
import json

In [None]:
class Data:
  def __init__(self, train_loader, valid_loader, test_loader, attack_loader):
    self.train_loader = train_loader
    self.valid_loader = valid_loader
    self.test_loader = test_loader
    self.attack_loader = attack_loader

class Model:
  model = None
  def __init__(self, id):
    self.id = id

class Loss:
  def __init__(self, loss_fn, id=None):
    self.loss_fn = loss_fn
    self.id = id

class Configuration:
  def __init__(self, data, model, loss, attack, model_pt=None, id=None):
    self.data = data
    self.model = model
    self.loss = loss
    self.attack = attack
    self.model_pt = model_pt # Should move this to model

    self.id = id

  def getConfig(self):
    return self.data, self.model, self.loss, self.attack

  def getId(self):
    return self.id

class CIFAR10CDataset(Dataset):
    def __init__(self, data, labels, transform=None):
        self.data = data
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        image = self.data[idx]
        label = self.labels[idx]
        if self.transform:
            image = self.transform(image)
        return image, label

def general_trades_loss_fn(beta=6.0, epsilon=0.3, step_size=0.007, num_steps=10):
  def trades_loss_fn(model, data, target, optimizer):
    return trades_loss(model=model, x_natural=data, y=target, optimizer=optimizer, step_size=step_size,
                      epsilon=epsilon, perturb_steps=num_steps, beta=beta, distance='l_inf')
  return trades_loss_fn

def ce_loss_fn(model, data, target, optimizer):
    return F.cross_entropy(model(data), target)
    
def identity_attack(model, X, y):
  out = model(X)
  acc = (out.data.max(1)[1] == y.data).float().sum()
  return acc.item()



In [None]:
def accuracy(model, data_loader, device):
    print('EVAL')
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for batch_idx, (data, target) in enumerate(data_loader):
            data, target = data.to(device), target.to(device)
            outputs = model(data)
            _, predicted = torch.max(outputs, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()

    return 100. * correct / total

def robust_accuracy(model, attack, data_loader, device):
    print('ROBUST EVAL')
    model.eval()
    correct = 0
    total = 0

    for data, target in data_loader:
        data, target = data.to(device), target.to(device)

        X, y = Variable(data, requires_grad=True), Variable(target)
        correct_count = attack(model, X, y)
        correct += correct_count
        total += target.size(0)
    return 100. * correct / total 

def train(model, data, loss, config, epochs, eval_interval, device):
  print('TRAINING')
  data_loader = data.train_loader
  valid_loader = data.valid_loader
  attack_loader = data.attack_loader

  optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
  # TODO: We can move the optimizer to a field of Loss object

  model.to(device)
  if config.model_pt is not None:
    model.load_state_dict(torch.load(config.model_pt))

  best_eval_acc = 0.0
  patience = 5  # number of VAL Acc values observed after best value to stop training

  # Initialize lists to store per-epoch loss and validation accuracy
  epoch_losses = []
  eval_accuracies = []

  for epoch in range(1, epochs+1):
    model.train()
    total_loss = 0.0
    for batch_idx, (data, target) in enumerate(data_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        l = loss.loss_fn(model, data, target, optimizer)
        l.backward()
        optimizer.step()
        total_loss += l.item()

        print(loss.id + f" @ EP={epoch} & Batch idx " + str(batch_idx) + " / " + str(len(data_loader) - 1) + " Loss: " + str(l.item()))
    
    epoch_losses.append(total_loss / len(data_loader))

    if epoch == 1 or epoch % eval_interval == 0 or epoch == epochs:
      eval_acc= accuracy(model, valid_loader, device)
      eval_accuracies.append(eval_acc) 

      if (eval_acc > best_eval_acc):  # best so far so save checkpoint to restore later
        best_eval_acc = eval_acc
        patience_count = 0
        torch.save(model.state_dict(), os.path.join("weights", loss.id + ".pt"))
        torch.save(optimizer.state_dict(), os.path.join("optimizers", loss.id +  ".tar"))
      else:
          patience_count += 1

    # Plotting the loss and accuracy
    plt.figure(figsize=(12, 5))
    plt.suptitle(loss.id)

    # Plot training loss
    plt.subplot(1, 2, 1)
    plt.plot(epoch_losses, label='Training Loss')
    plt.title('Loss vs. Epochs')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()

    # Plot validation accuracy
    plt.subplot(1, 2, 2)
    plt.plot(eval_accuracies, label='Validation Accuracy')

    if epoch == epochs or patience_count >= patience:
      # Get the CIFAR 10 C evaluation accuracy and plot the horizontal line
      cifar10c_eval_acc = robust_accuracy(model, config.attack, attack_loader, device)
      plt.axhline(y=cifar10c_eval_acc, color='r', linestyle='-', label='CIFAR 10 C EVAL')

    plt.title('Validation Accuracy vs. Epochs')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    plt.legend()

    # Save the plots
    if not os.path.exists('plots'):
        os.makedirs('plots')
    plt.savefig(os.path.join('plots', loss.id + '_training_validation_plot.png'))
    plt.close()

    if patience_count >= patience:
      print(f"Early Stopping!, epoch {epoch}")
      break

  return total_loss


In [None]:
def run_experiment(num_epoch=2, valid_size=0.2, eval_interval=1):
  transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),])
  transform_test = transforms.Compose([
    transforms.ToTensor(),])
  use_cuda = torch.cuda.is_available()
  kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
  transform_attack = transforms.Compose([transforms.ToTensor(),])

  trainset = torchvision.datasets.CIFAR10(root='../data', train=True, download=True, transform=transform_train)
  validset = torchvision.datasets.CIFAR10(root='../data', train=True, download=True, transform=transform_train)
  testset = torchvision.datasets.CIFAR10(root='../data', train=False, download=True, transform=transform_test)
  num_train = len(trainset)
  indices = list(range(num_train))
  split = int(np.floor(valid_size * num_train))
  train_idx, valid_idx = indices[split:], indices[:split]
  train_sampler = SubsetRandomSampler(train_idx)
  valid_sampler = SubsetRandomSampler(valid_idx)


  cifar10_train_loader = torch.utils.data.DataLoader(trainset, batch_size=128, sampler=train_sampler, **kwargs)
  cifar10_valid_loader = torch.utils.data.DataLoader(trainset , batch_size=128, sampler=valid_sampler, **kwargs)
  cifar10_test_loader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False, **kwargs)


  transform_cifar10c = transforms.Compose([
    transforms.ToPILImage(),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))])
  images = np.load('../data/CIFAR-10-C/spatter.npy') # Set this to whatever we want
  labels = np.load('../data/CIFAR-10-C/labels.npy')
  cifar10c_dataset = CIFAR10CDataset(data=images,labels=labels,transform=transform_cifar10c)
  cifar10c_attack_loader = DataLoader(cifar10c_dataset, batch_size=200, shuffle=False)

  cifar10_c_data = Data(cifar10_train_loader, cifar10_valid_loader,cifar10_test_loader, cifar10c_attack_loader)
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

  configurations = []
  final_loss = {}
  natural_accuracy = {}
  robustness_accuracy = {}

  for alpha in [5, 6]:
    beta = 1 / alpha
    id = f'CIFARC10:Alexnet:TRADES_LOSS:BETA={beta}'
    model_pt = None # os.path.join("weights", f'CIFARC10:RESNET18:TRADES_LOSS:BETA={beta}_ep=0.pt')

    alexnet = AlexNet().to(device)
    trades_loss_beta = Loss(general_trades_loss_fn(beta=beta), id)
    config1 = Configuration(cifar10_c_data, alexnet, trades_loss_beta, identity_attack, model_pt, trades_loss_beta.id)
    configurations.append(config1)

  alexnet = AlexNet().to(device)
  id = f'CIFARC10:Alexnet:CE_LOSS'
  ce_loss = Loss(ce_loss_fn, id)
  baseline = Configuration(cifar10_c_data, alexnet, ce_loss, identity_attack, id=ce_loss.id)
  configurations.append(baseline)

  for (c, configuration) in enumerate(configurations):
    data, model, loss, attack = configuration.getConfig()

    with open('results/final_loss.json', 'r') as fp:
        final_loss = json.load(fp)
    with open('results/natural_accuracy.json', 'r') as fp:
        natural_accuracy = json.load(fp)
    with open('results/robustness_accuracy.json', 'r') as fp:
        robustness_accuracy = json.load(fp)

    final_loss[configuration.getId()] = train(model, data, loss, configuration, num_epoch, eval_interval, device)
    natural_accuracy[configuration.getId()] = accuracy(model, data.test_loader, device)
    robustness_accuracy[configuration.getId()] = robust_accuracy(model, attack, data.attack_loader, device)

    with open('results/final_loss.json', 'w') as fp:
        json.dump(final_loss, fp)
    with open('results/natural_accuracy.json', 'w') as fp:
        json.dump(natural_accuracy, fp)
    with open('results/robustness_accuracy.json', 'w') as fp:
        json.dump(robustness_accuracy, fp)

  return final_loss, natural_accuracy, robustness_accuracy

In [None]:
final_loss, natural_accuracy, robustness_accuracy = run_experiment(num_epoch=50, valid_size=0.2, eval_interval=1)

In [None]:
def make_result_plot_combined():

  # Define the directory where the JSON files are located
  directory = 'results/'

  # Initialize dictionaries to hold the beta values and corresponding accuracies
  natural_accuracies = {}
  robustness_accuracies = {}

  # Define the configuration pattern we're interested in
  config_pattern = 'CIFARC10:Alexnet:TRADES_LOSS:BETA='

  # Function to extract beta and accuracies from the json file
  def extract_data(file_name):
      with open(os.path.join(directory, file_name), 'r') as file:
          data = json.load(file)
          for config_id, accuracy in data.items():
              if config_id.startswith(config_pattern):
                  # Extract beta value from the configuration ID
                  beta = float(config_id.split('=')[-1])
                  if 'final_loss' not in file_name:
                      if 'natural_accuracy' in file_name:
                          natural_accuracies[beta] = accuracy
                      elif 'robustness_accuracy' in file_name:
                          robustness_accuracies[beta] = accuracy

  # Read data from each JSON file
  for file_name in os.listdir(directory):
      if file_name.endswith('.json'):
          extract_data(file_name)

  # Sort the data by beta values
  sorted_betas = sorted(natural_accuracies.keys())
  sorted_natural_accuracies = [natural_accuracies[beta] for beta in sorted_betas]
  sorted_robustness_accuracies = [robustness_accuracies[beta] for beta in sorted_betas]

  # Plotting the data
  plt.figure(figsize=(10, 5))

  # Natural accuracy vs beta
  plt.plot(sorted_betas, sorted_natural_accuracies, label='Natural Accuracy', marker='o')

  # Robustness accuracy vs beta
  plt.plot(sorted_betas, sorted_robustness_accuracies, label='Robustness Accuracy', marker='x')

  # Adding title and labels
  plt.title('Natural and Robustness Accuracy vs Beta')
  plt.xlabel('Beta')
  plt.ylabel('Accuracy')

  # Adding legend
  plt.legend()
  plt.savefig(os.path.join('plots', 'ACCURACY_SEPARATE.png'))
  plt.close()


In [None]:
def make_result_plot_separate():

  # Define the directory where the JSON files are located
  directory = 'results/'

  # Initialize dictionaries to hold the beta values and corresponding accuracies
  natural_accuracies = {}
  robustness_accuracies = {}

  # Define the configuration pattern we're interested in
  config_pattern = 'CIFARC10:Alexnet:TRADES_LOSS:BETA='

  # Function to extract beta and accuracies from the json file
  def extract_data(file_name):
      with open(os.path.join(directory, file_name), 'r') as file:
          data = json.load(file)
          for config_id, accuracy in data.items():
              if config_id.startswith(config_pattern):
                  # Extract beta value from the configuration ID
                  beta = float(config_id.split('=')[-1])
                  if 'final_loss' not in file_name:
                      if 'natural_accuracy' in file_name:
                          natural_accuracies[beta] = accuracy
                      elif 'robustness_accuracy' in file_name:
                          robustness_accuracies[beta] = accuracy

  # Read data from each JSON file
  for file_name in os.listdir(directory):
      if file_name.endswith('.json'):
          extract_data(file_name)

  # Sort the data by beta values
  sorted_betas = sorted(natural_accuracies.keys())
  sorted_natural_accuracies = [natural_accuracies[beta] for beta in sorted_betas]
  sorted_robustness_accuracies = [robustness_accuracies[beta] for beta in sorted_betas]

  # Create a figure with two subplots
  fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 10))

  # Plot natural accuracy vs beta
  ax1.plot(sorted_betas, sorted_natural_accuracies, label='Natural Accuracy', marker='o')
  ax1.set_title('Natural Accuracy vs Beta')
  ax1.set_xlabel('Beta')
  ax1.set_ylabel('Natural Accuracy')
  ax1.legend()

  # Plot robustness accuracy vs beta
  ax2.plot(sorted_betas, sorted_robustness_accuracies, label='Robustness Accuracy', marker='x')
  ax2.set_title('Robustness Accuracy vs Beta')
  ax2.set_xlabel('Beta')
  ax2.set_ylabel('Robustness Accuracy')
  ax2.legend()

  # Adjust layout to prevent overlap
  plt.tight_layout()
  plt.savefig(os.path.join('plots', 'ACCURACY_COMBINED.png'))
  plt.close()



In [None]:
make_result_plot_separate()
make_result_plot_combined()