In [1]:
import argparse
import os
import copy

import torch
from torch import nn
import torch.optim as optim
import torch.backends.cudnn as cudnn
from torch.utils.data.dataloader import DataLoader
from tqdm import tqdm

from models import SRCNN, SRCNN_video
from datasets import TrainDataset, EvalDataset, TrainDataset_3D, EvalDataset_3D
from utils import AverageMeter, calc_psnr

In [2]:
# num_workers = 8
# all the arguments in variables with their default values
# train_file = 'original_training_data/x4/91-image_x4.h5' #TODO
train_file = 'original_training_data/for_training/AMVTG_004.h5' #TODO
# eval_file = 'original_training_data/x4/Set5_x4.h5' #TODO
eval_file = 'preparing_data/prepare_out_3d/for_eval/AMVTG_004.h5' #TODO
# eval_file = 'original_training_data/for_training/car05_001.h5' #TODO
outputs_dir = 'outputs'
scale = 4
lr = 1e-4
batch_size = 320
num_epochs = 10
num_workers = 12
seed = 123

# new output dir using the statics variables
outputs_dir = os.path.join(outputs_dir, 'x{}'.format(scale))

# if not os.path.exists(args.outputs_dir):
#     os.makedirs(args.outputs_dir)


if not os.path.exists(outputs_dir):
    os.makedirs(outputs_dir)

cudnn.benchmark = True
device = torch.device('cuda:0' if torch.cuda.is_available() else 'mps')

# torch.manual_seed(args.seed)
torch.manual_seed(seed)

model = SRCNN_video().to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam([
    {'params': model.conv1.parameters()},
    {'params': model.conv2.parameters()},
    {'params': model.conv3.parameters(), 'lr': lr * 0.1}
], lr=lr)

train_dataset = TrainDataset_3D(train_file) # CHANGE HERE
train_dataloader = DataLoader(dataset=train_dataset,
                                batch_size=batch_size,
                                shuffle=True,
                                num_workers=num_workers,
                                pin_memory=True,
                                drop_last=True)
eval_dataset = EvalDataset_3D(eval_file)
eval_dataloader = DataLoader(dataset=eval_dataset, batch_size=1)

best_weights = copy.deepcopy(model.state_dict())
best_epoch = 0
best_psnr = 0.0

for epoch in range(num_epochs):
    model.train()
    epoch_losses = AverageMeter()

    with tqdm(total=(len(train_dataset) - len(train_dataset) % batch_size)) as t:
        t.set_description('epoch: {}/{}'.format(epoch+1, num_epochs))

        for data in train_dataloader:
            inputs, labels = data

            inputs = inputs.to(device)
            labels = labels.to(device)

            preds = model(inputs)

            loss = criterion(preds, labels)

            epoch_losses.update(loss.item(), len(inputs))

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            t.set_postfix(loss='{:.6f}'.format(epoch_losses.avg))
            t.update(len(inputs))

    torch.save(model.state_dict(), os.path.join(outputs_dir, 'epoch_{}.pth'.format(epoch)))

    model.eval()
    epoch_psnr = AverageMeter()

    for data in eval_dataloader:
        inputs, labels = data

        inputs = inputs.to(device)
        labels = labels.to(device)

        with torch.no_grad():
            preds = model(inputs).clamp(0.0, 1.0)

        epoch_psnr.update(calc_psnr(preds, labels), len(inputs))

    print('eval psnr: {:.2f}'.format(epoch_psnr.avg))

    if epoch_psnr.avg > best_psnr:
        best_epoch = epoch
        best_psnr = epoch_psnr.avg
        best_weights = copy.deepcopy(model.state_dict())

print('best epoch: {}, psnr: {:.2f}'.format(best_epoch, best_psnr))
torch.save(best_weights, os.path.join(outputs_dir, 'best.pth'))


epoch: 0/9: 100%|██████████| 74240/74240 [01:13<00:00, 1011.38it/s, loss=0.012667]


eval psnr: 23.73


epoch: 1/9: 100%|██████████| 74240/74240 [01:13<00:00, 1016.83it/s, loss=0.004150]


eval psnr: 24.38


epoch: 2/9: 100%|██████████| 74240/74240 [01:12<00:00, 1017.49it/s, loss=0.003670]


eval psnr: 24.61


epoch: 3/9: 100%|██████████| 74240/74240 [01:12<00:00, 1017.60it/s, loss=0.003461]


eval psnr: 24.73


epoch: 4/9: 100%|██████████| 74240/74240 [01:13<00:00, 1016.02it/s, loss=0.003328]


eval psnr: 24.82


epoch: 5/9: 100%|██████████| 74240/74240 [01:13<00:00, 1008.84it/s, loss=0.003226]


eval psnr: 24.90


epoch: 6/9: 100%|██████████| 74240/74240 [01:13<00:00, 1005.47it/s, loss=0.003138]


eval psnr: 24.98


epoch: 7/9: 100%|██████████| 74240/74240 [01:13<00:00, 1006.97it/s, loss=0.003061]


eval psnr: 25.05


epoch: 8/9: 100%|██████████| 74240/74240 [01:13<00:00, 1005.04it/s, loss=0.002993]


eval psnr: 25.11


epoch: 9/9: 100%|██████████| 74240/74240 [01:13<00:00, 1011.25it/s, loss=0.002931]


eval psnr: 25.15
best epoch: 9, psnr: 25.15


In [None]:
import h5py

def print_structure(h5_file, indent=''):
    """
    Recursively prints the structure of an HDF5 file along with dataset shapes and data types.
    """
    print(f"All keys: {h5_file.keys()} ")
    for key in h5_file.keys():   
        item = h5_file[key]
        print(f'{indent}{key}: ', end='')
        if isinstance(item, h5py.Dataset):  # Check if the item is a dataset
            print(f'Dataset with shape {item.shape} and data type {item.dtype}')
        elif isinstance(item, h5py.Group):  # Check if the item is a group
            print(f'Group')
            print_structure(item, indent + '    ')  # Recurse into the group with increased indentation

# Usage example
            
            #EVAL DATA
# print('EVAL DATA')
# with h5py.File('original_training_data/x4/Set5_x4.h5', 'r') as file:
#     print_structure(file)
#             # TRAINING DATA
# print('TRAINING DATA')
# with h5py.File('original_training_data/x4/91-image_x4.h5', 'r') as file:
#     print_structure(file)



print('EVAL DATA')
with h5py.File('preparing_data/prepare_out_3d/for_eval/AMVTG_004.h5', 'r') as file:
    print_structure(file)
            # TRAINING DATA
print('TRAINING DATA')
with h5py.File('original_training_data/for_training/AMVTG_004.h5', 'r') as file:
    print_structure(file)

EVAL DATA
All keys: <KeysViewHDF5 ['hr', 'lr', 'prev_lr']> 
hr: Group
All keys: <KeysViewHDF5 ['1', '10', '11', '12', '13', '14', '15', '16', '17', '18', '19', '2', '20', '21', '22', '23', '24', '25', '26', '27', '28', '29', '3', '30', '4', '5', '6', '7', '8', '9']> 
    1: Dataset with shape (540, 960) and data type float32
    10: Dataset with shape (540, 960) and data type float32
    11: Dataset with shape (540, 960) and data type float32
    12: Dataset with shape (540, 960) and data type float32
    13: Dataset with shape (540, 960) and data type float32
    14: Dataset with shape (540, 960) and data type float32
    15: Dataset with shape (540, 960) and data type float32
    16: Dataset with shape (540, 960) and data type float32
    17: Dataset with shape (540, 960) and data type float32
    18: Dataset with shape (540, 960) and data type float32
    19: Dataset with shape (540, 960) and data type float32
    2: Dataset with shape (540, 960) and data type float32
    20: Datase