In [1]:
import torch
import os
import skimage
from PIL import Image
from pathlib import Path
from torch.utils.data import DataLoader, Dataset
from logging_module.wandblogger import WandBLogger2D
from training.trainer import MRTrainer
from datasets.imagesignal import ImageSignal
from networks.mrnet import MRFactory
from datasets.pyramids import create_MR_structure
import yaml
from yaml.loader import SafeLoader
import os

In [2]:
os.environ["WANDB_NOTEBOOK_NAME"] = "train-wb.ipynb"
BASE_DIR = Path('.').absolute().parents[0]
IMAGE_PATH = BASE_DIR.joinpath('img')

In [3]:
project_name = "lvelho_fase2_teste_laplace_with_mrweights"
#-- hyperparameters in configs --#
config_file = '../configs/config_base_m_net.yml'
with open(config_file) as f:
    hyper = yaml.load(f, Loader=SafeLoader)
    print(hyper)

{'omega_0': [8, 16, 32, 64], 'in_features': 2, 'hidden_features': [64, 64, 96, 96], 'hidden_layers': 1, 'superposition_w0': True, 'hidden_omega_0': [30, 30, 30, 30], 'sampling_scheme': 'uniform', 'multiresolution': 'laplace_pyramid', 'max_epochs_per_stage': [500, 500, 600, 600], 'opt_method': 'Adam', 'loss_function': ['d0'], 'lr': 0.0001, 'loss_tol': 1e-16, 'diff_tol': 1e-05, 'batch_pixels_perc': 1, 'batch_size': 1, 'image_name': 'lena.png', 'width': 256, 'height': 256, 'channels': 1, 'max_stages': 4, 'model': 'L', 'useattributes': True, 'device': 'cuda', 'eval_device': 'cpu', 'bias': False}


In [4]:
base_signal = ImageSignal.init_fromfile(
                    os.path.join(IMAGE_PATH, hyper['image_name']),
                    useattributes=hyper.get('useattributes', False),
                    batch_pixels_perc=hyper['batch_pixels_perc'],
                    width=hyper['width'],height= hyper['height'])
if hyper['multiresolution'] == 'signal':
    train_dataloader = DataLoader(base_signal, batch_size=hyper['batch_size'], shuffle=True, pin_memory=True, num_workers=0)
    test_dataloader = DataLoader(base_signal, batch_size=hyper['batch_size'], pin_memory=True, num_workers=0)
else:
    hyper['type_mr'], hyper['shape_mr'] = tuple(hyper['multiresolution'].split('_'))

    pyramid = create_MR_structure(base_signal, hyper['max_stages'],type_pyr=hyper['type_mr']+ "_pyramid")
    tower = create_MR_structure(base_signal, hyper['max_stages'],type_pyr=hyper['type_mr']  + "_tower")

    trainsource = pyramid if hyper['shape_mr'] == 'pyramid' else tower
    train_dataloader = [DataLoader(signal, shuffle=True, batch_size=hyper['batch_size']) 
                        for signal in trainsource]
    test_dataloader = [DataLoader(signal, batch_size=hyper['batch_size']) 
                        for signal in tower]

In [5]:
wandblogger = WandBLogger2D(project_name,
                            f"{hyper['model']}{hyper['multiresolution'][0].upper()}{hyper['image_name'][0:4]}_",
                            hyper,
                            BASE_DIR)
mrmodel = MRFactory.from_dict(hyper)
print("Model: ", type(mrmodel))
mrtrainer = MRTrainer.init_from_dict(mrmodel, train_dataloader, test_dataloader, wandblogger, hyper)
mrtrainer.train(hyper['device'])

Model:  <class 'networks.mrnet.LNet'>
MRWEIGHTS  tensor([1.])


[34m[1mwandb[0m: Currently logged in as: [33mlvelho[0m ([33msiren-song[0m). Use [1m`wandb login --relogin`[0m to force relogin


MRWEIGHTS  tensor([0., 1.])


0,1
D0 loss,█▆▅▄▄▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
D0 loss,0.00088


MRWEIGHTS  tensor([0., 0., 1.])


0,1
D0 loss,█▇▆▅▅▄▃▃▃▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
D0 loss,8e-05


MRWEIGHTS  tensor([0., 0., 0., 1.])


0,1
D0 loss,█▆▅▄▄▄▄▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
D0 loss,0.00017


0,1
D0 loss,██▆▅▅▅▅▄▄▄▄▄▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁

0,1
D0 loss,0.00073


Total model parameters =  27904
Training finished after 2200 epochs
