In [None]:
import os
import numpy as np
import seaborn as sns
import lightning as L
import torch.optim as optim
import glob
import h5py
import pandas as pd
import matplotlib.pyplot as plt

from lightning.pytorch.loggers import CSVLogger
from skimage.transform import resize
from torch.utils.data import DataLoader
from models import *
from utils import *
from dataset import *

In [None]:
LR = 1e-3
DECAY = 1e-7 
EPOCHS = 10
BATCH_SIZE = 64
PATCH_SIZE = 128

In [None]:
path = '../data/super_res_set1'
train = LSTDataset(os.path.join(path, 'train.hdf5'))
valid = LSTDataset(os.path.join(path, 'valid.hdf5'))
test = LSTDataset(os.path.join(path, 'test.hdf5'))

loader_train = DataLoader(dataset=train, batch_size=BATCH_SIZE, shuffle=True)
loader_valid = DataLoader(dataset=valid, batch_size=len(valid), shuffle=True)
loader_test = DataLoader(dataset=test, batch_size=len(test), shuffle=True)

In [None]:
train_loss = []
valid_loss = []

if torch.cuda.is_available():
  device = torch.device(torch.cuda.current_device())
else:
  raise RuntimeError('No GPU')

model = DnCNN(channels=1).to(device)

In [None]:
model.apply(weights_init_kaiming)
optimizer = optim.Adam(model.parameters(), lr=LR, weight_decay=DECAY)
loss_fn = nn.MSELoss(reduction='sum')

modis_model = dncnn_lightning(model=model, optimizer=optimizer, loss_fn=loss_fn)
# modis_model = dncnn_lightning.load_from_checkpoint('./logs/super_res1/version_2/checkpoints/epoch=8-step=720.ckpt', model=model, optimizer=optimizer, loss_fn=loss_fn)

In [None]:
exp_name = 'super_res1'
logger = CSVLogger('logs', name=exp_name)
logger.log_hyperparams({'epochs': EPOCHS, 'loss_fn': str(loss_fn), 'lr': LR, 'optimizer': str(optimizer)})

trainer = L.Trainer(max_epochs=EPOCHS, logger=logger, log_every_n_steps=5)
trainer.fit(model=modis_model, train_dataloaders=loader_train, val_dataloaders=loader_valid)

In [None]:
metrics = pd.read_csv('logs/super_res1/version_15/metrics.csv')

In [None]:
plt.plot(metrics['train_loss'])
plt.plot(metrics['valid_loss'])

In [None]:
losses = trainer.test(model=modis_model, dataloaders=loader_test)