In [None]:
import torch
import os
from pathlib import Path
from logs.locallogger import LocalLogger2D
from training.trainer import MRTrainer
from datasets.imagesignal import ImageSignal
from networks.mrnet import MRFactory
from datasets.pyramids import create_MR_structure
from datasets.sampler import make_grid_coords
import yaml
from yaml.loader import SafeLoader
import matplotlib.pyplot as plt

In [None]:
os.environ["WANDB_NOTEBOOK_NAME"] = "train-wb.ipynb"
BASE_DIR = Path('.').absolute().parents[0]
IMAGE_PATH = BASE_DIR.joinpath('img')
MODEL_PATH = BASE_DIR.joinpath('models')

In [None]:
project_name = "evaluation"
config_file = '../configs/config_base_m_net.yml'
with open(config_file) as f:
    hyper = yaml.load(f, Loader=SafeLoader)
    print(hyper)

In [None]:
base_signal = ImageSignal.init_fromfile(
                    os.path.join(IMAGE_PATH, hyper['image_name']),
                    domain=hyper['domain'],
                    batch_samples_perc=hyper['batch_samples_perc'],
                    sampling_scheme=hyper['sampling_scheme'],
                    width=hyper['width'], height= hyper['height'],
                    attributes=hyper['attributes'],
                    channels=hyper['channels'])
train_dataloader = create_MR_structure(base_signal, hyper['max_stages'],
                                        hyper['filter'], hyper['decimation'])
test_dataloader = create_MR_structure(base_signal, hyper['max_stages'],
                                        hyper['filter'])

In [None]:
locallogger = LocalLogger2D(project_name,
                            f"{hyper['model']}{hyper['filter'][0].upper()}{hyper['image_name'][0:4]}_",
                            hyper,
                            BASE_DIR, 
                            to_file=True)
mrmodel = MRFactory.from_dict(hyper)
print("Model: ", type(mrmodel))
mrtrainer = MRTrainer.init_from_dict(mrmodel, 
                                    train_dataloader, test_dataloader, locallogger, hyper)
mrtrainer.train(hyper['device'])

In [None]:
filename = f"{hyper['model']}{hyper['filter'][0].upper()}{hyper['image_name'][0:4]}.pth"
path = os.path.join(MODEL_PATH, filename)

MRFactory.save(mrmodel, path)

In [None]:
mrmodel_eval = MRFactory.load_state_dict(path)

output = mrmodel_eval(make_grid_coords((hyper['width'], hyper['height']), 
                                       *hyper['domain'], dim=2))
model_out = torch.clamp(output['model_out'], 0.0, 1.0)

plt.imshow(model_out.cpu().view(hyper['width'], hyper['height']).detach().numpy())