# Experiments with CIFAR100 Dataset and Attentive Gate MoE Training

The experiments in this notebook include training the attentive gate MoE models as follows:

1. attentive gate MoE without regularization.
2. attentive gate MoE with $L_{importance}$ regularization.
3. attentive gate MoE with $L_s$ regularization.

In [38]:
import time
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm  # colormaps

%matplotlib inline

In [39]:
import seaborn as sns
import numpy as np
from statistics import mean
from math import ceil, sin, cos, radians
from collections import OrderedDict
import os
import pandas as pd
from pprint import pprint
from copy import deepcopy
from itertools import product

In [40]:
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import TensorDataset
import torchvision.transforms.functional as TF
import torchvision.transforms as transforms

In [41]:
if torch.cuda.is_available():
    device = torch.device("cuda:0")
else:
    device = torch.device("cpu")
print('device', device)

device cuda:0


In [42]:
# import MoE expectation model. All experiments for this dataset are done with the expectation model as it
# provides the best guarantee of interpretable task decompositions
from moe_models.moe_expectation_model import moe_expectation_model
from helper.moe_models import cross_entropy_loss
from helper.visualise_results import *

In [43]:
# Paths to where the trained models and figures will be stored. You can change this as you see fit.
fig_path = '../figures'
model_path = '../models/hidden_256'
results_path = '../results'
pre_trained_model_path = '../models/pre_trained'

if not os.path.exists(fig_path):
    os.mkdir(fig_path)
if not os.path.exists(model_path):
    os.mkdir(model_path)
if not os.path.exists(results_path):
    os.mkdir(results_path)

In [44]:
stats = ((0.5074,0.4867,0.4411),(0.2011,0.1987,0.2025))
train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32,padding=4,padding_mode="reflect"),
    transforms.ToTensor(),
    transforms.Normalize(*stats)
])

test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(*stats)
])

In [45]:
cifar100_trainset = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=train_transform)
cifar100_testset = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=test_transform)
cifar100_testset, cifar100_trainset

Files already downloaded and verified
Files already downloaded and verified


(Dataset CIFAR100
     Number of datapoints: 10000
     Root location: ./data
     Split: Test
     StandardTransform
 Transform: Compose(
                ToTensor()
                Normalize(mean=(0.5074, 0.4867, 0.4411), std=(0.2011, 0.1987, 0.2025))
            ),
 Dataset CIFAR100
     Number of datapoints: 50000
     Root location: ./data
     Split: Train
     StandardTransform
 Transform: Compose(
                RandomHorizontalFlip(p=0.5)
                RandomCrop(size=(32, 32), padding=4)
                ToTensor()
                Normalize(mean=(0.5074, 0.4867, 0.4411), std=(0.2011, 0.1987, 0.2025))
            ))

In [46]:
num_classes = 100

In [47]:
trainsize = 50000
testsize = 10000

In [48]:
batch_size = 256

In [49]:
cifar100_trainloader = torch.utils.data.DataLoader(torch.utils.data.Subset(cifar100_trainset, range(trainsize)), batch_size=batch_size,
                                          shuffle=True, num_workers=4, pin_memory=True)
cifar100_testloader = torch.utils.data.DataLoader(torch.utils.data.Subset(cifar100_testset, range(testsize)), batch_size=batch_size,
                                         shuffle=True, num_workers=4, pin_memory=True)

In [50]:
import csv
with open('data/cifar100_class_names.txt','r') as csvfile:
    csvreader = csv.reader(csvfile, delimiter=' ')
    classes_cifar100 = []
    for row in csvreader:
        if row:
            classes_cifar100.append(row[1])

classes_cifar100            

['apple',
 'aquarium_fish',
 'baby',
 'bear',
 'beaver',
 'bed',
 'bee',
 'beetle',
 'bicycle',
 'bottle',
 'bowl',
 'boy',
 'bridge',
 'bus',
 'butterfly',
 'camel',
 'can',
 'castle',
 'caterpillar',
 'cattle',
 'chair',
 'chimpanzee',
 'clock',
 'cloud',
 'cockroach',
 'couch',
 'cra',
 'crocodile',
 'cup',
 'dinosaur',
 'dolphin',
 'elephant',
 'flatfish',
 'forest',
 'fox',
 'girl',
 'hamster',
 'house',
 'kangaroo',
 'keyboard',
 'lamp',
 'lawn_mower',
 'leopard',
 'lion',
 'lizard',
 'lobster',
 'man',
 'maple_tree',
 'motorcycle',
 'mountain',
 'mouse',
 'mushroom',
 'oak_tree',
 'orange',
 'orchid',
 'otter',
 'palm_tree',
 'pear',
 'pickup_truck',
 'pine_tree',
 'plain',
 'plate',
 'poppy',
 'porcupine',
 'possum',
 'rabbit',
 'raccoon',
 'ray',
 'road',
 'rocket',
 'rose',
 'sea',
 'seal',
 'shark',
 'shrew',
 'skunk',
 'skyscraper',
 'snail',
 'snake',
 'spider',
 'squirrel',
 'streetcar',
 'sunflower',
 'sweet_pepper',
 'table',
 'tank',
 'telephone',
 'television',
 'tige

In [51]:
#Function to display the images
def plot_colour_images(images_to_plot, titles=None, nrows=None, ncols=6, thefigsize=(18,18)):
    # images_to_plot: list of images to be displayed
    # titles: list of titles corresponding to the images
    # ncols: The number of images per row to display. The number of rows 
    #        is computed from the number of images to display and the ncols
    # theFigsize: The size of the layour of all the displayed images
    
    n_images = images_to_plot.shape[0]
    
    # Compute the number of rows
    if nrows is None:
        nrows = np.ceil(n_images/ncols).astype(int)
    
    fig,ax = plt.subplots(nrows, ncols, sharex=True, sharey=True, figsize=thefigsize)
    ax = ax.flatten()
    
    for i in range(n_images):
        img = images_to_plot[i,:,:,:]
        npimg = np.clip(img.numpy(),0,1)
        ax[i].imshow(npimg) 
        ax[i].axis('off')  
        if titles is not None and i<10:
            ax[i].set_title(titles[i%10])

In [52]:
# # get some random training images
# dataiter = iter(cifar100_trainloader)
# images, labels = dataiter.next()
# print(images.shape, np.unique(labels))

# images_to_plot = []
# count = 0
# selected_labels = []
# for i in range(100):
#     if count == 10:
#         break
#     index = np.where(labels==i)[0]
#     if len(index) >= 3:
#         selected_labels.append(i)
#         images_to_plot.append(images[index[0:3],:,:])
#         count += 1
    
# selected_labels = [classes_cifar100[i] for i in selected_labels]
# images_to_plot = torch.transpose(torch.stack(images_to_plot),0,1)
# new_shape = images_to_plot.shape
# images_to_plot = images_to_plot.reshape(new_shape[0]*new_shape[1], new_shape[2], new_shape[3], new_shape[4])
# images_to_plot = images_to_plot.permute(0,2,3,1)
# plot_colour_images(images_to_plot, nrows=3, ncols=10,thefigsize=(20,6), titles=selected_labels)

## Define expert and gate networks

In [54]:
# Convolutional network with one convultional layer and 2 hidden layers with ReLU activation
class expert_layers(nn.Module):
    def __init__(self, num_classes, channels=3):
        super(expert_layers, self).__init__()
        filter_size = 3
        self.filters = 16
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=self.filters, kernel_size=filter_size, padding=1)
        self.conv2 = nn.Conv2d(in_channels=self.filters, out_channels=self.filters*2, kernel_size=filter_size, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(self.filters*2)
        self.mp = nn.MaxPool2d(2,2)

        self.conv3 = nn.Conv2d(in_channels= self.filters*2, out_channels=self.filters*4, kernel_size=filter_size, stride=1, padding=1)
        self.conv4 = nn.Conv2d(in_channels=self.filters*4, out_channels=self.filters*8, kernel_size=filter_size, stride=1, padding=1,bias=False)
        self.bn8 = nn.BatchNorm2d(self.filters*8)

        self.fc1 = nn.Linear(self.filters*8*2*2,1024)
        self.fc2 = nn.Linear(1024, 256)
        
        self.out = nn.Linear(in_features=256, out_features=num_classes)
                        
    def forward(self, x):
        # conv 1        
        x = self.mp(F.relu(self.conv1(x)))
        x = self.mp(F.relu(self.bn2(self.conv2(x))))    
    
        x = self.mp(F.relu(self.conv3(x)))
        x = self.mp(F.relu(self.bn8(self.conv4(x))))
        
        x = x.reshape(-1, self.filters*8*2*2)

        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        
        self.hidden = x
        
        x = F.relu(x)
        
        x = self.out(x)
        
        # output
        x = F.softmax(x, dim=1)

        return x    


In [19]:
class gate_attn_layers(nn.Module):
    def __init__(self, num_experts):
        super(gate_attn_layers, self).__init__()
        # define layers
       # define layers
        filter_size = 3
        self.filters = 64
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=self.filters, kernel_size=filter_size, padding=1)
        self.conv2 = nn.Conv2d(in_channels=self.filters, out_channels=self.filters*2, kernel_size=filter_size, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(self.filters*2)
        self.mp = nn.MaxPool2d(2,2)
        
        self.conv3 = nn.Conv2d(in_channels= self.filters*2, out_channels=self.filters*4, kernel_size=filter_size, stride=1, padding=1)
        self.conv4 = nn.Conv2d(in_channels=self.filters*4, out_channels=self.filters*8, kernel_size=filter_size, stride=1, padding=1, bias=False)
        self.bn8 = nn.BatchNorm2d(self.filters*8)

        self.fc1 = nn.Linear(self.filters*8*2*2, 1024)
        self.fc2 = nn.Linear(1024, 256)
                 
    def forward(self, x, T=1.0, y=None):
        # conv 1
        x = self.mp(F.relu(self.conv1(x)))
        x = self.mp(F.relu(self.bn2(self.conv2(x))))

        x = self.mp(F.relu(self.conv3(x)))
        x = self.mp(F.relu(self.bn8(self.conv4(x))))
        
        x = x.reshape(-1, self.filters*8*2*2)
        
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        
        return x


In [21]:
class gate_layers(nn.Module):
    def __init__(self, num_experts):
        super(gate_layers, self).__init__()
        # define layers
        filter_size = 3
        self.filters = 64
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=self.filters, kernel_size=filter_size, padding=1)
        self.conv2 = nn.Conv2d(in_channels=self.filters, out_channels=self.filters*2, kernel_size=filter_size, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(self.filters*2)
        self.mp = nn.MaxPool2d(2,2)
        
        self.conv3 = nn.Conv2d(in_channels= self.filters*2, out_channels=self.filters*4, kernel_size=filter_size, stride=1, padding=1)
        self.conv4 = nn.Conv2d(in_channels=self.filters*4, out_channels=self.filters*8, kernel_size=filter_size, stride=1, padding=1, bias=False)
        self.bn8 = nn.BatchNorm2d(self.filters*8)

        self.fc1 = nn.Linear(self.filters*8*2*2, 1024)
        self.fc2 = nn.Linear(1024, 256)
        
        self.out = nn.Linear(in_features=256, out_features=num_experts)
        
    def forward(self, x, T=1.0, y=None):
        # conv 1        
        x = self.mp(F.relu(self.conv1(x)))
        x = self.mp(F.relu(self.bn2(self.conv2(x))))

        x = self.mp(F.relu(self.conv3(x)))
        x = self.mp(F.relu(self.bn8(self.conv4(x))))
        
        x = x.reshape(-1, self.filters*8*2*2)
        
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        
        x = self.out(x)
        x = F.softmax(x/T, dim=1)
        return x


In [23]:
# create a set of experts
def experts(num_experts, num_classes, expert_layers_type=expert_layers):
    models = []
    for i in range(num_experts):
        models.append(expert_layers_type(num_classes))
    return nn.ModuleList(models)

## Initialize configurations and helper functions

In [24]:
# Compute accuracy of the model
def accuracy(out, yb, mean=True):
    preds = torch.argmax(out, dim=1).to(device, non_blocking=True)
    if mean:
        return (preds == yb).float().mean()
    else:
        return (preds == yb).float()

## Functions to train models

### Function to train attentive gate model with and without regularization

* w_importance_range is the range of values for the $w_{importance}$ hyperparameter of the $L_{importance}$ regularization.
* w_sample_sim_same_range is the range of values for $\beta_s$ hyperparameter of the $L_s$ regularization.
* w_sample_sim_diff_range is the range of values for $\beta_d$ hyperparameter of the $L_s$ regularization.

In [25]:
def train_with_attention(model_1, trainloader, testloader, runs, temps=[[1.0]*20], 
                         w_importance_range=[0.0], w_sample_sim_same_range=[0.0], 
                         w_sample_sim_diff_range=[0.0], 
                         num_classes=10, total_experts=5, num_epochs=20):
    
    hidden = 256
    
    for T, w_importance, w_sample_sim_same, w_sample_sim_diff in product(temps, w_importance_range, w_sample_sim_same_range,  
                                                                         w_sample_sim_diff_range):                                                                       
        
        print('w_importance','{:.1f}'.format(w_importance))
        if w_sample_sim_same < 1:
            print('w_sample_sim_same',str(w_sample_sim_same))
        else:
            print('w_sample_sim_same','{:.1f}'.format(w_sample_sim_same))
        
        if w_sample_sim_diff < 1:
            print('w_sample_sim_diff',str(w_sample_sim_diff))
        else:
            print('w_sample_sim_diff','{:.1f}'.format(w_sample_sim_diff))
        
        for run in range(1, runs+1):

            print('Run:', run)
            
            n_run_models_1 = []
            
            models = {'moe_expectation_model':{'model':moe_expectation_model,'loss':cross_entropy_loss().to(device),
                                               'experts':{}},}
            for key, val in models.items():

                expert_models = experts(total_experts, num_classes).to(device)
                
                gate_model = gate_attn_layers(total_experts).to(device)

                moe_model = val['model'](total_experts, num_classes, attention_flag=1, hidden=hidden, 
                                         experts=expert_models, gate=gate_model, device=device).to(device)
                
                optimizer_moe = optim.Adam(moe_model.parameters(), lr=0.001, amsgrad=False)
                

                hist = moe_model.train(trainloader, testloader,  val['loss'], optimizer_moe = optimizer_moe,
                                       T = T, w_importance=w_importance, 
                                       accuracy=accuracy, epochs=num_epochs)
                val['experts'][total_experts] = {'model':moe_model, 'history':hist}

            # Save all the trained models
            plot_file = generate_plot_file(model_1, T[0], w_importance=w_importance, w_sample_sim_same=w_sample_sim_same,w_sample_sim_diff=w_sample_sim_diff,
                                           specific=str(num_classes)+'_'+str(total_experts)+'_models.pt')
            
            if os.path.exists(os.path.join(model_path, plot_file)):
                n_run_models_1 = torch.load(open(os.path.join(model_path, plot_file),'rb'))
            n_run_models_1.append(models)
            torch.save(n_run_models_1,open(os.path.join(model_path, plot_file),'wb'))
            n_run_models_1 = []
            

### Function to distill the attentive gate model to the original model

In [26]:
def train_from_model(m, num_epochs, num_classes, total_experts, w_importance_range=[0.0], 
                     w_sample_sim_same_range=[0.0], w_sample_sim_diff_range=[0.0],
                     trainloader=None, testloader=None, expert_no_grad=True, gate_no_grad=False):
    
    T = [1.0]*num_epochs
    for w_importance, w_sample_sim_same, w_sample_sim_diff in product(w_importance_range, w_sample_sim_same_range, w_sample_sim_diff_range):
        
        print('w_importance','{:.1f}'.format(w_importance))
        
        if w_sample_sim_same < 1:
            print('w_sample_sim_same',str(w_sample_sim_same))
        else:
            print('w_sample_sim_same','{:.1f}'.format(w_sample_sim_same))
        
        if w_sample_sim_diff < 1:
            print('w_sample_sim_diff',str(w_sample_sim_diff))
        else:
            print('w_sample_sim_diff','{:.1f}'.format(w_sample_sim_diff))
        
        plot_file = generate_plot_file(m, temp=T[0], w_importance=w_importance,  
                                       w_sample_sim_same=w_sample_sim_same,w_sample_sim_diff=w_sample_sim_diff,
                                       specific=str(num_classes)+'_'+str(total_experts)+'_models.pt')

        attn_models = torch.load(open(os.path.join(model_path, plot_file),'rb'), map_location=device)

        n_run_models_1 = []
        for model in attn_models: 
            # Initialise the new expert weights to the weights of the experts of the trained attentive gate model.
            # Fix all the weights of the new experts so they are not trained. 

            new_expert_models = experts(total_experts, num_classes).to(device)
            old_expert_models = model['moe_expectation_model']['experts'][total_experts]['model'].experts
            for i, expert in enumerate(new_expert_models):
                old_expert = old_expert_models[i]
                expert.load_state_dict(old_expert.state_dict())
                if expert_no_grad:
                    for param in expert.parameters():
                        param.requires_grad = False

            new_gate_model = gate_layers(total_experts).to(device)
            old_gate_model = model['moe_expectation_model']['experts'][total_experts]['model'].gate
            new_gate_model.load_state_dict(old_gate_model.state_dict(), strict=False)

            if gate_no_grad:
                for param in new_gate_model.parameters():
                    param.requires_grad = False
                new_gate_model.out = nn.Linear(in_features=32, out_features=num_experts)
                
            gate_model = new_gate_model

            models = {'moe_expectation_model':{'model':moe_expectation_model,'loss':cross_entropy_loss().to(device),
                                           'experts':{}},}

            for key, val in models.items():

                # gate_model = gate_layers(total_experts).to(device)                

                moe_model = val['model'](total_experts, num_classes,
                                         experts=new_expert_models, gate= gate_model, device=device).to(device)

                optimizer_moe = optim.Adam(moe_model.parameters(), lr=0.001, amsgrad=False)


                hist = moe_model.train(trainloader, testloader,  val['loss'], optimizer_moe = optimizer_moe,
                                       T = T, accuracy=accuracy, epochs=num_epochs)
                val['experts'][total_experts] = {'model':moe_model, 'history':hist}

            plot_file = generate_plot_file('new_'+m, T[0], w_importance=w_importance, w_sample_sim_same=w_sample_sim_same,w_sample_sim_diff=w_sample_sim_diff,
                                           specific=str(num_classes)+'_'+str(total_experts)+'_models.pt')
        
            if os.path.exists(os.path.join(model_path, plot_file)):
                n_run_models_1 = torch.load(open(os.path.join(model_path, plot_file),'rb'))
                
            n_run_models_1.append(models)                                
            torch.save(n_run_models_1,open(os.path.join(model_path, plot_file),'wb'))
            n_run_models_1 = []
            print(plot_file)        


## Experiments

### Experiment 1: Attentive gate MoE model training

In [60]:
# Model with gate and expert parameters initialized to default values
model_1 = 'cifar100_with_attention'

In [61]:
total_experts = 20

In [62]:
num_epochs = 40

In [63]:
runs = 5

In [None]:
train_with_attention(model_1, cifar100_trainloader, cifar100_testloader, runs, num_classes=num_classes, 
                 total_experts=total_experts, num_epochs=num_epochs)

### Experiment 2: Attentive gate MoE model training with $L_{importance}$ regularization

In [377]:
# Model with gate and expert parameters initialized to default values
model_2 = 'cifar100_with_attn_reg'

In [378]:
total_experts = 20

In [379]:
num_epochs = 40

In [55]:
w_importance_range = [i * 0.2 for i in range(1, 6)]
print('w_importance_range = ', ['{:.1f}'.format(w) for w in w_importance_range])

w_importance_range =  ['0.2', '0.4', '0.6', '0.8', '1.0']


In [381]:
temps = [[1.0]*num_epochs]

In [382]:
runs = 5

In [383]:
train_with_attention(model_2, cifar100_trainloader, cifar100_testloader, runs, temps=temps, 
                     w_importance_range=w_importance_range, 
                     num_classes=num_classes, total_experts=total_experts, num_epochs=num_epochs)

w_importance 0.2
w_sample_sim_same 0.0
w_sample_sim_diff 0.0
Run: 1
epoch 0 training loss 3.83 , training accuracy 0.13 , test accuracy 0.18
epoch 1 training loss 3.24 , training accuracy 0.23 , test accuracy 0.25
epoch 2 training loss 2.97 , training accuracy 0.28 , test accuracy 0.28
epoch 3 training loss 2.79 , training accuracy 0.32 , test accuracy 0.32
epoch 4 training loss 2.66 , training accuracy 0.34 , test accuracy 0.35
epoch 5 training loss 2.54 , training accuracy 0.37 , test accuracy 0.37
epoch 6 training loss 2.44 , training accuracy 0.39 , test accuracy 0.38
epoch 7 training loss 2.37 , training accuracy 0.41 , test accuracy 0.40
epoch 8 training loss 2.29 , training accuracy 0.42 , test accuracy 0.41
epoch 9 training loss 2.24 , training accuracy 0.44 , test accuracy 0.42
epoch 10 training loss 2.19 , training accuracy 0.45 , test accuracy 0.43
epoch 11 training loss 2.12 , training accuracy 0.47 , test accuracy 0.44
epoch 12 training loss 2.08 , training accuracy 0.47 ,

### Experiment 3: Attentive gate MoE model training with sample similarity $L_s$ regularization

In [368]:
# Model with gate and expert parameters initialized to default values
model_3 = 'cifar100_with_attn_reg'

In [369]:
total_experts = 20

In [370]:
num_epochs = 40

In [56]:
w_sample_sim_same_range = [1e-5,1e-4,1e-3]
w_sample_sim_diff_range = [1e-7, 1e-6, 1e-5,1e-4,1e-3,1e-2,1e-1]
print('w_sample_sim_same_range = ', w_sample_sim_same_range)
print('w_sample_sim_diff_range = ', w_sample_sim_diff_range)

w_sample_sim_same_range =  [1e-05, 0.0001, 0.001]
w_sample_sim_diff_range =  [1e-07, 1e-06, 1e-05, 0.0001, 0.001, 0.01, 0.1]


In [372]:
runs = 5

In [373]:
train_with_attention(model_3, cifar100_trainloader, cifar100_testloader, runs, temps, 
                     w_sample_sim_same_range=w_sample_sim_same_range, w_sample_sim_diff_range=w_sample_sim_diff_range, 
                     num_classes=num_classes, total_experts=total_experts, num_epochs=num_epochs)

w_importance 0.0
w_sample_sim_same 0.0001
w_sample_sim_diff 1e-06
Run: 1
epoch 0 training loss 3.89 , training accuracy 0.11 , test accuracy 0.15
epoch 1 training loss 3.31 , training accuracy 0.20 , test accuracy 0.22
epoch 2 training loss 3.04 , training accuracy 0.25 , test accuracy 0.27
epoch 3 training loss 2.86 , training accuracy 0.29 , test accuracy 0.29
epoch 4 training loss 2.73 , training accuracy 0.32 , test accuracy 0.33
epoch 5 training loss 2.62 , training accuracy 0.34 , test accuracy 0.33
epoch 6 training loss 2.53 , training accuracy 0.36 , test accuracy 0.36
epoch 7 training loss 2.45 , training accuracy 0.38 , test accuracy 0.37
epoch 8 training loss 2.38 , training accuracy 0.40 , test accuracy 0.39
epoch 9 training loss 2.33 , training accuracy 0.41 , test accuracy 0.40
epoch 10 training loss 2.27 , training accuracy 0.42 , test accuracy 0.41
epoch 11 training loss 2.22 , training accuracy 0.43 , test accuracy 0.41
epoch 12 training loss 2.18 , training accuracy 0

### Experiment 4: Distilling attentive gate MoE model

#### Distilling attentive gate MoE model with $L_{importance}$ regularization

In [None]:
m = 'cifar100_with_attn_reg'

total_experts = 20
num_classes = 100
num_epochs = 40

w_importance_range = [i * 0.2 for i in range(1, 6)]
print('w_importance_range = ', ['{:.1f}'.format(w) for w in w_importance_range])

train_from_model(m, num_epochs, num_classes, total_experts, w_importance_range=w_importance_range, 
                 trainloader=cifar100_trainloader, testloader=cifar100_testloader)

#### Distilling attentive gate MoE model with $L_{s}$ regularization

In [None]:
m = 'cifar100_with_attn_reg'

total_experts = 20
num_classes = 100
num_epochs = 40

w_sample_sim_same_range = [1e-5,1e-4,1e-3]
w_sample_sim_diff_range = [1e-7,1e-6,1e-5,1e-4,1e-3,1e-2,1e-1]
print('w_sample_sim_same_range = ', w_sample_sim_same_range)
print('w_sample_sim_diff_range = ', w_sample_sim_diff_range)

train_from_model(m, num_epochs, num_classes, total_experts,
                 w_sample_sim_same_range=w_sample_sim_same_range,
                 w_sample_sim_diff_range=w_sample_sim_diff_range,
                 trainloader=cifar100_trainloader, testloader=cifar100_testloader)

w_sample_sim_same_range =  [1e-05, 0.001]
w_sample_sim_diff_range =  [1e-07, 1e-06, 1e-05, 0.0001, 0.001, 0.01, 0.1]
w_importance 0.0
w_sample_sim_same 1e-05
w_sample_sim_diff 1e-07
epoch 0 training loss 2.63 , training accuracy 0.36 , test accuracy 0.36
epoch 1 training loss 2.22 , training accuracy 0.43 , test accuracy 0.38
epoch 2 training loss 2.10 , training accuracy 0.45 , test accuracy 0.41
epoch 3 training loss 2.01 , training accuracy 0.47 , test accuracy 0.42
epoch 4 training loss 1.96 , training accuracy 0.48 , test accuracy 0.43
epoch 5 training loss 1.91 , training accuracy 0.50 , test accuracy 0.44
epoch 6 training loss 1.86 , training accuracy 0.51 , test accuracy 0.44
epoch 7 training loss 1.81 , training accuracy 0.52 , test accuracy 0.45
epoch 8 training loss 1.77 , training accuracy 0.53 , test accuracy 0.46
epoch 9 training loss 1.74 , training accuracy 0.53 , test accuracy 0.45
epoch 10 training loss 1.70 , training accuracy 0.54 , test accuracy 0.46
epoch 11 train

## Results

### Collect the train error, test error, mutual information $I(E;Y)$, sample entropy $H_s$ and expert usage entropy $H_u$ for all the models trained with CIFAR-100 dataset. Store the reuslts in the '../results/cifar100_results.csv' file.

In [27]:
pretrained_model_path = '../models/hidden_256'

In [30]:
results_file = 'cifar100_results_hidden_256.csv'

In [31]:
import sys
sys.path.append('../src')

In [32]:
import csv
from helper import moe_models

def collect_results(m, temps=[1.0], w_importance_range=[0.0], 
                    w_sample_sim_same_range=[0.0], w_sample_sim_diff_range=[0.0],
                    total_experts=5, num_classes=10, num_epochs=20, 
                    testloader=None, model_path=None, results_path=None, filename = results_file):
    
    filename = os.path.join(results_path, filename)
    
    if os.path.exists(filename):
        p = 'a'
    else:
        p = 'w'
        
    header = ['filename', 'train error', 'test error','mutual information', 'sample entropy', 'experts usage']
    
    with open(filename, p) as f:
                
        writer = csv.writer(f)        
        
        if p == 'w':            
            writer.writerow(header)
        
        for w_importance, w_sample_sim_same, w_sample_sim_diff in product(w_importance_range, w_sample_sim_same_range, w_sample_sim_diff_range):
            plot_file = generate_plot_file(m, w_importance=w_importance, w_sample_sim_same=w_sample_sim_same, w_sample_sim_diff=w_sample_sim_diff, 
                                           specific=str(num_classes)+'_'+str(total_experts)+'_models.pt')

            models = torch.load(open(os.path.join(model_path, plot_file),'rb'), map_location=device)
            for _ in range(len(models)):
                mod = models.pop()
                data = [plot_file]
                # model
                model = mod['moe_expectation_model']['experts'][total_experts]['model']
                # history
                history = mod['moe_expectation_model']['experts'][total_experts]['history']
                # train error
                data.append(1-history['accuracy'][-1].item())
                running_test_accuracy = 0.0
                running_entropy = 0.0
                num_batches = 0
                ey =  torch.zeros((num_classes, total_experts)).to(device)
                for test_inputs, test_labels in testloader:
                    test_inputs, test_labels = test_inputs.to(device, non_blocking=True), test_labels.to(device, non_blocking=True)                
                    outputs = model(test_inputs)
                    running_test_accuracy += accuracy(outputs, test_labels)
                    
                    selected_experts = torch.argmax(model.gate_outputs, dim=1)
                    y = test_labels
                    e = selected_experts
                    for j in range(y.shape[0]):
                        ey[int(torch.argmax(model.expert_outputs[j,e[j],:])), int(e[j])] += 1

                    running_entropy += moe_models.entropy(model.gate_outputs)
                    
                    num_batches+=1
 
                mutual_EY,_,_,_ = moe_models.mutual_information(ey.detach())
    
                test_error = 1-(running_test_accuracy/num_batches)
                data.append(test_error.item())
                data.append(mutual_EY.item())
                
                data.append(running_entropy.item()/num_batches)  
                gate_probabilities_sum = torch.mean(model.gate_outputs, dim=0)            
                data.append(entropy(gate_probabilities_sum).item())
                
                writer.writerow(data)
            
            

In [33]:
total_experts = 20
num_classes = 100
num_epochs = 40

In [27]:
model_path = os.path.join(pre_trained_model_path,'cifar10')
results_path = os.path.join(results_path,'test')

In [411]:
m = 'cifar100_without_reg'
collect_results(m, total_experts=total_experts, num_classes=num_classes, num_epochs=num_epochs, 
                testloader=cifar100_testloader, model_path=pretrained_model_path, results_path=results_path)

In [72]:
m = 'cifar100_with_reg'
w_importance_range = [i * 0.2 for i in range(1, 6)]
collect_results(m, w_importance_range=w_importance_range,
                total_experts=total_experts, num_classes=num_classes, num_epochs=num_epochs, 
                testloader=cifar100_testloader, model_path=pretrained_model_path, results_path=results_path)

device cuda:0
Files already downloaded and verified
Files already downloaded and verified


In [32]:
m = 'cifar100_with_reg'
w_sample_sim_same_range = [1e-5]
w_sample_sim_diff_range = [1e-7, 1e-6, 1e-5,1e-4,1e-3,1e-2,1e-1]
collect_results(m, w_sample_sim_same_range=w_sample_sim_same_range, w_sample_sim_diff_range=w_sample_sim_diff_range,
                total_experts=total_experts, num_classes=num_classes, num_epochs=num_epochs, 
                testloader=cifar100_testloader, model_path=pretrained_model_path, results_path=results_path)

In [412]:
m = 'cifar100_with_attention'
collect_results(m, total_experts=total_experts, num_classes=num_classes, num_epochs=num_epochs, 
                testloader=cifar100_testloader, model_path=pretrained_model_path, results_path=results_path)

In [30]:
m = 'cifar100_with_attn_reg'
w_importance_range = [i * 0.2 for i in range(1, 6)]
collect_results(m, w_importance_range=w_importance_range,
                total_experts=total_experts, num_classes=num_classes, num_epochs=num_epochs, 
                testloader=cifar100_testloader, model_path=pretrained_model_path, results_path=results_path)

In [31]:
m = 'cifar100_with_attn_reg'
w_sample_sim_same_range = [1e-5,1e-4,1e-3]
w_sample_sim_diff_range = [1e-7, 1e-6, 1e-5,1e-4,1e-3,1e-2,1e-1]
collect_results(m, w_sample_sim_same_range=w_sample_sim_same_range, w_sample_sim_diff_range=w_sample_sim_diff_range,
                total_experts=total_experts, num_classes=num_classes, num_epochs=num_epochs, 
                testloader=cifar100_testloader, model_path=pretrained_model_path, results_path=results_path)

device cuda:0


In [60]:
m = 'new_cifar100_with_attn_reg'
w_importance_range = [i * 0.2 for i in range(1, 6)]
w_importance_range = [1.0]
collect_results(m, w_importance_range=w_importance_range,
                total_experts=total_experts, num_classes=num_classes, num_epochs=num_epochs, 
                testloader=cifar100_testloader, model_path=pretrained_model_path, results_path=results_path)

In [49]:
m = 'new_cifar100_with_attn_reg'
w_sample_sim_same_range = [1e-7, 1e-6]
w_sample_sim_same_range = [1e-4]
w_sample_sim_diff_range = [1e-6, 1e-5,1e-4,1e-3,1e-2,1e-1]
w_sample_sim_diff_range = [1e-7]
collect_results(m, w_sample_sim_same_range=w_sample_sim_same_range, w_sample_sim_diff_range=w_sample_sim_diff_range,
                total_experts=total_experts, num_classes=num_classes, num_epochs=num_epochs, 
                testloader=cifar100_testloader, model_path=pretrained_model_path, results_path=results_path)

### Final results

#### These are the final performance results, reported in the paper, on the test data for the model with the minimum training error for each category of MoE model and training method.

In [34]:
import pandas as pd
filename = os.path.join(results_path, results_file)
data = pd.read_csv(filename)

In [35]:
final_results_filename = os.path.join(results_path,'cifar100_final_results_hidden_256.csv')

In [36]:
models = ['cifar100_single_model', 'cifar100_without_reg', 'cifar100_with_reg_importance', 
          'cifar100_with_reg_sample', 'cifar100_with_attention', 'cifar100_with_attn_reg_importance',
          'cifar100_with_attn_reg_sample', 'new_cifar100_with_attn_reg_importance',
          'new_cifar100_with_attn_reg_sample'] 
# 'cifar100_single_model', 
header = True
for i, m in enumerate(models):
    model_data = data[data['filename'].str.startswith(m)]
    if model_data.empty:
        continue
    if i:
        header = False
    row = data.loc[[data[data['filename'].str.startswith(m)]['train error'].idxmin()]]
    # row = data.loc[[data[data['filename'].str.startswith(m)]['test error'].idxmin()]]
    # row = data.loc[[data[data['filename'].str.startswith(m)]['experts usage'].idxmax()]]
    std = '{:.3f}'.format(data[data['filename'].str.startswith(m)]['test error'].std())
    row['val error std'] = std
    row.to_csv(final_results_filename,mode='a',header=header, index=False, float_format='%.3f')

In [37]:
data = pd.read_csv(final_results_filename)
data.fillna('NA', inplace=True)
data

Unnamed: 0,filename,train error,test error,mutual information,sample entropy,experts usage,val error std
0,cifar100_single_model_2,0.398,0.477,,,,0.006
1,cifar100_without_reg_100_20_models.pt,0.398,0.46,0.967,0.306,1.023,0.01
2,cifar100_with_reg_importance_0.2_100_20_models.pt,0.408,0.483,4.177,1.315,3.981,0.007
3,cifar100_with_reg_sample_sim_same_1e-05_sample...,0.393,0.457,1.424,0.381,1.279,0.012
4,cifar100_with_attention_100_20_models.pt,0.254,0.45,1.792,0.463,2.178,0.006
5,cifar100_with_attn_reg_importance_0.2_100_20_m...,0.254,0.451,3.684,1.036,4.141,0.005
6,cifar100_with_attn_reg_sample_sim_same_0.0001_...,0.238,0.456,3.126,0.773,3.252,0.016
7,new_cifar100_with_attn_reg_importance_0.2_100_...,0.399,0.531,3.179,1.75,3.843,0.131
8,new_cifar100_with_attn_reg_sample_sim_same_0.0...,0.15,0.482,1.605,0.718,2.541,0.065
