# Model comparison using MNIST

Shun Li, 03/07/24

## Initialize and load dataset

In [1]:
import torch
import torch.nn as nn
import torchvision.datasets as datasets
import torchvision.transforms as transforms

from EPLHb import EPLHb, gd, adam

import numpy as np
from scipy import stats

%matplotlib inline
import matplotlib.pyplot as plt
from matplotlib import rcParams

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

rcParams['axes.spines.top'] = False
rcParams['axes.spines.right'] = False
rcParams['figure.autolayout'] = True

In [6]:
# Downloading MNIST data

train_data = datasets.MNIST(root = './data', train = True,
                        transform = transforms.ToTensor(), download = True)

test_data = datasets.MNIST(root = './data', train = False,
                       transform = transforms.ToTensor())

# Loading the data
batch_size = 100 # the size of input data took for one iteration

train_loader = torch.utils.data.DataLoader(dataset = train_data,batch_size = batch_size,shuffle = True)
test_loader = torch.utils.data.DataLoader(dataset = test_data,batch_size = batch_size,shuffle = False)

## Models to be tested

Different initialization scenarios
- random initialization of every synapses
- Dale's law initialization of every synapses

Different network structure
- LHb to DAN is all inhibitory + LHb to LHb is all excitatory (if RNN)
- Every layer have mixed excitatory/inhibitory output

Different update methods
- Normal ADAM
- Fixed sign ADAM

In [7]:
initialization = ['random','dales_law']
network_struct = ['real','mixed']
update_methods = ['corelease','fixed_sign']

LHb_network = ['MLP','RNN']

In [8]:
EP_size = 784 # img_size = (28,28) ---> 28*28=784 in total
LHb_size = 500 # number of nodes at hidden layer
DAN_size = 10 # number of output classes discrete range [0,9]
num_epochs = 10 # 20 # number of times which the entire dataset is passed throughout the model
lr = 1e-3 # size of step

prob_EP_to_LHb = 1
prob_LHb_to_LHb = 1
prob_LHb_to_DAN = 1

n_networks = 2 # number of networks to train

## Train models

In [9]:
training_loss_summary, test_accuracy_summar = {}, {}

for LHb in LHb_network:
    for init in initialization:
        for struct in network_struct:
            for method in update_methods:
                print('LHb: ',LHb, '; Initialization:',init,'; Network:',struct,'; Method:',method)
                
                # Initialize network-specific loss and accuracy summary
                network_training_loss, network_test_accuracy = [], []

                # Initialize network params
                if LHb == 'MLP': rnn = False
                else: rnn = True
                if init == 'random': fixed_sign_init = False
                else: fixed_sign_init = True
                if struct == 'real': real_circuit = True
                else: real_circuit = False
                if method == 'corelease': fixed_sign_update = False
                else: fixed_sign_update = True

                # Train n_networks networks
                for i in range(1,n_networks+1):
                    # Initialize a network
                    net = EPLHb(EP_size,LHb_size,DAN_size,
                                rnn=rnn,fixed_sign=fixed_sign_init,real_circuit=real_circuit,
                                prob_EP_to_LHb=prob_EP_to_LHb,prob_LHb_to_LHb=prob_LHb_to_LHb,prob_LHb_to_DAN=prob_LHb_to_DAN)
                    initial_params = net.record_params(calc_sign=False)
                    training_loss, test_accuracy = [], []
                    if torch.cuda.is_available(): net.cuda()

                    # Train on original data
                    optimizer = adam(net.parameters(), lr=lr, fixed_sign=fixed_sign_update)
                    # scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.95)
                    training_loss, test_accuracy = net.train_model(num_epochs,train_loader,optimizer,
                                                    test_loader=test_loader,print_epoch=False,loss='CrossEntropyLoss')
                    training_loss.extend(training_loss)
                    test_accuracy.extend(test_accuracy)

                    # Train on flipped data
                    # optimizer = adam(net.parameters(), lr=lr, fixed_sign=fixed_sign_update)
                    # training_loss = net.train_model(num_epochs,flip_loader,optimizer,print_epoch=False)
                    # net_training_loss.extend(training_loss)

                    network_training_loss.append(training_loss)
                    network_test_accuracy.append(test_accuracy)
                    print('Finished training network %d/%d' %(i,n_networks))

                # Convert list to numpy array
                network_training_loss = np.array(network_training_loss)
                network_test_accuracy = np.array(network_test_accuracy)

                # Store name and stats of network to summary
                network_name = LHb+'_'+init+'_'+struct+'_'+method
                training_loss_summary[network_name] = network_training_loss
                test_accuracy_summar[network_name] = network_test_accuracy


LHb:  MLP ; Initialization: random ; Network: real ; Method: corelease
Finished training co-release network 1/2
Finished training co-release network 2/2
LHb:  MLP ; Initialization: random ; Network: real ; Method: fixed_sign
Finished training co-release network 1/2
Finished training co-release network 2/2
LHb:  MLP ; Initialization: random ; Network: mixed ; Method: corelease
Finished training co-release network 1/2
Finished training co-release network 2/2
LHb:  MLP ; Initialization: random ; Network: mixed ; Method: fixed_sign
Finished training co-release network 1/2
Finished training co-release network 2/2
LHb:  MLP ; Initialization: dales_law ; Network: real ; Method: corelease
Finished training co-release network 1/2
Finished training co-release network 2/2
LHb:  MLP ; Initialization: dales_law ; Network: real ; Method: fixed_sign
Finished training co-release network 1/2
Finished training co-release network 2/2
LHb:  MLP ; Initialization: dales_law ; Network: mixed ; Method: corele

## Summary analysis

In [13]:
for network in training_loss_summary:
    print(network)
    print('Training Loss:', np.mean(training_loss_summary[network],axis=0))
    print('Test Accuracy:', np.mean(test_accuracy_summar[network],axis=0))

MLP_random_real_corelease
Training Loss: [2.296679   2.2139473  2.1373115  ... 0.8231911  0.8318355  0.80935943]
Test Accuracy: [17.4      89.399994 90.74     91.61     92.285    92.565    92.955
 93.244995 93.73     94.195    94.229996 94.64     94.96     95.085
 95.325    95.325    95.405    95.66     95.64     95.835    96.05
 96.195    96.119995 96.175    96.46     96.44     96.565    96.7
 96.770004 96.68     96.88     96.880005 96.97     96.92     96.979996
 97.       97.185    97.134995 97.155    97.155    97.16     97.275
 97.25     97.295    97.350006 97.415    97.369995 97.41     97.405
 97.475006 97.405    97.57     97.42     97.47     97.545    97.54
 97.61     97.595    97.485    97.53     17.4      89.399994 90.74
 91.61     92.285    92.565    92.955    93.244995 93.73     94.195
 94.229996 94.64     94.96     95.085    95.325    95.325    95.405
 95.66     95.64     95.835    96.05     96.195    96.119995 96.175
 96.46     96.44     96.565    96.7      96.770004 96.68  

In [10]:
# # Loss
# mean_cr_loss = np.mean(cr_training_loss_summary,axis=0)
# sem_cr_loss = stats.sem(cr_training_loss_summary)
# mean_fs_loss = np.mean(fs_training_loss_summary,axis=0)
# sem_fs_loss = stats.sem(fs_training_loss_summary)
# mean_fs_posneg_loss = np.mean(fs_posneg_training_loss_summary,axis=0)
# sem_fs_posneg_loss = stats.sem(fs_posneg_training_loss_summary)

# # Test accuracy
# mean_cr_accuracy = np.mean(cr_test_accuracy_summary,axis=0)
# sem_cr_accuracy = stats.sem(cr_test_accuracy_summary)
# mean_fs_accuracy = np.mean(fs_test_accuracy_summary,axis=0)
# sem_fs_accuracy = stats.sem(fs_test_accuracy_summary)
# mean_fs_posneg_accuracy = np.mean(fs_posneg_test_accuracy_summary,axis=0)
# sem_fs_posneg_accuracy = stats.sem(fs_posneg_test_accuracy_summary)

# # Plot
# fig, axs = plt.subplots(1,2,figsize=(15, 5))

# # Plot loss
# x = np.linspace(1,mean_cr_loss.shape[0],num=mean_cr_loss.shape[0],dtype='int32')
# axs[0].plot(mean_cr_loss, label='Co-release')
# axs[0].fill_between(x,mean_cr_loss+sem_cr_loss,mean_cr_loss-sem_cr_loss,alpha=0.2)
# axs[0].plot(mean_fs_loss, label='Fixed sign')
# axs[0].fill_between(x,mean_fs_loss+sem_fs_loss,mean_fs_loss-sem_fs_loss,alpha=0.2)
# axs[0].plot(mean_fs_posneg_loss, label='Fixed sign without neg output')
# axs[0].fill_between(x,mean_fs_posneg_loss+sem_fs_posneg_loss,mean_fs_posneg_loss-sem_fs_posneg_loss,alpha=0.2)

# axs[0].set_xlabel('Trianing epochs')
# axs[0].set_ylabel('Training loss')
# axs[0].legend()

# # Plot accuracy
# x = np.linspace(1,mean_cr_accuracy.shape[0],num=mean_cr_accuracy.shape[0],dtype='int32')
# axs[0].plot(mean_cr_accuracy, label='Co-release')
# axs[0].fill_between(x,mean_cr_accuracy+sem_cr_accuracy,mean_cr_accuracy-sem_cr_accuracy,alpha=0.2)
# axs[0].plot(mean_fs_accuracy, label='Fixed sign')
# axs[0].fill_between(x,mean_fs_accuracy+sem_fs_accuracy,mean_fs_accuracy-sem_fs_accuracy,alpha=0.2)
# axs[0].plot(mean_fs_posneg_accuracy, label='Fixed sign without neg output')
# axs[0].fill_between(x,mean_fs_posneg_accuracy+sem_fs_posneg_accuracy,mean_fs_posneg_accuracy-sem_fs_posneg_accuracy,alpha=0.2)

# axs[0].set_xlabel('Trianing epochs')
# axs[0].set_ylabel('Training loss')
# axs[0].legend()