Initialize shallow feedforward fully-connected network with V1 type weights and classify __full mnist__ dataset using __Stochastic Gradient descent__.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from data_fns import load_mnist
from estimator import classical_weights, V1_inspired_weights
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
import pickle

In [None]:
# load data
train, train_labels, test, test_labels = load_mnist('./data/mnist/')
X_train = torch.from_numpy(train).float().to('cuda')
y_train = torch.from_numpy(train_labels).long().to('cuda')
X_test = torch.from_numpy(test).float().to('cuda')
y_test = torch.from_numpy(test_labels).long().to('cuda')
n, d = X_train.shape

train_set = TensorDataset(X_train, y_train)
train_loader = DataLoader(dataset=train_set, batch_size=1, shuffle=True)

In [None]:
class V1_net(nn.Module):
    def __init__(self, hidden_size, scale):
        super().__init__()
        self.fc1 = nn.Linear(d, hidden_size)
        self.fc1.weight.data = torch.FloatTensor(V1_inspired_weights(hidden_size, d, t=5, l=3, scale=scale))
        self.output = nn.Linear(hidden_size, 10)
        
    def forward(self, inputs):
        x = torch.relu(self.fc1(inputs))
        return self.output(x)
    
class He_net(nn.Module):
    def __init__(self, hidden_size, scale):
        super().__init__()
        self.fc1 = nn.Linear(d, hidden_size)
        torch.nn.init.kaiming_normal_(self.fc1.weight)
        self.output = nn.Linear(hidden_size, 10)
        
    def forward(self, inputs):
        x = torch.relu(self.fc1(inputs))
        return self.output(x)
    
class RF_net(nn.Module):
    def __init__(self, hidden_size, scale):
        super().__init__()
        self.fc1 = nn.Linear(d, hidden_size)
        self.fc1.weight.data = torch.FloatTensor(classical_weights(hidden_size, d, scale=scale))
        self.output = nn.Linear(hidden_size, 10)
        
    def forward(self, inputs):
        x = torch.relu(self.fc1(inputs))
        return self.output(x)

In [None]:
def predict(model, X):
    return model(X).data.max(1)[1]

def error(model, X, y):
    y_pred = predict(model, X)
    accuracy = 1.0 * torch.sum(y_pred == y) / len(y)
    return 1 - accuracy

In [None]:
n_trials, n_epochs = 5, 10
models = {'V1': V1_net, 'He': He_net}
h_list = [50, 100, 400, 1000]
lr_list = [1e-3, 1e-2, 1e-1, 1e0]

train_err = {m: {h: {lr: np.zeros((n_trials, n_epochs)) for lr in lr_list} for h in h_list} for m in models.keys()}
test_err = {m: {h: {lr: np.zeros((n_trials, n_epochs)) for lr in lr_list} for h in h_list} for m in models.keys()}
loss_list = {m: {h: {lr: np.zeros((n_trials, n_epochs)) for lr in lr_list} for h in h_list} for m in models.keys()}

In [None]:
scale = 2/d
t, l = 5, 3
loss_func = nn.CrossEntropyLoss()

for h in h_list:
    for lr in lr_list:
        for m, network in models.items():
            for i in range(n_trials):
                model = network(h, scale).to('cuda')
                optim = torch.optim.SGD(model.parameters(), lr=lr)
                for j in range(n_epochs):
                    for x_batch, y_batch in train_loader:
                        optim.zero_grad()
                        loss = loss_func(model(x_batch), y_batch)
                        loss.backward()
                        optim.step()

                    train_err[m][h][lr][i, j] = error(model, X_train, y_train)
                    test_err[m][h][lr][i, j] = error(model, X_test, y_test)
                    loss_list[m][h][lr][i, j] = loss_func(model(X_train), y_train)

                    if (j % 1 == 0):
                        print('Trial %d, Epoch: %d, %s model, h=%d, lr=%0.5f, Loss=%0.5f, test err=%0.3f' % (i,j, m, h, lr, 
                                                                                                             loss_list[m][h][lr][i, j], 
                                                                                                            test_err[m][h][lr][i, j]))
results = {'test_err': test_err, 'train_err': train_err, 'loss': loss_list}
with open('results/initialize_mnist/full_data_SGD/clf_t=%0.2f_l=%0.2f.pickle' % (t, l), 'wb') as handle:
    pickle.dump(results, handle, protocol=pickle.HIGHEST_PROTOCOL)

# plot results

In [None]:
import pickle
import numpy as np
import matplotlib.pyplot as plt

t, l, n_epochs= 5, 3, 3001
models = ['V1', 'He']
h_list = [50, 100, 400, 1000]
lr_list = [1e-3, 1e-2, 1e-1, 1e0]

for h in h_list:
    for lr in lr_list:
        with open('results/initialize_mnist/full_data_SGD/clf_t=%0.2f_l=%0.2f.pickle' % (t, l), 'rb') as handle:
            sims = pickle.load(handle)

        fig = plt.figure(figsize=(12, 5))
        plt.suptitle(r'Shallow FFW FC net w/ SGD. h=%d, lr=%0.4f, '% (h, lr), fontsize=16)
        
        # loss
        ax = fig.add_subplot(131)
        plt.title('Network loss', fontsize=16)
        for m in models:
            avg_loss = np.mean(sims['loss'][m][h][lr], axis=0)
            std_loss = np.std(sims['loss'][m][h][lr], axis=0)                       
            plt.plot(np.arange(n_epochs), avg_loss, label=m, lw=3)
            plt.fill_between(np.arange(n_epochs), avg_loss - std_loss, avg_loss + std_loss, alpha=0.2)
        plt.xlabel('Epoch', fontsize=20)
        plt.ylabel('Training loss', fontsize=20)
        ax.tick_params(axis = 'both', which = 'major', labelsize = 14, width=2, length=6)
        plt.yscale('log')
        plt.legend(fontsize=18)
        
        # train err
        ax = fig.add_subplot(132)
        plt.title('Train error', fontsize=16)
        for m in models:
            avg_train_err = np.mean(sims['train_err'][m][h][lr], axis=0)
            std_train_err = np.std(sims['train_err'][m][h][lr], axis=0)
            plt.plot(np.arange(n_epochs), avg_train_err, label=m, lw=3)
            plt.fill_between(np.arange(n_epochs), avg_train_err - std_train_err, avg_train_err + std_train_err,  alpha=0.2)
        
        plt.xlabel('Epoch', fontsize=20)
        plt.ylabel('Training error', fontsize=20)
        ax.tick_params(axis = 'both', which = 'major', labelsize = 14, width=2, length=6)
        plt.yticks(np.arange(0, 1, 0.2))
        plt.yscale('log')
        plt.legend(fontsize=18)

        # test err
        ax = fig.add_subplot(133)
        plt.title('Test error', fontsize=16)
        for m in models:
            avg_test_err = np.mean(sims['test_err'][m][h][lr], axis=0)
            std_test_err = np.std(sims['test_err'][m][h][lr], axis=0)
            plt.plot(np.arange(n_epochs), avg_test_err, label=m, lw=3)
            plt.fill_between(np.arange(n_epochs), avg_test_err - std_test_err, avg_test_err + std_test_err, alpha=0.2)

        plt.xlabel('Epoch', fontsize=20)
        plt.ylabel('Test error', fontsize=20)
        ax.tick_params(axis = 'both', which = 'major', labelsize = 14, width=2, length=6)
        plt.yticks(np.arange(0, 1, 0.2))
        plt.yscale('log')
        plt.legend(fontsize=18)

        plt.tight_layout()
        plt.subplots_adjust(top=0.8)    

        print(h, lr)
        plt.savefig('results/initialize_mnist/full_data_SGD/init_t=%0.2f_l=%0.2f_h=%d_lr=%0.4f.png' % (t, l, h, lr))
        plt.close()

In [None]:
!nvidia-smi