In [1]:
import os
from pathlib import Path

from logs.wandblogger import WandBLogger1D
from training.trainer import MRTrainer
from datasets.signals import Signal1D
from networks.mrnet import MRFactory
from datasets.pyramids import create_MR_structure
import yaml
from yaml.loader import SafeLoader
import torch
from skimage.transform import pyramid_gaussian

In [2]:
torch.manual_seed(777)
os.environ["WANDB_NOTEBOOK_NAME"] = "train-wb.ipynb"
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
BASE_DIR = Path('.').absolute().parents[0]
DATA_PATH = BASE_DIR.joinpath('img')
MODEL_PATH = BASE_DIR.joinpath('models')

In [3]:
project_name = "siggraph_asia"
#-- hyperparameters in configs --#
config_file = '../configs/config_1d_m_net.yml'
with open(config_file) as f:
    hyper = yaml.load(f, Loader=SafeLoader)
    if isinstance(hyper['batch_size'], str):
        hyper['batch_size'] = eval(hyper['batch_size'])
    print(hyper)
srcpath = os.path.join(DATA_PATH, hyper['image_name'])

{'model': 'M', 'positive_freqs': True, 'in_features': 1, 'out_features': 1, 'hidden_layers': 1, 'hidden_features': [[4, 32], [64, 24], 8, 16, 32], 'bias': True, 'max_stages': 1, 'period': 2, 'domain': [-1, 1], 'omega_0': [32, 128, 16, 32, 64], 'hidden_omega_0': [30, 30, 30, 30, 30, 30, 30], 'superposition_w0': False, 'sampling_scheme': 'regular', 'decimation': True, 'filter': 'gauss', 'attributes': ['d0'], 'loss_weights': {'d0': 1}, 'loss_function': 'mse', 'opt_method': 'Adam', 'lr': 0.0005, 'loss_tol': 1e-14, 'diff_tol': 1e-11, 'max_epochs_per_stage': [4000, 1000, 401, 401, 401, 401, 401], 'batch_size': 16384, 'image_name': 'freqs', 'width': 512, 'channels': 1, 'device': 'cuda', 'eval_device': 'cpu', 'save_format': 'general', 'visualize_grad': True, 'extrapolate': [-2, 2]}


In [4]:
# filtered = [data for data in pyramid_gaussian(noise.numpy(), max_layer=1,
#                                 sigma=2/3, mode='wrap')]
# print(len(filtered), len(filtered[-1]))
# torch.from_numpy(filtered[-1]).view(1, -1),
def summed_frequencies(x, freqs):
    res = torch.zeros_like(x)
    for k in freqs:
        res += torch.sin(k * 2*torch.pi * x)
    return res

def allmul_frequencies(x, freqs):
    res = torch.ones_like(x)
    for k in freqs:
        res *= torch.sin(k * 2*torch.pi * x)
    return res

def onemul_frequencies(x, freqs):
    res = summed_frequencies(x, freqs[:-1])
    m = freqs[-1]
    return torch.sin(m * 2*torch.pi * x) * res

def composition(x, freqs):
    res = summed_frequencies(x, freqs[:-1])
    m = freqs[-1]
    return torch.sin(m * 2*torch.pi * res)
func_map = {'S': summed_frequencies, 
            'A': allmul_frequencies, 
            'O': onemul_frequencies, 
            'C': composition}
fcode = 'C'

In [5]:
x = torch.linspace(-1, 1, hyper['width'])
frequencies = [2, 7, 4]
synthetic = func_map[fcode](x, frequencies)
base_signal = Signal1D(synthetic.view(1, -1),
                        domain=hyper['domain'],
                        sampling_scheme=hyper['sampling_scheme'],
                        attributes=hyper['attributes'],
                        batch_size=hyper['batch_size'])

train_dataloader = create_MR_structure(base_signal, hyper['max_stages'], 
                                       hyper['filter'], hyper['decimation'])
test_dataloader = create_MR_structure(base_signal, hyper['max_stages'], 
                                      hyper['filter'], False)

In [6]:
hyper['device']

'cuda'

In [7]:
img_name = os.path.basename(hyper['image_name'])
wandblogger = WandBLogger1D(project_name,
                            f"{hyper['model']}{fcode}{img_name[0:5]}",
                            hyper,
                            BASE_DIR)
mrmodel = MRFactory.from_dict(hyper)
print("Model: ", type(mrmodel))
mrtrainer = MRTrainer.init_from_dict(mrmodel, train_dataloader, test_dataloader, wandblogger, hyper)
mrtrainer.train(hyper['device'])

Model:  <class 'networks.mrnet.MNet'>


[34m[1mwandb[0m: Currently logged in as: [33mhallpaz[0m ([33msiren-song[0m). Use [1m`wandb login --relogin`[0m to force relogin


DATA SIZE torch.Size([1, 512])
(4,) [-10.  -8.   6.  24.]
Logged frequencies
File  MCfreqs_1-1_w32F_hf432_MEp4000_hl1_r512_pr2.pth


0,1
D0 loss,██▇▆▆▆▅▅▅▄▄▄▃▃▃▃▃▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
D0 loss,0.21674


Total model parameters =  200
Training finished after 2696 epochs
