In [None]:
from tqdm import tqdm
import torch
import torch.optim as optim

from dataset import get_data
from network import DeepISP
from loss import deepISPloss

In [None]:
# data_path = '/home/jupyter/mnt/datasets/S7Dataset/S7-ISP-Dataset'
data_path = '/home/tima/projects/isp/dataset/S7-ISP-Dataset'
train, test = get_data(data_path, batch_size=1, crop_size=256)

print(f'train batch number {len(train)}')
print(f'test  batch number {len(test)}')

In [None]:
# questions:
# - sum in low level

start_epoch = 0
e = 100
lr = 10e-5
momentum = 0.9
betas = (0.9, 0.999)

make_checkpoints = True
checkpoint_path = '/home/jupyter/work/resources/deepISP-implementation/checkp/15-3m'
# checkpoint_path = '/home/tima/projects/isp/deepisp/CP'

epochs = [i for i in range(start_epoch, e)]

# we can create any number of low level layers
# but we can create limited number of high level layers
# its because we do pool(2, 2) in every hl layer
# so we can create maximum hlc = log2(img_size)
# assuming image is a squire matrix with height = width = img_size
llc, hlc = 15, 3
model = DeepISP(llc, hlc).float()
criterion = deepISPloss()
# optimizer = optim.SGD(DeepISP.parameters(model), lr, momentum)
optimizer = optim.Adam(DeepISP.parameters(model), lr, betas)

torch.autograd.set_detect_anomaly(True)

if start_epoch > 0:
    dirs = listdir(checkpoint_path)
    cp = [i for i in dirs if f'_e{start_epoch - 1}_' in i]
    checkpoint = torch.load(sep.join([checkpoint_path, cp[0]]))
    
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

In [None]:
print('Starting trainig...')

for epoch in epochs:
    train_iter = tqdm(train, ncols=100, desc='Epoch: {}, training'.format(epoch))
    for (x, target) in train_iter:
        optimizer.zero_grad()
        y = model(x.float())
        loss = criterion(y, target)
        loss.requires_grad_()
        
        # print(loss)
        loss.backward()
        optimizer.step()
    train_iter.close()
    
    test_iter = tqdm(test, ncols=128, desc='Epoch: {}, testing '.format(epoch))
    for idx, (x, target) in enumerate(test_iter):
        y = model(x)
        loss = criterion(y, target)
        test_loss += loss
        test_iter.set_postfix(str=f'loss: {test_loss / (idx + 1)}')
    test_loss /= len(test_iter)
    test_iter.close()
    
    if make_checkpoints:
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': test_loss,
        }, checkpoint_path + '/model{}-{}_e{}_loss{}'.format(llc, hlc, epoch, test_loss))

print('Training done!')