In [4]:
from keras.datasets import fashion_mnist,mnist
import wandb
import numpy as np
from sklearn.metrics import classification_report
import copy
import argparse
from types import SimpleNamespace
import plotly.figure_factory as ff

from utilities.NeuralNetwork import NN
from utilities.HelperFunctions import OneHotEncoder,compute_accuracy_score,compute_confusion_matrix
from utilities.config import * # reading global variables 


def pre_process(x):
    '''
    reshape and normalized the data to bring to 0-1 scale.
    '''
    x=x.reshape(-1,784)
    x=x/255
    return x

def load_data(dataset=fmnist_dataset,split_size=valid_split_size):
    
    '''
    loads and returns data after doing train-valid split.
    '''
    if dataset==fmnist_dataset:
        (x_train, y_train), (x_test, y_test) = fashion_mnist.load_data()
    elif dataset==mnist_dataset:
        (x_train, y_train), (x_test, y_test) = mnist.load_data()
    
    
    x_train,x_valid=x_train[:int(len(x_train)*split_size)],x_train[int(len(x_train)*split_size):] #splitting train into train and valid
    y_train,y_valid=y_train[:int(len(y_train)*split_size)],y_train[int(len(y_train)*split_size):]

    x_train=pre_process(x_train)
    x_valid=pre_process(x_valid)
    x_test=pre_process(x_test) 
    
    one_hot=OneHotEncoder(10)
    y_train=one_hot.transform(y_train)
    y_valid=one_hot.transform(y_valid)
    y_test=one_hot.transform(y_test)
    
    
    return x_train,y_train,x_valid,y_valid,x_test,y_test

x_train,y_train,x_valid,y_valid,x_test,y_test=load_data(mnist_dataset,1-valid_split_size)
    

In [10]:
def concat(*x):
    return np.concatenate(x)

x_train=concat(x_train,x_valid) # adding validation data to train since we are done with hyper paramter search
y_train=concat(y_train.T,y_valid.T).T

params={}
params['layer_size']=128 #64 64 #128
params['batch_size']=64  #16 32 #64 
params['num_layers']=3
params['optimizer']='nadam'
params['activation']='tanh'

params['epochs']=10
params['learning_rate']=0.0001
params['weight_init']='xavier'
params['loss']=entropy_loss
params['weight_decay']=0.0005

def model_fit(params):
    wandb.init(project='Assign_1_DL',config=params)
    wandb.run.name='MNSIT '+"-batch_"+str(wandb.config.batch_size)+ "-layerSize_"+str(wandb.config.layer_size)
    params=SimpleNamespace(**params)
    layers=[params.layer_size]*params.num_layers
    layers.append(10)
    obj=NN(784,layers,params)
    obj.train(x_train.T,y_train,x_test.T,y_test)
    print(": Done")
    return obj

obj=model_fit(params)

VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01675302639999927, max=1.0)…

epoch 1 : train loss = 0.43 valid loss = 0.24 train accuracy = 93.38 valid accuracy = 92.72
epoch 2 : train loss = 0.19 valid loss = 0.17 train accuracy = 95.53 valid accuracy = 94.77
epoch 3 : train loss = 0.14 valid loss = 0.14 train accuracy = 96.58 valid accuracy = 95.70
epoch 4 : train loss = 0.11 valid loss = 0.12 train accuracy = 97.26 valid accuracy = 96.11
epoch 5 : train loss = 0.09 valid loss = 0.11 train accuracy = 97.69 valid accuracy = 96.45
epoch 6 : train loss = 0.07 valid loss = 0.10 train accuracy = 98.06 valid accuracy = 96.69
epoch 7 : train loss = 0.06 valid loss = 0.09 train accuracy = 98.36 valid accuracy = 97.00
epoch 8 : train loss = 0.05 valid loss = 0.09 train accuracy = 98.58 valid accuracy = 97.13
epoch 9 : train loss = 0.05 valid loss = 0.09 train accuracy = 98.79 valid accuracy = 97.26
epoch 10 : train loss = 0.04 valid loss = 0.08 train accuracy = 98.95 valid accuracy = 97.40
: Done


In [11]:
out=obj.predict_probas(x_test.T)
pred=np.argmax(out,axis=0)
test_acc=compute_accuracy_score(np.argmax(y_test,axis=0),pred)*100
print("Accuracy on test data is ",test_acc)
wandb.log({'test accuracy':test_acc})
wandb.finish()

Accuracy on test data is  97.39999999999999


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
test accuracy,▁
train accuracy,▁▄▅▆▆▇▇███
train loss,█▄▃▂▂▂▁▁▁▁
valid accuracy,▁▄▅▆▇▇▇███
valid loss,█▅▄▃▂▂▂▁▁▁

0,1
test accuracy,97.4
train accuracy,98.94861
train loss,0.03949
valid accuracy,97.4
valid loss,0.08361
