In [None]:
import torch
import torch.optim as optim
import numpy as np

from tqdm import tqdm
from os import listdir, sep
from kornia.color import rgb_to_xyz, xyz_to_rgb

from network import *
from unet import UNet
from dataset import get_data
from loss import *

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device', device)
try:
    torch.multiprocessing.set_start_method('spawn')
except RuntimeError:
    pass

data_path = '/home/tima/projects/isp/dataset/S7-ISP-Dataset'
train, test = get_data(data_path, num_workers=4, batch_size=2, crop_size=512)

print(f'train batch number {len(train)}')
print(f'test  batch number {len(test)}')

# train = BackgroundGenerator(train)
# test = BackgroundGenerator(test)

In [None]:
start_epoch = 0
e = 501
test_every_n = 10
lr = 5e-5
alpha = 0.5

make_checkpoints = True
checkpoint_path = '/home/tima/projects/isp/CameraNet/CP'

In [None]:
def get_cp_name(epoch, checkpoint_path):
    dirs = listdir(checkpoint_path)
    cp = [i for i in dirs if f'_e{epoch}_' in i]
    return sep.join([checkpoint_path, cp[0]])


model = CameraNet().enhance
model_name = 'enhance'
# model = UNet(3, 3)
# print(model)
model = torch.nn.DataParallel(model)
model = model.to(device)
# model.to('cuda')

if start_epoch > 0:
    cp = get_cp_name(start_epoch-1, checkpoint_path)
    params = cp.split(sep)[-1].split('_e')[0][5:].split('-')
    cp = torch.load(cp)
    o = cp['model_state_dict']
    # state_dict = {}
    # for i in o.keys():
    #     state_dict[i.replace('module.', '')] = o[i]
    # model.load_state_dict(state_dict)
    model.load_state_dict(o)

In [None]:
print('Total number of parameters:')
n_params = torch.sum(torch.tensor([p.numel() for p in model.parameters()], dtype=torch.int)).item()
print(f'{(n_params / 1e6):.02f}M')
# print(f'{device=}')
crit_res = RestoreNetLoss(device=device)
crit_enh = EnhanceNetLoss()
# crit_res = deepISPloss()

optimizer = optim.Adam(CameraNet.parameters(model), lr, weight_decay=1e-5)

In [None]:
print('Starting trainig...')

for epoch in  range(start_epoch, e):
    train_iter = tqdm(train, ncols=150, desc='Epoch: {}, training'.format(epoch))
    train_loss = []
    for idx, (x, mid, target) in enumerate(train_iter):
        x = x.float().to(device)
        mid = mid.float().to(device)
        target = target.float().to(device)
        m = model(mid)
#         l1 = crit_res(m, rgb_to_xyz(mid))
#         l1 = crit_res(m, mid)
        l1 = crit_res(m, target)

        loss = l1
        train_loss.append(loss)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_iter.set_postfix(str=f'loss: {torch.mean(torch.tensor(train_loss)).item():.03f}')
    train_iter.close()
    torch.cuda.empty_cache()

    # skipping testing and saving checkpoints for some epochs
    if epoch % test_every_n != 0:
        continue

    test_loss = []
    test_iter = tqdm(test, ncols=150, desc='Epoch: {}, testing '.format(epoch))
    with torch.no_grad():
        for idx, (x, mid, target) in enumerate(test_iter):
            x = x.float().to(device)
            mid = mid.float().to(device)
            target = target.float().to(device)
            m = model(mid)
            l1 = crit_res(m, target)

            loss = l1

            test_loss.append(loss.item())
            test_iter.set_postfix(str=f'loss: {torch.mean(torch.tensor(test_loss)).item():.03f}')
    test_iter.close()

    if make_checkpoints:
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': np.mean(test_loss),
        }, checkpoint_path + '/{}t_e{}_loss{}'.format(model_name, epoch, round(np.mean(test_loss), 3)))

print('Training done!')