In [1]:
from config import DEVICE, NUM_CLASSES, NUM_EPOCHS, OUT_DIR
from config import VISUALIZE_TRANSFORMED_IMAGES
from config import SAVE_PLOTS_EPOCH, SAVE_MODEL_EPOCH
from model import create_model
from utils import Averager
from tqdm.auto import tqdm
from datasets import train_loader, valid_loader, ButtonDataset
import torch
import matplotlib.pyplot as plt
import time

plt.style.use('ggplot')

Number of training samples: 10
Number of validation samples: 4



In [2]:
# function for running training iterations
def train(train_data_loader, model):
    print('Training')
    global train_itr
    global train_loss_list

    # initialize tqdm progress bar
    prog_bar = tqdm(train_data_loader, total=len(train_data_loader))

    for i, data in enumerate(prog_bar):
        optimizer.zero_grad()
        images, targets = data

        images = list(image.to(DEVICE) for image in images)
        targets = [{k: v.to(DEVICE) for k, v in t.items()} for t in targets]
        loss_dict = model(images, targets)
        losses = sum(loss for loss in loss_dict.values())
        loss_value = losses.item()
        train_loss_list.append(loss_value)
        train_loss_hist.send(loss_value)
        losses.backward()
        optimizer.step()
        train_itr += 1

        # update the loss value beside the progress bar for each iteration
        prog_bar.set_description(desc=f"Loss: {loss_value:.4f}")
    return train_loss_list

In [3]:
# function for running validation iterations
def validate(valid_data_loader, model):
    print('Validating')
    global val_itr
    global val_loss_list

    # initialize tqdm progress bar
    prog_bar = tqdm(valid_data_loader, total=len(valid_data_loader))

    for i, data in enumerate(prog_bar):
        images, targets = data

        images = list(image.to(DEVICE) for image in images)
        targets = [{k: v.to(DEVICE) for k, v in t.items()} for t in targets]

        with torch.no_grad():
            loss_dict = model(images, targets)

        losses = sum(loss for loss in loss_dict.values())
        loss_value = losses.item()
        val_loss_list.append(loss_value)

        val_loss_hist.send(loss_value)

        val_itr += 1

        # update the loss value beside the progress bar for each iteration
        prog_bar.set_description(desc=f"Loss: {loss_value:.4f}")
    return val_loss_list

In [4]:
def train_model():
    # initialize the model and move to the computation device
    model = create_model(num_classes=NUM_CLASSES)
    model = model.to(DEVICE)
    # get the model parameters
    params = [p for p in model.parameters() if p.requires_grad]
    # define the optimizer
    optimizer = torch.optim.SGD(params, lr=0.001, momentum=0.9, weight_decay=0.0005)
    # initialize the Averager class
    train_loss_hist = Averager()
    val_loss_hist = Averager()
    train_itr = 1
    val_itr = 1
    # train and validation loss lists to store loss values of all...
    # ... iterations till ena and plot graphs for all iterations
    train_loss_list = []
    val_loss_list = []
    # name to save the trained model with
    MODEL_NAME = 'model'
    # whether to show transformed images from data loader or not
    if VISUALIZE_TRANSFORMED_IMAGES:
        from utils import show_tranformed_image

        show_tranformed_image(train_loader)
    # start the training epochs
    for epoch in range(NUM_EPOCHS):
        print(f"\nEPOCH {epoch + 1} of {NUM_EPOCHS}")
        # reset the training and validation loss histories for the current epoch
        train_loss_hist.reset()
        val_loss_hist.reset()
        # create two subplots, one for each, training and validation
        figure_1, train_ax = plt.subplots()
        figure_2, valid_ax = plt.subplots()
        # start timer and carry out training and validation
        start = time.time()
        train_loss = train(train_loader, model)
        val_loss = validate(valid_loader, model)
        print(f"Epoch #{epoch} train loss: {train_loss_hist.value:.3f}")
        print(f"Epoch #{epoch} validation loss: {val_loss_hist.value:.3f}")
        end = time.time()
        print(f"Took {((end - start) / 60):.3f} minutes for epoch {epoch}")
        if (epoch + 1) % SAVE_MODEL_EPOCH == 0:  # save model after every n epochs
            torch.save(model.state_dict(), f"{OUT_DIR}/model{epoch + 1}.pth")
            print('SAVING MODEL COMPLETE...\n')

        if (epoch + 1) % SAVE_PLOTS_EPOCH == 0:  # save loss plots after n epochs
            train_ax.plot(train_loss, color='blue')
            train_ax.set_xlabel('iterations')
            train_ax.set_ylabel('train loss')
            valid_ax.plot(val_loss, color='red')
            valid_ax.set_xlabel('iterations')
            valid_ax.set_ylabel('validation loss')
            figure_1.savefig(f"{OUT_DIR}/train_loss_{epoch + 1}.png")
            figure_2.savefig(f"{OUT_DIR}/valid_loss_{epoch + 1}.png")
            print('SAVING PLOTS COMPLETE...')

        if (epoch + 1) == NUM_EPOCHS:  # save loss plots and model once at the end
            train_ax.plot(train_loss, color='blue')
            train_ax.set_xlabel('iterations')
            train_ax.set_ylabel('train loss')
            valid_ax.plot(val_loss, color='red')
            valid_ax.set_xlabel('iterations')
            valid_ax.set_ylabel('validation loss')
            figure_1.savefig(f"{OUT_DIR}/train_loss_{epoch + 1}.png")
            figure_2.savefig(f"{OUT_DIR}/valid_loss_{epoch + 1}.png")
            torch.save(model.state_dict(), f"{OUT_DIR}/model{epoch + 1}.pth")

        plt.close('all')


EPOCH 1 of 10
Training


  0%|          | 0/5 [00:00<?, ?it/s]

Validating


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch #0 train loss: 0.751
Epoch #0 validation loss: 0.409
Took 6.850 minutes for epoch 0
SAVING MODEL COMPLETE...

SAVING PLOTS COMPLETE...

EPOCH 2 of 10
Training


  0%|          | 0/5 [00:00<?, ?it/s]

Validating


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch #1 train loss: 0.349
Epoch #1 validation loss: 0.384
Took 4.915 minutes for epoch 1
SAVING MODEL COMPLETE...

SAVING PLOTS COMPLETE...

EPOCH 3 of 10
Training


  0%|          | 0/5 [00:00<?, ?it/s]

Validating


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch #2 train loss: 0.366
Epoch #2 validation loss: 0.349
Took 4.549 minutes for epoch 2
SAVING MODEL COMPLETE...

SAVING PLOTS COMPLETE...

EPOCH 4 of 10
Training


  0%|          | 0/5 [00:00<?, ?it/s]

Validating


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch #3 train loss: 0.329
Epoch #3 validation loss: 0.346
Took 3.886 minutes for epoch 3
SAVING MODEL COMPLETE...

SAVING PLOTS COMPLETE...

EPOCH 5 of 10
Training


  0%|          | 0/5 [00:00<?, ?it/s]

Validating


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch #4 train loss: 0.328
Epoch #4 validation loss: 0.309
Took 3.725 minutes for epoch 4
SAVING MODEL COMPLETE...

SAVING PLOTS COMPLETE...

EPOCH 6 of 10
Training


  0%|          | 0/5 [00:00<?, ?it/s]

Validating


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch #5 train loss: 0.321
Epoch #5 validation loss: 0.289
Took 3.891 minutes for epoch 5
SAVING MODEL COMPLETE...

SAVING PLOTS COMPLETE...

EPOCH 7 of 10
Training


  0%|          | 0/5 [00:00<?, ?it/s]

Validating


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch #6 train loss: 0.310
Epoch #6 validation loss: 0.277
Took 4.108 minutes for epoch 6
SAVING MODEL COMPLETE...

SAVING PLOTS COMPLETE...

EPOCH 8 of 10
Training


  0%|          | 0/5 [00:00<?, ?it/s]

Validating


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch #7 train loss: 0.298
Epoch #7 validation loss: 0.256
Took 3.955 minutes for epoch 7
SAVING MODEL COMPLETE...

SAVING PLOTS COMPLETE...

EPOCH 9 of 10
Training


  0%|          | 0/5 [00:00<?, ?it/s]

Validating


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch #8 train loss: 0.262
Epoch #8 validation loss: 0.239
Took 4.087 minutes for epoch 8
SAVING MODEL COMPLETE...

SAVING PLOTS COMPLETE...

EPOCH 10 of 10
Training


  0%|          | 0/5 [00:00<?, ?it/s]

Validating


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch #9 train loss: 0.266
Epoch #9 validation loss: 0.209
Took 4.197 minutes for epoch 9
SAVING MODEL COMPLETE...

SAVING PLOTS COMPLETE...
