In [8]:
import argparse
import os
import copy

import torch
from torch import nn
import torch.optim as optim
import torch.backends.cudnn as cudnn
from torch.utils.data.dataloader import DataLoader
from tqdm import tqdm

from models import SRCNN
from datasets import TrainDataset, EvalDataset
from utils import AverageMeter, calc_psnr

In [9]:
num_workers = 8
# all the arguments in variables with their default values
# train_file = 'original_training_data/x4/91-image_x4.h5' #TODO
train_file = 'original_training_data/for_training/AMVTG_004.h5' #TODO
eval_file = 'original_training_data/x4/Set5_x4.h5' #TODO
outputs_dir = 'outputs'
scale = 4
lr = 1e-4
batch_size = 16
num_epochs = 400
num_workers = 8
seed = 123

# new output dir using the statics variables
outputs_dir = os.path.join(outputs_dir, 'x{}'.format(scale))

# if not os.path.exists(args.outputs_dir):
#     os.makedirs(args.outputs_dir)


if not os.path.exists(outputs_dir):
    os.makedirs(outputs_dir)

cudnn.benchmark = True
device = torch.device('cuda:0' if torch.cuda.is_available() else 'mps')

# torch.manual_seed(args.seed)
torch.manual_seed(seed)

model = SRCNN().to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam([
    {'params': model.conv1.parameters()},
    {'params': model.conv2.parameters()},
    {'params': model.conv3.parameters(), 'lr': lr * 0.1}
], lr=lr)

train_dataset = TrainDataset(train_file)
train_dataloader = DataLoader(dataset=train_dataset,
                                batch_size=batch_size,
                                shuffle=True,
                                num_workers=num_workers,
                                pin_memory=True,
                                drop_last=True)
eval_dataset = EvalDataset(eval_file)
eval_dataloader = DataLoader(dataset=eval_dataset, batch_size=1)

best_weights = copy.deepcopy(model.state_dict())
best_epoch = 0
best_psnr = 0.0

for epoch in range(num_epochs):
    model.train()
    epoch_losses = AverageMeter()

    with tqdm(total=(len(train_dataset) - len(train_dataset) % batch_size)) as t:
        t.set_description('epoch: {}/{}'.format(epoch, num_epochs - 1))

        for data in train_dataloader:
            inputs, labels = data

            inputs = inputs.to(device)
            labels = labels.to(device)

            preds = model(inputs)

            loss = criterion(preds, labels)

            epoch_losses.update(loss.item(), len(inputs))

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            t.set_postfix(loss='{:.6f}'.format(epoch_losses.avg))
            t.update(len(inputs))

    torch.save(model.state_dict(), os.path.join(outputs_dir, 'epoch_{}.pth'.format(epoch)))

    model.eval()
    epoch_psnr = AverageMeter()

    for data in eval_dataloader:
        inputs, labels = data

        inputs = inputs.to(device)
        labels = labels.to(device)

        with torch.no_grad():
            preds = model(inputs).clamp(0.0, 1.0)

        epoch_psnr.update(calc_psnr(preds, labels), len(inputs))

    print('eval psnr: {:.2f}'.format(epoch_psnr.avg))

    if epoch_psnr.avg > best_psnr:
        best_epoch = epoch
        best_psnr = epoch_psnr.avg
        best_weights = copy.deepcopy(model.state_dict())

print('best epoch: {}, psnr: {:.2f}'.format(best_epoch, best_psnr))
torch.save(best_weights, os.path.join(outputs_dir, 'best.pth'))


epoch: 0/399:   0%|          | 0/74368 [00:00<?, ?it/s]

epoch: 0/399: 100%|██████████| 74368/74368 [01:20<00:00, 928.59it/s, loss=0.003858] 


eval psnr: 27.74


epoch: 1/399: 100%|██████████| 74368/74368 [01:20<00:00, 928.32it/s, loss=0.002854] 


eval psnr: 27.82


epoch: 2/399: 100%|██████████| 74368/74368 [01:20<00:00, 928.42it/s, loss=0.002623] 


eval psnr: 27.84


epoch: 3/399: 100%|██████████| 74368/74368 [01:20<00:00, 924.79it/s, loss=0.002494] 


eval psnr: 27.91


epoch: 4/399: 100%|██████████| 74368/74368 [01:17<00:00, 961.85it/s, loss=0.002406] 


KeyboardInterrupt: 