In [1]:
import torch
import os
import skimage
from PIL import Image
from pathlib import Path
from torch.utils.data import DataLoader, Dataset
from training.wandblogger import WandBLogger2D
from training.trainer import MRTrainer
from datasets.imagesignal import ImageSignal
from networks.mrnet import MRFactory
from datasets.pyramids import create_MR_structure
import yaml
from yaml.loader import SafeLoader
import os

In [2]:
os.environ["WANDB_NOTEBOOK_NAME"] = "train-wb.ipynb"
BASE_DIR = Path('.').absolute().parents[0]
IMAGE_PATH = BASE_DIR.joinpath('img')

In [3]:
project_name = "test_mnet_f2"
#-- hyperparameters in configs --#
#config_file = '../configs/config_base_l_net.yml'
config_file = '../configs/config_base_m_net.yml'
with open(config_file) as f:
    hyper = yaml.load(f, Loader=SafeLoader)
    print(hyper)

{'omega_0': [64, 128], 'in_features': 2, 'hidden_features': [128, 256], 'hidden_layers': 1, 'superposition_w0': False, 'hidden_omega_0': [70, 80], 'sampling_scheme': 'uniform', 'multiresolution': 'pyramid', 'max_epochs_per_stage': [250, 150], 'opt_method': 'Adam', 'loss_function': ['d0'], 'lr': 0.0001, 'loss_tol': 1e-14, 'diff_tol': 1e-11, 'batch_pixels_perc': 1.0, 'batch_size': 1, 'image_name': 'lena.png', 'width': 128, 'height': 128, 'channels': 1, 'stage': 1, 'max_stages': 2, 'model': 'M', 'useattributes': True, 'device': 'cuda', 'eval_device': 'cuda', 'bias': True, 'save_format': 'general'}


In [4]:
base_signal = ImageSignal.init_fromfile(
                    os.path.join(IMAGE_PATH, hyper['image_name']),
                    useattributes=hyper.get('useattributes', False),
                    batch_pixels_perc=hyper['batch_pixels_perc'],
                    width=hyper['width'],height= hyper['height'])
if hyper['multiresolution'] == 'capacity':
    train_dataloader = DataLoader(base_signal, batch_size=hyper['batch_size'], shuffle=True, pin_memory=True, num_workers=0)
    test_dataloader = DataLoader(base_signal, batch_size=hyper['batch_size'], pin_memory=True, num_workers=0)
else:
    pyramid = create_MR_structure(base_signal, hyper['max_stages'],type_pyr="pyramid")
    tower = create_MR_structure(base_signal, hyper['max_stages'],type_pyr="tower")
    laplacian_tower = create_MR_structure(base_signal, hyper['max_stages'],type_pyr="laplace_tower")
    laplacian_pyramid = create_MR_structure(base_signal, hyper['max_stages'],type_pyr="laplace_pyramid")

    trainsource = pyramid if hyper['multiresolution'] == 'pyramid' else tower
    train_dataloader = [DataLoader(signal, shuffle=True, batch_size=hyper['batch_size']) 
                        for signal in trainsource]
    test_dataloader = [DataLoader(signal, batch_size=hyper['batch_size']) 
                        for signal in tower]

In [5]:
wandblogger = WandBLogger2D(project_name,
                            f"{hyper['model']}{hyper['multiresolution'][0].upper()}{hyper['image_name'][0:4]}_",
                            hyper,
                            BASE_DIR)
mrmodel = MRFactory.from_dict(hyper)
print("Model: ", type(mrmodel))
mrtrainer = MRTrainer.init_from_dict(mrmodel, train_dataloader, test_dataloader, wandblogger, hyper)
mrtrainer.train(hyper['device'])

Model:  <class 'networks.mrnet.MNet'>


[34m[1mwandb[0m: Currently logged in as: [33mhallpaz[0m ([33msiren-song[0m). Use [1m`wandb login --relogin`[0m to force relogin


0,1
D0 loss,█▇▅▄▃▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
D0 loss,0.00025


File  MPlena__2-2_w128F_hf256_MEp150_hl1_128px.pth


0,1
D0 loss,█▇▇▆▆▅▅▄▃▃▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
D0 loss,0.0


Training finished after 400 epochs
