In [None]:
import torch
import os
import skimage
from PIL import Image
from torchvision.transforms import Resize, Compose, ToTensor, Normalize
from pathlib import Path
from torch.utils.data import DataLoader, Dataset
from training.wandblogger import WandBLogger2D
from training.trainer import MRTrainer
from datasets.imagesignal import ImageSignal
from networks.mrnet import MRFactory
from datasets.imageutils import gaussian_pyramid2D, gaussian_tower2D, box_kernel2D
import yaml
from yaml.loader import SafeLoader

In [None]:
os.environ["WANDB_NOTEBOOK_NAME"] = "l-net.ipynb"
BASE_DIR = Path('.').absolute().parents[0]
IMAGE_PATH = BASE_DIR.joinpath('img')

In [None]:
project_name = "testing_lnet_yaml_dictionary_test_gradients"

# hyperparameters
with open('../configs/config_base_l_net.yml') as f:
    hyper = yaml.load(f, Loader=SafeLoader)
    print(hyper)


In [None]:
kernel = box_kernel2D(5)
base_signal = ImageSignal.init_fromfile(
                    os.path.join(IMAGE_PATH, hyper['image_name']),
                    useattributes=hyper.get('useattributes', False),
                    batch_pixels=None,
                    width=hyper['width'],height= hyper['height'])
# incluir parametro batch_pixels - aqui e no dicionario acima
hyper['width'], hyper['height'] = base_signal.dimensions()
if hyper['multiresolution'] == 'capacity':
    train_dataloader = DataLoader(base_signal, batch_size=hyper['batch_size'], shuffle=True, pin_memory=True, num_workers=0)
    test_dataloader = DataLoader(base_signal, batch_size=hyper['batch_size'], pin_memory=True, num_workers=0)
else:
    pyramid = gaussian_pyramid2D(base_signal, hyper['max_stages'], kernel)
    tower = gaussian_tower2D(base_signal, hyper['max_stages'], kernel)
    trainsource = pyramid if hyper['multiresolution'] == 'pyramid' else tower
    train_dataloader = [DataLoader(signal, shuffle=True, batch_size=hyper['batch_size']) 
                        for signal in trainsource]
    test_dataloader = [DataLoader(signal, batch_size=hyper['batch_size']) 
                        for signal in tower]

In [None]:
wandblogger = WandBLogger2D(project_name,
                            f"{hyper['model']}{hyper['multiresolution'][0].upper()}{hyper['image_name'][0:4]}_",
                            hyper,
                            BASE_DIR)
mrmodel = MRFactory.from_dict(hyper)
print("Model: ", type(mrmodel))
mrtrainer = MRTrainer.init_from_dict(mrmodel, train_dataloader, test_dataloader, wandblogger, hyper)
mrtrainer.train(hyper['device'])