In [1]:
import argparse
import os
import copy

import torch
from torch import nn
import torch.optim as optim
import torch.backends.cudnn as cudnn
from torch.utils.data.dataloader import DataLoader
from tqdm import tqdm

from models import SRCNN
from datasets import TrainDataset, EvalDataset
from utils import AverageMeter, calc_psnr

import numpy as np

In [2]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'mps')


In [3]:
def features(img, model):
    submodules=list(model.children())[0:-1]
    features=[]
    last = img.to(device)
    for submodule in submodules:
        last = submodule(last).to(device)
        features.append(last)
    return features

def featureLoss(model,j):
    def feature_loss(output, target):
        output_ft = features(output, model)
        target_ft = features(target,model)
        feature_loss=[]

        submodules=list(model.children())[0:-1]

        for i in range(len(submodules)):
            scale = 1 / (np.prod(output_ft[i].shape)* np.prod(submodules[i].kernel_size))
            feature_loss.append(scale * torch.norm(output_ft[i] - target_ft[i]))
        return feature_loss[j]
    return feature_loss
        

In [5]:
num_workers = 8
# all the arguments in variables with their default values
train_file = 'original_training_data/x4/91-image_x4.h5' #TODO
#train_file = 'original_training_data/for_training/AMVTG_004.h5' #TODO
eval_file = 'original_training_data/x4/Set5_x4.h5' #TODO
outputs_dir = 'outputs'
scale = 4
lr = 1e-4
batch_size = 16
num_epochs = 10
num_workers = 8
seed = 123

# new output dir using the statics variables
outputs_dir = os.path.join(outputs_dir, 'x{}'.format(scale))

# if not os.path.exists(args.outputs_dir):
#     os.makedirs(args.outputs_dir)


if not os.path.exists(outputs_dir):
    os.makedirs(outputs_dir)

cudnn.benchmark = True
device = torch.device('cuda:0' if torch.cuda.is_available() else 'mps')

# torch.manual_seed(args.seed)
torch.manual_seed(seed)

model_orig = SRCNN().to(device)
model_orig.load_state_dict(torch.load("pretrained/srcnn_x4.pth", map_location=device))
model_orig.to(device)
model_orig.eval()

model_feat = SRCNN().to(device)

#criterion = nn.MSELoss()
criterion = featureLoss(model_orig, 2) #this number is j
optimizer = optim.Adam([
    {'params': model_feat.conv1.parameters()},
    {'params': model_feat.conv2.parameters()},
    {'params': model_feat.conv3.parameters(), 'lr': lr * 0.1}
], lr=lr)

train_dataset = TrainDataset(train_file)
train_dataloader = DataLoader(dataset=train_dataset,
                                batch_size=batch_size,
                                shuffle=True,
                                num_workers=num_workers,
                                pin_memory=True,
                                drop_last=True)
eval_dataset = EvalDataset(eval_file)
eval_dataloader = DataLoader(dataset=eval_dataset, batch_size=1)

best_weights = copy.deepcopy(model_feat.state_dict())
best_epoch = 0
best_psnr = 0.0

for epoch in range(num_epochs):
    model_feat.train()
    epoch_losses = AverageMeter()

    with tqdm(total=(len(train_dataset) - len(train_dataset) % batch_size)) as t:
        t.set_description('epoch: {}/{}'.format(epoch, num_epochs - 1))

        for data in train_dataloader:
            inputs, labels = data

            inputs = inputs.to(device)
            labels = labels.to(device)

            preds = model_feat(inputs)

            loss = criterion(preds, labels)

            epoch_losses.update(loss.item(), len(inputs))

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            t.set_postfix(loss='{:.6f}'.format(epoch_losses.avg))
            t.update(len(inputs))

    torch.save(model_feat.state_dict(), os.path.join(outputs_dir, 'epoch_{}.pth'.format(epoch)))

    model_feat.eval()
    epoch_psnr = AverageMeter()

    for data in eval_dataloader:
        inputs, labels = data

        inputs = inputs.to(device)
        labels = labels.to(device)

        with torch.no_grad():
            preds = model_feat(inputs).clamp(0.0, 1.0)

        epoch_psnr.update(calc_psnr(preds, labels), len(inputs))

    print('eval psnr: {:.2f}'.format(epoch_psnr.avg))

    if epoch_psnr.avg > best_psnr:
        best_epoch = epoch
        best_psnr = epoch_psnr.avg
        best_weights = copy.deepcopy(model_feat.state_dict())

print('best epoch: {}, psnr: {:.2f}'.format(best_epoch, best_psnr))
torch.save(best_weights, os.path.join(outputs_dir, 'best.pth'))


epoch: 0/9:   0%|          | 0/21760 [00:00<?, ?it/s]

epoch: 0/9: 100%|██████████| 21760/21760 [01:12<00:00, 301.90it/s, loss=0.000184]


eval psnr: 28.26


epoch: 1/9: 100%|██████████| 21760/21760 [01:11<00:00, 304.58it/s, loss=0.000168]


eval psnr: 28.48


epoch: 2/9: 100%|██████████| 21760/21760 [01:10<00:00, 307.93it/s, loss=0.000167]


eval psnr: 28.90


epoch: 3/9: 100%|██████████| 21760/21760 [01:09<00:00, 311.43it/s, loss=0.000166]


eval psnr: 28.98


epoch: 4/9: 100%|██████████| 21760/21760 [01:09<00:00, 312.99it/s, loss=0.000165]


eval psnr: 29.07


epoch: 5/9: 100%|██████████| 21760/21760 [01:09<00:00, 312.18it/s, loss=0.000164]


eval psnr: 28.89


epoch: 6/9: 100%|██████████| 21760/21760 [01:10<00:00, 310.66it/s, loss=0.000163]


eval psnr: 29.07


epoch: 7/9: 100%|██████████| 21760/21760 [01:10<00:00, 307.92it/s, loss=0.000163]


eval psnr: 29.02


epoch: 8/9: 100%|██████████| 21760/21760 [01:10<00:00, 310.41it/s, loss=0.000163]


eval psnr: 29.12


epoch: 9/9: 100%|██████████| 21760/21760 [16:22<00:00, 22.14it/s, loss=0.000162] 


eval psnr: 29.10
best epoch: 8, psnr: 29.12


In [6]:
model_MSE = SRCNN().to(device)

criterion = nn.MSELoss()
#criterion = featureLoss(model_orig, 2) #this number is j
optimizer = optim.Adam([
    {'params': model_MSE.conv1.parameters()},
    {'params': model_MSE.conv2.parameters()},
    {'params': model_MSE.conv3.parameters(), 'lr': lr * 0.1}
], lr=lr)

train_dataset = TrainDataset(train_file)
train_dataloader = DataLoader(dataset=train_dataset,
                                batch_size=batch_size,
                                shuffle=True,
                                num_workers=num_workers,
                                pin_memory=True,
                                drop_last=True)
eval_dataset = EvalDataset(eval_file)
eval_dataloader = DataLoader(dataset=eval_dataset, batch_size=1)

best_weights = copy.deepcopy(model_MSE.state_dict())
best_epoch = 0
best_psnr = 0.0

for epoch in range(num_epochs):
    model_MSE.train()
    epoch_losses = AverageMeter()

    with tqdm(total=(len(train_dataset) - len(train_dataset) % batch_size)) as t:
        t.set_description('epoch: {}/{}'.format(epoch, num_epochs - 1))

        for data in train_dataloader:
            inputs, labels = data

            inputs = inputs.to(device)
            labels = labels.to(device)

            preds = model_MSE(inputs)

            loss = criterion(preds, labels)

            epoch_losses.update(loss.item(), len(inputs))

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            t.set_postfix(loss='{:.6f}'.format(epoch_losses.avg))
            t.update(len(inputs))

    torch.save(model_MSE.state_dict(), os.path.join(outputs_dir, 'epoch_{}.pth'.format(epoch)))

    model_MSE.eval()
    epoch_psnr = AverageMeter()

    for data in eval_dataloader:
        inputs, labels = data

        inputs = inputs.to(device)
        labels = labels.to(device)

        with torch.no_grad():
            preds = model_MSE(inputs).clamp(0.0, 1.0)

        epoch_psnr.update(calc_psnr(preds, labels), len(inputs))

    print('eval psnr: {:.2f}'.format(epoch_psnr.avg))

    if epoch_psnr.avg > best_psnr:
        best_epoch = epoch
        best_psnr = epoch_psnr.avg
        best_weights = copy.deepcopy(model_MSE.state_dict())

print('best epoch: {}, psnr: {:.2f}'.format(best_epoch, best_psnr))
torch.save(best_weights, os.path.join(outputs_dir, 'best.pth'))


epoch: 0/9: 100%|██████████| 21760/21760 [16:19<00:00, 22.21it/s, loss=0.006608]  


eval psnr: 28.82


epoch: 1/9: 100%|██████████| 21760/21760 [00:59<00:00, 364.12it/s, loss=0.002877] 


eval psnr: 28.99


epoch: 2/9: 100%|██████████| 21760/21760 [00:55<00:00, 393.80it/s, loss=0.002808] 


eval psnr: 29.09


epoch: 3/9: 100%|██████████| 21760/21760 [00:54<00:00, 401.73it/s, loss=0.002772] 


eval psnr: 29.11


epoch: 4/9: 100%|██████████| 21760/21760 [00:57<00:00, 379.09it/s, loss=0.002748] 


eval psnr: 29.22


epoch: 5/9: 100%|██████████| 21760/21760 [00:55<00:00, 394.46it/s, loss=0.002730] 


eval psnr: 29.19


epoch: 6/9: 100%|██████████| 21760/21760 [00:55<00:00, 395.12it/s, loss=0.002714] 


eval psnr: 29.30


epoch: 7/9: 100%|██████████| 21760/21760 [00:56<00:00, 386.06it/s, loss=0.002698] 


eval psnr: 29.31


epoch: 8/9: 100%|██████████| 21760/21760 [00:57<00:00, 380.77it/s, loss=0.002688] 


eval psnr: 29.31


epoch: 9/9: 100%|██████████| 21760/21760 [00:58<00:00, 372.77it/s, loss=0.002675] 


eval psnr: 29.36
best epoch: 9, psnr: 29.36


In [7]:
for data in eval_dataloader:
    inputs, labels = data

    inputs = inputs.to(device)
    labels = labels.to(device)

    with torch.no_grad():
        preds_feat = model_feat(inputs).clamp(0.0, 1.0)
        preds_MSE = model_MSE(inputs).clamp(0.0, 1.0)

    print(calc_psnr(preds_feat, labels), calc_psnr(preds_MSE, labels))




tensor(32.7767, device='mps:0') tensor(33.0503, device='mps:0')
tensor(30.0175, device='mps:0') tensor(30.3439, device='mps:0')
tensor(24.2898, device='mps:0') tensor(24.4742, device='mps:0')
tensor(30.5397, device='mps:0') tensor(30.9082, device='mps:0')
tensor(27.8911, device='mps:0') tensor(28.0476, device='mps:0')


  nonzero_finite_vals = torch.masked_select(
