In [None]:
from torchvision import models
from torch import nn
import torch

class Identity(nn.Module):
  def __init__(self):
    super(Identity, self).__init__()
      
  def forward(self, x):
    return x


from torch import nn
import math

class SimpleCNN(nn.Module):
  def __init__(self):
    super(SimpleCNN, self).__init__()
    
    self.conv = nn.Conv2d(in_channels=3, out_channels=256, 
                        kernel_size=3, stride=1, padding=0)
    last_size = 30*30*256
    self.fc = nn.Linear(last_size, 64)
    self.classifier = nn.Linear(64, 10)

  def forward(self, x):
    forward = nn.Sequential(self.conv, nn.ReLU(),
                            nn.Flatten(),
                            self.fc, nn.ReLU(),
                            self.classifier)
    x = forward(x)
    return x

# install and import the torchinfo library
from torchinfo import summary

# get some info for the models to check if everything is OK
simple_cnn = SimpleCNN()
print(summary(simple_cnn))

import torchvision.transforms as transforms
import torchvision
import copy
import numpy as np
from sklearn.feature_extraction import image

def get_cifar10_sets(num_examples_upstream, num_examples_downstream, num_classes_upstream, \
                     num_classes_downstream, is_downstream_random, batch_size):
  """
  Function to create dataloaders to be used throughout the experiments.

  num_examples_upstream: # of samples to be used in upstream training
  num_examples_downstream: # of samples to be used in downstream training
  num_classes_upstream: # of random classes to be used in upstream training
  num_classes_downstream: # of random classes to be used in downstream training. If `is_downstream_random`
                          is False then this parameter will be ignored.
  is_downstream_random: Whether downstream task uses random labels. 
  batch_size: Batch size.
  """

  # as a result it normalizes the input between [-1,1]
  TF = transforms.Compose([
      transforms.ToTensor(),
      transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
  ])

  trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=TF)
  testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=TF)

  # Create random indices for splitting upstream and downstream tasks
  indices = np.arange(0,len(trainset))
  #np.random.seed(12345)
  np.random.shuffle(indices)
  indices = indices.tolist()

  # Make upstream labels random
  temp = np.array(trainset.targets)
  temp[indices[:num_examples_upstream]] = np.random.randint(num_classes_upstream, size=num_examples_upstream)
  
  # If the downstream task uses random labels
  if is_downstream_random:
    #np.random.seed(12345)
    temp[indices[-num_examples_downstream:]] = np.random.randint(num_classes_downstream, size=num_examples_downstream).tolist()
  trainset.targets = [int(label) for label in temp]

  # Build dataloaders for upstream and downstream tasks
  train_upstream_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=False,
                                                sampler=torch.utils.data.SubsetRandomSampler(indices[:num_examples_upstream]))

  train_downstream_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=False,
                                                sampler=torch.utils.data.SubsetRandomSampler(indices[-num_examples_downstream:]))

  test_loader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=True)

  # To check the random labels
  '''
  u_batch = next(iter(train_upstream_loader))
  d_batch = next(iter(train_downstream_loader))

  u_labels = u_batch[1]
  print(u_labels)
  d_labels = d_batch[1]
  print(d_labels)
  '''
  return (train_upstream_loader, train_downstream_loader, test_loader)


def accuracy(model, dataloader):
  """
  Calculates accuracy of model in given dataloader

  model: Model that the accuracy is calculated for
  dataloader: Dataloader to get samples that accuracy is calculated on
  """
  correct = 0
  total = 0

  with torch.no_grad():
    for i, data in enumerate(dataloader, 0):
      inputs, labels = data
      outputs = model(inputs)
      _, predicted = torch.max(outputs.data, 1)
      total += labels.size(0)
      correct += (predicted == labels).sum().item()

  return (correct / total)

def train(model, criterion, optimizer, epochs, dataloader, device, scheduler=None,
          acc_mode="none", acc_dataloaders=None, verbose=True):
  """
  Trains given model with given configuration.

  model: Model to be used
  criterion: Loss function to be used
  optimizer: Optimizer to be used.
  epochs: # of epochs
  dataloader: Dataloader for training set
  device: which hardware the model is trained on. CPU or GPU(cuda)
  scheduler: learning rate scheduler
  acc_mode: Accuracy mode indicating that which datasets accuracy metrics are calculated.
            "none"   = no accuracy calculation
            "traing" = accuracy calculation on training set
            "both"   = accuracy calculation on both training and test set
  acc_dataloaders: Dataloaders to get samples that are used for accuracy calculation. Value of this
                   parameter should match with `acc_mode`. For example given `acc_mode` with "both"
                   option, two dataloaders should be given inside this parameter.
  verbose: Verbosity option
  """
  loss_history = []
  train_accuracies = []
  test_accuracies = []
  for epoch in range(1,epochs+1):
    count = 0
    for i, data in enumerate(dataloader, 0):     
      inputs, labels = data
      inputs = inputs.to(device)
      labels = labels.to(device)

      # canonical forward-backward-update chain
      optimizer.zero_grad()
      outputs = model(inputs)
      loss = criterion(outputs.to(device), labels)
      loss.backward()
      optimizer.step()
      
      loss_history.append(loss.item())
      # count is used to control resolution of accuracy data. Experiments in the paper calculates
      # accuracy on the whole training-test set each iteration which is very time consuming. With
      # the variable it's just made less frequent.

      count += 1
      if (count % 15 == 0):
          if acc_mode=="both":
            train_acc = accuracy(model, acc_dataloaders[0])
            train_accuracies.append(train_acc)
            test_acc = accuracy(model, acc_dataloaders[1])
            test_accuracies.append(test_acc)
          elif acc_mode=="train":
            train_acc = accuracy(model, acc_dataloaders[0])
            train_accuracies.append(train_acc)
          elif acc_mode=="test":
            test_acc = accuracy(model, acc_dataloaders[1])
            test_accuracies.append(test_acc)
    
    if verbose: 
        print(f'Epoch {epoch} / {epochs}: avg. loss of last 5 iterations {np.sum(loss_history[:-6:-1])/5}')
        if acc_mode=="train" or acc_mode =="both":
          print(f'Epoch {epoch} / {epochs}: Last train acc:{train_acc}')
        if acc_mode=="test" or acc_mode =="both":
          print(f'Epoch {epoch} / {epochs}: Last test acc:{test_acc}')

    if scheduler is not None:
      scheduler.step()
      # Print Learning Rate
      print('Epoch:', epoch,'LR:', scheduler.get_last_lr())

  if acc_mode=="both":
    return loss_history, train_accuracies, test_accuracies
  elif acc_mode=="train":
    return loss_history, train_accuracies
  elif acc_mode=="test":
    return loss_history, test_accuracies
  return loss_history


import torch.optim as optim
from torch.optim.lr_scheduler import StepLR

def get_training_components(learnable_params, init_lr, step_size):
  '''
  Function to prepare components of the training. These are all the same
  throughout the experiments made.
  
  learnable_params: Parameters that optimizer runs to update.
  init_lr: Initial learning rate that scheduler is given. 
  step_size: step size of lowering learning rate. Every `step_size` epochs
             learning rate will be scaled by constant 1/3.
  '''
  # The problem always be a image classification
  criterion = nn.CrossEntropyLoss()
  # Which parameters to be updated depends on if the training is done from scratch
  # or with pre-trained model.
  optimizer = optim.SGD(learnable_params, lr=init_lr, momentum=0.9)
  # Step size depends on the # of epoch, gamma is contant in the paper
  scheduler = StepLR(optimizer, step_size=step_size, gamma=1/3)
  return (criterion, optimizer, scheduler)

def get_device():
  if torch.cuda.is_available():
    print("Cuda (GPU support) is available and enabled!")
    device = torch.device("cuda")
  else:
    print("Cuda (GPU support) is not available :(")
    device = torch.device("cpu")

def get_learnable_parameters(model):
  params_to_update = []
  for name, param in model.named_parameters():
    if param.requires_grad == True:
      params_to_update.append(param)
  return params_to_update

def create_patches_data(trainloader, patch_size):
  # 32x32 images in CIFAR10, stride is 1
  num_patches_per_image = ((32 - patch_size + 1)**2)
  # 3 channels  
  flattened_patches_length = (patch_size**2) * 3
  data_matrix = None
  batch = 1
  patches_list = []
  for data in trainloader:
    inputs, labels = data
    for i in range(len(inputs)):
      img = inputs[i]
      img = img.permute(1,2,0)
      patches = image.extract_patches_2d(img, (patch_size, patch_size))
      patches_list.append(patches)
    batch+=1
  patches_list = np.stack(patches_list, axis=0)
  print(patches_list.shape)
  data_matrix = np.reshape(patches_list, (-1, flattened_patches_length))
  print(data_matrix.shape)
  return data_matrix

def create_eigs(data):
  cov_matrix = np.cov(data , rowvar=False)
  w, v = LA.eig(cov_matrix)
  return w,v

# Upstream training for 1st and 2nd graphs of figure1
from torch import linalg as LA

train_upstream_loader, train_downstream_loader, test_loader = get_cifar10_sets(
    num_examples_upstream = 10000,
    num_examples_downstream = 10000,
    num_classes_upstream = 10,
    num_classes_downstream = 10,
    is_downstream_random = False,
    batch_size = 256
  )

In [None]:


model = SimpleCNN()
norm_dict = {}
for name, param in model.named_parameters():
  norm_dict[name] = LA.norm(param)

epochs = 40
learnable_parameters = model.parameters()
criterion, optimizer, scheduler = get_training_components(learnable_parameters, 0.001, (int)(epochs/3))

device = get_device()
model = model.to(device)
model.train()
loss_history = train(model, criterion, optimizer, epochs, train_upstream_loader, device, scheduler=scheduler)



torch.save(model.state_dict(), "./pretrainednoscale_model_dict_upstream_1_2.pt")
torch.save(model, "./pretrainednoscale_model_upstream_1_2.pt")

# scale norms back to initial values after upstream training
for name, param in model.named_parameters():
  scale = norm_dict[name]
  param.data.copy_(param*(scale.item()/LA.norm(param)))

CNN_weights = np.reshape(model.conv.weight.detach().numpy(), (model.conv.weight.detach().numpy().shape[0], -1))
Cov_y = np.cov(CNN_weights , rowvar=False)
print(Cov_y.shape)
# save upstream training
torch.save(model.state_dict(), "./pretrained_model_dict_upstream_1_2.pt")
torch.save(model, "./pretrained_model_upstream_1_2.pt")

In [None]:



def setup_and_train_downstream(train_upstream_loader, train_downstream_loader, test_loader,
                   acc_train_loader, acc_test_loader, epochs, lr, step_size, num_examples_upstream, 
                   num_examples_downstream, num_classes_upstream, num_classes_downstream, 
                   is_downstream_random, batch_size, num_outputs,  upstream_model_path=None, 
                   weights_are_covariance=False, cov_mat=None):
  """
  Creates required parameters for downstream task and starts training. 

  epochs: # of epochs
  lr: learning rate
  num_examples_upstream: # of samples to be used in upstream training
  num_examples_downstream: # of samples to be used in downstream training
  num_classes_upstream: # of random classes to be used in upstream training
  num_classes_downstream: # of random classes to be used in downstream training. If `is_downstream_random`
                          is False then this parameter will be ignored.
  is_downstream_random: Whether downstream task uses random labels. 
  batch_size: Batch size.
  num_outputs: # of neurons in classifier layer that is equal to # of classes in the problem
  upstream_model_path: Path to load weights of upstream training.
  
  train_upstream_loader, train_downstream_loader, test_loader = get_cifar10_sets(
      num_examples_upstream, num_examples_downstream, num_classes_upstream,
      num_classes_downstream, is_downstream_random, batch_size)
  
  # separate dataloader only used for accuracy
  _, acc_train_loader, acc_test_loader = get_cifar10_sets(
      num_examples_upstream = 10000,
      num_examples_downstream = num_examples_downstream,
      num_classes_upstream = 10,
      num_classes_downstream = num_classes_downstream,  
      is_downstream_random = is_downstream_random,
      batch_size = num_examples_downstream
  )
  """
  # to not get an error while loading state of upstream training, classification layer is
  # initialized with `num_classes_downstream` neurons. Afterwards, layer is corrected with
  # `num_outputs` parameter
  model2 = SimpleCNN()
  if upstream_model_path is not None:
      model2.load_state_dict(torch.load(upstream_model_path), strict=False)
      #model2.load(upstream_model_path)
      model2.classifier = nn.Linear(64, num_outputs)#nn.Linear(256, 64)#
      torch.nn.init.kaiming_normal_(model2.classifier.weight, mode='fan_out', nonlinearity='relu')
  if weights_are_covariance is not False:
    #patch_size = 3
    #patches_data = create_patches_data(train_upstream_loader, patch_size)a view of a leaf Variable that requires grad is being used in an in-place operation.
    #cov_mat = np.cov(patches_data , rowvar=False)
    #print(model.conv.weight.detach().numpy().shape )
    #print(cov_mat.shape)
    #model.conv.weight.requires_grad = False
    new_weights = np.empty((256,3,3,3))
    for i in range(256): #Out Channels
        new_weights[i] = np.random.multivariate_normal(np.zeros(cov_mat.shape[0]), cov_mat, 1).reshape((3,3,3))
        #print(new_weights.shape)
        #print(model.conv.weight[i].shape)
    
    model2.conv.weight = torch.nn.Parameter( torch.tensor(new_weights).float())
    model2.conv.weight.requires_grad = True
  #print(model.conv.weight.detach().numpy().shape )
  learnable_parameters = model2.parameters()
  criterion, optimizer, scheduler = get_training_components(learnable_parameters, lr, step_size)
  device = get_device()
  model2 = model2.to(device)
  model2.train()
  if not is_downstream_random:
    loss_history, train_accuracies, test_accuracies = train(model2, criterion, optimizer, epochs, train_downstream_loader, 
                                                            device, scheduler, "both",
                                                            [acc_train_loader, acc_test_loader])
    return loss_history, train_accuracies, test_accuracies
  loss_history, train_accuracies = train(model2, criterion, optimizer, epochs, train_downstream_loader, 
                                          device, scheduler, "train",
                                          [acc_train_loader])
  return loss_history, train_accuracies


num_examples_upstream=10000 
num_examples_downstream=10000
num_classes_upstream = 10 
num_classes_downstream = 10
is_downstream_random = False 
batch_size = 256

#train_upstream_loader, train_downstream_loader, test_loader = get_cifar10_sets(
#      num_examples_upstream, num_examples_downstream, num_classes_upstream,
#      num_classes_downstream, is_downstream_random, batch_size)
  
  # separate dataloader only used for accuracy
_, acc_train_loader, acc_test_loader = get_cifar10_sets(
num_examples_upstream = 10000,
num_examples_downstream = num_examples_downstream,
num_classes_upstream = 10,
num_classes_downstream = num_classes_downstream,  
is_downstream_random = is_downstream_random,
batch_size = num_examples_downstream)


loss_noscale_real, accs_noscale_real_train, accs_noscale_real_test = setup_and_train_downstream(
    train_upstream_loader, train_downstream_loader,
    test_loader, acc_train_loader, acc_test_loader,
    epochs=40, lr=0.001, step_size=(int)(40/3), num_examples_upstream=10000, 
    num_examples_downstream=10000, num_classes_upstream = 10, num_classes_downstream = 10,
    is_downstream_random = False, batch_size = 256, num_outputs=10, upstream_model_path = "pretrainednoscale_model_dict_upstream_1_2.pt")


 #Using upstream training's weights with real labels
loss_upstream_real, accs_upstream_real_train, accs_upstream_real_test = setup_and_train_downstream(
    train_upstream_loader, train_downstream_loader,
    test_loader, acc_train_loader, acc_test_loader,
    epochs=40, lr=0.001, step_size=(int)(40/3), num_examples_upstream=10000, 
    num_examples_downstream=10000, num_classes_upstream = 10, num_classes_downstream = 10,
    is_downstream_random = False, batch_size = 256, num_outputs=10, upstream_model_path = "pretrained_model_dict_upstream_1_2.pt")


loss_covariance_real, accs_covariance_real_train, accs_covariance_real_test = setup_and_train_downstream(
    train_upstream_loader, train_downstream_loader,
    test_loader, acc_train_loader, acc_test_loader,
    epochs=40, lr=0.001, step_size=(int)(40/3), num_examples_upstream=10000, 
    num_examples_downstream=10000, num_classes_upstream = 10, num_classes_downstream = 10,
    is_downstream_random = False, batch_size = 256, num_outputs=10, weights_are_covariance = True, cov_mat=Cov_y)

 #From scratch with real labels
loss_scratch_real, accs_scratch_real_train, accs_scratch_real_test = setup_and_train_downstream(
    train_upstream_loader, train_downstream_loader,
    test_loader, acc_train_loader, acc_test_loader,
    epochs=40, lr=0.001, step_size=(int)(40/3), num_examples_upstream=10000, 
    num_examples_downstream=10000, num_classes_upstream = 10, num_classes_downstream = 10,
    is_downstream_random = False, batch_size = 256, num_outputs=10)





In [None]:

import matplotlib.pyplot as plt
orange = "#fcba03"
light_blue = "#03bafc"
black = "#080504"

plt.plot(accs_scratch_real_train, linestyle="-", color=black)
plt.plot(accs_upstream_real_train, linestyle="-", color=light_blue)
plt.plot(accs_covariance_real_train, linestyle="-", color=orange)
plt.plot(accs_noscale_real_train, linestyle="--", color=black)
#plt.plot(accs_scratch_real_test, linestyle="--", color=orange)
#plt.plot(accs_upstream_real_test, linestyle="--", color=light_blue)
plt.legend(["from scratch (train)", "pretrained (train)", "covariance","no scale"], loc ="best")
plt.xlabel('# of iterations')
plt.ylabel('Accuracy')
plt.show()