# Training the Untrainable: Introducing Inductive Bias via Representational Alignment

This is an overview notebook for our method for training untrainable network, which we refer to as **Guidance**. This notebook will provide a clean place to modify and add your own networks/similarity metrics to play around with the method. We hope people find this useful! Please let us know what features could be added to make this easier to use.

For this tutorial, we use our Deep FCN model with a ResNet-18 guide for image classification on ImageNet.

**A quick note**: This code is pretty memory- and compute-intensive. Therefore, we can't promise this will run on Google Colab. But, we hope this can provide a good overview of our approach in code.

## Packages

In [2]:
!pip install datasets

Collecting datasets
  Downloading datasets-3.0.1-py3-none-any.whl.metadata (20 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess (from datasets)
  Downloading multiprocess-0.70.17-py310-none-any.whl.metadata (7.2 kB)
INFO: pip is looking at multiple versions of multiprocess to determine which version is compatible with other requirements. This could take a while.
  Downloading multiprocess-0.70.16-py310-none-any.whl.metadata (7.2 kB)
Downloading datasets-3.0.1-py3-none-any.whl (471 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m471.6/471.6 kB[0m [31m7.7 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m12.4 MB/s[0m eta [36m0:00:

In [63]:
import os
import numpy as np
import datasets
import torch
import torch.utils.data as data
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import torchvision
import torchvision.transforms as transforms
import warnings

from collections import OrderedDict
from tqdm import tqdm

warnings.filterwarnings('ignore')

## Dataset/Dataloading

We include our dataset/dataloading code for ImageNet. This can be modified for your own application.

In [64]:
class ImageNetDataset(data.Dataset):
    def __init__(self, split = 'train'):
        dataset = datasets.load_dataset('ILSVRC/imagenet-1k')
        self.data = dataset[split]
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                        std=[0.229, 0.224, 0.225])
        if split == 'train':
            self.transform = transforms.Compose([
                transforms.RandomResizedCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ])
        else:
            self.transform = transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                normalize,
            ])

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        entry = self.data[idx]
        img = entry['image'].convert('RGB')
        label = entry['label']

        img = self.transform(img)
        return {'img': img, 'label': label}

In [37]:
def get_dataloaders(batch_size = 64, num_workers = 4):
    train_dataset, val_dataset, test_dataset = ImageNetDataset('train'), ImageNetDataset('validation'), ImageNetDataset('test')
    train_loader = data.DataLoader(train_dataset, batch_size, shuffle = True, num_workers = num_workers)
    val_loader = data.DataLoader(val_dataset, batch_size = batch_size, num_workers = num_workers, shuffle = False)
    test_loader = data.DataLoader(test_dataset, batch_size = 1)
    return train_loader, val_loader, test_loader

## Representational Alignment
We provide all code to compute representational similarity across batches while training. This code can be modified to use new metrics for similarity, etc.

### Centered Kernel Alignment

The code to compute centered kernel alignment is provided below. This class will automatically move all computations to GPU as well.

In [38]:
class CKA(object):
    '''
    Calculates the linear centered kernel alignment between two sets of representations in pytorch.
    Main function runner is linear_CKA.
    >>> cka = CKA()
    >>> x, y = torch.randn(200, 768), torch.randn(200, 768)
    >>> cka.linear_cka(x, y)
    '''
    def __init__(self, device):
        self.device = device

    def centering(self, K):
        n = K.shape[0]
        I = torch.eye(n, device=self.device)
        H = I - torch.ones([n, n], device=self.device) / n
        return H @ K @ H

    def linear_HSIC(self, X, Y):
        #Calculate Gram matrix.
        L_X = X @ X.T
        L_Y = Y @ Y.T
        #Center the two Gram matrices and calculate the HSIC between them i.e. the trace of their product.
        return torch.sum(self.centering(L_X) * self.centering(L_Y))

    def linear_CKA(self, X, Y):
        #Numerator HSIC
        hsic = self.linear_HSIC(X, Y)
        #Denominator HSICs
        var1 = torch.sqrt(self.linear_HSIC(X, X))
        var2 = torch.sqrt(self.linear_HSIC(Y, Y))

        return hsic / (var1 * var2)

### Layerwise Extraction

In [48]:
class FeatureMapExtractor:
    '''
    Activation extractor for pytorch networks. This extracts activations from outputs of set layers based on whether the layers have tunable parameters.
    Pass in your model and run `get_feature_maps` as follows:
    >>> model = YourModel()
    >>> extractor = FeatureMapExtractor(model)
    >>> feature_maps = extractor.get_feature_maps(inputs)
    '''
    def __init__(self, model, enforce_input_shape = True, eval = True):
        self.model = model
        self.enforce_input_shape = enforce_input_shape
        self.layers_to_retain = [nn.Conv2d, nn.Linear, nn.AdaptiveAvgPool2d, nn.LSTM, nn.RNN, nn.TransformerDecoderLayer, nn.LayerNorm, nn.MultiheadAttention]
        self.feature_maps = OrderedDict()
        self.hooks = []
        self.eval = eval
        self.device = next(model.parameters()).device

    @staticmethod
    def get_module_name(module, feature_maps):
        return f'{type(module).__name__}_{len(feature_maps)}'

    @staticmethod
    def get_module_type(module):
        return type(module)

    @staticmethod
    def check_for_input_axis(feature_map, input_size):
        axis_match = [dim for dim in feature_map.shape if dim == input_size]
        return True if len(axis_match) == 1 else False

    @staticmethod
    def reset_input_axis(feature_map, input_size):
        input_axis = feature_map.shape.index(input_size)
        return torch.swapaxes(feature_map, 0, input_axis)

    def register_hook(self, module):
        def hook(module, input, output):
            def process_output(output, module_name):
                if isinstance(output, nn.utils.rnn.PackedSequence):
                    output, _ = nn.utils.rnn.pad_packed_sequence(output, batch_first = True)
                if isinstance(output, torch.Tensor):
                    output = output.to(self.device)
                    if self.enforce_input_shape:
                        if output.shape[0] == self.inputs.shape[0]:
                            self.feature_maps[module_name] = output
                        else:
                            if self.check_for_input_axis(output, self.inputs.shape[0]):
                                output = self.reset_input_axis(output, self.inputs.shape[0])
                                self.feature_maps[module_name] = output
                            else:
                                self.feature_maps[module_name] = None
                    else:
                        self.feature_maps[module_name] = output

            module_type = self.get_module_type(module)
            if module_type in self.layers_to_retain:
                module_name = self.get_module_name(module, self.feature_maps)
                if any([isinstance(output, type_) for type_ in (tuple, list)]):
                    if module_type in [nn.RNN, nn.LSTM]:
                        output = output[:-1]
                    for output_i, output_ in enumerate(output):
                        module_name_ = '-'.join([module_name, str(output_i+1)])
                        process_output(output_, module_name_)
                else:
                    process_output(output, module_name)

        if (not isinstance(module, nn.Sequential) and not isinstance(module, nn.ModuleList)):
            self.hooks.append(module.register_forward_hook(hook))

    def get_feature_maps(self, inputs, **kwargs):
        self.inputs = inputs
        self.feature_maps = OrderedDict()
        self.hooks = []
        self.model.apply(self.register_hook)

        if self.eval:
            with torch.no_grad():
                self.model(inputs, **kwargs)
        else:
            self.model(inputs)

        for hook in self.hooks:
            hook.remove()

        self.feature_maps = {map: features for (map, features) in list(self.feature_maps.items())[:-1]
                             if features is not None}
        return self.feature_maps

In [49]:
def layer_supervision(target_model_layers, student_model_layers):
    source_count = len(target_model_layers)
    target_count = len(student_model_layers)
    step = (target_count - 1) / (source_count - 1) if source_count > 1 else 1

    mapping = {}
    for i, source_layer in enumerate(target_model_layers):
        target_index = min(round(i * step), target_count - 1)
        mapping[source_layer] = student_model_layers[target_index]
    return mapping

In [50]:
def get_layer_outputs(model, inputs, eval = True, **kwargs):
    extractor = FeatureMapExtractor(model, eval = eval)
    if kwargs['lengths'] != None:
        feature_maps = extractor.get_feature_maps(inputs, **kwargs)
    else:
        feature_maps = extractor.get_feature_maps(inputs)
    return feature_maps

def torchvision_fe(model, inputs, device):
    '''
    Torchvision has a tricky layer-wise extraction set-up. In this case, I defer to the torchvision implementation...
    '''
    layers = torchvision.models.feature_extraction.get_graph_node_names(model)[0]
    extract_layers = [l for l in layers if ('mlp' in l) or ('self_attention' in l)] + ['getitem_5']
    feature_extractor = torchvision.models.feature_extraction.create_feature_extractor(model, return_nodes = extract_layers)
    feature_extractor = feature_extractor.to(device)
    with torch.no_grad():
        output = feature_extractor(inputs)
    return output

In [51]:
def layermap_sim(train_model, target_model, rep_sim, inputs, target_inputs, device, lengths = None, torchvision_extract = False):
    cka = CKA(device)
    if not torchvision_extract:
        pretrained_outputs = get_layer_outputs(target_model, target_inputs, eval = True, lengths = lengths)
    else:
        pretrained_outputs = torchvision_fe(target_model, inputs, device)
    training_outputs = get_layer_outputs(train_model, inputs, eval = False, lengths = lengths)
    teacher_layers = list(pretrained_outputs.keys())
    student_layers = list(training_outputs.keys())
    if len(teacher_layers) <= len(student_layers):
        model_mapping = layer_supervision(teacher_layers, student_layers)
    else:
        #NOTE: I am trying to add multiple levels of supervision in this case. If this works better, I'll keep it.
        #Otherwise, I'll switch back
        model_mapping = layer_supervision(teacher_layers, student_layers)
        # model_mapping = {v : k for k, v in model_mapping.items()}
    sim_scores = {}
    for layer in model_mapping:
        assert layer in pretrained_outputs, f'Layer {layer} is not in target network {pretrained_outputs.keys()}'
        tr_layer = model_mapping[layer]
        assert tr_layer in training_outputs, f'Layer {layer} is not in {student_model} {training_outputs.keys()}'

        pretrained_output = pretrained_outputs[layer]
        training_output = training_outputs[tr_layer]

        assert pretrained_output.shape[0] == training_output.shape[0]

        if rep_sim == 'CKA':
            pretrained_output = pretrained_output.contiguous().view(inputs.size(0), -1)
            training_output = training_output.contiguous().view(inputs.size(0), -1)
            sim = 1 - cka.linear_CKA(training_output.to(torch.float32), pretrained_output.to(torch.float32))
        else:
            raise NotImplementedError()
        sim_scores[tr_layer] = sim
    del pretrained_outputs
    del training_outputs
    return sim_scores

### Representational Similarity Loss Calculation

In [52]:
def rep_similarity_loss(train_model, target_model, rep_sim, inputs, device, lengths = None,
                        torchvision_extract = False):
    target_batch = inputs #NOTE: This was a remnant from having an option to pass noise into the guide network. This has been removed for simplicity but I am too lazy to clean code up.
    if rep_sim == 'CKA':
        sims = torch.stack(list(layermap_sim(train_model, target_model, rep_sim, inputs, target_batch, device,
                                                lengths = lengths, torchvision_extract = torchvision_extract).values()))
        sim = torch.sum(sims)
    else:
        raise NotImplementedError
    return sim

## Guidance Training

We include the outline of the main Deep FCN network and all training code for alignment with ResNet-18

In [44]:
class DeepFCN(nn.Module):
    def __init__(self):
        super(DeepFCN, self).__init__()
        self.flatten = nn.Flatten()
        self.initial_layer = nn.Sequential(
            nn.Linear(150528, 2048),
            nn.BatchNorm1d(2048),
            nn.ReLU(),
            nn.Dropout(0.1)
        )
        self.intermediate_layers = nn.Sequential(
            *[nn.Sequential(
                nn.Linear(2048 if i == 0 else 1024, 1024),
                nn.BatchNorm1d(1024),
                nn.ReLU(),
                nn.Dropout(0.1)
              ) for i in range(47)]
        )
        self.output_layer = nn.Sequential(
            nn.Linear(1024, 1000),
        )

    def forward(self, x):
        x = self.flatten(x)
        x = self.initial_layer(x)
        x = self.intermediate_layers(x)
        x = self.output_layer(x)
        return x

In [45]:
def adjust_learning_rate(lr, optimizer, epoch):
    lr = lr * (0.1 ** (epoch // 30))
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

def validate(model, val_loader, loss_fn, device):
    model = model.eval()
    val_loss = 0.0
    for batch in tqdm(val_loader, desc = 'Iterating over validation batches...'):
        imgs, labels = batch['image'].to(device), batch['label'].to(device)
        imgs = imgs.to(device)
        labels = labels.to(device)
        with torch.no_grad():
            preds = model(imgs)
        loss = loss_fn(preds, labels)
        val_loss += loss.item()
    torch.cuda.empty_cache()
    return val_loss/len(val_loader)

In [53]:
def accuracy(output, target, topk=(1,)):
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)
        _, pred = torch.topk(output, maxk, dim = 1, largest = True, sorted = True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))
        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res

def eval_loop(model, val_loader, device):
    model = model.eval()
    top1_acc = 0
    top5_acc = 0
    total_samples = 0
    with torch.no_grad():
        for batch in tqdm(val_loader, desc = 'Iterating over test batches...'):
            img, label = batch['image'].to(device), batch['label'].to(device)
            outputs = model(img)
            acc1, acc5 = accuracy(outputs, label, topk=(1, 5))
            top1_acc += acc1.item() * img.size(0)
            top5_acc += acc5.item() * img.size(0)
            total_samples += img.size(0)
    top1_acc /= total_samples
    top5_acc /= total_samples
    return top1_acc, top5_acc

In [55]:
def total_loss(train_model, target_model, rep_sim, loss_fn, preds, imgs, labels, rep_sim_alpha, device,
               torchvision_extract = False):
    rep_sim = rep_similarity_loss(train_model, target_model, rep_sim, imgs, device, torchvision_extract = torchvision_extract)
    ce_loss = loss_fn(preds, labels)
    return ce_loss + rep_sim_alpha * rep_sim, rep_sim, ce_loss

In [57]:
def train_image_classifier(exp_name, rep_sim, num_epochs, lr = 1e-3, batch_size = 64, num_workers = 16, pretrained = True, rep_dist = None, rep_sim_alpha = 1.0):
  device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
  train_loader, val_loader, _ = get_dataloaders(batch_size, num_workers)
  ## CHANGE WITH ANY MODEL
  model = DeepFCN()
  model = model.to(device)

  ## CHANGE WITH ANY MODEL
  target_model = torchvision.models.resnet18(pretrained = pretrained).to(device)

  loss_fn = nn.CrossEntropyLoss()
  optimizer = optim.Adam(model.parameters(), lr = lr)

  epoch_train_losses = []
  step_train_losses = []
  val_losses = []
  step_ce_loss = []
  step_rep_sim_loss = []
  accs = []
  for epoch in range(num_epochs):
    adjust_learning_rate(lr, optimizer, epoch)
    avg_val_loss = validate(model, val_loader, loss_fn, device)
    if avg_val_loss < min(val_losses, default = np.nan):
      torch.save(model.state_dict(), f'saved_models/{exp_name}.pt')
    print(f'Epoch {epoch}, Validation Loss: {avg_val_loss}')
    val_losses.append(avg_val_loss)

    acc1, acc5 = eval_loop(model, val_loader, device)
    print(f'Epoch {epoch}, Validation Top1: {acc1}, Top5: {acc5}')
    accs.append(acc5)

    model = model.train()
    train_loss = 0.0
    for i, batch in enumerate(tqdm(train_loader, desc = 'Iterating over training batches...')):
      imgs, labels = batch['image'].to(device), batch['label'].to(device)
      preds = model(imgs)

      if not rep_sim:
        loss = loss_fn(preds, labels)
        ce_loss = None
      else:
        loss, rep_sim, ce_loss = total_loss(model, target_model, rep_dist, loss_fn, preds, imgs, labels, rep_sim_alpha, device,
                                            torchvision_extract = False)
        step_ce_loss.append(ce_loss.item())
        step_rep_sim_loss.append(rep_sim.item())
      loss.backward()
      optimizer.step()
      optimizer.zero_grad()

      if ce_loss == None:
          train_loss += loss.item()
      else:
          train_loss += ce_loss.item()
      step_train_losses.append(loss.item())

    avg_train_loss = train_loss/len(train_loader)
    print(f'Epoch {epoch + 1}, Training Loss: {avg_train_loss}')
    epoch_train_losses.append(avg_train_loss)
  final_avg_val_loss = validate(model, val_loader, loss_fn, device)
  print(f'Epoch {epoch+1}, Validation Loss: {final_avg_val_loss}')
  return model, step_train_losses, val_losses, epoch_train_losses, step_ce_loss, step_rep_sim_loss

### Run Guidance!

In [65]:
exp_name = 'guidance_tutorial'
rep_sim = True
num_epochs = 100
lr = 1e-4
batch_size = 256
pretrained = False
repdist = 'CKA'

train_image_classifier(exp_name, rep_sim, num_epochs, lr, batch_size, pretrained, repdist)

imagenet-1k.py:   0%|          | 0.00/4.58k [00:00<?, ?B/s]

README.md:   0%|          | 0.00/85.4k [00:00<?, ?B/s]

classes.py:   0%|          | 0.00/46.4k [00:00<?, ?B/s]

The repository for ILSVRC/imagenet-1k contains custom code which must be executed to correctly load the dataset. You can inspect the repository content at https://hf.co/datasets/ILSVRC/imagenet-1k.
You can avoid this prompt in future by passing the argument `trust_remote_code=True`.

Do you wish to run the custom code? [y/N] y


train_images_0.tar.gz:   0%|          | 0.00/29.1G [00:00<?, ?B/s]

KeyboardInterrupt: 