In [1]:
import os
from pathlib import Path
import torch
from logs.wandblogger import WandBLogger2D
from training.trainer import MRTrainer
from datasets.signals import ImageSignal#, make_mask
from networks.mrnet import MRFactory
from datasets.pyramids import create_MR_structure
import yaml
from yaml.loader import SafeLoader
import os

In [2]:
os.environ["WANDB_NOTEBOOK_NAME"] = "train-wb.ipynb"
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
BASE_DIR = Path('.').absolute().parents[0]
IMAGE_PATH = BASE_DIR.joinpath('img')
MODEL_PATH = BASE_DIR.joinpath('models')
torch.manual_seed(777)

project_name = "img_masked"
#-- hyperparameters in configs --#
config_file = '../configs/siggraph_asia/config_siggraph_masked.yml'
with open(config_file) as f:
    hyper = yaml.load(f, Loader=SafeLoader)
    if isinstance(hyper['batch_size'], str):
        hyper['batch_size'] = eval(hyper['batch_size'])
    if hyper.get('channels', 0) == 0:
            hyper['channels'] = hyper['out_features']
    print(hyper)
imgpath = os.path.join(IMAGE_PATH, hyper['image_name'])
maskpath = None
# maskpath = "/Users/hallpaz/Workspace/impa/mrimg/img/synthetic/mask_inverted.png" #make_mask(imgpath, hyper['mask_color'])
hyper['device']

{'model': 'M', 'positive_freqs': False, 'in_features': 2, 'out_features': 3, 'hidden_layers': 1, 'hidden_features': [[24, 32], [48, 32], [80, 64], [192, 160], [256, 256], [512, 512]], 'bias': True, 'max_stages': 6, 'period': 2, 'pmode': 'wrap', 'domain': [-1, 1], 'omega_0': [3, 6, 12, 24, 48, 96], 'hidden_omega_0': [30, 30, 30, 30, 30, 30, 30], 'superposition_w0': False, 'sampling_scheme': 'regular', 'decimation': True, 'filter': 'gauss', 'attributes': ['d0', 'd1'], 'loss_function': 'hermite', 'loss_weights': {'d0': 1, 'd1': 0.0}, 'opt_method': 'Adam', 'lr': 0.0005, 'loss_tol': 1e-12, 'diff_tol': 1e-05, 'max_epochs_per_stage': [1400, 1000, 800, 600, 400, 200], 'batch_size': 32768, 'image_name': 'siggraph_asia/autumm.jpg', 'width': 0, 'height': 0, 'channels': 3, 'color_space': 'RGB', 'device': 'cuda', 'eval_device': 'cuda', 'save_format': 'general', 'visualize_grad': True, 'extrapolate': [-2, 2], 'zoom': [2, 4], 'zoom_filters': ['linear', 'cubic', 'nearest']}


'cuda'

In [3]:
base_signal = ImageSignal.init_fromfile(
                    imgpath,
                    domain=hyper['domain'],
                    channels=hyper['channels'],
                    sampling_scheme=hyper['sampling_scheme'],
                    width=hyper['width'], height=hyper['height'],
                    attributes=hyper['attributes'],
                    batch_size=hyper['batch_size'],
                    color_space=hyper['color_space'])

train_dataset = create_MR_structure(base_signal, 
                                       hyper['max_stages'], 
                                       hyper['filter'], 
                                       hyper['decimation'],
                                       hyper['pmode'])
test_dataset = create_MR_structure(base_signal, 
                                      hyper['max_stages'], 
                                      hyper['filter'], 
                                      False,
                                      hyper['pmode'])

if hyper['width'] == 0:
    hyper['width'] = base_signal.shape[0]
if hyper['height'] == 0:
    hyper['height'] = base_signal.shape[0]

In [4]:
img_name = os.path.basename(hyper['image_name'])
mrmodel = MRFactory.from_dict(hyper)
print("Model: ", type(mrmodel))
wandblogger = WandBLogger2D(project_name,
                            f"{hyper['model']}{hyper['filter'][0].upper()}{img_name[0:5]}",
                            hyper,
                            BASE_DIR)
mrtrainer = MRTrainer.init_from_dict(mrmodel, 
                                     train_dataset, 
                                     test_dataset, 
                                     wandblogger, 
                                     hyper)
mrtrainer.train(hyper['device'])

Model:  <class 'networks.mrnet.MNet'>


[34m[1mwandb[0m: Currently logged in as: [33mhallpaz[0m ([33msiren-song[0m). Use [1m`wandb login --relogin`[0m to force relogin


DATA SIZE torch.Size([3, 32, 32])
32.12817843686081
[Logger] All inference done in 28.734549045562744s on cuda


VBox(children=(Label(value='6.127 MB of 6.127 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
D0 loss,█▅▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
D1 loss,███▇▇▆▆▆▅▅▅▄▄▄▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁
Total loss,█▅▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
D0 loss,0.00107
D1 loss,3.1483
Total loss,0.00107


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016666666666666666, max=1.0…

DATA SIZE torch.Size([3, 63, 63])
27.101546646667288
[Logger] All inference done in 26.516663551330566s on cuda


0,1
D0 loss,██▆▆▅▄▄▃▃▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
D1 loss,▂▂▂▂▁▁▁▁▁▂▂▂▃▃▃▄▄▄▄▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇▇█████
Total loss,██▆▆▅▄▄▃▃▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
D0 loss,0.00107
D1 loss,42.66385
Total loss,0.00107


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016916666666668335, max=1.0…

DATA SIZE torch.Size([3, 125, 125])
25.539147374945188
[Logger] All inference done in 28.71558952331543s on cuda


VBox(children=(Label(value='11.884 MB of 11.884 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, m…

0,1
D0 loss,█▇▅▄▃▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
D1 loss,▁▁▁▂▂▃▄▄▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇████████████
Total loss,█▇▅▄▃▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
D0 loss,0.00162
D1 loss,329.53143
Total loss,0.00162


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016666666666666666, max=1.0…

DATA SIZE torch.Size([3, 250, 250])
26.330719628542198
[Logger] All inference done in 27.618685245513916s on cuda


VBox(children=(Label(value='16.182 MB of 16.182 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, m…

0,1
D0 loss,█▆▃▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
D1 loss,▁▁▄▅▅▆▆▆▇▇▇▇▇▇▇▇▇▇██████████████████████
Total loss,█▆▃▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
D0 loss,0.00093
D1 loss,1381.91388
Total loss,0.00093


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016933333333330816, max=1.0…

DATA SIZE torch.Size([3, 500, 500])
27.796647994155688
[Logger] All inference done in 28.505040645599365s on cuda


VBox(children=(Label(value='20.798 MB of 20.798 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, m…

0,1
D0 loss,█▇▄▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
D1 loss,▁▁▂▄▄▅▆▆▇▇▇▇▇▇█████▇▇███████████████████
Total loss,█▇▄▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
D0 loss,0.00174
D1 loss,4695.94562
Total loss,0.00174


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016933333333330816, max=1.0…

DATA SIZE torch.Size([3, 1000, 1000])
