In [1]:
import numpy as np
import matplotlib.pyplot as plt
import torchbones
from torchbones import Net
import torch
import torch.nn as nn
from torch import optim
import pandas as pd
import pickle

## Network Building
This notebook demonstrates how to create simple networks using torchbones without automatic training implementation. We will use MNIST wine data to demonstrate.

In [2]:
wines = pd.read_csv('winequality-red.csv', delimiter = ';').to_numpy()
truth = wines[:, -1] #ratings of wines
data = wines[:, :-1] #features of wines
data.shape

(1599, 11)

### Initializing the Network

In [29]:
sizes = data.shape[1:] #size of input layer
lins = [100, 100, 10, 1] # size of each linear layer, the last value should be the desired number of outputs
activation = nn.Tanh # activation function

# create a net object
net = Net(sizes, lins, activation).double()



## You can now train in whatever way is most appropriate for your application

In [32]:
optimizer =  optim.SGD(net.parameters(), lr=0.001) # specify optiization algorythm and learning rate
batches = 1000 # number of batches to train on
batch_size = int(len(truth)/batches) # size of each batch
epochs = 10 # number of epochs to train for

for k in range(epochs):
    shuffle = torch.randperm(len(truth)) #shuffle training set
    for i in range(batches):
        where = shuffle[i * batch_size:(i + 1) * batch_size] #take batch of training set
        output = net(torch.tensor(data[where]).double())
        truetrain = truth[where]
        loss = torch.mean(nn.MSELoss()(output.squeeze(), torch.tensor(truetrain).double()))
        optimizer.zero_grad()
        loss.backward(retain_graph = True)
        optimizer.step()
    print(f'Epoch {k}: loss = {loss}')

Epoch 0: loss = 16.000019366805347
Epoch 1: loss = 25.00002401517848
Epoch 2: loss = 25.000023745372406
Epoch 3: loss = 36.000028490122375
Epoch 4: loss = 36.0000284693846
Epoch 5: loss = 9.000014208173525
Epoch 6: loss = 16.000019033420333
Epoch 7: loss = 36.00002885387495
Epoch 8: loss = 9.000014386780462
Epoch 9: loss = 25.000024048260354
