In [None]:
import os
from pathlib import Path

from logs.wandblogger import WandBLogger1D
from training.trainer import MRTrainer
from datasets.signals import Signal1D
from datasets.utils import perlin_noise
from networks.mrnet import MRFactory
from datasets.pyramids import create_MR_structure
import yaml
from yaml.loader import SafeLoader
import os

In [None]:
os.environ["WANDB_NOTEBOOK_NAME"] = "train-wb.ipynb"
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
BASE_DIR = Path('.').absolute().parents[0]
DATA_PATH = BASE_DIR.joinpath('img')
MODEL_PATH = BASE_DIR.joinpath('models')

In [None]:
project_name = "dev-sandbox"
#-- hyperparameters in configs --#
config_file = '../configs/config_1d_m_net.yml'
with open(config_file) as f:
    hyper = yaml.load(f, Loader=SafeLoader)
    if isinstance(hyper['batch_size'], str):
        hyper['batch_size'] = eval(hyper['batch_size'])
    print(hyper)
srcpath = os.path.join(DATA_PATH, hyper['image_name'])

In [None]:
noise = perlin_noise(hyper['width'], octaves=7, p=1.4)
base_signal = Signal1D(noise.view(1, -1),
                        domain=hyper['domain'],
                        sampling_scheme=hyper['sampling_scheme'],
                        attributes=hyper['attributes'],
                        batch_size=hyper['batch_size'])

train_dataloader = create_MR_structure(base_signal, hyper['max_stages'], 
                                       hyper['filter'], hyper['decimation'])
test_dataloader = create_MR_structure(base_signal, hyper['max_stages'], 
                                      hyper['filter'], False)

In [None]:
hyper['device']

In [None]:
img_name = os.path.basename(hyper['image_name'])
wandblogger = WandBLogger1D(project_name,
                            f"{hyper['model']}{hyper['filter'][0].upper()}{img_name[0:5]}",
                            hyper,
                            BASE_DIR,
                            visualize_gt_grads=hyper.get('visualize_grad', False))
mrmodel = MRFactory.from_dict(hyper)
print("Model: ", type(mrmodel))
mrtrainer = MRTrainer.init_from_dict(mrmodel, train_dataloader, test_dataloader, wandblogger, hyper)
mrtrainer.train(hyper['device'])