# Train

This file is used to train the networks. While training, we save a checkpoint file after each epoch of training which contains the model's parameters, the optimizer's parameters, the average loss over the epoch, and the average validation dice score. Using this approach, we can then later plot learning curves and restart training from any checkpoint.

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
### Install dependencies if needed and import modules 

In [2]:
!pip install pytorch-msssim

Collecting pytorch-msssim
  Downloading https://files.pythonhosted.org/packages/9d/d3/3cb0f397232cf79e1762323c3a8862e39ad53eca0bb5f6be9ccc8e7c070e/pytorch_msssim-0.2.1-py3-none-any.whl
Installing collected packages: pytorch-msssim
Successfully installed pytorch-msssim-0.2.1


In [3]:
# Import Modules
import sys
sys.path.insert(1, "/content/drive/My Drive/CMPUT511/Project/Code/RFDN")
import train

from RFDN import RFDN, RFDN1, RFDN2
from BaseN import BaseN

## Training Loop

The cells below contain the training loop. The trainer object will take in the checkpoint file, data containing directory, learning rate, and model number. It will then train the network, checkpointing and saving in the same directory that contains the argument checkpoint file. Additionally, the trainer can load in the checkpoint file and resume training.

To train our networks, we use the mean $L_1$ loss, as specified in the RFDN paper:
$$
    \mathcal{L}(\theta) = \frac{1}{N} \sum\limits_{i=1}^{N} || H(I_{i}^{LR}) - I_{i}^{HR} ||_{1}
$$
where $H$ is the hypothesis of the model being trained, $I_{i}^{LR}$ is the $i^{th}$ pixel of the low-resolution image, $I_{i}^{HR}$ is the $i^{th}$ pixel of the high resolution imgae, $\theta$ are the model parameters, $N$ is the number of pixels in the image, and $|| \cdot ||_1$ is the $L_1$ norm.

The creation and optimization of this loss function is taken care of by the Trainer class. Additionally, we use the Adam optimizer for training, which is also created and maintained by the Trainer class.

In [4]:
model = RFDN1(nf=10, upscale=2)
# model = BaseN(nf=10, upscale=2)
# model = FDCN(nf=10, upscale=2)
# model = SRN(nf=10, upscale=2)
data_dir = "/content/drive/My Drive/CMPUT511/Project/Data"
checkpoint_file = "/content/drive/My Drive/CMPUT511/Project/Checkpoints/checkpoint_0_1.tar"
trainer = train.Trainer(model, checkpoint_file, data_dir, lr=1e-2, div=1.005, num=0)

In [5]:
trainer.train(40, load=False)

100%|██████████| 800/800 [56:50<00:00,  4.26s/it]
100%|██████████| 800/800 [02:48<00:00,  4.75it/s]
100%|██████████| 800/800 [02:47<00:00,  4.79it/s]
100%|██████████| 800/800 [02:47<00:00,  4.79it/s]
100%|██████████| 800/800 [02:47<00:00,  4.78it/s]
100%|██████████| 800/800 [02:48<00:00,  4.74it/s]
100%|██████████| 800/800 [02:48<00:00,  4.74it/s]
100%|██████████| 800/800 [02:48<00:00,  4.76it/s]
100%|██████████| 800/800 [02:48<00:00,  4.75it/s]
100%|██████████| 800/800 [02:48<00:00,  4.73it/s]
100%|██████████| 800/800 [02:48<00:00,  4.75it/s]
100%|██████████| 800/800 [02:49<00:00,  4.72it/s]
100%|██████████| 800/800 [02:49<00:00,  4.73it/s]
100%|██████████| 800/800 [02:48<00:00,  4.74it/s]
100%|██████████| 800/800 [02:48<00:00,  4.74it/s]
100%|██████████| 800/800 [02:49<00:00,  4.72it/s]
100%|██████████| 800/800 [02:49<00:00,  4.73it/s]
100%|██████████| 800/800 [02:48<00:00,  4.76it/s]
100%|██████████| 800/800 [02:48<00:00,  4.76it/s]
100%|██████████| 800/800 [02:48<00:00,  4.74it/s]
