In [1]:
%matplotlib inline
from src.connect4 import train

import numpy as np
import pickle
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data

from sklearn.model_selection import train_test_split


torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.manual_seed(999)
np.random.seed(0)

In [2]:
channels = 2
filters = 64

n_epocs = 10000
batch_size = 512
test_size = 0.2
learning_rate = 0.0002 * (batch_size / 1024.0)
momentum = 0.0

In [3]:
class Connect4Dataset(data.Dataset):
    def __init__(self, boards, values):
        assert len(boards) == len(values)
        self.boards = boards
        self.values = values
        
    def __len__(self):
        return len(self.boards)
    
    def __getitem__(self, idx):
        return self.boards[idx], self.values[idx]

boards = torch.load('/home/richard/Downloads/connect4_boards.pth').numpy()
values = torch.load('/home/richard/Downloads/connect4_values.pth').numpy()

# Here we don't want to have the player to move channel
boards = boards[:, 3 - channels:]

board_train, board_test, value_train, value_test = train_test_split(boards, values, test_size=test_size, shuffle=True)

In [4]:
# augment data
v_wins = value_train[value_train == 1.0]
v_draws = value_train[value_train == 0.5]
v_losses = value_train[value_train == 0.0]
b_wins = board_train[value_train == 1.0]
b_draws = board_train[value_train == 0.5]
b_losses = board_train[value_train == 0.0]

n_wins = len(v_wins)
n_draws = len(v_draws)
n_losses = len(v_losses)

v_augmented_draws = np.repeat(v_draws, n_wins/n_draws)
v_augmented_losses = np.repeat(v_losses, n_wins/n_losses)
b_augmented_draws = np.repeat(b_draws, n_wins/n_draws, axis=0)
b_augmented_losses = np.repeat(b_losses, n_wins/n_losses, axis=0)

extra_draw_idx = np.random.choice(range(len(v_draws)), n_wins - len(v_augmented_draws), replace=False)
extra_losses_idx = np.random.choice(range(len(v_losses)), n_wins - len(v_augmented_losses), replace=False)

v_augmented_draws = np.hstack([v_augmented_draws, v_draws[extra_draw_idx]])
v_augmented_losses = np.hstack([v_augmented_losses, v_losses[extra_losses_idx]])
b_augmented_draws = np.concatenate([b_augmented_draws, b_draws[extra_draw_idx]], axis=0)
b_augmented_losses = np.concatenate([b_augmented_losses, b_losses[extra_losses_idx]], axis=0)

value_train = np.hstack([v_wins, v_augmented_draws, v_augmented_losses])
board_train = np.concatenate([b_wins, b_augmented_draws, b_augmented_losses], axis=0)

# np.random.shuffle(value_train)
# np.random.shuffle(board_train)

print(len(value_train), len(board_train))

train = Connect4Dataset(torch.from_numpy(board_train), torch.from_numpy(value_train))
test = Connect4Dataset(torch.from_numpy(board_test), torch.from_numpy(value_test))

train_gen = data.DataLoader(train, batch_size, shuffle=False)
test_gen = data.DataLoader(test, batch_size, shuffle=False)

106617 106617


In [5]:
from src.connect4.utils import Connect4Stats as info

# Input with N * channels * (6,7)
# Output with N * filters * (6,7)
convolutional_layer = \
    nn.Sequential(nn.Conv2d(in_channels=channels,
                            out_channels=filters,
                            kernel_size=3,
                            stride=1,
                            padding=1,
                            dilation=1,
                            groups=1,
                            bias=False),
                  nn.BatchNorm2d(filters),
                  nn.LeakyReLU())

# Input with N * filters * (6,7)
# Output with N * filters * (6,7)
class ResidualLayer(nn.Module):
    def __init__(self):
        super(ResidualLayer, self).__init__()
        self.conv1 = nn.Conv2d(filters, filters, 3, padding=1, bias=False)
        self.conv2 = nn.Conv2d(filters, filters, 3, padding=1, bias=False)
        self.batch_norm1 = nn.BatchNorm2d(filters)
        self.batch_norm2 = nn.BatchNorm2d(filters)
        self.relu = nn.LeakyReLU()

    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.batch_norm1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.batch_norm2(out)

        out += residual
        out = self.relu(out)
        return out


# Input with N * filters * (6,7)
# Output with N * 1 * 1
class ValueHead(nn.Module):
    def __init__(self):
        super(ValueHead, self).__init__()
        self.conv1 = nn.Conv2d(filters, 1, 1)
        self.batch_norm = nn.BatchNorm2d(1)
        self.relu = nn.LeakyReLU()
        self.fcN = nn.Sequential(*[nn.Linear(info.area, info.area) for _ in range(4)])
        self.fc2 = nn.Linear(info.area, 1)
        self.tanh = torch.nn.Tanh()
        self.w1 = nn.Parameter(torch.tensor(1.0), requires_grad=False)
        self.w2 = nn.Parameter(torch.tensor(0.5), requires_grad=False)

    def forward(self, x):
        x = self.conv1(x)
        x = self.batch_norm(x)
        x = self.relu(x)
        x = x.view(x.shape[0], 1, -1)
        x = self.fcN(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.tanh(x)
#         map from [-1, 1] to [0, 1]
        x = (x + self.w1) * self.w2
        x = x.view(-1, 1)
        return x

def init_normal(m):
    if type(m) == nn.Linear:
        nn.init.normal_(m.weight, std=0.002)
    elif type(m) == nn.Conv2d:
        nn.init.normal_(m.weight, std=0.002)
    elif type(m) == nn.BatchNorm2d:
        nn.init.normal_(m.weight, std=0.002)
#         nn.init.constant_(m.bias, 0)

net = nn.Sequential(convolutional_layer,
                    nn.Sequential(*[ResidualLayer() for _ in range(4)]),
                    ValueHead())

# net.apply(init_normal)

# net = convolutional_layer
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
net.to(device)

Sequential(
  (0): Sequential(
    (0): Conv2d(2, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): LeakyReLU(negative_slope=0.01)
  )
  (1): Sequential(
    (0): ResidualLayer(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (batch_norm1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (batch_norm2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): LeakyReLU(negative_slope=0.01)
    )
    (1): ResidualLayer(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (batch_norm1): BatchNorm2d(64, eps=1e-05, momentum=0

In [6]:
# criterion = nn.MSELoss()
criterion = nn.L1Loss()
optimizer = optim.SGD(net.parameters(), lr=learning_rate, momentum=momentum)
# optimizer = optim.Adam(net.parameters(), lr=learning_rate)

In [7]:
def categorise_predictions(preds):
    preds = preds * 3.0
    torch.floor_(preds)
    preds = preds / 2.0
    return preds

class Stats():
    def __init__(self):
        self.n = 0
        self.average_value = 0.0
        self.loss = 0.0
        self.smallest = 1.0
        self.largest = 0.0
        self.correct = {i : 0 for i in [0.0, 0.5, 1.0]}
        self.total = {i : 0 for i in [0.0, 0.5, 1.0]}
        
    def __repr__(self):
        x = "Average loss:  " + "{:.5f}".format(self.loss / self.n) + \
            ",  Smallest:  " + "{:.5f}".format(self.smallest) + \
            ",  Largest:  " + "{:.5f}".format(self.largest) + \
            ",  Average:  " + "{:.5f}".format(self.average_value / self.n)
        
        for k in self.correct:
            x += "\nCategory, # Predictions, # Correct:  {}, {}, {}".format(
                k,
                self.total[k],
                self.correct[k])
        return x

def update_stats(stats: Stats, outputs, values, loss):
    stats.n += len(values)
    stats.average_value += outputs.sum().item()
    stats.loss += loss.item() * len(values)
    stats.smallest = min(stats.smallest, torch.min(outputs).item())
    stats.largest = max(stats.largest, torch.max(outputs).item())
            
    categories = categorise_predictions(outputs)
    values = values.view(-1)
    categories = categories.view(-1)

    for k in stats.correct:
        idx = (values == k).nonzero()
        stats.total[k] += len(idx)
        stats.correct[k] += len(torch.eq(categories[idx], values[idx]).nonzero()) 

    return stats


In [8]:
%%time

for epoch in range(n_epocs):
    
    net = net.train()
    train_stats = Stats()
    test_stats = Stats()
    
    for board, value in train_gen:
        board, value = board.to(device), value.to(device)
        
        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        output = net(board)
        loss = criterion(output, value)
        loss.backward()
        optimizer.step()
        
        train_stats = update_stats(train_stats, output, value, loss)

    # validate
    with torch.set_grad_enabled(False):
        net = net.eval()
        for board, value in test_gen:
            board, value = board.to(device), value.to(device)

            output = net(board)
            loss = criterion(output, value)
            
            test_stats = update_stats(test_stats, output, value, loss)
            
    print("Epoch:  ", epoch, "  Train:\n", train_stats, "\nTest:\n", test_stats)
            

print('Finished Training')

Train:
 Average loss:  0.34150,  Smallest:  0.42234,  Largest:  0.55620,  Average:  0.47951
Category, # Predictions, # Correct:  0.0, 35539, 0
Category, # Predictions, # Correct:  0.5, 35539, 35539
Category, # Predictions, # Correct:  1.0, 35539, 0 
Test:
 Average loss:  0.46389,  Smallest:  0.42018,  Largest:  0.53730,  Average:  0.47834
Category, # Predictions, # Correct:  0.0, 3300, 0
Category, # Predictions, # Correct:  0.5, 1278, 1278
Category, # Predictions, # Correct:  1.0, 8934, 0
Train:
 Average loss:  0.34031,  Smallest:  0.42603,  Largest:  0.56021,  Average:  0.48333
Category, # Predictions, # Correct:  0.0, 35539, 0
Category, # Predictions, # Correct:  0.5, 35539, 35539
Category, # Predictions, # Correct:  1.0, 35539, 0 
Test:
 Average loss:  0.46215,  Smallest:  0.42355,  Largest:  0.54104,  Average:  0.48186
Category, # Predictions, # Correct:  0.0, 3300, 0
Category, # Predictions, # Correct:  0.5, 1278, 1278
Category, # Predictions, # Correct:  1.0, 8934, 0
Train:
 Aver

Train:
 Average loss:  0.33581,  Smallest:  0.44390,  Largest:  0.56908,  Average:  0.49769
Category, # Predictions, # Correct:  0.0, 35539, 0
Category, # Predictions, # Correct:  0.5, 35539, 35539
Category, # Predictions, # Correct:  1.0, 35539, 0 
Test:
 Average loss:  0.45586,  Smallest:  0.44177,  Largest:  0.54882,  Average:  0.49513
Category, # Predictions, # Correct:  0.0, 3300, 0
Category, # Predictions, # Correct:  0.5, 1278, 1278
Category, # Predictions, # Correct:  1.0, 8934, 0
Train:
 Average loss:  0.33573,  Smallest:  0.44414,  Largest:  0.56880,  Average:  0.49772
Category, # Predictions, # Correct:  0.0, 35539, 0
Category, # Predictions, # Correct:  0.5, 35539, 35539
Category, # Predictions, # Correct:  1.0, 35539, 0 
Test:
 Average loss:  0.45583,  Smallest:  0.44200,  Largest:  0.54872,  Average:  0.49519
Category, # Predictions, # Correct:  0.0, 3300, 0
Category, # Predictions, # Correct:  0.5, 1278, 1278
Category, # Predictions, # Correct:  1.0, 8934, 0
Train:
 Aver

Train:
 Average loss:  0.33458,  Smallest:  0.44719,  Largest:  0.56480,  Average:  0.49800
Category, # Predictions, # Correct:  0.0, 35539, 0
Category, # Predictions, # Correct:  0.5, 35539, 35539
Category, # Predictions, # Correct:  1.0, 35539, 0 
Test:
 Average loss:  0.45544,  Smallest:  0.44480,  Largest:  0.54772,  Average:  0.49593
Category, # Predictions, # Correct:  0.0, 3300, 0
Category, # Predictions, # Correct:  0.5, 1278, 1278
Category, # Predictions, # Correct:  1.0, 8934, 0
Train:
 Average loss:  0.33452,  Smallest:  0.44737,  Largest:  0.56458,  Average:  0.49801
Category, # Predictions, # Correct:  0.0, 35539, 0
Category, # Predictions, # Correct:  0.5, 35539, 35539
Category, # Predictions, # Correct:  1.0, 35539, 0 
Test:
 Average loss:  0.45543,  Smallest:  0.44493,  Largest:  0.54771,  Average:  0.49596
Category, # Predictions, # Correct:  0.0, 3300, 0
Category, # Predictions, # Correct:  0.5, 1278, 1278
Category, # Predictions, # Correct:  1.0, 8934, 0
Train:
 Aver

Train:
 Average loss:  0.33373,  Smallest:  0.44936,  Largest:  0.56161,  Average:  0.49810
Category, # Predictions, # Correct:  0.0, 35539, 0
Category, # Predictions, # Correct:  0.5, 35539, 35539
Category, # Predictions, # Correct:  1.0, 35539, 0 
Test:
 Average loss:  0.45522,  Smallest:  0.44665,  Largest:  0.54833,  Average:  0.49642
Category, # Predictions, # Correct:  0.0, 3300, 0
Category, # Predictions, # Correct:  0.5, 1278, 1278
Category, # Predictions, # Correct:  1.0, 8934, 0
Train:
 Average loss:  0.33368,  Smallest:  0.44944,  Largest:  0.56144,  Average:  0.49810
Category, # Predictions, # Correct:  0.0, 35539, 0
Category, # Predictions, # Correct:  0.5, 35539, 35539
Category, # Predictions, # Correct:  1.0, 35539, 0 
Test:
 Average loss:  0.45522,  Smallest:  0.44672,  Largest:  0.54839,  Average:  0.49644
Category, # Predictions, # Correct:  0.0, 3300, 0
Category, # Predictions, # Correct:  0.5, 1278, 1278
Category, # Predictions, # Correct:  1.0, 8934, 0
Train:
 Aver

Train:
 Average loss:  0.33304,  Smallest:  0.44761,  Largest:  0.55954,  Average:  0.49808
Category, # Predictions, # Correct:  0.0, 35539, 0
Category, # Predictions, # Correct:  0.5, 35539, 35539
Category, # Predictions, # Correct:  1.0, 35539, 0 
Test:
 Average loss:  0.45510,  Smallest:  0.44742,  Largest:  0.54994,  Average:  0.49677
Category, # Predictions, # Correct:  0.0, 3300, 0
Category, # Predictions, # Correct:  0.5, 1278, 1278
Category, # Predictions, # Correct:  1.0, 8934, 0
Train:
 Average loss:  0.33300,  Smallest:  0.44745,  Largest:  0.55947,  Average:  0.49809
Category, # Predictions, # Correct:  0.0, 35539, 0
Category, # Predictions, # Correct:  0.5, 35539, 35539
Category, # Predictions, # Correct:  1.0, 35539, 0 
Test:
 Average loss:  0.45510,  Smallest:  0.44745,  Largest:  0.55007,  Average:  0.49679
Category, # Predictions, # Correct:  0.0, 3300, 0
Category, # Predictions, # Correct:  0.5, 1278, 1278
Category, # Predictions, # Correct:  1.0, 8934, 0
Train:
 Aver

Train:
 Average loss:  0.33240,  Smallest:  0.44358,  Largest:  0.55847,  Average:  0.49808
Category, # Predictions, # Correct:  0.0, 35539, 0
Category, # Predictions, # Correct:  0.5, 35539, 35539
Category, # Predictions, # Correct:  1.0, 35539, 0 
Test:
 Average loss:  0.45503,  Smallest:  0.44563,  Largest:  0.55263,  Average:  0.49710
Category, # Predictions, # Correct:  0.0, 3300, 0
Category, # Predictions, # Correct:  0.5, 1278, 1278
Category, # Predictions, # Correct:  1.0, 8934, 0
Train:
 Average loss:  0.33236,  Smallest:  0.44323,  Largest:  0.55842,  Average:  0.49807
Category, # Predictions, # Correct:  0.0, 35539, 0
Category, # Predictions, # Correct:  0.5, 35539, 35539
Category, # Predictions, # Correct:  1.0, 35539, 0 
Test:
 Average loss:  0.45503,  Smallest:  0.44541,  Largest:  0.55279,  Average:  0.49711
Category, # Predictions, # Correct:  0.0, 3300, 0
Category, # Predictions, # Correct:  0.5, 1278, 1278
Category, # Predictions, # Correct:  1.0, 8934, 0
Train:
 Aver

Train:
 Average loss:  0.33174,  Smallest:  0.43724,  Largest:  0.56353,  Average:  0.49798
Category, # Predictions, # Correct:  0.0, 35539, 0
Category, # Predictions, # Correct:  0.5, 35539, 35539
Category, # Predictions, # Correct:  1.0, 35539, 0 
Test:
 Average loss:  0.45501,  Smallest:  0.44139,  Largest:  0.55594,  Average:  0.49736
Category, # Predictions, # Correct:  0.0, 3300, 0
Category, # Predictions, # Correct:  0.5, 1278, 1278
Category, # Predictions, # Correct:  1.0, 8934, 0
Train:
 Average loss:  0.33169,  Smallest:  0.43682,  Largest:  0.56404,  Average:  0.49797
Category, # Predictions, # Correct:  0.0, 35539, 0
Category, # Predictions, # Correct:  0.5, 35539, 35539
Category, # Predictions, # Correct:  1.0, 35539, 0 
Test:
 Average loss:  0.45501,  Smallest:  0.44112,  Largest:  0.55615,  Average:  0.49737
Category, # Predictions, # Correct:  0.0, 3300, 0
Category, # Predictions, # Correct:  0.5, 1278, 1278
Category, # Predictions, # Correct:  1.0, 8934, 0
Train:
 Aver

Train:
 Average loss:  0.33101,  Smallest:  0.42965,  Largest:  0.57247,  Average:  0.49790
Category, # Predictions, # Correct:  0.0, 35539, 0
Category, # Predictions, # Correct:  0.5, 35539, 35539
Category, # Predictions, # Correct:  1.0, 35539, 0 
Test:
 Average loss:  0.45501,  Smallest:  0.43573,  Largest:  0.56005,  Average:  0.49765
Category, # Predictions, # Correct:  0.0, 3300, 0
Category, # Predictions, # Correct:  0.5, 1278, 1278
Category, # Predictions, # Correct:  1.0, 8934, 0
Train:
 Average loss:  0.33097,  Smallest:  0.42917,  Largest:  0.57307,  Average:  0.49789
Category, # Predictions, # Correct:  0.0, 35539, 0
Category, # Predictions, # Correct:  0.5, 35539, 35539
Category, # Predictions, # Correct:  1.0, 35539, 0 
Test:
 Average loss:  0.45501,  Smallest:  0.43537,  Largest:  0.56046,  Average:  0.49768
Category, # Predictions, # Correct:  0.0, 3300, 0
Category, # Predictions, # Correct:  0.5, 1278, 1278
Category, # Predictions, # Correct:  1.0, 8934, 0
Train:
 Aver

Train:
 Average loss:  0.33019,  Smallest:  0.42086,  Largest:  0.58336,  Average:  0.49786
Category, # Predictions, # Correct:  0.0, 35539, 0
Category, # Predictions, # Correct:  0.5, 35539, 35539
Category, # Predictions, # Correct:  1.0, 35539, 0 
Test:
 Average loss:  0.45500,  Smallest:  0.42941,  Largest:  0.56732,  Average:  0.49804
Category, # Predictions, # Correct:  0.0, 3300, 0
Category, # Predictions, # Correct:  0.5, 1278, 1278
Category, # Predictions, # Correct:  1.0, 8934, 0
Train:
 Average loss:  0.33014,  Smallest:  0.42030,  Largest:  0.58405,  Average:  0.49786
Category, # Predictions, # Correct:  0.0, 35539, 0
Category, # Predictions, # Correct:  0.5, 35539, 35539
Category, # Predictions, # Correct:  1.0, 35539, 0 
Test:
 Average loss:  0.45500,  Smallest:  0.42902,  Largest:  0.56778,  Average:  0.49807
Category, # Predictions, # Correct:  0.0, 3300, 0
Category, # Predictions, # Correct:  0.5, 1278, 1278
Category, # Predictions, # Correct:  1.0, 8934, 0
Train:
 Aver

Train:
 Average loss:  0.32925,  Smallest:  0.41079,  Largest:  0.59625,  Average:  0.49786
Category, # Predictions, # Correct:  0.0, 35539, 0
Category, # Predictions, # Correct:  0.5, 35539, 35539
Category, # Predictions, # Correct:  1.0, 35539, 0 
Test:
 Average loss:  0.45500,  Smallest:  0.42213,  Largest:  0.57544,  Average:  0.49851
Category, # Predictions, # Correct:  0.0, 3300, 0
Category, # Predictions, # Correct:  0.5, 1278, 1278
Category, # Predictions, # Correct:  1.0, 8934, 0
Train:
 Average loss:  0.32918,  Smallest:  0.41016,  Largest:  0.59710,  Average:  0.49786
Category, # Predictions, # Correct:  0.0, 35539, 0
Category, # Predictions, # Correct:  0.5, 35539, 35539
Category, # Predictions, # Correct:  1.0, 35539, 0 
Test:
 Average loss:  0.45500,  Smallest:  0.42165,  Largest:  0.57595,  Average:  0.49854
Category, # Predictions, # Correct:  0.0, 3300, 0
Category, # Predictions, # Correct:  0.5, 1278, 1278
Category, # Predictions, # Correct:  1.0, 8934, 0
Train:
 Aver

Train:
 Average loss:  0.32810,  Smallest:  0.39931,  Largest:  0.61131,  Average:  0.49795
Category, # Predictions, # Correct:  0.0, 35539, 0
Category, # Predictions, # Correct:  0.5, 35539, 35539
Category, # Predictions, # Correct:  1.0, 35539, 0 
Test:
 Average loss:  0.45498,  Smallest:  0.41314,  Largest:  0.58536,  Average:  0.49912
Category, # Predictions, # Correct:  0.0, 3300, 0
Category, # Predictions, # Correct:  0.5, 1278, 1278
Category, # Predictions, # Correct:  1.0, 8934, 0
Train:
 Average loss:  0.32802,  Smallest:  0.39856,  Largest:  0.61222,  Average:  0.49795
Category, # Predictions, # Correct:  0.0, 35539, 0
Category, # Predictions, # Correct:  0.5, 35539, 35539
Category, # Predictions, # Correct:  1.0, 35539, 0 
Test:
 Average loss:  0.45499,  Smallest:  0.41256,  Largest:  0.58597,  Average:  0.49915
Category, # Predictions, # Correct:  0.0, 3300, 0
Category, # Predictions, # Correct:  0.5, 1278, 1278
Category, # Predictions, # Correct:  1.0, 8934, 0
Train:
 Aver

Train:
 Average loss:  0.32670,  Smallest:  0.38643,  Largest:  0.62678,  Average:  0.49786
Category, # Predictions, # Correct:  0.0, 35539, 0
Category, # Predictions, # Correct:  0.5, 35539, 35539
Category, # Predictions, # Correct:  1.0, 35539, 0 
Test:
 Average loss:  0.45508,  Smallest:  0.40229,  Largest:  0.59812,  Average:  0.49961
Category, # Predictions, # Correct:  0.0, 3300, 0
Category, # Predictions, # Correct:  0.5, 1278, 1278
Category, # Predictions, # Correct:  1.0, 8934, 0
Train:
 Average loss:  0.32661,  Smallest:  0.38565,  Largest:  0.62782,  Average:  0.49786
Category, # Predictions, # Correct:  0.0, 35539, 0
Category, # Predictions, # Correct:  0.5, 35539, 35539
Category, # Predictions, # Correct:  1.0, 35539, 0 
Test:
 Average loss:  0.45508,  Smallest:  0.40160,  Largest:  0.59893,  Average:  0.49964
Category, # Predictions, # Correct:  0.0, 3300, 0
Category, # Predictions, # Correct:  0.5, 1278, 1278
Category, # Predictions, # Correct:  1.0, 8934, 0
Train:
 Aver

Train:
 Average loss:  0.32499,  Smallest:  0.37259,  Largest:  0.64294,  Average:  0.49789
Category, # Predictions, # Correct:  0.0, 35539, 0
Category, # Predictions, # Correct:  0.5, 35539, 35539
Category, # Predictions, # Correct:  1.0, 35539, 0 
Test:
 Average loss:  0.45517,  Smallest:  0.38835,  Largest:  0.61180,  Average:  0.50023
Category, # Predictions, # Correct:  0.0, 3300, 0
Category, # Predictions, # Correct:  0.5, 1278, 1278
Category, # Predictions, # Correct:  1.0, 8934, 0
Train:
 Average loss:  0.32488,  Smallest:  0.37172,  Largest:  0.64395,  Average:  0.49789
Category, # Predictions, # Correct:  0.0, 35539, 0
Category, # Predictions, # Correct:  0.5, 35539, 35539
Category, # Predictions, # Correct:  1.0, 35539, 0 
Test:
 Average loss:  0.45518,  Smallest:  0.38752,  Largest:  0.61259,  Average:  0.50027
Category, # Predictions, # Correct:  0.0, 3300, 0
Category, # Predictions, # Correct:  0.5, 1278, 1278
Category, # Predictions, # Correct:  1.0, 8934, 0
Train:
 Aver

Train:
 Average loss:  0.32288,  Smallest:  0.35673,  Largest:  0.66085,  Average:  0.49765
Category, # Predictions, # Correct:  0.0, 35539, 0
Category, # Predictions, # Correct:  0.5, 35539, 35539
Category, # Predictions, # Correct:  1.0, 35539, 0 
Test:
 Average loss:  0.45547,  Smallest:  0.37285,  Largest:  0.62935,  Average:  0.50056
Category, # Predictions, # Correct:  0.0, 3300, 0
Category, # Predictions, # Correct:  0.5, 1278, 1278
Category, # Predictions, # Correct:  1.0, 8934, 0
Train:
 Average loss:  0.32274,  Smallest:  0.35567,  Largest:  0.66196,  Average:  0.49763
Category, # Predictions, # Correct:  0.0, 35539, 0
Category, # Predictions, # Correct:  0.5, 35539, 35539
Category, # Predictions, # Correct:  1.0, 35539, 0 
Test:
 Average loss:  0.45549,  Smallest:  0.37185,  Largest:  0.63043,  Average:  0.50057
Category, # Predictions, # Correct:  0.0, 3300, 0
Category, # Predictions, # Correct:  0.5, 1278, 1278
Category, # Predictions, # Correct:  1.0, 8934, 0
Train:
 Aver

Train:
 Average loss:  0.32029,  Smallest:  0.33749,  Largest:  0.68049,  Average:  0.49742
Category, # Predictions, # Correct:  0.0, 35539, 0
Category, # Predictions, # Correct:  0.5, 35539, 35539
Category, # Predictions, # Correct:  1.0, 35539, 2 
Test:
 Average loss:  0.45589,  Smallest:  0.35422,  Largest:  0.64903,  Average:  0.50079
Category, # Predictions, # Correct:  0.0, 3300, 0
Category, # Predictions, # Correct:  0.5, 1278, 1278
Category, # Predictions, # Correct:  1.0, 8934, 0
Train:
 Average loss:  0.32011,  Smallest:  0.33625,  Largest:  0.68172,  Average:  0.49739
Category, # Predictions, # Correct:  0.0, 35539, 0
Category, # Predictions, # Correct:  0.5, 35539, 35539
Category, # Predictions, # Correct:  1.0, 35539, 2 
Test:
 Average loss:  0.45593,  Smallest:  0.35301,  Largest:  0.65022,  Average:  0.50077
Category, # Predictions, # Correct:  0.0, 3300, 0
Category, # Predictions, # Correct:  0.5, 1278, 1278
Category, # Predictions, # Correct:  1.0, 8934, 0
Train:
 Aver

Train:
 Average loss:  0.31701,  Smallest:  0.31414,  Largest:  0.70234,  Average:  0.49687
Category, # Predictions, # Correct:  0.0, 35539, 9
Category, # Predictions, # Correct:  0.5, 35539, 35539
Category, # Predictions, # Correct:  1.0, 35539, 57 
Test:
 Average loss:  0.45670,  Smallest:  0.33165,  Largest:  0.67065,  Average:  0.50038
Category, # Predictions, # Correct:  0.0, 3300, 2
Category, # Predictions, # Correct:  0.5, 1278, 1278
Category, # Predictions, # Correct:  1.0, 8934, 2
Train:
 Average loss:  0.31680,  Smallest:  0.31263,  Largest:  0.70372,  Average:  0.49685
Category, # Predictions, # Correct:  0.0, 35539, 11
Category, # Predictions, # Correct:  0.5, 35539, 35539
Category, # Predictions, # Correct:  1.0, 35539, 69 
Test:
 Average loss:  0.45675,  Smallest:  0.33017,  Largest:  0.67203,  Average:  0.50034
Category, # Predictions, # Correct:  0.0, 3300, 3
Category, # Predictions, # Correct:  0.5, 1278, 1278
Category, # Predictions, # Correct:  1.0, 8934, 2
Train:
 A

Train:
 Average loss:  0.31282,  Smallest:  0.28901,  Largest:  0.72730,  Average:  0.49622
Category, # Predictions, # Correct:  0.0, 35539, 194
Category, # Predictions, # Correct:  0.5, 35539, 35532
Category, # Predictions, # Correct:  1.0, 35539, 410 
Test:
 Average loss:  0.45797,  Smallest:  0.30507,  Largest:  0.69543,  Average:  0.49919
Category, # Predictions, # Correct:  0.0, 3300, 30
Category, # Predictions, # Correct:  0.5, 1278, 1278
Category, # Predictions, # Correct:  1.0, 8934, 41
Train:
 Average loss:  0.31256,  Smallest:  0.28752,  Largest:  0.72884,  Average:  0.49619
Category, # Predictions, # Correct:  0.0, 35539, 219
Category, # Predictions, # Correct:  0.5, 35539, 35532
Category, # Predictions, # Correct:  1.0, 35539, 451 
Test:
 Average loss:  0.45806,  Smallest:  0.30339,  Largest:  0.69696,  Average:  0.49911
Category, # Predictions, # Correct:  0.0, 3300, 33
Category, # Predictions, # Correct:  0.5, 1278, 1278
Category, # Predictions, # Correct:  1.0, 8934, 44


Train:
 Average loss:  0.30737,  Smallest:  0.26065,  Largest:  0.75480,  Average:  0.49512
Category, # Predictions, # Correct:  0.0, 35539, 1112
Category, # Predictions, # Correct:  0.5, 35539, 35493
Category, # Predictions, # Correct:  1.0, 35539, 1476 
Test:
 Average loss:  0.46022,  Smallest:  0.27399,  Largest:  0.72301,  Average:  0.49613
Category, # Predictions, # Correct:  0.0, 3300, 125
Category, # Predictions, # Correct:  0.5, 1278, 1278
Category, # Predictions, # Correct:  1.0, 8934, 210
Train:
 Average loss:  0.30700,  Smallest:  0.25875,  Largest:  0.75644,  Average:  0.49505
Category, # Predictions, # Correct:  0.0, 35539, 1189
Category, # Predictions, # Correct:  0.5, 35539, 35492
Category, # Predictions, # Correct:  1.0, 35539, 1576 
Test:
 Average loss:  0.46039,  Smallest:  0.27201,  Largest:  0.72478,  Average:  0.49588
Category, # Predictions, # Correct:  0.0, 3300, 143
Category, # Predictions, # Correct:  0.5, 1278, 1278
Category, # Predictions, # Correct:  1.0, 89

Train:
 Average loss:  0.30008,  Smallest:  0.22568,  Largest:  0.78336,  Average:  0.49341
Category, # Predictions, # Correct:  0.0, 35539, 3315
Category, # Predictions, # Correct:  0.5, 35539, 35368
Category, # Predictions, # Correct:  1.0, 35539, 3537 
Test:
 Average loss:  0.46410,  Smallest:  0.23615,  Largest:  0.75412,  Average:  0.48982
Category, # Predictions, # Correct:  0.0, 3300, 432
Category, # Predictions, # Correct:  0.5, 1278, 1271
Category, # Predictions, # Correct:  1.0, 8934, 473
Train:
 Average loss:  0.29958,  Smallest:  0.22346,  Largest:  0.78553,  Average:  0.49328
Category, # Predictions, # Correct:  0.0, 35539, 3506
Category, # Predictions, # Correct:  0.5, 35539, 35352
Category, # Predictions, # Correct:  1.0, 35539, 3703 
Test:
 Average loss:  0.46440,  Smallest:  0.23378,  Largest:  0.75602,  Average:  0.48929
Category, # Predictions, # Correct:  0.0, 3300, 458
Category, # Predictions, # Correct:  0.5, 1278, 1270
Category, # Predictions, # Correct:  1.0, 89

Train:
 Average loss:  0.28996,  Smallest:  0.18520,  Largest:  0.82378,  Average:  0.49040
Category, # Predictions, # Correct:  0.0, 35539, 7838
Category, # Predictions, # Correct:  0.5, 35539, 34900
Category, # Predictions, # Correct:  1.0, 35539, 6456 
Test:
 Average loss:  0.47117,  Smallest:  0.19453,  Largest:  0.78567,  Average:  0.47678
Category, # Predictions, # Correct:  0.0, 3300, 968
Category, # Predictions, # Correct:  0.5, 1278, 1232
Category, # Predictions, # Correct:  1.0, 8934, 778
Train:
 Average loss:  0.28924,  Smallest:  0.18266,  Largest:  0.82647,  Average:  0.49014
Category, # Predictions, # Correct:  0.0, 35539, 8219
Category, # Predictions, # Correct:  0.5, 35539, 34878
Category, # Predictions, # Correct:  1.0, 35539, 6637 
Test:
 Average loss:  0.47174,  Smallest:  0.19195,  Largest:  0.78751,  Average:  0.47568
Category, # Predictions, # Correct:  0.0, 3300, 1016
Category, # Predictions, # Correct:  0.5, 1278, 1223
Category, # Predictions, # Correct:  1.0, 8

Train:
 Average loss:  0.27540,  Smallest:  0.14103,  Largest:  0.86840,  Average:  0.48577
Category, # Predictions, # Correct:  0.0, 35539, 15247
Category, # Predictions, # Correct:  0.5, 35539, 34018
Category, # Predictions, # Correct:  1.0, 35539, 9973 
Test:
 Average loss:  0.48378,  Smallest:  0.14758,  Largest:  0.81322,  Average:  0.45186
Category, # Predictions, # Correct:  0.0, 3300, 1903
Category, # Predictions, # Correct:  0.5, 1278, 1067
Category, # Predictions, # Correct:  1.0, 8934, 979
Train:
 Average loss:  0.27437,  Smallest:  0.13845,  Largest:  0.87103,  Average:  0.48544
Category, # Predictions, # Correct:  0.0, 35539, 15745
Category, # Predictions, # Correct:  0.5, 35539, 33940
Category, # Predictions, # Correct:  1.0, 35539, 10193 
Test:
 Average loss:  0.48477,  Smallest:  0.14475,  Largest:  0.81464,  Average:  0.44988
Category, # Predictions, # Correct:  0.0, 3300, 1958
Category, # Predictions, # Correct:  0.5, 1278, 1040
Category, # Predictions, # Correct:  1.

Train:
 Average loss:  0.25388,  Smallest:  0.09966,  Largest:  0.91050,  Average:  0.48079
Category, # Predictions, # Correct:  0.0, 35539, 23720
Category, # Predictions, # Correct:  0.5, 35539, 32298
Category, # Predictions, # Correct:  1.0, 35539, 14234 
Test:
 Average loss:  0.50508,  Smallest:  0.10167,  Largest:  0.83671,  Average:  0.40848
Category, # Predictions, # Correct:  0.0, 3300, 2744
Category, # Predictions, # Correct:  0.5, 1278, 626
Category, # Predictions, # Correct:  1.0, 8934, 969
Train:
 Average loss:  0.25233,  Smallest:  0.09737,  Largest:  0.91279,  Average:  0.48049
Category, # Predictions, # Correct:  0.0, 35539, 24162
Category, # Predictions, # Correct:  0.5, 35539, 32186
Category, # Predictions, # Correct:  1.0, 35539, 14508 
Test:
 Average loss:  0.50669,  Smallest:  0.09919,  Largest:  0.83831,  Average:  0.40516
Category, # Predictions, # Correct:  0.0, 3300, 2775
Category, # Predictions, # Correct:  0.5, 1278, 603
Category, # Predictions, # Correct:  1.0

Train:
 Average loss:  0.22293,  Smallest:  0.06196,  Largest:  0.94480,  Average:  0.47787
Category, # Predictions, # Correct:  0.0, 35539, 30422
Category, # Predictions, # Correct:  0.5, 35539, 29885
Category, # Predictions, # Correct:  1.0, 35539, 19725 
Test:
 Average loss:  0.53716,  Smallest:  0.06318,  Largest:  0.85542,  Average:  0.34248
Category, # Predictions, # Correct:  0.0, 3300, 3146
Category, # Predictions, # Correct:  0.5, 1278, 249
Category, # Predictions, # Correct:  1.0, 8934, 709
Train:
 Average loss:  0.22078,  Smallest:  0.05986,  Largest:  0.94647,  Average:  0.47785
Category, # Predictions, # Correct:  0.0, 35539, 30698
Category, # Predictions, # Correct:  0.5, 35539, 29744
Category, # Predictions, # Correct:  1.0, 35539, 20108 
Test:
 Average loss:  0.53941,  Smallest:  0.06134,  Largest:  0.85602,  Average:  0.33785
Category, # Predictions, # Correct:  0.0, 3300, 3162
Category, # Predictions, # Correct:  0.5, 1278, 232
Category, # Predictions, # Correct:  1.0

Train:
 Average loss:  0.18349,  Smallest:  0.03383,  Largest:  0.96764,  Average:  0.48065
Category, # Predictions, # Correct:  0.0, 35539, 33729
Category, # Predictions, # Correct:  0.5, 35539, 27105
Category, # Predictions, # Correct:  1.0, 35539, 26281 
Test:
 Average loss:  0.57814,  Smallest:  0.03431,  Largest:  0.85246,  Average:  0.25852
Category, # Predictions, # Correct:  0.0, 3300, 3275
Category, # Predictions, # Correct:  0.5, 1278, 66
Category, # Predictions, # Correct:  1.0, 8934, 327
Train:
 Average loss:  0.18108,  Smallest:  0.03256,  Largest:  0.96867,  Average:  0.48101
Category, # Predictions, # Correct:  0.0, 35539, 33865
Category, # Predictions, # Correct:  0.5, 35539, 26962
Category, # Predictions, # Correct:  1.0, 35539, 26655 
Test:
 Average loss:  0.58062,  Smallest:  0.03295,  Largest:  0.85120,  Average:  0.25346
Category, # Predictions, # Correct:  0.0, 3300, 3278
Category, # Predictions, # Correct:  0.5, 1278, 59
Category, # Predictions, # Correct:  1.0, 

Train:
 Average loss:  0.14390,  Smallest:  0.01582,  Largest:  0.98301,  Average:  0.48856
Category, # Predictions, # Correct:  0.0, 35539, 34940
Category, # Predictions, # Correct:  0.5, 35539, 25200
Category, # Predictions, # Correct:  1.0, 35539, 31475 
Test:
 Average loss:  0.61687,  Smallest:  0.01669,  Largest:  0.84294,  Average:  0.18036
Category, # Predictions, # Correct:  0.0, 3300, 3291
Category, # Predictions, # Correct:  0.5, 1278, 21
Category, # Predictions, # Correct:  1.0, 8934, 114
Train:
 Average loss:  0.14175,  Smallest:  0.01510,  Largest:  0.98369,  Average:  0.48922
Category, # Predictions, # Correct:  0.0, 35539, 34957
Category, # Predictions, # Correct:  0.5, 35539, 25159
Category, # Predictions, # Correct:  1.0, 35539, 31689 
Test:
 Average loss:  0.61879,  Smallest:  0.01603,  Largest:  0.84516,  Average:  0.17652
Category, # Predictions, # Correct:  0.0, 3300, 3293
Category, # Predictions, # Correct:  0.5, 1278, 19
Category, # Predictions, # Correct:  1.0, 

Train:
 Average loss:  0.11241,  Smallest:  0.00783,  Largest:  0.99175,  Average:  0.49600
Category, # Predictions, # Correct:  0.0, 35539, 35314
Category, # Predictions, # Correct:  0.5, 35539, 24511
Category, # Predictions, # Correct:  1.0, 35539, 33966 
Test:
 Average loss:  0.64471,  Smallest:  0.00881,  Largest:  0.87168,  Average:  0.12505
Category, # Predictions, # Correct:  0.0, 3300, 3299
Category, # Predictions, # Correct:  0.5, 1278, 6
Category, # Predictions, # Correct:  1.0, 8934, 43
Train:
 Average loss:  0.11087,  Smallest:  0.00757,  Largest:  0.99205,  Average:  0.49613
Category, # Predictions, # Correct:  0.0, 35539, 35329
Category, # Predictions, # Correct:  0.5, 35539, 24441
Category, # Predictions, # Correct:  1.0, 35539, 34066 
Test:
 Average loss:  0.64600,  Smallest:  0.00855,  Largest:  0.87274,  Average:  0.12251
Category, # Predictions, # Correct:  0.0, 3300, 3299
Category, # Predictions, # Correct:  0.5, 1278, 6
Category, # Predictions, # Correct:  1.0, 893

Train:
 Average loss:  0.09097,  Smallest:  0.00452,  Largest:  0.99534,  Average:  0.50007
Category, # Predictions, # Correct:  0.0, 35539, 35440
Category, # Predictions, # Correct:  0.5, 35539, 24517
Category, # Predictions, # Correct:  1.0, 35539, 34899 
Test:
 Average loss:  0.66223,  Smallest:  0.00541,  Largest:  0.88223,  Average:  0.09058
Category, # Predictions, # Correct:  0.0, 3300, 3300
Category, # Predictions, # Correct:  0.5, 1278, 3
Category, # Predictions, # Correct:  1.0, 8934, 20
Train:
 Average loss:  0.09026,  Smallest:  0.00439,  Largest:  0.99548,  Average:  0.50065
Category, # Predictions, # Correct:  0.0, 35539, 35442
Category, # Predictions, # Correct:  0.5, 35539, 24564
Category, # Predictions, # Correct:  1.0, 35539, 34931 
Test:
 Average loss:  0.66299,  Smallest:  0.00525,  Largest:  0.88194,  Average:  0.08908
Category, # Predictions, # Correct:  0.0, 3300, 3300
Category, # Predictions, # Correct:  0.5, 1278, 2
Category, # Predictions, # Correct:  1.0, 893

Train:
 Average loss:  0.07761,  Smallest:  0.00287,  Largest:  0.99691,  Average:  0.50240
Category, # Predictions, # Correct:  0.0, 35539, 35467
Category, # Predictions, # Correct:  0.5, 35539, 25000
Category, # Predictions, # Correct:  1.0, 35539, 35275 
Test:
 Average loss:  0.67262,  Smallest:  0.00391,  Largest:  0.87962,  Average:  0.07019
Category, # Predictions, # Correct:  0.0, 3300, 3300
Category, # Predictions, # Correct:  0.5, 1278, 1
Category, # Predictions, # Correct:  1.0, 8934, 14
Train:
 Average loss:  0.07656,  Smallest:  0.00279,  Largest:  0.99695,  Average:  0.50186
Category, # Predictions, # Correct:  0.0, 35539, 35473
Category, # Predictions, # Correct:  0.5, 35539, 25191
Category, # Predictions, # Correct:  1.0, 35539, 35293 
Test:
 Average loss:  0.67296,  Smallest:  0.00380,  Largest:  0.87943,  Average:  0.06953
Category, # Predictions, # Correct:  0.0, 3300, 3300
Category, # Predictions, # Correct:  0.5, 1278, 1
Category, # Predictions, # Correct:  1.0, 893

Train:
 Average loss:  0.06805,  Smallest:  0.00208,  Largest:  0.99768,  Average:  0.50254
Category, # Predictions, # Correct:  0.0, 35539, 35487
Category, # Predictions, # Correct:  0.5, 35539, 25938
Category, # Predictions, # Correct:  1.0, 35539, 35441 
Test:
 Average loss:  0.67925,  Smallest:  0.00316,  Largest:  0.87256,  Average:  0.05719
Category, # Predictions, # Correct:  0.0, 3300, 3300
Category, # Predictions, # Correct:  0.5, 1278, 1
Category, # Predictions, # Correct:  1.0, 8934, 10
Train:
 Average loss:  0.06845,  Smallest:  0.00206,  Largest:  0.99770,  Average:  0.50319
Category, # Predictions, # Correct:  0.0, 35539, 35483
Category, # Predictions, # Correct:  0.5, 35539, 25739
Category, # Predictions, # Correct:  1.0, 35539, 35446 
Test:
 Average loss:  0.67943,  Smallest:  0.00314,  Largest:  0.87327,  Average:  0.05684
Category, # Predictions, # Correct:  0.0, 3300, 3300
Category, # Predictions, # Correct:  0.5, 1278, 1
Category, # Predictions, # Correct:  1.0, 893

Train:
 Average loss:  0.06246,  Smallest:  0.00170,  Largest:  0.99819,  Average:  0.50323
Category, # Predictions, # Correct:  0.0, 35539, 35491
Category, # Predictions, # Correct:  0.5, 35539, 26416
Category, # Predictions, # Correct:  1.0, 35539, 35499 
Test:
 Average loss:  0.68326,  Smallest:  0.00268,  Largest:  0.86188,  Average:  0.04934
Category, # Predictions, # Correct:  0.0, 3300, 3300
Category, # Predictions, # Correct:  0.5, 1278, 1
Category, # Predictions, # Correct:  1.0, 8934, 5
Train:
 Average loss:  0.06315,  Smallest:  0.00168,  Largest:  0.99819,  Average:  0.50353
Category, # Predictions, # Correct:  0.0, 35539, 35487
Category, # Predictions, # Correct:  0.5, 35539, 26472
Category, # Predictions, # Correct:  1.0, 35539, 35503 
Test:
 Average loss:  0.68356,  Smallest:  0.00264,  Largest:  0.85979,  Average:  0.04873
Category, # Predictions, # Correct:  0.0, 3300, 3300
Category, # Predictions, # Correct:  0.5, 1278, 1
Category, # Predictions, # Correct:  1.0, 8934

Train:
 Average loss:  0.05739,  Smallest:  0.00140,  Largest:  0.99853,  Average:  0.50316
Category, # Predictions, # Correct:  0.0, 35539, 35499
Category, # Predictions, # Correct:  0.5, 35539, 27346
Category, # Predictions, # Correct:  1.0, 35539, 35523 
Test:
 Average loss:  0.68571,  Smallest:  0.00213,  Largest:  0.86001,  Average:  0.04453
Category, # Predictions, # Correct:  0.0, 3300, 3300
Category, # Predictions, # Correct:  0.5, 1278, 1
Category, # Predictions, # Correct:  1.0, 8934, 6
Train:
 Average loss:  0.05754,  Smallest:  0.00139,  Largest:  0.99854,  Average:  0.50223
Category, # Predictions, # Correct:  0.0, 35539, 35499
Category, # Predictions, # Correct:  0.5, 35539, 27255
Category, # Predictions, # Correct:  1.0, 35539, 35523 
Test:
 Average loss:  0.68597,  Smallest:  0.00213,  Largest:  0.85528,  Average:  0.04403
Category, # Predictions, # Correct:  0.0, 3300, 3300
Category, # Predictions, # Correct:  0.5, 1278, 1
Category, # Predictions, # Correct:  1.0, 8934

Train:
 Average loss:  0.05306,  Smallest:  0.00127,  Largest:  0.99885,  Average:  0.50155
Category, # Predictions, # Correct:  0.0, 35539, 35505
Category, # Predictions, # Correct:  0.5, 35539, 28112
Category, # Predictions, # Correct:  1.0, 35539, 35533 
Test:
 Average loss:  0.68769,  Smallest:  0.00181,  Largest:  0.84652,  Average:  0.04067
Category, # Predictions, # Correct:  0.0, 3300, 3300
Category, # Predictions, # Correct:  0.5, 1278, 1
Category, # Predictions, # Correct:  1.0, 8934, 5
Train:
 Average loss:  0.05394,  Smallest:  0.00128,  Largest:  0.99886,  Average:  0.50348
Category, # Predictions, # Correct:  0.0, 35539, 35505
Category, # Predictions, # Correct:  0.5, 35539, 27721
Category, # Predictions, # Correct:  1.0, 35539, 35534 
Test:
 Average loss:  0.68794,  Smallest:  0.00183,  Largest:  0.84305,  Average:  0.04017
Category, # Predictions, # Correct:  0.0, 3300, 3300
Category, # Predictions, # Correct:  0.5, 1278, 0
Category, # Predictions, # Correct:  1.0, 8934

Train:
 Average loss:  0.05030,  Smallest:  0.00116,  Largest:  0.99899,  Average:  0.50214
Category, # Predictions, # Correct:  0.0, 35539, 35505
Category, # Predictions, # Correct:  0.5, 35539, 28781
Category, # Predictions, # Correct:  1.0, 35539, 35535 
Test:
 Average loss:  0.68904,  Smallest:  0.00154,  Largest:  0.83299,  Average:  0.03801
Category, # Predictions, # Correct:  0.0, 3300, 3300
Category, # Predictions, # Correct:  0.5, 1278, 0
Category, # Predictions, # Correct:  1.0, 8934, 4
Train:
 Average loss:  0.05032,  Smallest:  0.00116,  Largest:  0.99900,  Average:  0.50290
Category, # Predictions, # Correct:  0.0, 35539, 35503
Category, # Predictions, # Correct:  0.5, 35539, 28465
Category, # Predictions, # Correct:  1.0, 35539, 35535 
Test:
 Average loss:  0.68917,  Smallest:  0.00154,  Largest:  0.82952,  Average:  0.03777
Category, # Predictions, # Correct:  0.0, 3300, 3300
Category, # Predictions, # Correct:  0.5, 1278, 0
Category, # Predictions, # Correct:  1.0, 8934

Train:
 Average loss:  0.04812,  Smallest:  0.00106,  Largest:  0.99908,  Average:  0.50342
Category, # Predictions, # Correct:  0.0, 35539, 35505
Category, # Predictions, # Correct:  0.5, 35539, 28824
Category, # Predictions, # Correct:  1.0, 35539, 35537 
Test:
 Average loss:  0.68996,  Smallest:  0.00142,  Largest:  0.83336,  Average:  0.03622
Category, # Predictions, # Correct:  0.0, 3300, 3300
Category, # Predictions, # Correct:  0.5, 1278, 0
Category, # Predictions, # Correct:  1.0, 8934, 4
Train:
 Average loss:  0.04751,  Smallest:  0.00110,  Largest:  0.99908,  Average:  0.50337
Category, # Predictions, # Correct:  0.0, 35539, 35501
Category, # Predictions, # Correct:  0.5, 35539, 29064
Category, # Predictions, # Correct:  1.0, 35539, 35537 
Test:
 Average loss:  0.69018,  Smallest:  0.00141,  Largest:  0.82903,  Average:  0.03579
Category, # Predictions, # Correct:  0.0, 3300, 3300
Category, # Predictions, # Correct:  0.5, 1278, 0
Category, # Predictions, # Correct:  1.0, 8934

Train:
 Average loss:  0.04631,  Smallest:  0.00105,  Largest:  0.99915,  Average:  0.50257
Category, # Predictions, # Correct:  0.0, 35539, 35501
Category, # Predictions, # Correct:  0.5, 35539, 29407
Category, # Predictions, # Correct:  1.0, 35539, 35538 
Test:
 Average loss:  0.69095,  Smallest:  0.00131,  Largest:  0.83227,  Average:  0.03427
Category, # Predictions, # Correct:  0.0, 3300, 3300
Category, # Predictions, # Correct:  0.5, 1278, 0
Category, # Predictions, # Correct:  1.0, 8934, 3
Train:
 Average loss:  0.04565,  Smallest:  0.00106,  Largest:  0.99914,  Average:  0.50185
Category, # Predictions, # Correct:  0.0, 35539, 35501
Category, # Predictions, # Correct:  0.5, 35539, 29491
Category, # Predictions, # Correct:  1.0, 35539, 35538 
Test:
 Average loss:  0.69117,  Smallest:  0.00132,  Largest:  0.83075,  Average:  0.03384
Category, # Predictions, # Correct:  0.0, 3300, 3300
Category, # Predictions, # Correct:  0.5, 1278, 0
Category, # Predictions, # Correct:  1.0, 8934

Train:
 Average loss:  0.04316,  Smallest:  0.00099,  Largest:  0.99920,  Average:  0.50344
Category, # Predictions, # Correct:  0.0, 35539, 35503
Category, # Predictions, # Correct:  0.5, 35539, 30000
Category, # Predictions, # Correct:  1.0, 35539, 35539 
Test:
 Average loss:  0.69176,  Smallest:  0.00123,  Largest:  0.82933,  Average:  0.03269
Category, # Predictions, # Correct:  0.0, 3300, 3300
Category, # Predictions, # Correct:  0.5, 1278, 0
Category, # Predictions, # Correct:  1.0, 8934, 3
Train:
 Average loss:  0.04436,  Smallest:  0.00100,  Largest:  0.99922,  Average:  0.50360
Category, # Predictions, # Correct:  0.0, 35539, 35503
Category, # Predictions, # Correct:  0.5, 35539, 29737
Category, # Predictions, # Correct:  1.0, 35539, 35539 
Test:
 Average loss:  0.69175,  Smallest:  0.00124,  Largest:  0.83104,  Average:  0.03272
Category, # Predictions, # Correct:  0.0, 3300, 3300
Category, # Predictions, # Correct:  0.5, 1278, 0
Category, # Predictions, # Correct:  1.0, 8934

Train:
 Average loss:  0.04184,  Smallest:  0.00096,  Largest:  0.99923,  Average:  0.50207
Category, # Predictions, # Correct:  0.0, 35539, 35499
Category, # Predictions, # Correct:  0.5, 35539, 30396
Category, # Predictions, # Correct:  1.0, 35539, 35539 
Test:
 Average loss:  0.69229,  Smallest:  0.00118,  Largest:  0.82473,  Average:  0.03166
Category, # Predictions, # Correct:  0.0, 3300, 3300
Category, # Predictions, # Correct:  0.5, 1278, 0
Category, # Predictions, # Correct:  1.0, 8934, 3
Train:
 Average loss:  0.04278,  Smallest:  0.00095,  Largest:  0.99925,  Average:  0.50333
Category, # Predictions, # Correct:  0.0, 35539, 35501
Category, # Predictions, # Correct:  0.5, 35539, 30268
Category, # Predictions, # Correct:  1.0, 35539, 35539 
Test:
 Average loss:  0.69217,  Smallest:  0.00116,  Largest:  0.82636,  Average:  0.03189
Category, # Predictions, # Correct:  0.0, 3300, 3300
Category, # Predictions, # Correct:  0.5, 1278, 0
Category, # Predictions, # Correct:  1.0, 8934

Train:
 Average loss:  0.03989,  Smallest:  0.00086,  Largest:  0.99926,  Average:  0.50151
Category, # Predictions, # Correct:  0.0, 35539, 35501
Category, # Predictions, # Correct:  0.5, 35539, 30902
Category, # Predictions, # Correct:  1.0, 35539, 35539 
Test:
 Average loss:  0.69291,  Smallest:  0.00114,  Largest:  0.81330,  Average:  0.03044
Category, # Predictions, # Correct:  0.0, 3300, 3300
Category, # Predictions, # Correct:  0.5, 1278, 0
Category, # Predictions, # Correct:  1.0, 8934, 3
Train:
 Average loss:  0.03946,  Smallest:  0.00086,  Largest:  0.99928,  Average:  0.50092
Category, # Predictions, # Correct:  0.0, 35539, 35505
Category, # Predictions, # Correct:  0.5, 35539, 30907
Category, # Predictions, # Correct:  1.0, 35539, 35539 
Test:
 Average loss:  0.69299,  Smallest:  0.00112,  Largest:  0.81602,  Average:  0.03030
Category, # Predictions, # Correct:  0.0, 3300, 3300
Category, # Predictions, # Correct:  0.5, 1278, 0
Category, # Predictions, # Correct:  1.0, 8934

Train:
 Average loss:  0.03748,  Smallest:  0.00079,  Largest:  0.99930,  Average:  0.50170
Category, # Predictions, # Correct:  0.0, 35539, 35501
Category, # Predictions, # Correct:  0.5, 35539, 31505
Category, # Predictions, # Correct:  1.0, 35539, 35539 
Test:
 Average loss:  0.69311,  Smallest:  0.00108,  Largest:  0.81741,  Average:  0.03005
Category, # Predictions, # Correct:  0.0, 3300, 3300
Category, # Predictions, # Correct:  0.5, 1278, 0
Category, # Predictions, # Correct:  1.0, 8934, 3
Train:
 Average loss:  0.03927,  Smallest:  0.00080,  Largest:  0.99929,  Average:  0.50309
Category, # Predictions, # Correct:  0.0, 35539, 35503
Category, # Predictions, # Correct:  0.5, 35539, 31012
Category, # Predictions, # Correct:  1.0, 35539, 35539 
Test:
 Average loss:  0.69322,  Smallest:  0.00110,  Largest:  0.81234,  Average:  0.02985
Category, # Predictions, # Correct:  0.0, 3300, 3300
Category, # Predictions, # Correct:  0.5, 1278, 0
Category, # Predictions, # Correct:  1.0, 8934

Train:
 Average loss:  0.03711,  Smallest:  0.00076,  Largest:  0.99935,  Average:  0.50211
Category, # Predictions, # Correct:  0.0, 35539, 35503
Category, # Predictions, # Correct:  0.5, 35539, 31712
Category, # Predictions, # Correct:  1.0, 35539, 35539 
Test:
 Average loss:  0.69390,  Smallest:  0.00104,  Largest:  0.81421,  Average:  0.02851
Category, # Predictions, # Correct:  0.0, 3300, 3300
Category, # Predictions, # Correct:  0.5, 1278, 0
Category, # Predictions, # Correct:  1.0, 8934, 2
Train:
 Average loss:  0.03616,  Smallest:  0.00073,  Largest:  0.99937,  Average:  0.50122
Category, # Predictions, # Correct:  0.0, 35539, 35503
Category, # Predictions, # Correct:  0.5, 35539, 31852
Category, # Predictions, # Correct:  1.0, 35539, 35539 
Test:
 Average loss:  0.69364,  Smallest:  0.00102,  Largest:  0.81471,  Average:  0.02901
Category, # Predictions, # Correct:  0.0, 3300, 3300
Category, # Predictions, # Correct:  0.5, 1278, 0
Category, # Predictions, # Correct:  1.0, 8934

Train:
 Average loss:  0.03707,  Smallest:  0.00070,  Largest:  0.99937,  Average:  0.50161
Category, # Predictions, # Correct:  0.0, 35539, 35503
Category, # Predictions, # Correct:  0.5, 35539, 31877
Category, # Predictions, # Correct:  1.0, 35539, 35539 
Test:
 Average loss:  0.69396,  Smallest:  0.00101,  Largest:  0.80829,  Average:  0.02840
Category, # Predictions, # Correct:  0.0, 3300, 3300
Category, # Predictions, # Correct:  0.5, 1278, 0
Category, # Predictions, # Correct:  1.0, 8934, 2
Train:
 Average loss:  0.03676,  Smallest:  0.00069,  Largest:  0.99937,  Average:  0.50283
Category, # Predictions, # Correct:  0.0, 35539, 35501
Category, # Predictions, # Correct:  0.5, 35539, 31742
Category, # Predictions, # Correct:  1.0, 35539, 35539 
Test:
 Average loss:  0.69390,  Smallest:  0.00105,  Largest:  0.80489,  Average:  0.02851
Category, # Predictions, # Correct:  0.0, 3300, 3300
Category, # Predictions, # Correct:  0.5, 1278, 0
Category, # Predictions, # Correct:  1.0, 8934

Train:
 Average loss:  0.03598,  Smallest:  0.00067,  Largest:  0.99942,  Average:  0.50377
Category, # Predictions, # Correct:  0.0, 35539, 35503
Category, # Predictions, # Correct:  0.5, 35539, 32165
Category, # Predictions, # Correct:  1.0, 35539, 35539 
Test:
 Average loss:  0.69465,  Smallest:  0.00098,  Largest:  0.80282,  Average:  0.02704
Category, # Predictions, # Correct:  0.0, 3300, 3300
Category, # Predictions, # Correct:  0.5, 1278, 0
Category, # Predictions, # Correct:  1.0, 8934, 2
Train:
 Average loss:  0.03559,  Smallest:  0.00065,  Largest:  0.99943,  Average:  0.50337
Category, # Predictions, # Correct:  0.0, 35539, 35503
Category, # Predictions, # Correct:  0.5, 35539, 32326
Category, # Predictions, # Correct:  1.0, 35539, 35539 
Test:
 Average loss:  0.69444,  Smallest:  0.00098,  Largest:  0.80111,  Average:  0.02744
Category, # Predictions, # Correct:  0.0, 3300, 3300
Category, # Predictions, # Correct:  0.5, 1278, 0
Category, # Predictions, # Correct:  1.0, 8934

Train:
 Average loss:  0.03387,  Smallest:  0.00062,  Largest:  0.99946,  Average:  0.50262
Category, # Predictions, # Correct:  0.0, 35539, 35501
Category, # Predictions, # Correct:  0.5, 35539, 32491
Category, # Predictions, # Correct:  1.0, 35539, 35539 
Test:
 Average loss:  0.69476,  Smallest:  0.00096,  Largest:  0.79804,  Average:  0.02683
Category, # Predictions, # Correct:  0.0, 3300, 3300
Category, # Predictions, # Correct:  0.5, 1278, 0
Category, # Predictions, # Correct:  1.0, 8934, 2
Train:
 Average loss:  0.03500,  Smallest:  0.00063,  Largest:  0.99946,  Average:  0.50078
Category, # Predictions, # Correct:  0.0, 35539, 35503
Category, # Predictions, # Correct:  0.5, 35539, 32484
Category, # Predictions, # Correct:  1.0, 35539, 35539 
Test:
 Average loss:  0.69488,  Smallest:  0.00095,  Largest:  0.80214,  Average:  0.02659
Category, # Predictions, # Correct:  0.0, 3300, 3300
Category, # Predictions, # Correct:  0.5, 1278, 0
Category, # Predictions, # Correct:  1.0, 8934

Train:
 Average loss:  0.03327,  Smallest:  0.00060,  Largest:  0.99948,  Average:  0.50288
Category, # Predictions, # Correct:  0.0, 35539, 35505
Category, # Predictions, # Correct:  0.5, 35539, 32639
Category, # Predictions, # Correct:  1.0, 35539, 35539 
Test:
 Average loss:  0.69528,  Smallest:  0.00089,  Largest:  0.80022,  Average:  0.02580
Category, # Predictions, # Correct:  0.0, 3300, 3300
Category, # Predictions, # Correct:  0.5, 1278, 0
Category, # Predictions, # Correct:  1.0, 8934, 2
Train:
 Average loss:  0.03095,  Smallest:  0.00059,  Largest:  0.99950,  Average:  0.50279
Category, # Predictions, # Correct:  0.0, 35539, 35503
Category, # Predictions, # Correct:  0.5, 35539, 33054
Category, # Predictions, # Correct:  1.0, 35539, 35539 
Test:
 Average loss:  0.69537,  Smallest:  0.00087,  Largest:  0.79782,  Average:  0.02562
Category, # Predictions, # Correct:  0.0, 3300, 3300
Category, # Predictions, # Correct:  0.5, 1278, 0
Category, # Predictions, # Correct:  1.0, 8934

Train:
 Average loss:  0.03290,  Smallest:  0.00055,  Largest:  0.99952,  Average:  0.50158
Category, # Predictions, # Correct:  0.0, 35539, 35499
Category, # Predictions, # Correct:  0.5, 35539, 33025
Category, # Predictions, # Correct:  1.0, 35539, 35539 
Test:
 Average loss:  0.69527,  Smallest:  0.00086,  Largest:  0.78775,  Average:  0.02582
Category, # Predictions, # Correct:  0.0, 3300, 3300
Category, # Predictions, # Correct:  0.5, 1278, 0
Category, # Predictions, # Correct:  1.0, 8934, 2
Train:
 Average loss:  0.03282,  Smallest:  0.00056,  Largest:  0.99953,  Average:  0.50264
Category, # Predictions, # Correct:  0.0, 35539, 35503
Category, # Predictions, # Correct:  0.5, 35539, 32852
Category, # Predictions, # Correct:  1.0, 35539, 35539 
Test:
 Average loss:  0.69539,  Smallest:  0.00087,  Largest:  0.79355,  Average:  0.02560
Category, # Predictions, # Correct:  0.0, 3300, 3300
Category, # Predictions, # Correct:  0.5, 1278, 0
Category, # Predictions, # Correct:  1.0, 8934

Train:
 Average loss:  0.03194,  Smallest:  0.00053,  Largest:  0.99955,  Average:  0.50046
Category, # Predictions, # Correct:  0.0, 35539, 35499
Category, # Predictions, # Correct:  0.5, 35539, 33179
Category, # Predictions, # Correct:  1.0, 35539, 35539 
Test:
 Average loss:  0.69558,  Smallest:  0.00085,  Largest:  0.79780,  Average:  0.02522
Category, # Predictions, # Correct:  0.0, 3300, 3300
Category, # Predictions, # Correct:  0.5, 1278, 0
Category, # Predictions, # Correct:  1.0, 8934, 2
Train:
 Average loss:  0.03189,  Smallest:  0.00054,  Largest:  0.99955,  Average:  0.50370
Category, # Predictions, # Correct:  0.0, 35539, 35499
Category, # Predictions, # Correct:  0.5, 35539, 32958
Category, # Predictions, # Correct:  1.0, 35539, 35539 
Test:
 Average loss:  0.69549,  Smallest:  0.00086,  Largest:  0.79534,  Average:  0.02539
Category, # Predictions, # Correct:  0.0, 3300, 3300
Category, # Predictions, # Correct:  0.5, 1278, 0
Category, # Predictions, # Correct:  1.0, 8934

KeyboardInterrupt: 

In [9]:
# validate
with torch.set_grad_enabled(False):
    net = net.eval()
    correct = {i : 0 for i in [0, 0.5, 1]}
    total = {i : 0 for i in [0, 0.5, 1]}
    for board, value in test_gen:
        board, value = board.to(device), value.to(device)

        outputs = net(board)
        categories = categorise_predictions(outputs)

        for k in correct:
            idx = (categories == k).nonzero()
            total[k] += len(idx)
            correct[k] += (categories[idx] == value[idx]).nonzero().sum().item()
    for k in correct:
        print('Category, # Predictions, Accuracy of the network on the test%s: %d %d %%' % (
            k,
            total[k],
            (100 * float(correct[k]) / float(total[k])) if total[k] != 0 else 0.))

RuntimeError: The size of tensor a (2) must match the size of tensor b (0) at non-singleton dimension 1

In [None]:
assert(False)
# save that crap
torch.save({
    'net_state_dict': net.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'loss': loss
},
    open('/home/richard/Downloads/nn.pth', 'wb'))

In [None]:
# alternatively load it
assert(False)
checkpoint = torch.load('/home/richard/Downloads/nn.pth')
net.load_state_dict(checkpoint['net_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
loss = checkpoint['loss']

In [None]:
dataiter = iter(testloader)
images, labels = dataiter.next()

# print images
imshow(torchvision.utils.make_grid(images))
print('GroundTruth: ', ' '.join('%5s' % classes[labels[j]] for j in range(4)))