In [1]:
import os
from pathlib import Path

from logs.wandblogger import WandBLogger1D
from training.trainer import MRTrainer
from datasets.signals import Signal1D
from datasets.utils import perlin_noise
from networks.mrnet import MRFactory
from datasets.pyramids import create_MR_structure
import yaml
from yaml.loader import SafeLoader
import torch
from skimage.transform import pyramid_gaussian

In [2]:
torch.manual_seed(777)
os.environ["WANDB_NOTEBOOK_NAME"] = "train-wb.ipynb"
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
BASE_DIR = Path('.').absolute().parents[0]
DATA_PATH = BASE_DIR.joinpath('img')
MODEL_PATH = BASE_DIR.joinpath('models')

In [3]:
project_name = "optim-init"
#-- hyperparameters in configs --#
config_file = '../configs/config_1d_m_net.yml'
with open(config_file) as f:
    hyper = yaml.load(f, Loader=SafeLoader)
    if isinstance(hyper['batch_size'], str):
        hyper['batch_size'] = eval(hyper['batch_size'])
    print(hyper)
srcpath = os.path.join(DATA_PATH, hyper['image_name'])

{'model': 'M', 'positive_freqs': True, 'in_features': 1, 'out_features': 1, 'hidden_layers': 1, 'hidden_features': [[4, 32], [64, 24], 8, 16, 32], 'bias': True, 'max_stages': 1, 'period': 2, 'domain': [-1, 1], 'omega_0': [32, 128, 16, 32, 64], 'hidden_omega_0': [30, 30, 30, 30, 30, 30, 30], 'superposition_w0': False, 'sampling_scheme': 'regular', 'decimation': True, 'filter': 'gauss', 'attributes': ['d0'], 'loss_function': 'mse', 'opt_method': 'Adam', 'lr': 0.0005, 'loss_tol': 1e-14, 'diff_tol': 1e-11, 'max_epochs_per_stage': [4000, 1000, 401, 401, 401, 401, 401], 'batch_size': 16384, 'image_name': 'odd', 'width': 128, 'channels': 1, 'device': 'cpu', 'eval_device': 'cpu', 'save_format': 'general', 'visualize_grad': True, 'extrapolate': [-2, 2]}


In [4]:
noise = perlin_noise(hyper['width'], octaves=7, p=1.4)

In [5]:
# filtered = [data for data in pyramid_gaussian(noise.numpy(), max_layer=1,
#                                 sigma=2/3, mode='wrap')]
# print(len(filtered), len(filtered[-1]))
func = lambda x: torch.cos(4*x * 2*torch.pi) + torch.cos(8*x * 2*torch.pi) + torch.cos(24*x * 2*torch.pi) + torch.cos(30*x * 2*torch.pi)
synthetic = func(torch.linspace(-1, 1, hyper['width']))
base_signal = Signal1D(synthetic.view(1, -1),
    # torch.from_numpy(filtered[-1]).view(1, -1),
                        domain=hyper['domain'],
                        sampling_scheme=hyper['sampling_scheme'],
                        attributes=hyper['attributes'],
                        batch_size=hyper['batch_size'])

train_dataloader = create_MR_structure(base_signal, hyper['max_stages'], 
                                       hyper['filter'], hyper['decimation'])
test_dataloader = create_MR_structure(base_signal, hyper['max_stages'], 
                                      hyper['filter'], False)

In [6]:
hyper['device']

'cpu'

In [7]:
img_name = os.path.basename(hyper['image_name'])
wandblogger = WandBLogger1D(project_name,
                            f"{hyper['model']}{hyper['filter'][0].upper()}{img_name[0:5]}",
                            hyper,
                            BASE_DIR,
                            visualize_gt_grads=hyper.get('visualize_grad', False))
mrmodel = MRFactory.from_dict(hyper)
print("Model: ", type(mrmodel))
mrtrainer = MRTrainer.init_from_dict(mrmodel, train_dataloader, test_dataloader, wandblogger, hyper)
mrtrainer.train(hyper['device'])

Model:  <class 'networks.mrnet.MNet'>


[34m[1mwandb[0m: Currently logged in as: [33mhallpaz[0m ([33msiren-song[0m). Use [1m`wandb login --relogin`[0m to force relogin


DATA SIZE torch.Size([1, 128])
(4,) [ 3.  7. 23. 29.]
Logged
File  MGodd_1-1_w32F_hf432_MEp4000_hl1_r128_pr2.pth


0,1
D0 loss,█▅▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
D0 loss,1e-05


Total model parameters =  200
Training finished after 4000 epochs
