# Train Model

Now we get into the nuts and bolts of training the model. First and foremost lets load the data from the prior step. Next we divide into train and test where the first 40k images will be used for training and the remainder for testing. We also import `utils` which contains some functions used during the model training process. In particular:

- `scale_images()`: Is a tool for standardizing (or inverting) images to have mean: 0 and sd: 1
- `add_noise()`: Adds random noise to the images but clips to be within tolerance for pixel ranges
- `get_batch()`: Is a dataloader which grabs a batch of images (default 1000 images)

In [11]:
import os
import sys
import tqdm
import torch
import polars as pl
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

# Load modeling utilities
sys.path.append('../src/')
import utils

# Load images
images = torch.load('../data/images.pt')

# Divide into train and test
train = images[0:40000]
test = images[40000:50000]

Now we define our neural network architecture using `nn.Module` from `torch`. We are using a simple CNN which leverages convolutions in the network layers. Convolutions provide better support for spatial data, like images, as opposed to tabular data where each input is a distinct and independent feature. This is a fairly small CNN but should be suitable for our purposes using 64x64 images.

In [4]:
class ConvNet(nn.Module):

    # Defines network parameters
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(32, 1, kernel_size=3, padding=1)

    # Defines forward pass
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = self.conv3(x)
        return x

Here we define the training loop. We take batches of 100 images and add noise to construct the input image (artificially noisy) and outputs (original image). The model trains for a set number of epochs to learn how to de-noise the images in a repeatable way. We can then take a look at the testing data performance to get a sense of how the model is doing.

In [18]:
# Set a random seed
torch.manual_seed(1)
losses = []

# Define model, loss, and optimizer
model = ConvNet()
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Train for 1000 epochs
for epoch in tqdm.tqdm(range(1000)):
    optimizer.zero_grad()
    inputs, targets = utils.get_batch(train, n=100)
    outputs = model(inputs)
    loss = criterion(outputs, targets)
    loss.backward()
    optimizer.step()
    losses.append(loss.item())

  5%|██▏                                      | 52/1000 [00:10<03:13,  4.90it/s]


KeyboardInterrupt: 

Finally, we save the model and loss curve history.

In [26]:
torch.save(model.state_dict(), '../data/model.pt')
pl.DataFrame({'losses': losses}).with_row_index('epoch').write_csv('../data/losses.csv')