# **Multiclass Classification (MNIST) Demo**

In [None]:
# run this cell if you're using Google Colab
try:
    import google.colab
    IN_COLAB = True
except:
    IN_COLAB = False

if IN_COLAB:
  import sys
  import os
  !git clone https://github.com/satishchandrareddy/WhatisML.git
  code_location = "/content/WhatisML/Code/Supervised"
  sys.path.append(os.path.abspath(code_location)) 

## **Import Libraries**

In [None]:
from IPython.display import HTML
import load_mnist
import NeuralNetwork
import matplotlib.pyplot as plt
import numpy as np
import Optimizer
import plot_results

## **Settings to Change**
If you would like to experiment, here are settings to change

In [None]:
# Things to try:
# Change random seed to get different random numbers: seed
# Change number of training data samples: ntrain up to 60000
# Change number of validation data samples: nvalid up to 10000
# Change learning rate for optimization: learning_rate
# Change number of iterations: niterations
seed = 10
ntrain = 6000
nvalid = 1000
learning_rate = 0.02
niteration = 40

### **1. Set up Data**

In [None]:
nclass = 10
Xtrain,Ytrain,Xvalid,Yvalid = load_mnist.load_mnist(ntrain,nvalid)
plot_results.plot_data_mnist(Xtrain,Ytrain)

### **2. Define Model**

In [None]:
nfeature = Xtrain.shape[0]
np.random.seed(seed)
model = NeuralNetwork.NeuralNetwork(nfeature)
model.add_layer(128,"relu")
model.add_layer(nclass,"softmax")

### **3. Compile model**

In [None]:
optimizer = Optimizer.Adam(learning_rate,0.9,0.999,1e-7)
model.compile("crossentropy",optimizer)
model.summary()

### **4. Learning**

In [None]:
history = model.fit(Xtrain,Ytrain,niteration)

### **5. Plot results**

In [None]:
# plot loss
plot_results.plot_results_history(history,["loss"])

In [None]:
# Compute predicted results and accuracy for validation dataset
Yvalid_pred = model.predict(Xvalid)
accuracy = model.accuracy(Yvalid,Yvalid_pred)
print("Accuracy for Validation Data Set: {}".format(accuracy))

In [None]:
# prediction results animation
ani = plot_results.plot_results_mnist_animation(Xvalid,Yvalid,Yvalid_pred,25)
vid = HTML(ani.to_html5_video())
vid