In [None]:
import torch
import os
import skimage
from PIL import Image
from pathlib import Path
from torch.utils.data import DataLoader, Dataset
from logs.locallogger import LocalLogger2D
from training.trainer import MRTrainer
from datasets.imagesignal import ImageSignal
from networks.mrnet import MRFactory
from datasets.pyramids import create_MR_structure
from datasets.sampler import make2Dcoords
import yaml
from yaml.loader import SafeLoader
import matplotlib.pyplot as plt

In [None]:
os.environ["WANDB_NOTEBOOK_NAME"] = "train-wb.ipynb"
BASE_DIR = Path('.').absolute().parents[0]
IMAGE_PATH = BASE_DIR.joinpath('img')
MODEL_PATH = BASE_DIR.joinpath('models')

In [None]:
project_name = "test_net_f3"
config_file = '../configs/config_base_m_net.yml'
with open(config_file) as f:
    hyper = yaml.load(f, Loader=SafeLoader)
    print(hyper)

In [None]:
base_signal = ImageSignal.init_fromfile(
                    os.path.join(IMAGE_PATH, hyper['image_name']),
                    batch_samples_perc=hyper['batch_samples_perc'],
                    sampling_scheme=hyper['sampling_scheme'],
                    width=hyper['width'],height= hyper['height'],
                    attributes=hyper['attributes'])
if hyper['multiresolution'] == 'signal':
    train_dataloader = DataLoader(base_signal, batch_size=hyper['batch_size'], shuffle=True, pin_memory=True, num_workers=0)
    test_dataloader = DataLoader(base_signal, batch_size=hyper['batch_size'], pin_memory=True, num_workers=0)
else:
    hyper['type_mr'], hyper['shape_mr'] = tuple(hyper['multiresolution'].split('_'))

    pyramid = create_MR_structure(base_signal, hyper['max_stages'],type_pyr=hyper['type_mr']+ "_pyramid")
    tower = create_MR_structure(base_signal, hyper['max_stages'],type_pyr=hyper['type_mr']  + "_tower")

    trainsource = pyramid if hyper['shape_mr'] == 'pyramid' else tower
    train_dataloader = [signal
                        for signal in trainsource]
    test_dataloader = [signal 
                        for signal in tower]

In [None]:
locallogger = LocalLogger2D(project_name,
                            f"{hyper['model']}{hyper['multiresolution'][0].upper()}{hyper['image_name'][0:4]}_",
                            hyper,
                            BASE_DIR)
mrmodel = MRFactory.from_dict(hyper)
print("Model: ", type(mrmodel))
mrtrainer = MRTrainer.init_from_dict(mrmodel, train_dataloader, test_dataloader, locallogger, hyper)
mrtrainer.train(hyper['device'])

In [None]:
filename = f"{hyper['model']}{hyper['multiresolution'][0].upper()}{hyper['image_name'][0:4]}.pth"
path = os.path.join(MODEL_PATH, filename)

MRFactory.save(mrmodel, path)

In [None]:
mrmodel_eval = MRFactory.load_state_dict(path)

output = mrmodel_eval(make2Dcoords(128,128))
model_out = torch.clamp(output['model_out'], 0.0, 1.0)

plt.imshow(model_out.cpu().view(128,128).detach().numpy())