In [1]:
import os
from pathlib import Path

from logs.wandblogger import WandBLogger2D
from training.trainer import MRTrainer
from datasets.imagesignal import ImageSignal
from networks.mrnet import MRFactory
from datasets.pyramids import create_MR_structure
import yaml
from yaml.loader import SafeLoader
import os

In [2]:
os.environ["WANDB_NOTEBOOK_NAME"] = "train-wb.ipynb"
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
PYTORCH_ENABLE_MPS_FALLBACK = 1
BASE_DIR = Path('.').absolute().parents[0]
IMAGE_PATH = BASE_DIR.joinpath('img')

In [3]:
project_name = "testing_periodic"
#-- hyperparameters in configs --#
config_file = '../configs/config_base_m_net.yml'
with open(config_file) as f:
    hyper = yaml.load(f, Loader=SafeLoader)
    print(hyper)

{'model': 'M', 'in_features': 2, 'hidden_layers': 1, 'hidden_features': [25, 81, 289], 'bias': True, 'periodic': True, 'max_stages': 3, 'omega_0': [4, 8, 16], 'hidden_omega_0': [30, 30, 30, 30], 'superposition_w0': True, 'sampling_scheme': 'regular', 'decimation': True, 'filter': 'gauss', 'attributes': ['d0'], 'loss_function': 'mse', 'opt_method': 'Adam', 'lr': 0.0001, 'loss_tol': 1e-14, 'diff_tol': 1e-11, 'max_epochs_per_stage': [8000, 4000, 2000, 1000], 'batch_size': 1, 'batch_samples_perc': 1.0, 'image_name': 'periodic.png', 'width': 256, 'height': 256, 'channels': 1, 'device': 'cpu', 'eval_device': 'cpu', 'save_format': 'general', 'visualize_grad': True, 'extrapolate': [-3, 3]}


In [4]:
base_signal = ImageSignal.init_fromfile(
                    os.path.join(IMAGE_PATH, hyper['image_name']),
                    batch_samples_perc=hyper['batch_samples_perc'],
                    sampling_scheme=hyper['sampling_scheme'],
                    width=hyper['width'], height=hyper['height'],
                    attributes=hyper['attributes'], channels=hyper['channels'])
train_dataloader = create_MR_structure(base_signal, hyper['max_stages'],hyper['filter'],hyper['decimation'])
test_dataloader = create_MR_structure(base_signal, hyper['max_stages'],hyper['filter'])

In [5]:
hyper['device']

'cpu'

In [6]:
wandblogger = WandBLogger2D(project_name,
                            f"{hyper['model']}{hyper['filter'][0].upper()}{hyper['image_name'][0:4]}_C",
                            hyper,
                            BASE_DIR,
                            visualize_gt_grads=hyper.get('visualize_grad', False))
mrmodel = MRFactory.from_dict(hyper)
print("Model: ", type(mrmodel))
mrtrainer = MRTrainer.init_from_dict(mrmodel, train_dataloader, test_dataloader, wandblogger, hyper)
mrtrainer.train(hyper['device'])

Model:  <class 'networks.mrnet.MNet'>


[34m[1mwandb[0m: Currently logged in as: [33mhallpaz[0m ([33msiren-song[0m). Use [1m`wandb login --relogin`[0m to force relogin


No gradients in sampler and visualization is True. Set visualize_grad to False


0,1
D0 loss,█▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
D0 loss,0.00288


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016673314583333366, max=1.0…

No gradients in sampler and visualization is True. Set visualize_grad to False


0,1
D0 loss,█▃▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
D0 loss,0.00052


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01668764583333342, max=1.0)…

No gradients in sampler and visualization is True. Set visualize_grad to False
File  MGperi_C_3-3_w16T_hf289_MEp2000_hl1_256px.pth


0,1
D0 loss,█▆▃▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
D0 loss,0.00023


Total model parameters =  118116
Training finished after 13735 epochs


In [7]:
import torch
mrmodel.stages[0].first_layer.linear.weight / torch.pi

tensor([[ 1., -1.],
        [-4.,  0.],
        [-2., -1.],
        [ 2.,  3.],
        [ 4., -4.],
        [-4.,  1.],
        [-2.,  0.],
        [-3., -1.],
        [ 4.,  4.],
        [-2., -2.],
        [ 0., -4.],
        [-3.,  0.],
        [ 0., -3.],
        [-1.,  1.],
        [ 0., -1.],
        [-1., -3.],
        [ 2., -4.],
        [ 1.,  2.],
        [-1.,  4.],
        [ 3., -3.],
        [-1., -2.],
        [ 3.,  0.],
        [ 0., -2.],
        [-2.,  1.],
        [-4., -1.]])

In [8]:
mrmodel.periodic

True