# Mini neural net training
Athor: Alfredo Canziani  
Date: Fri 14 Feb 2020

In [None]:
import torch
from torch import nn, optim
from IPython import display
from PIL import Image, ImageFont, ImageDraw
from matplotlib.pyplot import imshow, axis, figure
import numpy
import random

In [None]:
# Input definition
class input_settings:
    batch_size = 1
    channels = 1
    height = 8
    width = 8

dummy_X = torch.randn(  # batch of inputs x
    input_settings.batch_size,
    input_settings.channels,
    input_settings.height,
    input_settings.width,
)

In [None]:
# Network architecture
class model_settings:
    conv_channels = 4
    kernel = 3
    pooling_kernel = 3
    flattened = 16
    output_size = 1
    
model = nn.Sequential(
    nn.Conv2d(
        in_channels=input_settings.channels,
        out_channels=model_settings.conv_channels,
        kernel_size=model_settings.kernel,
        bias=True,
    ),
    nn.ReLU(),
    nn.MaxPool2d(
        kernel_size=model_settings.pooling_kernel,
        stride=model_settings.pooling_kernel,
    ),  # we have 4 x 2x2
    nn.Flatten(),  # gives 16
    nn.Linear(
        in_features=model_settings.flattened,
        out_features=model_settings.output_size,
        bias=True,
    ),
)

In [None]:
# Inference
with torch.no_grad():
    print(model(dummy_X).size())

In [None]:
# Print the model architecture
print(model)

In [None]:
# Get weights and biases
def get_weights():
    print(
        model[0],
        model[0].weight,
        model[0].bias,
        sep='\n',
    )
    print(
        model[4],
        model[4].weight,
        model[4].bias,
        sep='\n',
    )
    # Maybe add some saving routines

In [None]:
font = ImageFont.truetype('Verdana', 8)  # let's keep it to Verdana 8pt
data_set_settings = dict(
    D=dict(
        x_min = -1,
        x_max = 2,
        y_min = -3,
        y_max = -1,
    ),
    C=dict(
        x_min = 0,
        x_max = 3,
        y_min = -4,
        y_max = -2,
    )
)

In [None]:
def generate_data(visualise=False, target=None):
    image = Image.new('L', (input_settings.height, input_settings.width))
    draw = ImageDraw.Draw(image)
    draw.fontmode = '1'
    if not target:
        character = random.choice(('C', 'D'))
    else:
        character = target
    x = random.randint(
        data_set_settings[character]['x_min'],
        data_set_settings[character]['x_max'],
    )
    y = random.randint(
        data_set_settings[character]['y_min'],
        data_set_settings[character]['y_max'],
    )

    draw.text((x, y), character, (255,), font=font)
    data = numpy.array(image, dtype=numpy.float32) / 255
    
    if visualise:
        figure(facecolor='k')
        imshow(data)
        axis('off');
    
    return torch.tensor(data).unsqueeze_(0), torch.tensor(character=='C', dtype=torch.float)

In [None]:
# Generate a C, D batch
def get_batch(visualise=False):
    Cx, Cy = generate_data(target='C', visualise=visualise)
    Dx, Dy = generate_data(target='D', visualise=visualise)
    x = torch.stack((Cx, Dx))
    y = torch.stack((Cy, Dy))
    return x, y

In [None]:
# Set up network training
nb_epochs = 10_000
optimiser = optim.SGD(params=model.parameters(), lr=1e-3)
loss = nn.BCEWithLogitsLoss()

In [None]:
# Training
for epoch in range(nb_epochs):
    # Training steps
    X, Y = get_batch()
    logits = model(X).squeeze(1)  # feed-forward
    J = loss(logits, Y)  # computes the loss
    model.zero_grad()  # cleans up previous gradients
    J.backward()
    optimiser.step()

    # Accuracy computation and display
    score, predicted = torch.max(logits, 0)
    acc = (Y == (logits > 0)).sum().float() / len(Y)
    print("[EPOCH]: %i, [LOSS]: %.6f, [ACCURACY]: %.3f" % (epoch, J.item(), acc))
    display.clear_output(wait=True)

In [None]:
# Inference: C vs. D detector
with torch.no_grad():
    print('C' if model(generate_data(visualise=True)[0].unsqueeze_(0)) > 0 else 'D')