In [1]:
from numpy import vstack
from pandas import read_csv
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import accuracy_score
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.utils.data import random_split
from torch import Tensor
from torch.nn import Linear
from torch.nn import ReLU
from torch.nn import Sigmoid
from torch.nn import Module
from torch.optim import SGD
from torch.nn import BCELoss
from torch.nn.init import kaiming_uniform_
from torch.nn.init import xavier_uniform_

In [2]:
# dataset definition
class CSVDataset(Dataset):
    # load the dataset
    def __init__(self, path):
        # load the csv file as a dataframe
        df = read_csv(path)
        # store the inputs and outputs
        self.X = df.values[:, :-1]
        self.y = df.values[:, -1]
        # ensure input data is floats
        self.X = self.X.astype('float32')
        # label encode target and ensure the values are floats
        self.y = LabelEncoder().fit_transform(self.y)
        self.y = self.y.astype('float32')
        self.y = self.y.reshape((len(self.y), 1))

    # number of rows in the dataset
    def __len__(self):
        return len(self.X)

    # get a row at an index
    def __getitem__(self, idx):
        return [self.X[idx], self.y[idx]]

    # get indexes for train and test rows
    def get_splits(self, n_test=0.33):
        # determine sizes
        test_size = round(n_test * len(self.X))
        train_size = len(self.X) - test_size
        # calculate the split
        return random_split(self, [train_size, test_size])

In [3]:
# model definition
class MLP(Module):
    # define model elements
    def __init__(self, n_inputs):
        super(MLP, self).__init__()
        # input to first hidden layer
        self.hidden1 = Linear(n_inputs, 10)
        kaiming_uniform_(self.hidden1.weight, nonlinearity='relu')
        self.act1 = ReLU()
        # second hidden layer
        self.hidden2 = Linear(10, 8)
        kaiming_uniform_(self.hidden2.weight, nonlinearity='relu')
        self.act2 = ReLU()
        # third hidden layer and output
        self.hidden3 = Linear(8, 1)
        xavier_uniform_(self.hidden3.weight)
        self.act3 = Sigmoid()

    # forward propagate input
    def forward(self, X):
        # input to first hidden layer
        X = self.hidden1(X)
        X = self.act1(X)
         # second hidden layer
        X = self.hidden2(X)
        X = self.act2(X)
        # third hidden layer and output
        X = self.hidden3(X)
        X = self.act3(X)
        return X

In [4]:
# prepare the dataset
def prepare_data(train_path,test_path = None):
    if test_path == None:
        dataset = CSVDataset(train_path)
        train, test = dataset.get_splits()
    else:
        train = CSVDataset(train_path)
        test = CSVDataset(test_path)

    train_dl = DataLoader(train, batch_size=32, shuffle=True)
    test_dl = DataLoader(test, batch_size=1024, shuffle=False)
    return train_dl, test_dl

In [5]:
# train the model
def train_model(train_dl, model):
    # define the optimization
    criterion = BCELoss()
    optimizer = SGD(model.parameters(), lr=0.01, momentum=0.9)
    # enumerate epochs
    for epoch in range(100):
        # enumerate mini batches
        for i, (inputs, targets) in enumerate(train_dl):
            # clear the gradients
            optimizer.zero_grad()
            # compute the model output
            yhat = model(inputs)
            # calculate loss
            loss = criterion(yhat, targets)
            print(f'epoch {epoch}: {loss} loss')
            # credit assignment
            loss.backward()
            # update model weights
            optimizer.step()

In [6]:
# evaluate the model
def evaluate_model(test_dl, model):
    predictions, actuals = list(), list()
    for i, (inputs, targets) in enumerate(test_dl):
        # evaluate the model on the test set
        yhat = model(inputs)
        # retrieve numpy array
        yhat = yhat.detach().numpy()
        actual = targets.numpy()
        actual = actual.reshape((len(actual), 1))
        # round to class values
        yhat = yhat.round()
        # store
        predictions.append(yhat)
        actuals.append(actual)
    predictions, actuals = vstack(predictions), vstack(actuals)

    # calculate accuracy
    acc = accuracy_score(actuals, predictions)
    return predictions, actuals, acc

# make a class prediction for one row of data
def predict(row, model):
    # convert row to data
    row = Tensor([row])
    # make prediction
    yhat = model(row)
    # retrieve numpy array
    yhat = yhat.detach().numpy()
    return yhat

In [7]:
# load data

train_path = '../data/5_train_dataset.csv'
test_path = '../data/4_test_dataset.csv'

train_dl, test_dl = prepare_data(train_path, test_path)
print(len(train_dl.dataset), len(test_dl.dataset))

n_features = len(train_dl.dataset.X[0])
n_features

5889 3927


21

In [8]:
# define the network
model = MLP(n_features)
# train the model
train_model(train_dl, model)


epoch 0: 4471.98193359375 loss
epoch 0: 7365.65966796875 loss
epoch 0: 4106.443359375 loss
epoch 0: 4981.37353515625 loss
epoch 0: 6582.47412109375 loss
epoch 0: 3613.3212890625 loss
epoch 0: 7559.375 loss
epoch 0: 4846.875 loss
epoch 0: 4521.875 loss
epoch 0: 1953.125 loss
epoch 0: 7984.375 loss
epoch 0: 4012.5 loss
epoch 0: 6406.25 loss
epoch 0: 115.625 loss
epoch 0: 2015.625 loss
epoch 0: 6062.5 loss
epoch 0: 3178.125 loss
epoch 0: 5790.625 loss
epoch 0: 3468.75 loss
epoch 0: 7418.75 loss
epoch 0: 1625.0 loss
epoch 0: 3209.375 loss
epoch 0: 4065.625 loss
epoch 0: 4687.5 loss
epoch 0: 2275.0 loss
epoch 0: 6250.0 loss
epoch 0: 3515.625 loss
epoch 0: 4990.625 loss
epoch 0: 5918.75 loss
epoch 0: 10865.625 loss
epoch 0: 4778.125 loss
epoch 0: 5393.75 loss
epoch 0: 9393.75 loss
epoch 0: 4375.0 loss
epoch 0: 4653.125 loss
epoch 0: 10481.25 loss
epoch 0: 7334.375 loss
epoch 0: 10287.5 loss
epoch 0: 6081.25 loss
epoch 0: 3528.125 loss
epoch 0: 6931.25 loss
epoch 0: 8981.25 loss
epoch 0: 7765

epoch 3: 7531.25 loss
epoch 3: 3690.625 loss
epoch 3: 2553.125 loss
epoch 3: 5206.25 loss
epoch 3: 5915.625 loss
epoch 3: 6068.75 loss
epoch 3: 4559.375 loss
epoch 3: 6896.875 loss
epoch 3: 8546.875 loss
epoch 3: 5534.375 loss
epoch 3: 750.0 loss
epoch 3: 5071.875 loss
epoch 3: 7084.375 loss
epoch 3: 2665.625 loss
epoch 3: 6715.625 loss
epoch 3: 8362.5 loss
epoch 3: 4678.125 loss
epoch 3: 7825.0 loss
epoch 3: 8559.375 loss
epoch 3: 896.875 loss
epoch 3: 3740.625 loss
epoch 3: 7490.625 loss
epoch 3: 8821.875 loss
epoch 3: 6093.75 loss
epoch 3: 5381.25 loss
epoch 3: 4500.0 loss
epoch 3: 5203.125 loss
epoch 3: 7296.875 loss
epoch 3: 6946.875 loss
epoch 3: 6096.875 loss
epoch 3: 6743.75 loss
epoch 3: 6950.0 loss
epoch 3: 4059.375 loss
epoch 3: 8962.5 loss
epoch 3: 5493.75 loss
epoch 3: 12925.0 loss
epoch 3: 1553.125 loss
epoch 3: 5146.875 loss
epoch 3: 5043.75 loss
epoch 3: 9484.375 loss
epoch 3: 4015.625 loss
epoch 3: 2415.625 loss
epoch 3: 6756.25 loss
epoch 3: 7343.75 loss
epoch 3: 5762

epoch 6: 1475.0 loss
epoch 6: 9253.125 loss
epoch 6: 787.5 loss
epoch 6: 12343.75 loss
epoch 6: 1468.75 loss
epoch 6: 4162.5 loss
epoch 6: 4650.0 loss
epoch 6: 3990.625 loss
epoch 6: 9390.625 loss
epoch 6: 2265.625 loss
epoch 6: 5978.125 loss
epoch 6: 4184.375 loss
epoch 6: 5571.875 loss
epoch 6: 5140.625 loss
epoch 6: 3415.625 loss
epoch 6: 5712.5 loss
epoch 6: 4637.5 loss
epoch 6: 1600.0 loss
epoch 6: 4540.625 loss
epoch 6: 2940.625 loss
epoch 6: 3459.375 loss
epoch 6: 8987.5 loss
epoch 6: 3109.375 loss
epoch 6: 9771.875 loss
epoch 6: 10293.75 loss
epoch 6: 7568.75 loss
epoch 6: 4703.125 loss
epoch 6: 6075.0 loss
epoch 6: 4493.75 loss
epoch 6: 6103.125 loss
epoch 6: 3796.875 loss
epoch 6: 2715.625 loss
epoch 6: 2693.75 loss
epoch 6: 7537.5 loss
epoch 6: 2578.125 loss
epoch 6: 3025.0 loss
epoch 6: 7556.25 loss
epoch 6: 5665.625 loss
epoch 6: 4662.5 loss
epoch 6: 6306.25 loss
epoch 6: 4956.25 loss
epoch 6: 6956.25 loss
epoch 6: 4737.5 loss
epoch 6: 4840.625 loss
epoch 6: 2156.25 loss
e

epoch 10: 6590.625 loss
epoch 10: 6321.875 loss
epoch 10: 4659.375 loss
epoch 10: 4518.75 loss
epoch 10: 3496.875 loss
epoch 10: 7334.375 loss
epoch 10: 5646.875 loss
epoch 10: 6053.125 loss
epoch 10: 7665.625 loss
epoch 10: 5381.25 loss
epoch 10: 7415.625 loss
epoch 10: 7000.0 loss
epoch 10: 6690.625 loss
epoch 10: 2178.125 loss
epoch 10: 6881.25 loss
epoch 10: 3653.125 loss
epoch 10: 3803.125 loss
epoch 10: 4581.25 loss
epoch 10: 4150.0 loss
epoch 10: 4381.25 loss
epoch 10: 3303.125 loss
epoch 10: 5418.75 loss
epoch 10: 5746.875 loss
epoch 10: 2378.125 loss
epoch 10: 8850.0 loss
epoch 10: 4421.875 loss
epoch 10: 9900.0 loss
epoch 10: 5718.75 loss
epoch 10: 6909.375 loss
epoch 10: 4568.75 loss
epoch 10: 2115.625 loss
epoch 10: 5290.625 loss
epoch 10: 4037.5 loss
epoch 10: 2528.125 loss
epoch 10: 8840.625 loss
epoch 10: 3993.75 loss
epoch 10: 2521.875 loss
epoch 10: 2168.75 loss
epoch 10: 4484.375 loss
epoch 10: 4596.875 loss
epoch 10: 8700.0 loss
epoch 10: 3081.25 loss
epoch 10: 10631

epoch 12: 8118.75 loss
epoch 12: 9146.875 loss
epoch 12: 8793.75 loss
epoch 12: 8043.75 loss
epoch 12: 4581.25 loss
epoch 12: 6109.375 loss
epoch 12: 3462.5 loss
epoch 12: 1125.0 loss
epoch 12: 4140.625 loss
epoch 12: 4578.125 loss
epoch 12: 3618.75 loss
epoch 12: 4134.375 loss
epoch 12: 131.25 loss
epoch 12: 2600.0 loss
epoch 12: 6878.125 loss
epoch 12: 6946.875 loss
epoch 12: 3309.375 loss
epoch 12: 4853.125 loss
epoch 12: 5696.875 loss
epoch 12: 6459.375 loss
epoch 12: 2675.0 loss
epoch 12: 1484.375 loss
epoch 12: 11859.375 loss
epoch 12: 7743.75 loss
epoch 12: 10675.0 loss
epoch 12: 4740.625 loss
epoch 12: 5325.0 loss
epoch 12: 6884.375 loss
epoch 12: 5934.375 loss
epoch 12: 7146.875 loss
epoch 12: 6218.75 loss
epoch 12: 6303.125 loss
epoch 12: 5081.25 loss
epoch 12: 8637.5 loss
epoch 12: 2618.75 loss
epoch 12: 8318.75 loss
epoch 12: 7603.125 loss
epoch 12: 2862.5 loss
epoch 12: 5765.625 loss
epoch 12: 2315.625 loss
epoch 12: 4409.375 loss
epoch 12: 8425.0 loss
epoch 12: 8568.75 lo

epoch 14: 3475.0 loss
epoch 14: 4721.875 loss
epoch 14: 3759.375 loss
epoch 14: 2265.625 loss
epoch 14: 2778.125 loss
epoch 14: 1290.625 loss
epoch 14: 9046.875 loss
epoch 14: 10865.625 loss
epoch 14: 1943.75 loss
epoch 14: 5143.75 loss
epoch 14: 2259.375 loss
epoch 14: 6912.5 loss
epoch 14: 5303.125 loss
epoch 14: 8656.25 loss
epoch 14: 3646.875 loss
epoch 14: 10762.5 loss
epoch 14: 7878.125 loss
epoch 14: 7125.0 loss
epoch 14: 4871.875 loss
epoch 14: 3937.5 loss
epoch 14: 1481.25 loss
epoch 14: 6881.25 loss
epoch 14: 5378.125 loss
epoch 14: 12896.875 loss
epoch 14: 4412.5 loss
epoch 14: 6618.75 loss
epoch 14: 4796.875 loss
epoch 14: 4221.875 loss
epoch 14: 6515.625 loss
epoch 14: 4000.0 loss
epoch 14: 7275.0 loss
epoch 14: 3453.125 loss
epoch 14: 6278.125 loss
epoch 14: 2021.875 loss
epoch 14: 5012.5 loss
epoch 14: 1540.625 loss
epoch 14: 5165.625 loss
epoch 14: 4490.625 loss
epoch 14: 4528.125 loss
epoch 14: 3718.75 loss
epoch 14: 2331.25 loss
epoch 14: 3268.75 loss
epoch 14: 3700.0

epoch 16: 6100.0 loss
epoch 16: 3831.25 loss
epoch 16: 7925.0 loss
epoch 16: 7434.375 loss
epoch 16: 346.875 loss
epoch 16: 3543.75 loss
epoch 16: 9675.0 loss
epoch 16: 2837.5 loss
epoch 16: 11081.25 loss
epoch 16: 2328.125 loss
epoch 16: 4693.75 loss
epoch 16: 5312.5 loss
epoch 16: 6168.75 loss
epoch 16: 10818.75 loss
epoch 16: 6121.875 loss
epoch 16: 3809.375 loss
epoch 16: 11434.375 loss
epoch 16: 3368.75 loss
epoch 16: 5075.0 loss
epoch 16: 2765.625 loss
epoch 16: 9537.5 loss
epoch 16: 2250.0 loss
epoch 16: 3009.375 loss
epoch 16: 2706.25 loss
epoch 16: 5775.0 loss
epoch 16: 8940.625 loss
epoch 16: 3306.25 loss
epoch 16: 4540.625 loss
epoch 16: 7303.125 loss
epoch 16: 4637.5 loss
epoch 16: 2459.375 loss
epoch 16: 4153.125 loss
epoch 16: 2393.75 loss
epoch 16: 3718.75 loss
epoch 16: 1896.875 loss
epoch 16: 4575.0 loss
epoch 16: 3250.0 loss
epoch 16: 7978.125 loss
epoch 16: 6209.375 loss
epoch 16: 2390.625 loss
epoch 16: 4831.25 loss
epoch 16: 5215.625 loss
epoch 16: 3896.875 loss
ep

epoch 20: 8412.5 loss
epoch 20: 1565.625 loss
epoch 20: 4265.625 loss
epoch 20: 1375.0 loss
epoch 20: 3043.75 loss
epoch 20: 4246.875 loss
epoch 20: 3568.75 loss
epoch 20: 5359.375 loss
epoch 20: 4628.125 loss
epoch 20: 7440.625 loss
epoch 20: 4100.0 loss
epoch 20: 2009.375 loss
epoch 20: 4131.25 loss
epoch 20: 6643.75 loss
epoch 20: 9103.125 loss
epoch 20: 3140.625 loss
epoch 20: 6112.5 loss
epoch 20: 7215.625 loss
epoch 20: 2490.625 loss
epoch 20: 6381.25 loss
epoch 20: 7575.0 loss
epoch 20: 5806.25 loss
epoch 20: 2568.75 loss
epoch 20: 8309.375 loss
epoch 20: 5071.875 loss
epoch 20: 4143.75 loss
epoch 20: 6184.375 loss
epoch 20: 8143.75 loss
epoch 20: 8315.625 loss
epoch 20: 8115.625 loss
epoch 20: 0.0 loss
epoch 20: 343.75 loss
epoch 20: 7037.5 loss
epoch 20: 2156.25 loss
epoch 20: 7615.625 loss
epoch 20: 2340.625 loss
epoch 20: 6575.0 loss
epoch 20: 3725.0 loss
epoch 20: 1259.375 loss
epoch 20: 6612.5 loss
epoch 20: 2740.625 loss
epoch 20: 8462.5 loss
epoch 20: 7734.375 loss
epoch

epoch 22: 3756.25 loss
epoch 22: 5956.25 loss
epoch 22: 11381.25 loss
epoch 22: 5581.25 loss
epoch 22: 6753.125 loss
epoch 22: 1334.375 loss
epoch 22: 2621.875 loss
epoch 22: 7193.75 loss
epoch 22: 6621.875 loss
epoch 22: 5825.0 loss
epoch 22: 5928.125 loss
epoch 22: 3637.5 loss
epoch 22: 5596.875 loss
epoch 22: 4509.375 loss
epoch 22: 3693.75 loss
epoch 22: 3546.875 loss
epoch 22: 4112.5 loss
epoch 22: 1462.5 loss
epoch 22: 5693.75 loss
epoch 22: 6950.0 loss
epoch 22: 1328.125 loss
epoch 22: 0.0 loss
epoch 22: 3740.625 loss
epoch 22: 3718.75 loss
epoch 22: 1934.375 loss
epoch 22: 3440.625 loss
epoch 22: 11540.625 loss
epoch 22: 6993.75 loss
epoch 22: 7325.0 loss
epoch 22: 4631.25 loss
epoch 22: 4290.625 loss
epoch 22: 0.0 loss
epoch 23: 4800.0 loss
epoch 23: 4603.125 loss
epoch 23: 5121.875 loss
epoch 23: 5800.0 loss
epoch 23: 3475.0 loss
epoch 23: 1640.625 loss
epoch 23: 2090.625 loss
epoch 23: 1734.375 loss
epoch 23: 10571.875 loss
epoch 23: 1853.125 loss
epoch 23: 5871.875 loss
epo

epoch 24: 4690.625 loss
epoch 24: 8462.5 loss
epoch 24: 1171.875 loss
epoch 24: 5434.375 loss
epoch 24: 4512.5 loss
epoch 24: 4456.25 loss
epoch 24: 2625.0 loss
epoch 24: 5312.5 loss
epoch 24: 4956.25 loss
epoch 24: 1900.0 loss
epoch 24: 4625.0 loss
epoch 24: 8859.375 loss
epoch 24: 14759.375 loss
epoch 24: 3334.375 loss
epoch 24: 3103.125 loss
epoch 24: 9031.25 loss
epoch 24: 7368.75 loss
epoch 24: 1390.625 loss
epoch 24: 3965.625 loss
epoch 24: 2459.375 loss
epoch 24: 2690.625 loss
epoch 24: 3690.625 loss
epoch 24: 6200.0 loss
epoch 24: 1603.125 loss
epoch 24: 8884.375 loss
epoch 24: 5368.75 loss
epoch 24: 8800.0 loss
epoch 24: 6750.0 loss
epoch 24: 1893.75 loss
epoch 24: 4303.125 loss
epoch 24: 4700.0 loss
epoch 24: 3228.125 loss
epoch 24: 7581.25 loss
epoch 24: 69200.0 loss
epoch 25: 8725.0 loss
epoch 25: 2943.75 loss
epoch 25: 8703.125 loss
epoch 25: 5384.375 loss
epoch 25: 3087.5 loss
epoch 25: 3284.375 loss
epoch 25: 10125.0 loss
epoch 25: 3806.25 loss
epoch 25: 1896.875 loss
ep

epoch 28: 2543.75 loss
epoch 28: 9356.25 loss
epoch 28: 3253.125 loss
epoch 28: 5368.75 loss
epoch 28: 5068.75 loss
epoch 28: 5443.75 loss
epoch 28: 3456.25 loss
epoch 28: 5003.125 loss
epoch 28: 3493.75 loss
epoch 28: 5446.875 loss
epoch 28: 6668.75 loss
epoch 28: 6790.625 loss
epoch 28: 4284.375 loss
epoch 28: 5109.375 loss
epoch 28: 7240.625 loss
epoch 28: 5371.875 loss
epoch 28: 9903.125 loss
epoch 28: 4731.25 loss
epoch 28: 3759.375 loss
epoch 28: 328.125 loss
epoch 28: 2853.125 loss
epoch 28: 6115.625 loss
epoch 28: 2875.0 loss
epoch 28: 4512.5 loss
epoch 28: 5553.125 loss
epoch 28: 9256.25 loss
epoch 28: 4590.625 loss
epoch 28: 6025.0 loss
epoch 28: 2450.0 loss
epoch 28: 0.0 loss
epoch 28: 1793.75 loss
epoch 28: 1278.125 loss
epoch 28: 4653.125 loss
epoch 28: 3887.5 loss
epoch 28: 4409.375 loss
epoch 28: 3640.625 loss
epoch 28: 4059.375 loss
epoch 28: 2559.375 loss
epoch 28: 2337.5 loss
epoch 28: 2896.875 loss
epoch 28: 5250.0 loss
epoch 28: 6453.125 loss
epoch 28: 12343.75 loss

epoch 30: 1284.375 loss
epoch 30: 5756.25 loss
epoch 30: 6356.25 loss
epoch 30: 3718.75 loss
epoch 30: 1356.25 loss
epoch 30: 7278.125 loss
epoch 30: 4953.125 loss
epoch 30: 9431.25 loss
epoch 30: 8496.875 loss
epoch 30: 2859.375 loss
epoch 30: 3993.75 loss
epoch 30: 6846.875 loss
epoch 30: 6103.125 loss
epoch 30: 3478.125 loss
epoch 30: 10025.0 loss
epoch 30: 6178.125 loss
epoch 30: 6309.375 loss
epoch 30: 1909.375 loss
epoch 30: 2618.75 loss
epoch 30: 3143.75 loss
epoch 30: 9809.375 loss
epoch 30: 4690.625 loss
epoch 30: 1375.0 loss
epoch 30: 4053.125 loss
epoch 30: 1571.875 loss
epoch 30: 3934.375 loss
epoch 30: 1490.625 loss
epoch 30: 7821.875 loss
epoch 30: 3975.0 loss
epoch 30: 9784.375 loss
epoch 30: 2437.5 loss
epoch 30: 9059.375 loss
epoch 30: 4915.625 loss
epoch 30: 3062.5 loss
epoch 30: 3009.375 loss
epoch 30: 1643.75 loss
epoch 30: 1303.125 loss
epoch 30: 1425.0 loss
epoch 30: 5071.875 loss
epoch 30: 4065.625 loss
epoch 30: 8809.375 loss
epoch 30: 4075.0 loss
epoch 30: 3975

epoch 33: 5509.375 loss
epoch 33: 3878.125 loss
epoch 33: 4784.375 loss
epoch 33: 1656.25 loss
epoch 33: 7528.125 loss
epoch 33: 2225.0 loss
epoch 33: 9018.75 loss
epoch 33: 8409.375 loss
epoch 33: 5534.375 loss
epoch 33: 10465.625 loss
epoch 33: 2890.625 loss
epoch 33: 7315.625 loss
epoch 33: 5068.75 loss
epoch 33: 893.75 loss
epoch 33: 6478.125 loss
epoch 33: 6662.5 loss
epoch 33: 3578.125 loss
epoch 33: 5656.25 loss
epoch 33: 4134.375 loss
epoch 33: 3296.875 loss
epoch 33: 4484.375 loss
epoch 33: 6434.375 loss
epoch 33: 4715.625 loss
epoch 33: 5868.75 loss
epoch 33: 3890.625 loss
epoch 33: 5596.875 loss
epoch 33: 6581.25 loss
epoch 33: 1937.5 loss
epoch 33: 5071.875 loss
epoch 33: 3590.625 loss
epoch 33: 10637.5 loss
epoch 33: 2193.75 loss
epoch 33: 3450.0 loss
epoch 33: 4328.125 loss
epoch 33: 3621.875 loss
epoch 33: 6975.0 loss
epoch 33: 5996.875 loss
epoch 33: 5437.5 loss
epoch 33: 5778.125 loss
epoch 33: 7328.125 loss
epoch 33: 5665.625 loss
epoch 33: 11837.5 loss
epoch 33: 3825

epoch 37: 5493.75 loss
epoch 37: 10000.0 loss
epoch 37: 2271.875 loss
epoch 37: 2556.25 loss
epoch 37: 3503.125 loss
epoch 37: 6284.375 loss
epoch 37: 6659.375 loss
epoch 37: 2821.875 loss
epoch 37: 4196.875 loss
epoch 37: 1418.75 loss
epoch 37: 4159.375 loss
epoch 37: 5737.5 loss
epoch 37: 2821.875 loss
epoch 37: 778.125 loss
epoch 37: 4071.875 loss
epoch 37: 4687.5 loss
epoch 37: 4950.0 loss
epoch 37: 4565.625 loss
epoch 37: 6778.125 loss
epoch 37: 6175.0 loss
epoch 37: 8325.0 loss
epoch 37: 5712.5 loss
epoch 37: 9265.625 loss
epoch 37: 8006.25 loss
epoch 37: 11312.5 loss
epoch 37: 4446.875 loss
epoch 37: 5287.5 loss
epoch 37: 11187.5 loss
epoch 37: 5565.625 loss
epoch 37: 1696.875 loss
epoch 37: 6509.375 loss
epoch 37: 2890.625 loss
epoch 37: 2625.0 loss
epoch 37: 8562.5 loss
epoch 37: 4562.5 loss
epoch 37: 8781.25 loss
epoch 37: 4771.875 loss
epoch 37: 5937.5 loss
epoch 37: 4309.375 loss
epoch 37: 3271.875 loss
epoch 37: 3675.0 loss
epoch 37: 7865.625 loss
epoch 37: 4518.75 loss
ep

epoch 39: 5021.875 loss
epoch 39: 5084.375 loss
epoch 39: 1253.125 loss
epoch 39: 3043.75 loss
epoch 39: 7043.75 loss
epoch 39: 6096.875 loss
epoch 39: 3890.625 loss
epoch 39: 2515.625 loss
epoch 39: 4412.5 loss
epoch 39: 2893.75 loss
epoch 39: 8503.125 loss
epoch 39: 3787.5 loss
epoch 39: 2896.875 loss
epoch 39: 3581.25 loss
epoch 39: 6546.875 loss
epoch 39: 4346.875 loss
epoch 39: 0.0 loss
epoch 40: 4893.75 loss
epoch 40: 5875.0 loss
epoch 40: 3600.0 loss
epoch 40: 8387.5 loss
epoch 40: 4037.5 loss
epoch 40: 8356.25 loss
epoch 40: 3865.625 loss
epoch 40: 3046.875 loss
epoch 40: 3403.125 loss
epoch 40: 2134.375 loss
epoch 40: 5796.875 loss
epoch 40: 5421.875 loss
epoch 40: 3534.375 loss
epoch 40: 1743.75 loss
epoch 40: 3603.125 loss
epoch 40: 2196.875 loss
epoch 40: 4828.125 loss
epoch 40: 675.0 loss
epoch 40: 4290.625 loss
epoch 40: 10184.375 loss
epoch 40: 3396.875 loss
epoch 40: 4281.25 loss
epoch 40: 6606.25 loss
epoch 40: 5712.5 loss
epoch 40: 4303.125 loss
epoch 40: 15546.875 lo

epoch 42: 2865.625 loss
epoch 42: 5515.625 loss
epoch 42: 12487.5 loss
epoch 42: 5146.875 loss
epoch 42: 4690.625 loss
epoch 42: 3509.375 loss
epoch 42: 7862.5 loss
epoch 42: 4243.75 loss
epoch 42: 7400.0 loss
epoch 42: 10384.375 loss
epoch 42: 0.0 loss
epoch 43: 5156.25 loss
epoch 43: 4440.625 loss
epoch 43: 5415.625 loss
epoch 43: 2165.625 loss
epoch 43: 4590.625 loss
epoch 43: 3040.625 loss
epoch 43: 5309.375 loss
epoch 43: 3784.375 loss
epoch 43: 6425.0 loss
epoch 43: 4100.0 loss
epoch 43: 4400.0 loss
epoch 43: 6087.5 loss
epoch 43: 6325.0 loss
epoch 43: 9618.75 loss
epoch 43: 5156.25 loss
epoch 43: 806.25 loss
epoch 43: 6703.125 loss
epoch 43: 7415.625 loss
epoch 43: 1387.5 loss
epoch 43: 4743.75 loss
epoch 43: 6146.875 loss
epoch 43: 4178.125 loss
epoch 43: 6259.375 loss
epoch 43: 2753.125 loss
epoch 43: 6818.75 loss
epoch 43: 5434.375 loss
epoch 43: 2528.125 loss
epoch 43: 8925.0 loss
epoch 43: 6343.75 loss
epoch 43: 6375.0 loss
epoch 43: 2137.5 loss
epoch 43: 5215.625 loss
epoc

epoch 46: 10215.625 loss
epoch 46: 7446.875 loss
epoch 46: 5659.375 loss
epoch 46: 7068.75 loss
epoch 46: 4006.25 loss
epoch 46: 2850.0 loss
epoch 46: 9146.875 loss
epoch 46: 3234.375 loss
epoch 46: 4746.875 loss
epoch 46: 1587.5 loss
epoch 46: 3503.125 loss
epoch 46: 5078.125 loss
epoch 46: 1159.375 loss
epoch 46: 7206.25 loss
epoch 46: 12168.75 loss
epoch 46: 2825.0 loss
epoch 46: 4759.375 loss
epoch 46: 8071.875 loss
epoch 46: 7368.75 loss
epoch 46: 7878.125 loss
epoch 46: 4246.875 loss
epoch 46: 6665.625 loss
epoch 46: 9106.25 loss
epoch 46: 5443.75 loss
epoch 46: 5890.625 loss
epoch 46: 9165.625 loss
epoch 46: 5218.75 loss
epoch 46: 5078.125 loss
epoch 46: 5471.875 loss
epoch 46: 7646.875 loss
epoch 46: 9718.75 loss
epoch 46: 6512.5 loss
epoch 46: 4293.75 loss
epoch 46: 3709.375 loss
epoch 46: 3531.25 loss
epoch 46: 3053.125 loss
epoch 46: 3615.625 loss
epoch 46: 990.625 loss
epoch 46: 7587.5 loss
epoch 46: 5406.25 loss
epoch 46: 4318.75 loss
epoch 46: 2887.5 loss
epoch 46: 4553.1

epoch 49: 3368.75 loss
epoch 49: 2018.75 loss
epoch 49: 3384.375 loss
epoch 49: 4334.375 loss
epoch 49: 331.25 loss
epoch 49: 4778.125 loss
epoch 49: 8609.375 loss
epoch 49: 2378.125 loss
epoch 49: 11534.375 loss
epoch 49: 1103.125 loss
epoch 49: 2346.875 loss
epoch 49: 5496.875 loss
epoch 49: 3756.25 loss
epoch 49: 725.0 loss
epoch 49: 10396.875 loss
epoch 49: 4871.875 loss
epoch 49: 6921.875 loss
epoch 49: 8150.0 loss
epoch 49: 7168.75 loss
epoch 49: 3603.125 loss
epoch 49: 9090.625 loss
epoch 49: 5825.0 loss
epoch 49: 375.0 loss
epoch 49: 3637.5 loss
epoch 49: 9771.875 loss
epoch 49: 5925.0 loss
epoch 49: 4793.75 loss
epoch 49: 2021.875 loss
epoch 49: 6178.125 loss
epoch 49: 6956.25 loss
epoch 49: 1956.25 loss
epoch 49: 7868.75 loss
epoch 49: 2259.375 loss
epoch 49: 0.0 loss
epoch 50: 3137.5 loss
epoch 50: 6559.375 loss
epoch 50: 8206.25 loss
epoch 50: 5690.625 loss
epoch 50: 7168.75 loss
epoch 50: 3971.875 loss
epoch 50: 7625.0 loss
epoch 50: 6571.875 loss
epoch 50: 8634.375 loss
e

epoch 51: 4003.125 loss
epoch 51: 6690.625 loss
epoch 51: 5078.125 loss
epoch 51: 0.0 loss
epoch 52: 7100.0 loss
epoch 52: 11215.625 loss
epoch 52: 2909.375 loss
epoch 52: 5662.5 loss
epoch 52: 8509.375 loss
epoch 52: 4634.375 loss
epoch 52: 4168.75 loss
epoch 52: 5437.5 loss
epoch 52: 9962.5 loss
epoch 52: 7475.0 loss
epoch 52: 5918.75 loss
epoch 52: 10487.5 loss
epoch 52: 4934.375 loss
epoch 52: 4534.375 loss
epoch 52: 6903.125 loss
epoch 52: 8165.625 loss
epoch 52: 3009.375 loss
epoch 52: 11775.0 loss
epoch 52: 3921.875 loss
epoch 52: 3296.875 loss
epoch 52: 8259.375 loss
epoch 52: 8334.375 loss
epoch 52: 8368.75 loss
epoch 52: 4906.25 loss
epoch 52: 5671.875 loss
epoch 52: 13140.625 loss
epoch 52: 5625.0 loss
epoch 52: 3356.25 loss
epoch 52: 9287.5 loss
epoch 52: 2640.625 loss
epoch 52: 5596.875 loss
epoch 52: 6521.875 loss
epoch 52: 5596.875 loss
epoch 52: 6437.5 loss
epoch 52: 6043.75 loss
epoch 52: 4584.375 loss
epoch 52: 5756.25 loss
epoch 52: 2928.125 loss
epoch 52: 3062.5 los

epoch 54: 7153.125 loss
epoch 54: 2203.125 loss
epoch 54: 6043.75 loss
epoch 54: 4609.375 loss
epoch 54: 9050.0 loss
epoch 54: 5521.875 loss
epoch 54: 11071.875 loss
epoch 54: 6025.0 loss
epoch 54: 6859.375 loss
epoch 54: 3281.25 loss
epoch 54: 6184.375 loss
epoch 54: 2853.125 loss
epoch 54: 6753.125 loss
epoch 54: 4234.375 loss
epoch 54: 4750.0 loss
epoch 54: 968.75 loss
epoch 54: 2318.75 loss
epoch 54: 2771.875 loss
epoch 54: 5340.625 loss
epoch 54: 9768.75 loss
epoch 54: 5400.0 loss
epoch 54: 3115.625 loss
epoch 54: 12537.5 loss
epoch 54: 1353.125 loss
epoch 54: 4534.375 loss
epoch 54: 8428.125 loss
epoch 54: 1671.875 loss
epoch 54: 4250.0 loss
epoch 54: 750.0 loss
epoch 54: 4493.75 loss
epoch 54: 12525.0 loss
epoch 54: 3928.125 loss
epoch 54: 3365.625 loss
epoch 54: 7037.5 loss
epoch 54: 878.125 loss
epoch 54: 968.75 loss
epoch 54: 2434.375 loss
epoch 54: 8012.5 loss
epoch 54: 4631.25 loss
epoch 54: 3503.125 loss
epoch 54: 8756.25 loss
epoch 54: 4587.5 loss
epoch 54: 5700.0 loss
ep

epoch 58: 6796.875 loss
epoch 58: 8828.125 loss
epoch 58: 10546.875 loss
epoch 58: 8315.625 loss
epoch 58: 7543.75 loss
epoch 58: 6246.875 loss
epoch 58: 8359.375 loss
epoch 58: 3909.375 loss
epoch 58: 8915.625 loss
epoch 58: 3375.0 loss
epoch 58: 2731.25 loss
epoch 58: 4850.0 loss
epoch 58: 10196.875 loss
epoch 58: 7106.25 loss
epoch 58: 4996.875 loss
epoch 58: 4396.875 loss
epoch 58: 1240.625 loss
epoch 58: 5315.625 loss
epoch 58: 4343.75 loss
epoch 58: 8187.5 loss
epoch 58: 9056.25 loss
epoch 58: 4531.25 loss
epoch 58: 4096.875 loss
epoch 58: 7303.125 loss
epoch 58: 2218.75 loss
epoch 58: 4168.75 loss
epoch 58: 7881.25 loss
epoch 58: 1975.0 loss
epoch 58: 721.875 loss
epoch 58: 5412.5 loss
epoch 58: 3190.625 loss
epoch 58: 9481.25 loss
epoch 58: 4656.25 loss
epoch 58: 2109.375 loss
epoch 58: 2509.375 loss
epoch 58: 4653.125 loss
epoch 58: 3003.125 loss
epoch 58: 4156.25 loss
epoch 58: 2559.375 loss
epoch 58: 4459.375 loss
epoch 58: 4468.75 loss
epoch 58: 5659.375 loss
epoch 58: 7146

epoch 60: 6581.25 loss
epoch 60: 9184.375 loss
epoch 60: 7243.75 loss
epoch 60: 4443.75 loss
epoch 60: 7118.75 loss
epoch 60: 3181.25 loss
epoch 60: 7153.125 loss
epoch 60: 3893.75 loss
epoch 60: 2062.5 loss
epoch 60: 5887.5 loss
epoch 60: 4737.5 loss
epoch 60: 0.0 loss
epoch 60: 1975.0 loss
epoch 60: 2618.75 loss
epoch 60: 9278.125 loss
epoch 60: 5440.625 loss
epoch 60: 6090.625 loss
epoch 60: 5371.875 loss
epoch 60: 8293.75 loss
epoch 60: 4371.875 loss
epoch 60: 4396.875 loss
epoch 60: 4462.5 loss
epoch 60: 6403.125 loss
epoch 60: 3296.875 loss
epoch 60: 9115.625 loss
epoch 60: 3259.375 loss
epoch 60: 993.75 loss
epoch 60: 6462.5 loss
epoch 60: 9578.125 loss
epoch 60: 8250.0 loss
epoch 60: 2609.375 loss
epoch 60: 8237.5 loss
epoch 60: 5759.375 loss
epoch 60: 4687.5 loss
epoch 60: 3293.75 loss
epoch 60: 1653.125 loss
epoch 60: 3668.75 loss
epoch 60: 9315.625 loss
epoch 60: 981.25 loss
epoch 60: 4006.25 loss
epoch 60: 6615.625 loss
epoch 60: 7512.5 loss
epoch 60: 3112.5 loss
epoch 60: 

epoch 62: 3825.0 loss
epoch 62: 2084.375 loss
epoch 62: 8606.25 loss
epoch 62: 3500.0 loss
epoch 62: 6437.5 loss
epoch 62: 562.5 loss
epoch 62: 953.125 loss
epoch 62: 4915.625 loss
epoch 62: 909.375 loss
epoch 62: 3693.75 loss
epoch 62: 2837.5 loss
epoch 62: 1400.0 loss
epoch 62: 5656.25 loss
epoch 62: 7543.75 loss
epoch 62: 2496.875 loss
epoch 62: 4843.75 loss
epoch 62: 3843.75 loss
epoch 62: 4100.0 loss
epoch 62: 553.125 loss
epoch 62: 6403.125 loss
epoch 62: 9865.625 loss
epoch 62: 6156.25 loss
epoch 62: 0.0 loss
epoch 63: 4134.375 loss
epoch 63: 5131.25 loss
epoch 63: 5334.375 loss
epoch 63: 7703.125 loss
epoch 63: 4246.875 loss
epoch 63: 9031.25 loss
epoch 63: 1337.5 loss
epoch 63: 3793.75 loss
epoch 63: 7343.75 loss
epoch 63: 4259.375 loss
epoch 63: 2875.0 loss
epoch 63: 3003.125 loss
epoch 63: 6012.5 loss
epoch 63: 13106.25 loss
epoch 63: 3303.125 loss
epoch 63: 3453.125 loss
epoch 63: 10496.875 loss
epoch 63: 5528.125 loss
epoch 63: 3815.625 loss
epoch 63: 2643.75 loss
epoch 63

epoch 66: 3737.5 loss
epoch 66: 5112.5 loss
epoch 66: 1434.375 loss
epoch 66: 721.875 loss
epoch 66: 5418.75 loss
epoch 66: 2196.875 loss
epoch 66: 6596.875 loss
epoch 66: 9215.625 loss
epoch 66: 2187.5 loss
epoch 66: 6050.0 loss
epoch 66: 0.0 loss
epoch 66: 3690.625 loss
epoch 66: 4375.0 loss
epoch 66: 6865.625 loss
epoch 66: 7125.0 loss
epoch 66: 5834.375 loss
epoch 66: 775.0 loss
epoch 66: 12743.75 loss
epoch 66: 3931.25 loss
epoch 66: 6262.5 loss
epoch 66: 2681.25 loss
epoch 66: 5228.125 loss
epoch 66: 5990.625 loss
epoch 66: 7581.25 loss
epoch 66: 6528.125 loss
epoch 66: 4006.25 loss
epoch 66: 4334.375 loss
epoch 66: 3175.0 loss
epoch 66: 2418.75 loss
epoch 66: 1890.625 loss
epoch 66: 3153.125 loss
epoch 66: 5165.625 loss
epoch 66: 5062.5 loss
epoch 66: 6381.25 loss
epoch 66: 6643.75 loss
epoch 66: 6437.5 loss
epoch 66: 5746.875 loss
epoch 66: 4275.0 loss
epoch 66: 13537.5 loss
epoch 66: 5578.125 loss
epoch 66: 1621.875 loss
epoch 66: 6253.125 loss
epoch 66: 5078.125 loss
epoch 66

epoch 68: 1484.375 loss
epoch 68: 4190.625 loss
epoch 68: 8578.125 loss
epoch 68: 5821.875 loss
epoch 68: 10115.625 loss
epoch 68: 4812.5 loss
epoch 68: 4384.375 loss
epoch 68: 6778.125 loss
epoch 68: 3565.625 loss
epoch 68: 3521.875 loss
epoch 68: 1378.125 loss
epoch 68: 5853.125 loss
epoch 68: 5375.0 loss
epoch 68: 6668.75 loss
epoch 68: 8031.25 loss
epoch 68: 7296.875 loss
epoch 68: 7437.5 loss
epoch 68: 5150.0 loss
epoch 68: 4581.25 loss
epoch 68: 6371.875 loss
epoch 68: 6706.25 loss
epoch 68: 4503.125 loss
epoch 68: 2903.125 loss
epoch 68: 3375.0 loss
epoch 68: 5237.5 loss
epoch 68: 9781.25 loss
epoch 68: 3884.375 loss
epoch 68: 7584.375 loss
epoch 68: 10015.625 loss
epoch 68: 6881.25 loss
epoch 68: 9312.5 loss
epoch 68: 3081.25 loss
epoch 68: 312.5 loss
epoch 68: 3037.5 loss
epoch 68: 1078.125 loss
epoch 68: 8771.875 loss
epoch 68: 1518.75 loss
epoch 68: 4637.5 loss
epoch 68: 0.0 loss
epoch 69: 2834.375 loss
epoch 69: 2603.125 loss
epoch 69: 2506.25 loss
epoch 69: 2828.125 loss
e

epoch 72: 10118.75 loss
epoch 72: 4171.875 loss
epoch 72: 4915.625 loss
epoch 72: 5175.0 loss
epoch 72: 4193.75 loss
epoch 72: 5150.0 loss
epoch 72: 3443.75 loss
epoch 72: 3021.875 loss
epoch 72: 5550.0 loss
epoch 72: 4850.0 loss
epoch 72: 553.125 loss
epoch 72: 2815.625 loss
epoch 72: 7975.0 loss
epoch 72: 7365.625 loss
epoch 72: 7793.75 loss
epoch 72: 1843.75 loss
epoch 72: 10356.25 loss
epoch 72: 6328.125 loss
epoch 72: 4162.5 loss
epoch 72: 9287.5 loss
epoch 72: 2090.625 loss
epoch 72: 5781.25 loss
epoch 72: 2765.625 loss
epoch 72: 6271.875 loss
epoch 72: 3875.0 loss
epoch 72: 6781.25 loss
epoch 72: 1778.125 loss
epoch 72: 3578.125 loss
epoch 72: 71.875 loss
epoch 72: 8046.875 loss
epoch 72: 3734.375 loss
epoch 72: 3687.5 loss
epoch 72: 4046.875 loss
epoch 72: 15568.75 loss
epoch 72: 4656.25 loss
epoch 72: 8646.875 loss
epoch 72: 4953.125 loss
epoch 72: 4850.0 loss
epoch 72: 8396.875 loss
epoch 72: 4915.625 loss
epoch 72: 7512.5 loss
epoch 72: 9534.375 loss
epoch 72: 11678.125 loss

epoch 75: 2378.125 loss
epoch 75: 3846.875 loss
epoch 75: 8718.75 loss
epoch 75: 2540.625 loss
epoch 75: 3850.0 loss
epoch 75: 12290.625 loss
epoch 75: 4443.75 loss
epoch 75: 5146.875 loss
epoch 75: 9621.875 loss
epoch 75: 6803.125 loss
epoch 75: 4771.875 loss
epoch 75: 1450.0 loss
epoch 75: 3621.875 loss
epoch 75: 11065.625 loss
epoch 75: 6806.25 loss
epoch 75: 8190.625 loss
epoch 75: 5112.5 loss
epoch 75: 4000.0 loss
epoch 75: 9053.125 loss
epoch 75: 4490.625 loss
epoch 75: 7037.5 loss
epoch 75: 6668.75 loss
epoch 75: 6987.5 loss
epoch 75: 6303.125 loss
epoch 75: 1875.0 loss
epoch 75: 5943.75 loss
epoch 75: 7693.75 loss
epoch 75: 1681.25 loss
epoch 75: 6693.75 loss
epoch 75: 2840.625 loss
epoch 75: 5478.125 loss
epoch 75: 6346.875 loss
epoch 75: 9237.5 loss
epoch 75: 6140.625 loss
epoch 75: 8331.25 loss
epoch 75: 6565.625 loss
epoch 75: 2687.5 loss
epoch 75: 5181.25 loss
epoch 75: 1865.625 loss
epoch 75: 5609.375 loss
epoch 75: 3884.375 loss
epoch 75: 4759.375 loss
epoch 75: 5096.875

epoch 80: 3484.375 loss
epoch 80: 3403.125 loss
epoch 80: 12478.125 loss
epoch 80: 7434.375 loss
epoch 80: 4971.875 loss
epoch 80: 4850.0 loss
epoch 80: 4731.25 loss
epoch 80: 4631.25 loss
epoch 80: 5168.75 loss
epoch 80: 9681.25 loss
epoch 80: 8459.375 loss
epoch 80: 8075.0 loss
epoch 80: 8918.75 loss
epoch 80: 3925.0 loss
epoch 80: 6546.875 loss
epoch 80: 7537.5 loss
epoch 80: 3650.0 loss
epoch 80: 2465.625 loss
epoch 80: 5328.125 loss
epoch 80: 6859.375 loss
epoch 80: 8346.875 loss
epoch 80: 7578.125 loss
epoch 80: 5537.5 loss
epoch 80: 3253.125 loss
epoch 80: 4128.125 loss
epoch 80: 6546.875 loss
epoch 80: 1612.5 loss
epoch 80: 5643.75 loss
epoch 80: 2975.0 loss
epoch 80: 4884.375 loss
epoch 80: 3015.625 loss
epoch 80: 4456.25 loss
epoch 80: 6490.625 loss
epoch 80: 4125.0 loss
epoch 80: 7734.375 loss
epoch 80: 3300.0 loss
epoch 80: 5425.0 loss
epoch 80: 4915.625 loss
epoch 80: 5090.625 loss
epoch 80: 5759.375 loss
epoch 80: 8200.0 loss
epoch 80: 1675.0 loss
epoch 80: 2365.625 loss


epoch 84: 5459.375 loss
epoch 84: 3293.75 loss
epoch 84: 3875.0 loss
epoch 84: 4887.5 loss
epoch 84: 6168.75 loss
epoch 84: 5284.375 loss
epoch 84: 5406.25 loss
epoch 84: 2946.875 loss
epoch 84: 3281.25 loss
epoch 84: 5640.625 loss
epoch 84: 2465.625 loss
epoch 84: 5512.5 loss
epoch 84: 5750.0 loss
epoch 84: 3403.125 loss
epoch 84: 11103.125 loss
epoch 84: 7868.75 loss
epoch 84: 4796.875 loss
epoch 84: 5465.625 loss
epoch 84: 7165.625 loss
epoch 84: 4406.25 loss
epoch 84: 2059.375 loss
epoch 84: 7643.75 loss
epoch 84: 5381.25 loss
epoch 84: 3384.375 loss
epoch 84: 5009.375 loss
epoch 84: 5246.875 loss
epoch 84: 9296.875 loss
epoch 84: 7362.5 loss
epoch 84: 4603.125 loss
epoch 84: 3800.0 loss
epoch 84: 2128.125 loss
epoch 84: 2856.25 loss
epoch 84: 6778.125 loss
epoch 84: 3031.25 loss
epoch 84: 2856.25 loss
epoch 84: 8428.125 loss
epoch 84: 4343.75 loss
epoch 84: 8390.625 loss
epoch 84: 8340.625 loss
epoch 84: 9418.75 loss
epoch 84: 10953.125 loss
epoch 84: 7275.0 loss
epoch 84: 2153.12

epoch 89: 1756.25 loss
epoch 89: 5253.125 loss
epoch 89: 2521.875 loss
epoch 89: 2631.25 loss
epoch 89: 7259.375 loss
epoch 89: 4900.0 loss
epoch 89: 4412.5 loss
epoch 89: 8428.125 loss
epoch 89: 6496.875 loss
epoch 89: 6400.0 loss
epoch 89: 4700.0 loss
epoch 89: 3859.375 loss
epoch 89: 5028.125 loss
epoch 89: 2059.375 loss
epoch 89: 2240.625 loss
epoch 89: 5400.0 loss
epoch 89: 8087.5 loss
epoch 89: 8587.5 loss
epoch 89: 8100.0 loss
epoch 89: 2481.25 loss
epoch 89: 7146.875 loss
epoch 89: 6603.125 loss
epoch 89: 2671.875 loss
epoch 89: 3646.875 loss
epoch 89: 6368.75 loss
epoch 89: 1443.75 loss
epoch 89: 4725.0 loss
epoch 89: 6693.75 loss
epoch 89: 12375.0 loss
epoch 89: 6984.375 loss
epoch 89: 4903.125 loss
epoch 89: 5131.25 loss
epoch 89: 9343.75 loss
epoch 89: 9456.25 loss
epoch 89: 3740.625 loss
epoch 89: 5143.75 loss
epoch 89: 5546.875 loss
epoch 89: 7525.0 loss
epoch 89: 1337.5 loss
epoch 89: 6721.875 loss
epoch 89: 4275.0 loss
epoch 89: 7087.5 loss
epoch 89: 11909.375 loss
epoc

epoch 91: 3671.875 loss
epoch 91: 7993.75 loss
epoch 91: 5790.625 loss
epoch 91: 6575.0 loss
epoch 91: 5487.5 loss
epoch 91: 3146.875 loss
epoch 91: 3968.75 loss
epoch 91: 7418.75 loss
epoch 91: 4215.625 loss
epoch 91: 3853.125 loss
epoch 91: 2590.625 loss
epoch 91: 3106.25 loss
epoch 91: 7753.125 loss
epoch 91: 7428.125 loss
epoch 91: 5571.875 loss
epoch 91: 4390.625 loss
epoch 91: 4181.25 loss
epoch 91: 5368.75 loss
epoch 91: 1009.375 loss
epoch 91: 4531.25 loss
epoch 91: 7856.25 loss
epoch 91: 7034.375 loss
epoch 91: 5640.625 loss
epoch 91: 4365.625 loss
epoch 91: 3756.25 loss
epoch 91: 3887.5 loss
epoch 91: 2625.0 loss
epoch 91: 8415.625 loss
epoch 91: 3984.375 loss
epoch 91: 8631.25 loss
epoch 91: 1921.875 loss
epoch 91: 0.0 loss
epoch 92: 3840.625 loss
epoch 92: 3293.75 loss
epoch 92: 1218.75 loss
epoch 92: 4740.625 loss
epoch 92: 4959.375 loss
epoch 92: 4106.25 loss
epoch 92: 8118.75 loss
epoch 92: 1018.75 loss
epoch 92: 6015.625 loss
epoch 92: 4484.375 loss
epoch 92: 1596.875 l

epoch 93: 10018.75 loss
epoch 93: 5028.125 loss
epoch 93: 5800.0 loss
epoch 93: 13756.25 loss
epoch 93: 9559.375 loss
epoch 93: 6896.875 loss
epoch 93: 375.0 loss
epoch 93: 7587.5 loss
epoch 93: 5237.5 loss
epoch 93: 3834.375 loss
epoch 93: 5390.625 loss
epoch 93: 4418.75 loss
epoch 93: 3078.125 loss
epoch 93: 5931.25 loss
epoch 93: 3934.375 loss
epoch 93: 9384.375 loss
epoch 93: 7556.25 loss
epoch 93: 962.5 loss
epoch 93: 1109.375 loss
epoch 93: 3175.0 loss
epoch 93: 3534.375 loss
epoch 93: 0.0 loss
epoch 94: 6287.5 loss
epoch 94: 6809.375 loss
epoch 94: 2409.375 loss
epoch 94: 10450.0 loss
epoch 94: 4281.25 loss
epoch 94: 7193.75 loss
epoch 94: 10178.125 loss
epoch 94: 5253.125 loss
epoch 94: 662.5 loss
epoch 94: 2590.625 loss
epoch 94: 3256.25 loss
epoch 94: 7128.125 loss
epoch 94: 12981.25 loss
epoch 94: 8965.625 loss
epoch 94: 2000.0 loss
epoch 94: 950.0 loss
epoch 94: 8290.625 loss
epoch 94: 3693.75 loss
epoch 94: 3378.125 loss
epoch 94: 4540.625 loss
epoch 94: 993.75 loss
epoch 

epoch 97: 8490.625 loss
epoch 97: 4190.625 loss
epoch 97: 7525.0 loss
epoch 97: 6031.25 loss
epoch 97: 4525.0 loss
epoch 97: 6853.125 loss
epoch 97: 5943.75 loss
epoch 97: 4125.0 loss
epoch 97: 8571.875 loss
epoch 97: 2134.375 loss
epoch 97: 6668.75 loss
epoch 97: 6831.25 loss
epoch 97: 6237.5 loss
epoch 97: 3956.25 loss
epoch 97: 7771.875 loss
epoch 97: 3175.0 loss
epoch 97: 6834.375 loss
epoch 97: 3131.25 loss
epoch 97: 5078.125 loss
epoch 97: 4968.75 loss
epoch 97: 3950.0 loss
epoch 97: 5878.125 loss
epoch 97: 10059.375 loss
epoch 97: 1940.625 loss
epoch 97: 6584.375 loss
epoch 97: 7818.75 loss
epoch 97: 318.75 loss
epoch 97: 2171.875 loss
epoch 97: 4909.375 loss
epoch 97: 7328.125 loss
epoch 97: 6918.75 loss
epoch 97: 2400.0 loss
epoch 97: 3550.0 loss
epoch 97: 6759.375 loss
epoch 97: 5596.875 loss
epoch 97: 8046.875 loss
epoch 97: 10915.625 loss
epoch 97: 5284.375 loss
epoch 97: 9200.0 loss
epoch 97: 2918.75 loss
epoch 97: 3078.125 loss
epoch 97: 5731.25 loss
epoch 97: 4150.0 loss

In [9]:
# evaluate the model
predictions, actuals, acc = evaluate_model(test_dl, model)
print('Accuracy: %.3f' % acc)
print(predictions)
print(actuals)

Accuracy: 0.862
[[0.]
 [0.]
 [0.]
 ...
 [0.]
 [0.]
 [0.]]
[[0.]
 [0.]
 [0.]
 ...
 [0.]
 [0.]
 [0.]]


In [10]:
torch.save(model.state_dict(), '../models/neural_networks.pt')
