In [1]:
from numpy import vstack
from pandas import read_csv
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import accuracy_score
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.utils.data import random_split
from torch import Tensor
from torch.nn import Linear
from torch.nn import ReLU
from torch.nn import Sigmoid
from torch.nn import Module
from torch.optim import SGD
from torch.nn import BCELoss
from torch.nn.init import kaiming_uniform_
from torch.nn.init import xavier_uniform_

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# dataset definition
class CSVDataset(Dataset):
    # load the dataset
    def __init__(self, path):
        # load the csv file as a dataframe
        df = read_csv(path)
        # store the inputs and outputs
        self.X = df.values[:, :-1]
        self.y = df.values[:, -1]
        # ensure input data is floats
        self.X = self.X.astype('float32')
        # label encode target and ensure the values are floats
        self.y = LabelEncoder().fit_transform(self.y)
        self.y = self.y.astype('float32')
        self.y = self.y.reshape((len(self.y), 1))

    # number of rows in the dataset
    def __len__(self):
        return len(self.X)

    # get a row at an index
    def __getitem__(self, idx):
        return [self.X[idx], self.y[idx]]

    # get indexes for train and test rows
    def get_splits(self, n_test=0.33):
        # determine sizes
        test_size = round(n_test * len(self.X))
        train_size = len(self.X) - test_size
        # calculate the split
        return random_split(self, [train_size, test_size])

In [3]:
# model definition
class MLP(Module):
    # define model elements
    def __init__(self, n_inputs):
        super(MLP, self).__init__()
        # input to first hidden layer
        self.hidden1 = Linear(n_inputs, 10)
        kaiming_uniform_(self.hidden1.weight, nonlinearity='relu')
        self.act1 = ReLU()
        # second hidden layer
        self.hidden2 = Linear(10, 8)
        kaiming_uniform_(self.hidden2.weight, nonlinearity='relu')
        self.act2 = ReLU()
        # third hidden layer and output
        self.hidden3 = Linear(8, 1)
        xavier_uniform_(self.hidden3.weight)
        self.act3 = Sigmoid()

    # forward propagate input
    def forward(self, X):
        # input to first hidden layer
        X = self.hidden1(X)
        X = self.act1(X)
         # second hidden layer
        X = self.hidden2(X)
        X = self.act2(X)
        # third hidden layer and output
        X = self.hidden3(X)
        X = self.act3(X)
        return X

In [4]:
# prepare the dataset
def prepare_data(train_path,test_path = None):
    if test_path == None:
        dataset = CSVDataset(train_path)
        train, test = dataset.get_splits()
    else:
        train = CSVDataset(train_path)
        test = CSVDataset(test_path)

    train_dl = DataLoader(train, batch_size=32, shuffle=True)
    test_dl = DataLoader(test, batch_size=1024, shuffle=False)
    return train_dl, test_dl

In [5]:
# train the model
def train_model(train_dl, model):
    # define the optimization
    criterion = BCELoss()
    optimizer = SGD(model.parameters(), lr=0.01, momentum=0.9)
    # enumerate epochs
    for epoch in range(100):
        # enumerate mini batches
        for i, (inputs, targets) in enumerate(train_dl):
            # clear the gradients
            optimizer.zero_grad()
            # compute the model output
            yhat = model(inputs)
            # calculate loss
            loss = criterion(yhat, targets)
            print(f'epoch {epoch}: {loss} loss')
            # credit assignment
            loss.backward()
            # update model weights
            optimizer.step()

In [6]:
# evaluate the model
def evaluate_model(test_dl, model):
    predictions, actuals = list(), list()
    for i, (inputs, targets) in enumerate(test_dl):
        # evaluate the model on the test set
        yhat = model(inputs)
        # retrieve numpy array
        yhat = yhat.detach().numpy()
        actual = targets.numpy()
        actual = actual.reshape((len(actual), 1))
        # round to class values
        yhat = yhat.round()
        # store
        predictions.append(yhat)
        actuals.append(actual)
    predictions, actuals = vstack(predictions), vstack(actuals)

    # calculate accuracy
    acc = accuracy_score(actuals, predictions)
    return predictions, actuals, acc

# make a class prediction for one row of data
def predict(row, model):
    # convert row to data
    row = Tensor([row])
    # make prediction
    yhat = model(row)
    # retrieve numpy array
    yhat = yhat.detach().numpy()
    return yhat

In [7]:
# load data

train_path = '../data/5_train_dataset.csv'
test_path = '../data/4_test_dataset.csv'

train_dl, test_dl = prepare_data(train_path, test_path)
print(len(train_dl.dataset), len(test_dl.dataset))

n_features = len(train_dl.dataset.X[0])
n_features

5889 3927


21

In [8]:
# define the network
model = MLP(n_features)
# train the model
train_model(train_dl, model)


epoch 0: 9880.1298828125 loss
epoch 0: -3449.89404296875 loss
epoch 0: -7952.8427734375 loss
epoch 0: -2033.5908203125 loss
epoch 0: -5324.779296875 loss
epoch 0: -4074.63525390625 loss
epoch 0: -5005.625 loss
epoch 0: -3214.55859375 loss
epoch 0: -6640.44287109375 loss
epoch 0: -6812.41015625 loss
epoch 0: -2746.55078125 loss
epoch 0: -9071.14453125 loss
epoch 0: -3502.80810546875 loss
epoch 0: -1274.9130859375 loss
epoch 0: -4714.494140625 loss
epoch 0: -5512.5 loss
epoch 0: -7599.97314453125 loss
epoch 0: -6962.5 loss
epoch 0: -2608.91162109375 loss
epoch 0: -5993.6181640625 loss
epoch 0: -11612.400390625 loss
epoch 0: -3774.743896484375 loss
epoch 0: -2312.45703125 loss
epoch 0: -7774.9765625 loss
epoch 0: -4571.83984375 loss
epoch 0: -6253.0927734375 loss
epoch 0: -715.566650390625 loss
epoch 0: -3175.0 loss
epoch 0: -5431.203125 loss
epoch 0: -5065.6044921875 loss
epoch 0: -9.316631317138672 loss
epoch 0: -5978.046875 loss
epoch 0: -2603.0859375 loss
epoch 0: -4609.29736328125 lo

epoch 3: 6603.1259765625 loss
epoch 3: 3065.62841796875 loss
epoch 3: 6050.009765625 loss
epoch 3: 7068.75244140625 loss
epoch 3: 3743.751708984375 loss
epoch 3: 4200.00048828125 loss
epoch 3: 9106.255859375 loss
epoch 3: 8781.2578125 loss
epoch 3: 6553.13427734375 loss
epoch 3: 6312.50048828125 loss
epoch 3: 8581.251953125 loss
epoch 3: 3018.75 loss
epoch 3: 12000.0 loss
epoch 4: 2071.884521484375 loss
epoch 4: 7600.001953125 loss
epoch 4: 2568.75048828125 loss
epoch 4: 6296.87646484375 loss
epoch 4: 3837.502197265625 loss
epoch 4: 0.0 loss
epoch 4: 3896.876953125 loss
epoch 4: 7321.87841796875 loss
epoch 4: 878.1260986328125 loss
epoch 4: 1378.1279296875 loss
epoch 4: 3068.75048828125 loss
epoch 4: 4693.751953125 loss
epoch 4: 6850.00341796875 loss
epoch 4: 5893.751953125 loss
epoch 4: 4865.625 loss
epoch 4: 2759.375 loss
epoch 4: 6718.751953125 loss
epoch 4: 246.87745666503906 loss
epoch 4: 6934.376953125 loss
epoch 4: 1768.7593994140625 loss
epoch 4: 4196.87890625 loss
epoch 4: 629

epoch 8: 9950.0 loss
epoch 8: 5590.62548828125 loss
epoch 8: 8859.376953125 loss
epoch 8: 2146.875244140625 loss
epoch 8: 1956.255126953125 loss
epoch 8: 4037.5 loss
epoch 8: 6159.37548828125 loss
epoch 8: 3781.253173828125 loss
epoch 8: 3415.62548828125 loss
epoch 8: 2590.62548828125 loss
epoch 8: 5146.875 loss
epoch 8: 5956.25048828125 loss
epoch 8: 8659.375 loss
epoch 8: 2881.25048828125 loss
epoch 8: 6284.37548828125 loss
epoch 8: 6796.87548828125 loss
epoch 8: 2746.87646484375 loss
epoch 8: 9831.25 loss
epoch 8: 1862.505615234375 loss
epoch 8: 10225.0009765625 loss
epoch 8: 2043.7501220703125 loss
epoch 8: 4059.375 loss
epoch 8: 6534.3759765625 loss
epoch 8: 2784.375244140625 loss
epoch 8: 11771.875 loss
epoch 8: 9262.5 loss
epoch 8: 3650.0 loss
epoch 8: 5050.0 loss
epoch 8: 1446.875732421875 loss
epoch 8: 1684.37548828125 loss
epoch 8: 3190.625 loss
epoch 8: 3606.25048828125 loss
epoch 8: 11890.625 loss
epoch 8: 5218.75048828125 loss
epoch 8: 2165.62646484375 loss
epoch 8: 6618.7

epoch 11: 0.0 loss
epoch 12: 6190.625 loss
epoch 12: 4381.25 loss
epoch 12: 3484.37548828125 loss
epoch 12: 1768.7501220703125 loss
epoch 12: 4912.5 loss
epoch 12: 8240.625 loss
epoch 12: 6168.75 loss
epoch 12: 10450.0 loss
epoch 12: 7275.0 loss
epoch 12: 5846.87548828125 loss
epoch 12: 6921.875 loss
epoch 12: 3725.0 loss
epoch 12: 8846.875 loss
epoch 12: 2693.75 loss
epoch 12: 3646.875 loss
epoch 12: 5109.375 loss
epoch 12: 4300.0 loss
epoch 12: 1443.75 loss
epoch 12: 2331.251953125 loss
epoch 12: 2040.6253662109375 loss
epoch 12: 2646.87548828125 loss
epoch 12: 2531.250244140625 loss
epoch 12: 3040.62548828125 loss
epoch 12: 3412.500244140625 loss
epoch 12: 5918.75 loss
epoch 12: 7356.25 loss
epoch 12: 5015.625 loss
epoch 12: 4687.5 loss
epoch 12: 1737.5009765625 loss
epoch 12: 2000.00048828125 loss
epoch 12: 4512.50048828125 loss
epoch 12: 3478.125 loss
epoch 12: 1806.250244140625 loss
epoch 12: 7865.62548828125 loss
epoch 12: 8715.625 loss
epoch 12: 6681.25 loss
epoch 12: 5787.5004

epoch 15: 7443.75 loss
epoch 15: 10362.5 loss
epoch 15: 5006.25 loss
epoch 15: 6665.625 loss
epoch 15: 8231.25 loss
epoch 15: 5234.375 loss
epoch 15: 6209.375 loss
epoch 15: 7362.5 loss
epoch 15: 8012.5 loss
epoch 15: 5043.75 loss
epoch 15: 5828.126953125 loss
epoch 15: 6990.626953125 loss
epoch 15: 10484.375 loss
epoch 15: 2375.0 loss
epoch 15: 6240.625 loss
epoch 15: 4003.125732421875 loss
epoch 15: 3118.753173828125 loss
epoch 15: 2475.000244140625 loss
epoch 15: 4025.00390625 loss
epoch 15: 7550.0 loss
epoch 15: 7537.5 loss
epoch 15: 8918.75 loss
epoch 15: 8125.0 loss
epoch 15: 4053.125244140625 loss
epoch 15: 8703.125 loss
epoch 15: 4556.25048828125 loss
epoch 15: 4075.003173828125 loss
epoch 15: 3021.875244140625 loss
epoch 15: 4968.75 loss
epoch 15: 7971.875 loss
epoch 15: 11028.125 loss
epoch 15: 3956.25 loss
epoch 15: 3825.0 loss
epoch 15: 3853.125244140625 loss
epoch 15: 1812.5001220703125 loss
epoch 15: 1800.0 loss
epoch 15: 3431.25 loss
epoch 15: 5853.125 loss
epoch 15: 426

epoch 19: 5265.625 loss
epoch 19: 4690.625 loss
epoch 19: 8259.375 loss
epoch 19: 6850.0 loss
epoch 19: 3406.25 loss
epoch 19: 1650.000244140625 loss
epoch 19: 4740.625 loss
epoch 19: 4490.625 loss
epoch 19: 3287.5 loss
epoch 19: 2753.125 loss
epoch 19: 9090.625 loss
epoch 19: 11212.5 loss
epoch 19: 2003.1251220703125 loss
epoch 19: 7118.75 loss
epoch 19: 4221.875 loss
epoch 19: 4812.5009765625 loss
epoch 19: 5465.625 loss
epoch 19: 11753.125 loss
epoch 19: 731.2501220703125 loss
epoch 19: 9337.5 loss
epoch 19: 5356.25 loss
epoch 19: 5065.625 loss
epoch 19: 5778.125 loss
epoch 19: 5862.5 loss
epoch 19: 2459.375 loss
epoch 19: 8703.125 loss
epoch 19: 12037.5 loss
epoch 19: 3409.375 loss
epoch 19: 1118.7515869140625 loss
epoch 19: 1825.0001220703125 loss
epoch 19: 6525.0 loss
epoch 19: 4071.87548828125 loss
epoch 19: 3075.0 loss
epoch 19: 1743.75 loss
epoch 19: 4675.0 loss
epoch 19: 11946.875 loss
epoch 19: 2800.000244140625 loss
epoch 19: 4381.25 loss
epoch 19: 828.1250610351562 loss
ep

epoch 23: 3515.62548828125 loss
epoch 23: 7706.25 loss
epoch 23: 5153.125 loss
epoch 23: 4300.001953125 loss
epoch 23: 5678.125 loss
epoch 23: 4740.62646484375 loss
epoch 23: 8281.25 loss
epoch 23: 5187.5 loss
epoch 23: 8453.125 loss
epoch 23: 7309.375 loss
epoch 23: 10656.25 loss
epoch 23: 7481.25 loss
epoch 23: 3359.375 loss
epoch 23: 7509.37646484375 loss
epoch 23: 2037.5001220703125 loss
epoch 23: 7200.0 loss
epoch 23: 9250.0 loss
epoch 23: 9468.75 loss
epoch 23: 8162.5 loss
epoch 23: 4178.1259765625 loss
epoch 23: 4296.875 loss
epoch 23: 5034.375 loss
epoch 23: 3053.12646484375 loss
epoch 23: 4475.0 loss
epoch 23: 2181.25 loss
epoch 23: 8.697900739207398e-06 loss
epoch 23: 2059.375 loss
epoch 23: 584.3750610351562 loss
epoch 23: 5968.75 loss
epoch 23: 13103.126953125 loss
epoch 23: 5550.0 loss
epoch 23: 6371.875 loss
epoch 23: 3140.625 loss
epoch 23: 6693.75 loss
epoch 23: 7615.625 loss
epoch 23: 9312.5 loss
epoch 23: 6156.25 loss
epoch 23: 3906.25 loss
epoch 23: 6103.125 loss
epo

epoch 27: 6221.875 loss
epoch 27: 7909.375 loss
epoch 27: 6571.875 loss
epoch 27: 3384.375 loss
epoch 27: 5828.125 loss
epoch 27: 3559.375 loss
epoch 27: 8990.625 loss
epoch 27: 3487.5 loss
epoch 27: 1068.75 loss
epoch 27: 2328.125 loss
epoch 27: 5956.25 loss
epoch 27: 6112.5 loss
epoch 27: 4400.0 loss
epoch 27: 296.8751220703125 loss
epoch 27: 5462.5 loss
epoch 27: 2231.25 loss
epoch 27: 5087.5009765625 loss
epoch 27: 6665.625 loss
epoch 27: 3150.0 loss
epoch 27: 4750.0 loss
epoch 27: 4631.2509765625 loss
epoch 27: 3187.5 loss
epoch 27: 4368.75 loss
epoch 27: 3390.625 loss
epoch 27: 6231.25 loss
epoch 27: 3100.0 loss
epoch 27: 3587.5 loss
epoch 27: 375.0001220703125 loss
epoch 27: 7187.5 loss
epoch 27: 8071.875 loss
epoch 27: 3709.375 loss
epoch 27: 8262.5 loss
epoch 27: 8468.75 loss
epoch 27: 8043.75 loss
epoch 27: 4671.875 loss
epoch 27: 3831.25 loss
epoch 27: 6453.125 loss
epoch 27: 4931.25 loss
epoch 27: 4625.0009765625 loss
epoch 27: 3550.0 loss
epoch 27: 6668.75 loss
epoch 27: 1

epoch 30: 6568.75 loss
epoch 30: 3578.125 loss
epoch 30: 5868.75 loss
epoch 30: 50600.0 loss
epoch 31: 7478.125 loss
epoch 31: 7303.125 loss
epoch 31: 2062.5 loss
epoch 31: 4884.375 loss
epoch 31: 6456.2509765625 loss
epoch 31: 7940.625 loss
epoch 31: 6825.0 loss
epoch 31: 8087.5 loss
epoch 31: 5925.0 loss
epoch 31: 8525.0 loss
epoch 31: 5275.0 loss
epoch 31: 4346.875 loss
epoch 31: 3000.0 loss
epoch 31: 2384.375 loss
epoch 31: 3143.75 loss
epoch 31: 3584.375 loss
epoch 31: 5493.75 loss
epoch 31: 5975.0 loss
epoch 31: 6268.75 loss
epoch 31: 3543.75 loss
epoch 31: 2818.75 loss
epoch 31: 8184.375 loss
epoch 31: 8653.125 loss
epoch 31: 9381.25 loss
epoch 31: 9568.75 loss
epoch 31: 3850.0 loss
epoch 31: 6506.25 loss
epoch 31: 1768.75 loss
epoch 31: 7062.5009765625 loss
epoch 31: 6384.375 loss
epoch 31: 4543.75 loss
epoch 31: 10318.75 loss
epoch 31: 3859.375 loss
epoch 31: 3512.5009765625 loss
epoch 31: 3837.5 loss
epoch 31: 3543.75 loss
epoch 31: 9884.375 loss
epoch 31: 3975.0 loss
epoch 3

epoch 34: 3428.125 loss
epoch 34: 4068.75 loss
epoch 34: 7456.25 loss
epoch 34: 12331.25 loss
epoch 34: 3550.0 loss
epoch 34: 7284.375 loss
epoch 34: 5246.875 loss
epoch 34: 4793.75 loss
epoch 34: 4200.0 loss
epoch 34: 7556.25048828125 loss
epoch 34: 7193.75 loss
epoch 34: 6143.75 loss
epoch 34: 5303.1259765625 loss
epoch 34: 10721.875 loss
epoch 34: 0.0 loss
epoch 35: 3131.25 loss
epoch 35: 750.0 loss
epoch 35: 4368.75 loss
epoch 35: 4.9561451305635273e-05 loss
epoch 35: 6475.0 loss
epoch 35: 8709.375 loss
epoch 35: 7637.5 loss
epoch 35: 2275.0 loss
epoch 35: 4271.875 loss
epoch 35: 8046.875 loss
epoch 35: 4306.25 loss
epoch 35: 6103.125 loss
epoch 35: 7781.25 loss
epoch 35: 4196.875 loss
epoch 35: 12437.5 loss
epoch 35: 3900.0 loss
epoch 35: 2806.25 loss
epoch 35: 7406.2509765625 loss
epoch 35: 7103.125 loss
epoch 35: 1196.875 loss
epoch 35: 1012.5001220703125 loss
epoch 35: 7687.5 loss
epoch 35: 215.62503051757812 loss
epoch 35: 6771.875 loss
epoch 35: 5800.0 loss
epoch 35: 4118.75 

epoch 38: 4981.25 loss
epoch 38: 684.3751220703125 loss
epoch 38: 4100.0 loss
epoch 38: 68500.0 loss
epoch 39: 6246.875 loss
epoch 39: 6503.125 loss
epoch 39: 3681.25 loss
epoch 39: 3303.125 loss
epoch 39: 3028.125 loss
epoch 39: 559.3750610351562 loss
epoch 39: 6893.75 loss
epoch 39: 540.6250610351562 loss
epoch 39: 2312.5 loss
epoch 39: 7665.625 loss
epoch 39: 6506.25 loss
epoch 39: 3481.25 loss
epoch 39: 8296.875 loss
epoch 39: 6318.75 loss
epoch 39: 856.2500610351562 loss
epoch 39: 3484.375 loss
epoch 39: 7590.625 loss
epoch 39: 5565.62548828125 loss
epoch 39: 9531.25 loss
epoch 39: 2415.626220703125 loss
epoch 39: 8509.375 loss
epoch 39: 5412.5 loss
epoch 39: 7287.5 loss
epoch 39: 6803.12548828125 loss
epoch 39: 2056.25 loss
epoch 39: 6834.375 loss
epoch 39: 8593.75 loss
epoch 39: 3387.5 loss
epoch 39: 8475.0 loss
epoch 39: 8971.875 loss
epoch 39: 4762.5 loss
epoch 39: 7256.25 loss
epoch 39: 9493.75 loss
epoch 39: 16390.625 loss
epoch 39: 6740.625 loss
epoch 39: 9175.0 loss
epoch 

epoch 43: 4471.875 loss
epoch 43: 3712.5 loss
epoch 43: 7700.0 loss
epoch 43: 3550.0 loss
epoch 43: 7803.125 loss
epoch 43: 10900.0 loss
epoch 43: 4334.375 loss
epoch 43: 4262.5 loss
epoch 43: 4881.25 loss
epoch 43: 4978.125 loss
epoch 43: 6603.125 loss
epoch 43: 4034.375 loss
epoch 43: 4921.875 loss
epoch 43: 3425.0 loss
epoch 43: 3859.375 loss
epoch 43: 2443.75 loss
epoch 43: 4196.875 loss
epoch 43: 2903.125 loss
epoch 43: 5193.75 loss
epoch 43: 3956.25 loss
epoch 43: 9503.125 loss
epoch 43: 6278.125 loss
epoch 43: 10371.875 loss
epoch 43: 6362.5 loss
epoch 43: 3468.75 loss
epoch 43: 2859.375 loss
epoch 43: 7393.75 loss
epoch 43: 5184.375 loss
epoch 43: 3153.125 loss
epoch 43: 4787.5 loss
epoch 43: 7112.5 loss
epoch 43: 375.00103759765625 loss
epoch 43: 12028.125 loss
epoch 43: 3243.75 loss
epoch 43: 9853.125 loss
epoch 43: 3300.0 loss
epoch 43: 2912.5 loss
epoch 43: 5912.5 loss
epoch 43: 3181.25 loss
epoch 43: 4296.875 loss
epoch 43: 8687.5 loss
epoch 43: 5.8079972404812e-06 loss
ep

epoch 47: 3665.62548828125 loss
epoch 47: 2268.75 loss
epoch 47: 2921.875 loss
epoch 47: 6587.5 loss
epoch 47: 2784.375 loss
epoch 47: 11643.75 loss
epoch 47: 5056.25 loss
epoch 47: 4856.25 loss
epoch 47: 3971.875 loss
epoch 47: 7012.5 loss
epoch 47: 3043.75 loss
epoch 47: 7015.625 loss
epoch 47: 2059.375 loss
epoch 47: 1343.75 loss
epoch 47: 4453.125 loss
epoch 47: 2412.5 loss
epoch 47: 1281.25 loss
epoch 47: 3206.25 loss
epoch 47: 4534.375 loss
epoch 47: 5315.625 loss
epoch 47: 2675.0 loss
epoch 47: 6653.125 loss
epoch 47: 10296.875 loss
epoch 47: 2537.5 loss
epoch 47: 7156.25 loss
epoch 47: 6143.75 loss
epoch 47: 5740.625 loss
epoch 47: 5668.75 loss
epoch 47: 3600.0 loss
epoch 47: 6156.25 loss
epoch 47: 4153.125 loss
epoch 47: 4978.125 loss
epoch 47: 3934.375 loss
epoch 47: 4221.875 loss
epoch 47: 4990.625 loss
epoch 47: 9125.0 loss
epoch 47: 3600.0 loss
epoch 47: 2215.625 loss
epoch 47: 7906.25 loss
epoch 47: 6275.0 loss
epoch 47: 6834.375 loss
epoch 47: 3371.875 loss
epoch 47: 565

epoch 51: 3406.25 loss
epoch 51: 5700.0 loss
epoch 51: 3562.5 loss
epoch 51: 5112.5 loss
epoch 51: 5840.625 loss
epoch 51: 6515.625 loss
epoch 51: 10675.0 loss
epoch 51: 8306.25 loss
epoch 51: 4425.0 loss
epoch 51: 3278.125 loss
epoch 51: 5446.875 loss
epoch 51: 3990.625 loss
epoch 51: 7028.125 loss
epoch 51: 2625.0 loss
epoch 51: 8012.5 loss
epoch 51: 5712.5 loss
epoch 51: 6393.75 loss
epoch 51: 6540.62548828125 loss
epoch 51: 3103.125 loss
epoch 51: 6696.875 loss
epoch 51: 6318.75 loss
epoch 51: 2218.75 loss
epoch 51: 8443.75 loss
epoch 51: 3718.75 loss
epoch 51: 2628.125 loss
epoch 51: 2215.625 loss
epoch 51: 5384.375 loss
epoch 51: 3312.5 loss
epoch 51: 2809.375 loss
epoch 51: 6390.625 loss
epoch 51: 3081.25 loss
epoch 51: 3562.5 loss
epoch 51: 4137.5 loss
epoch 51: 4300.0 loss
epoch 51: 2409.375 loss
epoch 51: 6828.125 loss
epoch 51: 6284.375 loss
epoch 51: 2821.875 loss
epoch 51: 8646.875 loss
epoch 51: 5709.375 loss
epoch 51: 7859.375 loss
epoch 51: 2221.875 loss
epoch 51: 2350.

epoch 55: 8815.625 loss
epoch 55: 5234.375 loss
epoch 55: 6068.75 loss
epoch 55: 4725.0 loss
epoch 55: 4718.75 loss
epoch 55: 7059.37548828125 loss
epoch 55: 3446.875 loss
epoch 55: 6012.5 loss
epoch 55: 9340.625 loss
epoch 55: 4425.0 loss
epoch 55: 3100.0 loss
epoch 55: 3881.25 loss
epoch 55: 5343.75 loss
epoch 55: 4315.625 loss
epoch 55: 4668.75 loss
epoch 55: 3025.0 loss
epoch 55: 2162.5 loss
epoch 55: 9221.875 loss
epoch 55: 4243.75 loss
epoch 55: 281.25 loss
epoch 55: 4106.25 loss
epoch 55: 3284.375 loss
epoch 55: 5768.75 loss
epoch 55: 3515.625 loss
epoch 55: 8515.625 loss
epoch 55: 2562.5 loss
epoch 55: 7593.75 loss
epoch 55: 10815.625 loss
epoch 55: 9193.75 loss
epoch 55: 6575.0 loss
epoch 55: 5503.125 loss
epoch 55: 5753.125 loss
epoch 55: 6268.75 loss
epoch 55: 4200.00048828125 loss
epoch 55: 8365.625 loss
epoch 55: 6978.125 loss
epoch 55: 8800.0 loss
epoch 55: 9281.25 loss
epoch 55: 5812.5 loss
epoch 55: 7346.875 loss
epoch 55: 6956.25 loss
epoch 55: 11934.375 loss
epoch 55:

epoch 59: 4084.37548828125 loss
epoch 59: 8950.0 loss
epoch 59: 4509.375 loss
epoch 59: 6031.25 loss
epoch 59: 6015.625 loss
epoch 59: 1481.250244140625 loss
epoch 59: 6787.50048828125 loss
epoch 59: 3871.875 loss
epoch 59: 2450.0 loss
epoch 59: 3440.625 loss
epoch 59: 5993.75 loss
epoch 59: 3728.125 loss
epoch 59: 3987.5 loss
epoch 59: 7271.875 loss
epoch 59: 1187.5 loss
epoch 59: 671.875 loss
epoch 59: 2800.0 loss
epoch 59: 6784.375 loss
epoch 59: 3059.375 loss
epoch 59: 7043.75 loss
epoch 59: 2653.125 loss
epoch 59: 7646.875 loss
epoch 59: 4040.625 loss
epoch 59: 3734.375 loss
epoch 59: 1046.875 loss
epoch 59: 4284.375 loss
epoch 59: 8990.625 loss
epoch 59: 8175.0 loss
epoch 59: 4931.25 loss
epoch 59: 10728.125 loss
epoch 59: 6406.25 loss
epoch 59: 2328.125 loss
epoch 59: 9096.875 loss
epoch 59: 3809.375 loss
epoch 59: 4403.125 loss
epoch 59: 7687.5 loss
epoch 59: 6303.125 loss
epoch 59: 9106.25 loss
epoch 59: 1.6969157741186791e-06 loss
epoch 59: 9856.25 loss
epoch 59: 5915.625 los

epoch 63: 3665.625 loss
epoch 63: 9165.625 loss
epoch 63: 10090.625 loss
epoch 63: 2334.375 loss
epoch 63: 4803.125 loss
epoch 63: 6153.125 loss
epoch 63: 1975.0 loss
epoch 63: 2790.625 loss
epoch 63: 3793.75 loss
epoch 63: 3293.75 loss
epoch 63: 8387.5 loss
epoch 63: 5968.75 loss
epoch 63: 2271.875 loss
epoch 63: 2940.625 loss
epoch 63: 8456.25 loss
epoch 63: 1684.375 loss
epoch 63: 12387.5 loss
epoch 63: 1159.375 loss
epoch 63: 4750.0 loss
epoch 63: 6971.875 loss
epoch 63: 9137.5 loss
epoch 63: 7043.75 loss
epoch 63: 6328.125 loss
epoch 63: 4181.25 loss
epoch 63: 6184.375 loss
epoch 63: 6709.375 loss
epoch 63: 4978.125 loss
epoch 63: 6562.5 loss
epoch 63: 6193.75 loss
epoch 63: 1096.875 loss
epoch 63: 3303.125 loss
epoch 63: 6175.0 loss
epoch 63: 512.5 loss
epoch 63: 8625.0 loss
epoch 63: 7490.625 loss
epoch 63: 6809.375 loss
epoch 63: 8006.25 loss
epoch 63: 2609.375 loss
epoch 63: 1796.875 loss
epoch 63: 1803.125 loss
epoch 63: 7050.0 loss
epoch 63: 9143.75 loss
epoch 63: 2828.125 l

epoch 67: 7300.0 loss
epoch 67: 6618.75 loss
epoch 67: 4262.5 loss
epoch 67: 7081.25 loss
epoch 67: 6681.25 loss
epoch 67: 3209.375 loss
epoch 67: 5496.875 loss
epoch 67: 12071.875 loss
epoch 67: 7821.875 loss
epoch 67: 6271.875 loss
epoch 67: 7906.25 loss
epoch 67: 7328.125 loss
epoch 67: 6296.875 loss
epoch 67: 9300.0 loss
epoch 67: 5534.375 loss
epoch 67: 1428.1251220703125 loss
epoch 67: 3262.5 loss
epoch 67: 506.25 loss
epoch 67: 6709.375 loss
epoch 67: 4540.625 loss
epoch 67: 9321.875 loss
epoch 67: 7162.5 loss
epoch 67: 2184.375 loss
epoch 67: 7337.5 loss
epoch 67: 3193.75 loss
epoch 67: 10221.875 loss
epoch 67: 5465.625 loss
epoch 67: 5759.375 loss
epoch 67: 5518.75 loss
epoch 67: 6412.5 loss
epoch 67: 9003.125 loss
epoch 67: 3859.375 loss
epoch 67: 1656.25 loss
epoch 67: 4731.25 loss
epoch 67: 5996.875 loss
epoch 67: 4618.75 loss
epoch 67: 6509.375 loss
epoch 67: 3196.875 loss
epoch 67: 4478.125 loss
epoch 67: 2959.375 loss
epoch 67: 3393.75 loss
epoch 67: 6125.0 loss
epoch 67

epoch 71: 6462.5 loss
epoch 71: 7725.0 loss
epoch 71: 8968.75 loss
epoch 71: 8762.5 loss
epoch 71: 10046.875 loss
epoch 71: 5400.0 loss
epoch 71: 4156.25 loss
epoch 71: 7268.75 loss
epoch 71: 7528.125 loss
epoch 71: 603.125 loss
epoch 71: 8825.0 loss
epoch 71: 5415.625 loss
epoch 71: 7084.375 loss
epoch 71: 3850.0 loss
epoch 71: 6315.625 loss
epoch 71: 6412.5 loss
epoch 71: 6643.75 loss
epoch 71: 11587.5 loss
epoch 71: 5956.25 loss
epoch 71: 5331.25 loss
epoch 71: 5506.25 loss
epoch 71: 8093.75 loss
epoch 71: 4418.75 loss
epoch 71: 1881.25 loss
epoch 71: 4768.75 loss
epoch 71: 1759.375 loss
epoch 71: 8137.5 loss
epoch 71: 1303.125 loss
epoch 71: 8921.875 loss
epoch 71: 1628.125244140625 loss
epoch 71: 2593.75048828125 loss
epoch 71: 1371.875 loss
epoch 71: 5562.5 loss
epoch 71: 2281.25 loss
epoch 71: 2000.0 loss
epoch 71: 3659.375 loss
epoch 71: 5703.125 loss
epoch 71: 3206.25 loss
epoch 71: 6656.25 loss
epoch 71: 6715.625 loss
epoch 71: 8403.125 loss
epoch 71: 3890.625 loss
epoch 71: 

epoch 75: 2343.75 loss
epoch 75: 3112.5 loss
epoch 75: 7543.75 loss
epoch 75: 5187.5 loss
epoch 75: 4900.0 loss
epoch 75: 6565.625 loss
epoch 75: 459.375 loss
epoch 75: 4281.25 loss
epoch 75: 4434.375 loss
epoch 75: 9946.875 loss
epoch 75: 5234.375 loss
epoch 75: 3021.875 loss
epoch 75: 9443.75 loss
epoch 75: 3628.125 loss
epoch 75: 5406.25 loss
epoch 75: 6606.25 loss
epoch 75: 3856.25 loss
epoch 75: 5618.75 loss
epoch 75: 2840.625 loss
epoch 75: 4003.125 loss
epoch 75: 2043.75 loss
epoch 75: 2918.75 loss
epoch 75: 3653.125 loss
epoch 75: 5943.75 loss
epoch 75: 4953.125 loss
epoch 75: 750.0 loss
epoch 75: 4537.5 loss
epoch 75: 2193.75 loss
epoch 75: 3662.5 loss
epoch 75: 7331.25 loss
epoch 75: 4640.625 loss
epoch 75: 5565.625 loss
epoch 75: 5390.625 loss
epoch 75: 2875.0 loss
epoch 75: 753.125 loss
epoch 75: 2875.000244140625 loss
epoch 75: 3990.625 loss
epoch 75: 1778.125 loss
epoch 75: 11153.125 loss
epoch 75: 5625.0 loss
epoch 75: 8665.625 loss
epoch 75: 3578.125 loss
epoch 75: 1203

epoch 79: 2537.5 loss
epoch 79: 4771.875 loss
epoch 79: 5206.25 loss
epoch 79: 3068.75 loss
epoch 79: 6893.75 loss
epoch 79: 3503.125 loss
epoch 79: 9956.25 loss
epoch 79: 3134.375 loss
epoch 79: 6721.875 loss
epoch 79: 3468.75 loss
epoch 79: 3421.875 loss
epoch 79: 10009.375 loss
epoch 79: 240.6250457763672 loss
epoch 79: 2484.375 loss
epoch 79: 6259.375 loss
epoch 79: 2890.625 loss
epoch 79: 9156.25 loss
epoch 79: 4359.375 loss
epoch 79: 6596.875 loss
epoch 79: 2256.25 loss
epoch 79: 2071.875 loss
epoch 79: 4312.5 loss
epoch 79: 4393.75 loss
epoch 79: 5150.0 loss
epoch 79: 7318.75 loss
epoch 79: 5193.75 loss
epoch 79: 7368.75 loss
epoch 79: 5996.875 loss
epoch 79: 5431.25 loss
epoch 79: 5037.5 loss
epoch 79: 5446.875 loss
epoch 79: 8656.25 loss
epoch 79: 3550.0 loss
epoch 79: 2315.625 loss
epoch 79: 3750.0 loss
epoch 79: 9062.5 loss
epoch 79: 2018.75 loss
epoch 79: 7531.25 loss
epoch 79: 5340.625 loss
epoch 79: 6340.625 loss
epoch 79: 0.0 loss
epoch 80: 878.125 loss
epoch 80: 5496.87

epoch 83: 4565.625 loss
epoch 83: 4990.625 loss
epoch 83: 4809.375 loss
epoch 83: 4384.375 loss
epoch 83: 887.5 loss
epoch 83: 6981.25 loss
epoch 83: 11406.25 loss
epoch 83: 8646.875 loss
epoch 83: 6656.25 loss
epoch 83: 7915.625 loss
epoch 83: 6265.625 loss
epoch 83: 5034.375 loss
epoch 83: 4721.875 loss
epoch 83: 4662.5 loss
epoch 83: 7359.375 loss
epoch 83: 5040.625 loss
epoch 83: 5378.125 loss
epoch 83: 5628.125 loss
epoch 83: 1365.625 loss
epoch 83: 3703.125 loss
epoch 83: 3562.5 loss
epoch 83: 3728.125 loss
epoch 83: 3678.125 loss
epoch 83: 4015.625 loss
epoch 83: 3931.25 loss
epoch 83: 3400.0 loss
epoch 83: 8362.5 loss
epoch 83: 4578.125 loss
epoch 83: 9290.625 loss
epoch 83: 3025.0 loss
epoch 83: 6037.5 loss
epoch 83: 7550.0 loss
epoch 83: 0.0003680071677081287 loss
epoch 84: 2509.375 loss
epoch 84: 2400.0 loss
epoch 84: 9231.25 loss
epoch 84: 2893.75 loss
epoch 84: 7587.5 loss
epoch 84: 4368.75 loss
epoch 84: 4787.5 loss
epoch 84: 1700.0 loss
epoch 84: 7428.125 loss
epoch 84: 

epoch 87: 3481.25 loss
epoch 87: 10450.0 loss
epoch 87: 8768.75 loss
epoch 87: 1459.375 loss
epoch 87: 4456.25 loss
epoch 87: 8006.25 loss
epoch 87: 2756.25 loss
epoch 87: 8012.5 loss
epoch 87: 6806.25 loss
epoch 87: 4737.5 loss
epoch 87: 3840.625 loss
epoch 87: 9056.25 loss
epoch 87: 2640.625 loss
epoch 87: 8731.25 loss
epoch 87: 5587.5 loss
epoch 87: 6456.25 loss
epoch 87: 5778.125 loss
epoch 87: 6668.75 loss
epoch 87: 8237.5 loss
epoch 87: 0.0 loss
epoch 88: 7484.375 loss
epoch 88: 1550.0 loss
epoch 88: 6521.875 loss
epoch 88: 4865.625 loss
epoch 88: 4271.875 loss
epoch 88: 5043.75 loss
epoch 88: 7159.375 loss
epoch 88: 2503.125 loss
epoch 88: 6456.25 loss
epoch 88: 3759.375 loss
epoch 88: 10965.625 loss
epoch 88: 1150.0 loss
epoch 88: 5071.875 loss
epoch 88: 6831.25 loss
epoch 88: 9021.875 loss
epoch 88: 4534.375 loss
epoch 88: 8615.625 loss
epoch 88: 9187.5 loss
epoch 88: 3128.125 loss
epoch 88: 8215.625 loss
epoch 88: 6259.375 loss
epoch 88: 5790.625 loss
epoch 88: 7221.875 loss


epoch 91: 3881.25 loss
epoch 91: 11915.625 loss
epoch 91: 7659.375 loss
epoch 91: 10815.625 loss
epoch 91: 3512.5 loss
epoch 91: 3231.25 loss
epoch 91: 4978.125 loss
epoch 91: 3128.125 loss
epoch 91: 2643.75 loss
epoch 91: 7093.75 loss
epoch 91: 10865.625 loss
epoch 91: 6546.875 loss
epoch 91: 6778.125 loss
epoch 91: 3065.625 loss
epoch 91: 7271.875 loss
epoch 91: 3375.0 loss
epoch 91: 12000.0 loss
epoch 92: 12962.5 loss
epoch 92: 6.521075556520373e-05 loss
epoch 92: 4381.25 loss
epoch 92: 7237.5 loss
epoch 92: 4859.375 loss
epoch 92: 6915.625 loss
epoch 92: 10471.875 loss
epoch 92: 1762.5 loss
epoch 92: 11875.0 loss
epoch 92: 2093.75 loss
epoch 92: 4459.375 loss
epoch 92: 3375.0 loss
epoch 92: 7021.875 loss
epoch 92: 2409.375 loss
epoch 92: 6446.875 loss
epoch 92: 3631.25 loss
epoch 92: 1662.5 loss
epoch 92: 7128.125 loss
epoch 92: 6978.125 loss
epoch 92: 2534.375 loss
epoch 92: 3831.25 loss
epoch 92: 5087.5 loss
epoch 92: 5909.375 loss
epoch 92: 3890.625 loss
epoch 92: 5418.75 loss
e

epoch 95: 0.0 loss
epoch 96: 1912.5 loss
epoch 96: 5121.875 loss
epoch 96: 10121.875 loss
epoch 96: 5553.125 loss
epoch 96: 5640.625 loss
epoch 96: 3065.625 loss
epoch 96: 3512.5 loss
epoch 96: 11756.25 loss
epoch 96: 1750.0 loss
epoch 96: 5081.25 loss
epoch 96: 5865.625 loss
epoch 96: 4318.75 loss
epoch 96: 3581.25 loss
epoch 96: 9625.0 loss
epoch 96: 3668.75 loss
epoch 96: 6643.75 loss
epoch 96: 2687.5 loss
epoch 96: 3193.75 loss
epoch 96: 7234.375 loss
epoch 96: 9106.25 loss
epoch 96: 3181.25 loss
epoch 96: 2412.5 loss
epoch 96: 8306.25 loss
epoch 96: 5915.625 loss
epoch 96: 6040.625 loss
epoch 96: 9337.5 loss
epoch 96: 8815.625 loss
epoch 96: 3850.0 loss
epoch 96: 2390.625 loss
epoch 96: 2412.5 loss
epoch 96: 8918.75 loss
epoch 96: 6025.0 loss
epoch 96: 9218.75 loss
epoch 96: 5909.375 loss
epoch 96: 6250.0 loss
epoch 96: 1615.625 loss
epoch 96: 10409.375 loss
epoch 96: 446.87518310546875 loss
epoch 96: 4128.125 loss
epoch 96: 7178.125 loss
epoch 96: 5450.0 loss
epoch 96: 2206.25 lo

In [9]:
# evaluate the model
predictions, actuals, acc = evaluate_model(test_dl, model)
print('Accuracy: %.3f' % acc)
print(predictions)
print(actuals)

Accuracy: 0.861
[[0.]
 [0.]
 [0.]
 ...
 [0.]
 [0.]
 [0.]]
[[0.]
 [0.]
 [0.]
 ...
 [0.]
 [0.]
 [0.]]
