In [1]:
import os
from pathlib import Path
import torch
from logs.wandblogger import WandBLogger2D
from training.trainer import MRTrainer
from datasets.signals import ImageSignal#, make_mask
from networks.mrnet import MRFactory
from datasets.pyramids import create_MR_structure
import yaml
from yaml.loader import SafeLoader
import os

In [2]:
os.environ["WANDB_NOTEBOOK_NAME"] = "train-wb.ipynb"
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
BASE_DIR = Path('.').absolute().parents[0]
IMAGE_PATH = BASE_DIR.joinpath('img')
MODEL_PATH = BASE_DIR.joinpath('models')
torch.manual_seed(777)

project_name = "siggraph_asia_mrnet"
#-- hyperparameters in configs --#
config_file = '../configs/siggraph_asia/config_siggraph_imgs.yml'
with open(config_file) as f:
    hyper = yaml.load(f, Loader=SafeLoader)
    if isinstance(hyper['batch_size'], str):
        hyper['batch_size'] = eval(hyper['batch_size'])
    if hyper.get('channels', 0) == 0:
            hyper['channels'] = hyper['out_features']
    print(hyper)
imgpath = os.path.join(IMAGE_PATH, hyper['image_name'])
maskpath = None
# maskpath = "/Users/hallpaz/Workspace/impa/mrimg/img/synthetic/mask_inverted.png" #make_mask(imgpath, hyper['mask_color'])
hyper['device']

{'model': 'M', 'positive_freqs': False, 'in_features': 2, 'out_features': 3, 'hidden_layers': 1, 'hidden_features': [[15, 32], [24, 32], [32, 32], [96, 64], [192, 160], [384, 256]], 'bias': True, 'max_stages': 6, 'period': 2, 'pmode': 'wrap', 'domain': [-1, 1], 'omega_0': [2, 6, 12, 24, 32, 96], 'hidden_omega_0': [30, 30, 30, 30, 30, 30, 30], 'superposition_w0': False, 'sampling_scheme': 'regular', 'decimation': True, 'filter': 'gauss', 'attributes': ['d0', 'd1'], 'loss_function': 'hermite', 'loss_weights': {'d0': 1, 'd1': 0.0}, 'opt_method': 'Adam', 'lr': [0.0008, 0.0004, 0.0002, 0.0001, 5e-05, 2e-05], 'loss_tol': 1e-12, 'diff_tol': 1e-05, 'max_epochs_per_stage': [1600, 1200, 1000, 800, 600, 400], 'batch_size': 32768, 'image_name': 'siggraph_asia/periodic512.jpg', 'width': 0, 'height': 0, 'channels': 3, 'color_space': 'YCbCr', 'device': 'cuda', 'eval_device': 'cuda', 'save_format': 'general', 'visualize_grad': True, 'extrapolate': [-2, 2], 'zoom': [2, 4], 'zoom_filters': ['linear', 'c

'cuda'

In [3]:
base_signal = ImageSignal.init_fromfile(
                    imgpath,
                    domain=hyper['domain'],
                    channels=hyper['channels'],
                    sampling_scheme=hyper['sampling_scheme'],
                    width=hyper['width'], height=hyper['height'],
                    attributes=hyper['attributes'],
                    batch_size=hyper['batch_size'],
                    color_space=hyper['color_space'])

train_dataset = create_MR_structure(base_signal, 
                                       hyper['max_stages'], 
                                       hyper['filter'], 
                                       hyper['decimation'],
                                       hyper['pmode'])
test_dataset = create_MR_structure(base_signal, 
                                      hyper['max_stages'], 
                                      hyper['filter'], 
                                      False,
                                      hyper['pmode'])

if hyper['width'] == 0:
    hyper['width'] = base_signal.shape[-1]
if hyper['height'] == 0:
    hyper['height'] = base_signal.shape[-1]

In [4]:
img_name = os.path.basename(hyper['image_name'])
mrmodel = MRFactory.from_dict(hyper)
print("Model: ", type(mrmodel))
wandblogger = WandBLogger2D(project_name,
                            f"{hyper['model']}{hyper['filter'][0].upper()}{img_name[0:5]}{hyper['color_space'][0]}",
                            hyper,
                            BASE_DIR)
mrtrainer = MRTrainer.init_from_dict(mrmodel, 
                                     train_dataset, 
                                     test_dataset, 
                                     wandblogger, 
                                     hyper)
mrtrainer.train(hyper['device'])

Model:  <class 'networks.mrnet.MNet'>


[34m[1mwandb[0m: Currently logged in as: [33mhallpaz[0m ([33msiren-song[0m). Use [1m`wandb login --relogin`[0m to force relogin


DATA SIZE torch.Size([3, 16, 16])
32.1454355799884
[Logger] All inference done in 9.294735193252563s on cuda


0,1
D0 loss,█▆▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
D1 loss,▅▆▇▆▆█▅▄▃▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Total loss,█▆▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
D0 loss,0.00023
D1 loss,0.09643
Total loss,0.00023


DATA SIZE torch.Size([3, 32, 32])
28.034824908728034
[Logger] All inference done in 9.062615871429443s on cuda


0,1
D0 loss,█▆▅▄▄▃▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
D1 loss,▂▁▁▁▁▂▂▃▃▄▄▄▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇█████████
Total loss,█▆▅▄▄▃▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
D0 loss,0.00016
D1 loss,1.4541
Total loss,0.00016


DATA SIZE torch.Size([3, 64, 64])
24.692183680134644
[Logger] All inference done in 9.546090126037598s on cuda


0,1
D0 loss,█▆▆▅▅▄▃▃▃▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
D1 loss,▁▁▁▁▁▁▂▂▃▃▄▄▄▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇▇███████████
Total loss,█▆▆▅▅▄▃▃▃▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
D0 loss,0.00037
D1 loss,15.76277
Total loss,0.00037


DATA SIZE torch.Size([3, 128, 128])
25.001245745902928
[Logger] All inference done in 9.612078189849854s on cuda


0,1
D0 loss,█▅▄▄▄▄▃▃▃▃▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
D1 loss,▁▁▂▂▂▂▂▂▃▃▃▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇▇▇██████
Total loss,█▅▄▄▄▄▃▃▃▃▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
D0 loss,0.00044
D1 loss,64.08427
Total loss,0.00044


DATA SIZE torch.Size([3, 256, 256])
27.89464076509894
[Logger] All inference done in 10.146342277526855s on cuda


0,1
D0 loss,█▇▆▅▅▄▃▃▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
D1 loss,▁▁▁▁▂▂▃▃▄▄▄▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇████████
Total loss,█▇▆▅▅▄▃▃▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
D0 loss,0.00023
D1 loss,184.34068
Total loss,0.00023


DATA SIZE torch.Size([3, 512, 512])
25.662789139811032
[Logger] All inference done in 9.909473896026611s on cuda
File  MGperioY_6-6_w96F_hf384256_MEp400_hl1_r512_pr2.pth


0,1
D0 loss,█▇▇▇▆▅▅▄▄▃▃▃▃▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
D1 loss,▁▁▁▁▁▁▂▂▃▃▃▄▄▄▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇███████
Total loss,█▇▇▇▆▅▅▄▄▃▃▃▃▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
D0 loss,0.00051
D1 loss,1612.31317
Total loss,0.00051


Total model parameters =  197281
Training finished after 4165 epochs
