# Assignment:
1. Define a network class that regresses to the 7 outputs.
2. Train a sufficiently large network to perform the categorization.
3. Measure the test accuracy of the model by counting the number of accurate labels

# Stretch Goals:
- Test out different network architectures (depth, breadth) and examine training performance

In [67]:
import numpy as np
import csv

rows = []

with open('Fish.csv') as csv_file:
    csv_reader = csv.reader(csv_file, delimiter=',')
    for row in csv_reader:
        rows.append(row)

print(len(rows))
print(rows[0]) # first row is a header
print(rows[1])

rows = rows[1:]

labels = {} # Create a dictionary of label strings to numeric values
for row in rows:
    if row[0] not in labels:
        labels[row[0]]=len(labels)

print(labels)

inputs = np.array([row[1:] for row in rows], dtype='float32')
outputs = np.array([labels[row[0]] for row in rows])
print(outputs)

160
['\ufeffSpecies', 'Weight', 'Length1', 'Length2', 'Length3', 'Height', 'Width']
['Bream', '242', '23.2', '25.4', '30', '11.52', '4.02']
{'Bream': 0, 'Roach': 1, 'Whitefish': 2, 'Parkki': 3, 'Perch': 4, 'Pike': 5, 'Smelt': 6}
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1
 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 3 3 3 3 3 3 3 3 3 3 3 4 4
 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4
 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 6 6 6
 6 6 6 6 6 6 6 6 6 6 6]


In [65]:
def output_to_one_hot(categories, max_val):
    data = np.zeros((len(categories), max_val))
    data[np.arange(len(categories)), categories] = 1
    return data

encodings = output_to_one_hot(outputs, len(labels))
print(encodings[:10])
print(encodings[-10:])

[[1. 0. 0. 0. 0. 0. 0.]
 [1. 0. 0. 0. 0. 0. 0.]
 [1. 0. 0. 0. 0. 0. 0.]
 [1. 0. 0. 0. 0. 0. 0.]
 [1. 0. 0. 0. 0. 0. 0.]
 [1. 0. 0. 0. 0. 0. 0.]
 [1. 0. 0. 0. 0. 0. 0.]
 [1. 0. 0. 0. 0. 0. 0.]
 [1. 0. 0. 0. 0. 0. 0.]
 [1. 0. 0. 0. 0. 0. 0.]]
[[0. 0. 0. 0. 0. 0. 1.]
 [0. 0. 0. 0. 0. 0. 1.]
 [0. 0. 0. 0. 0. 0. 1.]
 [0. 0. 0. 0. 0. 0. 1.]
 [0. 0. 0. 0. 0. 0. 1.]
 [0. 0. 0. 0. 0. 0. 1.]
 [0. 0. 0. 0. 0. 0. 1.]
 [0. 0. 0. 0. 0. 0. 1.]
 [0. 0. 0. 0. 0. 0. 1.]
 [0. 0. 0. 0. 0. 0. 1.]]


In [69]:
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(inputs, encodings)

In [33]:
from torch.autograd import Variable
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as utils

class Net(nn.Module):
    def __init__(self, inputs=1, outputs=1, breadth=500, depth=3):
        super().__init__()
        
        self.input_layer = nn.Linear(inputs, breadth)
        for i in range(depth-2):
            self.add_module('hidden_layer%i' % i, nn.Linear(breadth, breadth))
        self.output_layer = nn.Linear(breadth, outputs)
        
    def forward(self, x):
        layers = list(self.children())
        for layer in layers[:-1]:
            x = F.relu(layer(x))
        x = layers[-1](x)
        return x

In [93]:
from tqdm import trange # Used to provide progress bar



net = Net(inputs=inputs.shape[1], outputs=encodings.shape[1], depth=6, breadth=500)

net.zero_grad()
outputs = net(Variable(torch.Tensor([0] * inputs.shape[1])))
outputs.backward(torch.randn(encodings.shape[1])) # Use random gradients to break symmetry?

learning_rate = 1 # Need to initialize carefully
for f in net.parameters():
    f.data.sub_(f.grad.data * learning_rate)

losses = []

# create your optimizer
optimizer = optim.Adam(net.parameters())
criterion = nn.MSELoss()

num_epochs = 1000
t = trange(num_epochs)
for epoch in t:  # loop over the dataset multiple times

    running_loss = 0.0

    # wrap them in Variable
    reshaped_inputs = torch.from_numpy(X_train) # Structure with each input in its own row
    #reshaped_outputs = true_vals.view(-1, 1) # Neglecting to have outputs and true vals to match dimension is a common mistake.

    #reshaped_inputs = inputs
    reshaped_outputs = torch.from_numpy(y_train)

    # forward + backward + optimize
    train_out = net(reshaped_inputs.float())
    #print(outputs)
    #print(reshaped_outputs)
    loss = criterion(train_out, reshaped_outputs.float())
    losses.append(loss)

    loss.backward()
    optimizer.step()

    # zero the parameter gradients
    optimizer.zero_grad()

    t.set_description('ML (loss=%g)' % loss.item()) # Updates Loss information

print('Finished Training')













  0%|          | 0/1000 [00:00<?, ?it/s][A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=323.436):   0%|          | 0/1000 [00:00<?, ?it/s][A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=5.94656):   0%|          | 0/1000 [00:00<?, ?it/s][A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=1.03276):   0%|          | 0/1000 [00:00<?, ?it/s][A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=0.981415):   0%|          | 0/1000 [00:00<?, ?it/s][A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=1.13394):   0%|          | 0/1000 [00:00<?, ?it/s] [A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=1.39135):   0%|          | 0/1000 [00:00<?, ?it/s][A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=1.39135):   1%|          | 6/1000 [00:00<00:16, 59.37it/s][A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=0.56879):   1%|          | 6/1000 [00:00<00:16, 59.37it/s][A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=0.878214):   1%| 

ML (loss=0.117548):   6%|▌         | 57/1000 [00:00<00:12, 76.93it/s][A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=0.119804):   6%|▌         | 57/1000 [00:00<00:12, 76.93it/s][A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=0.109432):   6%|▌         | 57/1000 [00:00<00:12, 76.93it/s][A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=0.109432):   6%|▋         | 65/1000 [00:00<00:12, 77.29it/s][A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=0.12167):   6%|▋         | 65/1000 [00:00<00:12, 77.29it/s] [A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=0.106028):   6%|▋         | 65/1000 [00:00<00:12, 77.29it/s][A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=0.117298):   6%|▋         | 65/1000 [00:00<00:12, 77.29it/s][A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=0.107853):   6%|▋         | 65/1000 [00:00<00:12, 77.29it/s][A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=0.110862):   6%|▋         | 65/1000 [00:00<00:12, 77.29

ML (loss=0.0963916):  12%|█▏        | 119/1000 [00:01<00:10, 82.70it/s][A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=0.096277):  12%|█▏        | 119/1000 [00:01<00:10, 82.70it/s] [A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=0.0961524):  12%|█▏        | 119/1000 [00:01<00:10, 82.70it/s][A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=0.0960493):  12%|█▏        | 119/1000 [00:01<00:10, 82.70it/s][A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=0.0960493):  13%|█▎        | 128/1000 [00:01<00:10, 81.95it/s][A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=0.0959606):  13%|█▎        | 128/1000 [00:01<00:10, 81.95it/s][A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=0.0958718):  13%|█▎        | 128/1000 [00:01<00:10, 81.95it/s][A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=0.0957728):  13%|█▎        | 128/1000 [00:01<00:10, 81.95it/s][A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=0.0956578):  13%|█▎        | 128/1000 [

ML (loss=0.0955707):  18%|█▊        | 182/1000 [00:02<00:09, 82.32it/s][A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=0.102971):  18%|█▊        | 182/1000 [00:02<00:09, 82.32it/s] [A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=0.0948456):  18%|█▊        | 182/1000 [00:02<00:09, 82.32it/s][A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=0.0947126):  18%|█▊        | 182/1000 [00:02<00:09, 82.32it/s][A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=0.0995281):  18%|█▊        | 182/1000 [00:02<00:09, 82.32it/s][A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=0.0995281):  19%|█▉        | 191/1000 [00:02<00:09, 82.61it/s][A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=0.0931973):  19%|█▉        | 191/1000 [00:02<00:09, 82.61it/s][A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=0.0935568):  19%|█▉        | 191/1000 [00:02<00:09, 82.61it/s][A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=0.0972274):  19%|█▉        | 191/1000 [

ML (loss=0.0855095):  24%|██▍       | 245/1000 [00:03<00:09, 82.50it/s][A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=0.0856737):  24%|██▍       | 245/1000 [00:03<00:09, 82.50it/s][A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=0.0862121):  24%|██▍       | 245/1000 [00:03<00:09, 82.50it/s][A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=0.0854002):  24%|██▍       | 245/1000 [00:03<00:09, 82.50it/s][A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=0.0860245):  24%|██▍       | 245/1000 [00:03<00:09, 82.50it/s][A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=0.0856166):  24%|██▍       | 245/1000 [00:03<00:09, 82.50it/s][A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=0.0857867):  24%|██▍       | 245/1000 [00:03<00:09, 82.50it/s][A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=0.0857867):  25%|██▌       | 254/1000 [00:03<00:08, 83.49it/s][A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=0.0854889):  25%|██▌       | 254/1000 [

ML (loss=0.0930945):  31%|███       | 308/1000 [00:03<00:08, 84.56it/s][A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=0.0908668):  31%|███       | 308/1000 [00:03<00:08, 84.56it/s][A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=0.0971704):  31%|███       | 308/1000 [00:03<00:08, 84.56it/s][A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=0.0870223):  31%|███       | 308/1000 [00:03<00:08, 84.56it/s][A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=0.0936969):  31%|███       | 308/1000 [00:03<00:08, 84.56it/s][A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=0.0905665):  31%|███       | 308/1000 [00:03<00:08, 84.56it/s][A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=0.0866625):  31%|███       | 308/1000 [00:03<00:08, 84.56it/s][A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=0.0926768):  31%|███       | 308/1000 [00:03<00:08, 84.56it/s][A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=0.0926768):  32%|███▏      | 317/1000 [

ML (loss=0.077666):  37%|███▋      | 371/1000 [00:04<00:07, 84.63it/s][A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=0.0775413):  37%|███▋      | 371/1000 [00:04<00:07, 84.63it/s][A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=0.0774443):  37%|███▋      | 371/1000 [00:04<00:07, 84.63it/s][A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=0.077359):  37%|███▋      | 371/1000 [00:04<00:07, 84.63it/s] [A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=0.077283):  37%|███▋      | 371/1000 [00:04<00:07, 84.63it/s][A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=0.0771773):  37%|███▋      | 371/1000 [00:04<00:07, 84.63it/s][A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=0.077052):  37%|███▋      | 371/1000 [00:04<00:07, 84.63it/s] [A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=0.0769379):  37%|███▋      | 371/1000 [00:04<00:07, 84.63it/s][A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=0.0768209):  37%|███▋      | 371/1000 [00

ML (loss=0.0730879):  42%|████▎     | 425/1000 [00:05<00:06, 86.12it/s][A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=0.0726402):  42%|████▎     | 425/1000 [00:05<00:06, 86.12it/s][A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=0.0726402):  43%|████▎     | 434/1000 [00:05<00:06, 85.58it/s][A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=0.0726407):  43%|████▎     | 434/1000 [00:05<00:06, 85.58it/s][A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=0.0738985):  43%|████▎     | 434/1000 [00:05<00:06, 85.58it/s][A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=0.0800388):  43%|████▎     | 434/1000 [00:05<00:06, 85.58it/s][A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=0.108005):  43%|████▎     | 434/1000 [00:05<00:06, 85.58it/s] [A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=0.219098):  43%|████▎     | 434/1000 [00:05<00:06, 85.58it/s][A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=0.340194):  43%|████▎     | 434/1000 [00

ML (loss=0.0732445):  49%|████▉     | 488/1000 [00:05<00:05, 86.12it/s][A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=0.0727053):  49%|████▉     | 488/1000 [00:05<00:05, 86.12it/s][A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=0.0728961):  49%|████▉     | 488/1000 [00:05<00:05, 86.12it/s][A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=0.0728961):  50%|████▉     | 497/1000 [00:05<00:05, 86.33it/s][A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=0.0723473):  50%|████▉     | 497/1000 [00:05<00:05, 86.33it/s][A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=0.0721732):  50%|████▉     | 497/1000 [00:05<00:05, 86.33it/s][A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=0.0721286):  50%|████▉     | 497/1000 [00:05<00:05, 86.33it/s][A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=0.0716188):  50%|████▉     | 497/1000 [00:06<00:05, 86.33it/s][A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=0.0716056):  50%|████▉     | 497/1000 [

ML (loss=0.0661511):  55%|█████▌    | 551/1000 [00:06<00:05, 86.83it/s][A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=0.0643671):  55%|█████▌    | 551/1000 [00:06<00:05, 86.83it/s][A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=0.0644987):  55%|█████▌    | 551/1000 [00:06<00:05, 86.83it/s][A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=0.0647778):  55%|█████▌    | 551/1000 [00:06<00:05, 86.83it/s][A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=0.0646095):  55%|█████▌    | 551/1000 [00:06<00:05, 86.83it/s][A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=0.0646095):  56%|█████▌    | 560/1000 [00:06<00:05, 86.67it/s][A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=0.0635086):  56%|█████▌    | 560/1000 [00:06<00:05, 86.67it/s][A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=0.0634603):  56%|█████▌    | 560/1000 [00:06<00:05, 86.67it/s][A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=0.0638449):  56%|█████▌    | 560/1000 [

ML (loss=0.0577694):  61%|██████▏   | 614/1000 [00:07<00:04, 85.08it/s][A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=0.0585034):  61%|██████▏   | 614/1000 [00:07<00:04, 85.08it/s][A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=0.0582508):  61%|██████▏   | 614/1000 [00:07<00:04, 85.08it/s][A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=0.0582565):  61%|██████▏   | 614/1000 [00:07<00:04, 85.08it/s][A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=0.058433):  61%|██████▏   | 614/1000 [00:07<00:04, 85.08it/s] [A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=0.0561679):  61%|██████▏   | 614/1000 [00:07<00:04, 85.08it/s][A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=0.0567845):  61%|██████▏   | 614/1000 [00:07<00:04, 85.08it/s][A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=0.0567845):  62%|██████▏   | 623/1000 [00:07<00:04, 83.59it/s][A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=0.0563671):  62%|██████▏   | 623/1000 [

ML (loss=0.0525309):  67%|██████▋   | 672/1000 [00:08<00:05, 61.33it/s][A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=0.050121):  67%|██████▋   | 672/1000 [00:08<00:05, 61.33it/s] [A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=0.050121):  68%|██████▊   | 679/1000 [00:08<00:05, 56.22it/s][A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=0.0514547):  68%|██████▊   | 679/1000 [00:08<00:05, 56.22it/s][A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=0.0509047):  68%|██████▊   | 679/1000 [00:08<00:05, 56.22it/s][A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=0.0510553):  68%|██████▊   | 679/1000 [00:08<00:05, 56.22it/s][A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=0.0530821):  68%|██████▊   | 679/1000 [00:08<00:05, 56.22it/s][A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=0.0536661):  68%|██████▊   | 679/1000 [00:08<00:05, 56.22it/s][A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=0.0552054):  68%|██████▊   | 679/1000 [0

ML (loss=0.0448849):  73%|███████▎  | 733/1000 [00:09<00:06, 40.96it/s][A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=0.0445556):  73%|███████▎  | 733/1000 [00:09<00:06, 40.96it/s][A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=0.0443762):  73%|███████▎  | 733/1000 [00:09<00:06, 40.96it/s][A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=0.0437264):  73%|███████▎  | 733/1000 [00:09<00:06, 40.96it/s][A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=0.0437264):  74%|███████▍  | 738/1000 [00:09<00:06, 40.01it/s][A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=0.043189):  74%|███████▍  | 738/1000 [00:09<00:06, 40.01it/s] [A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=0.0429022):  74%|███████▍  | 738/1000 [00:09<00:06, 40.01it/s][A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=0.0428447):  74%|███████▍  | 738/1000 [00:09<00:06, 40.01it/s][A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=0.0429553):  74%|███████▍  | 738/1000 [

ML (loss=0.0561224):  79%|███████▊  | 787/1000 [00:11<00:06, 35.25it/s][A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=0.0461671):  79%|███████▊  | 787/1000 [00:11<00:06, 35.25it/s][A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=0.0461671):  79%|███████▉  | 791/1000 [00:11<00:05, 35.33it/s][A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=0.0554351):  79%|███████▉  | 791/1000 [00:11<00:05, 35.33it/s][A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=0.0477374):  79%|███████▉  | 791/1000 [00:11<00:05, 35.33it/s][A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=0.0495522):  79%|███████▉  | 791/1000 [00:11<00:05, 35.33it/s][A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=0.0458107):  79%|███████▉  | 791/1000 [00:11<00:05, 35.33it/s][A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=0.0458107):  80%|███████▉  | 795/1000 [00:11<00:05, 35.79it/s][A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=0.0473835):  80%|███████▉  | 795/1000 [

ML (loss=0.053186):  84%|████████▍ | 843/1000 [00:12<00:04, 37.70it/s][A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=0.0640646):  84%|████████▍ | 843/1000 [00:12<00:04, 37.70it/s][A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=0.0477526):  84%|████████▍ | 843/1000 [00:12<00:04, 37.70it/s][A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=0.0650712):  84%|████████▍ | 843/1000 [00:12<00:04, 37.70it/s][A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=0.0650712):  85%|████████▍ | 847/1000 [00:12<00:04, 37.94it/s][A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=0.0466114):  85%|████████▍ | 847/1000 [00:12<00:04, 37.94it/s][A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=0.0570299):  85%|████████▍ | 847/1000 [00:12<00:04, 37.94it/s][A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=0.0473153):  85%|████████▍ | 847/1000 [00:12<00:04, 37.94it/s][A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=0.0504263):  85%|████████▍ | 847/1000 [0

ML (loss=0.0483429):  90%|████████▉ | 896/1000 [00:14<00:02, 40.39it/s][A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=0.0343502):  90%|████████▉ | 896/1000 [00:14<00:02, 40.39it/s][A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=0.0343502):  90%|█████████ | 901/1000 [00:14<00:02, 40.58it/s][A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=0.038919):  90%|█████████ | 901/1000 [00:14<00:02, 40.58it/s] [A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=0.0432819):  90%|█████████ | 901/1000 [00:14<00:02, 40.58it/s][A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=0.0383162):  90%|█████████ | 901/1000 [00:14<00:02, 40.58it/s][A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=0.0367556):  90%|█████████ | 901/1000 [00:14<00:02, 40.58it/s][A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=0.0348454):  90%|█████████ | 901/1000 [00:14<00:02, 40.58it/s][A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=0.0348454):  91%|█████████ | 906/1000 [

ML (loss=0.0636683):  96%|█████████▌| 956/1000 [00:15<00:00, 45.81it/s][A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=0.0580013):  96%|█████████▌| 956/1000 [00:15<00:00, 45.81it/s][A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=0.0322691):  96%|█████████▌| 956/1000 [00:15<00:00, 45.81it/s][A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=0.0554992):  96%|█████████▌| 956/1000 [00:15<00:00, 45.81it/s][A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=0.0934089):  96%|█████████▌| 956/1000 [00:15<00:00, 45.81it/s][A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=0.0446919):  96%|█████████▌| 956/1000 [00:15<00:00, 45.81it/s][A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=0.0446919):  96%|█████████▌| 961/1000 [00:15<00:00, 46.02it/s][A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=0.0683638):  96%|█████████▌| 961/1000 [00:15<00:00, 46.02it/s][A[A[A[A[A[A[A[A[A[A[A[A











ML (loss=0.0555941):  96%|█████████▌| 961/1000 [

Finished Training





In [94]:
train_out = net(torch.from_numpy(X_test).float())
output_labels = np.argmax(train_out.data.numpy(), axis=1)
test_labels = np.argmax(y_test, axis=1)

In [95]:
[ x == y for x, y in zip(output_labels, test_labels) ].count(1) / len(output_labels)

0.725