In [1]:
import os
from pathlib import Path

from logs.wandblogger import WandBLogger2D
from training.trainer import MRTrainer
from datasets.signals import ImageSignal#, make_mask
from networks.mrnet import MRFactory
from datasets.pyramids import create_MR_structure
import yaml
from yaml.loader import SafeLoader
import os

In [2]:
os.environ["WANDB_NOTEBOOK_NAME"] = "train-wb.ipynb"
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
BASE_DIR = Path('.').absolute().parents[0]
IMAGE_PATH = BASE_DIR.joinpath('img')
MODEL_PATH = BASE_DIR.joinpath('models')

In [3]:
project_name = "highres-sandbox"
#-- hyperparameters in configs --#
config_file = '../configs/config_base_m_net.yml'
with open(config_file) as f:
    hyper = yaml.load(f, Loader=SafeLoader)
    if isinstance(hyper['batch_size'], str):
        hyper['batch_size'] = eval(hyper['batch_size'])
    print(hyper)
imgpath = os.path.join(IMAGE_PATH, hyper['image_name'])
maskpath = None
# maskpath = "/Users/hallpaz/Workspace/impa/mrimg/img/synthetic/mask_inverted.png" #make_mask(imgpath, hyper['mask_color'])

{'model': 'M', 'positive_freqs': False, 'in_features': 2, 'out_features': 1, 'hidden_layers': 1, 'hidden_features': [[24, 32], [48, 32], [64, 48], [128, 96], [256, 192], [512, 384]], 'bias': True, 'max_stages': 6, 'period': 2, 'domain': [-1, 1], 'omega_0': [3, 6, 16, 64, 128, 256], 'hidden_omega_0': [30, 30, 30, 30, 30, 30, 30], 'superposition_w0': False, 'sampling_scheme': 'regular', 'decimation': True, 'filter': 'gauss', 'attributes': ['d0', 'd1'], 'loss_function': 'hermite', 'loss_weights': {'d0': 1.0, 'd1': 0.0001}, 'opt_method': 'Adam', 'lr': 0.0003, 'loss_tol': 1e-11, 'diff_tol': 1e-07, 'max_epochs_per_stage': [1804, 1404, 1004, 804, 604, 404, 102], 'batch_size': 262144, 'image_name': 'pexels_textures/pic0.png', 'width': 1024, 'height': 1024, 'channels': 1, 'YCbCr': False, 'device': 'cuda', 'eval_device': 'cuda', 'save_format': 'general', 'visualize_grad': True, 'extrapolate': [-2, 2], 'zoom': [2, 4]}


In [4]:
#k = 3
base_signal = ImageSignal.init_fromfile(
                    imgpath,
                    domain=hyper['domain'],
                    channels=hyper['channels'],
                    sampling_scheme=hyper['sampling_scheme'],
                    width=hyper['width'], height=hyper['height'],
                    attributes=hyper['attributes'],
                    batch_size=hyper['batch_size'],
                    YCbCr=hyper.get('YCbCr', False))

train_dataloader = create_MR_structure(base_signal, hyper['max_stages'], 
                                       hyper['filter'], hyper['decimation'])#[-k:]
test_dataloader = create_MR_structure(base_signal, hyper['max_stages'], 
                                      hyper['filter'], False)#[-k:]
#hyper['max_stages'] = k

In [5]:
hyper['device']

'cuda'

In [6]:
img_name = os.path.basename(hyper['image_name'])
wandblogger = WandBLogger2D(project_name,
                            f"{hyper['model']}{hyper['filter'][0].upper()}{img_name[0:5]}",
                            hyper,
                            BASE_DIR)
mrmodel = MRFactory.from_dict(hyper)
print("Model: ", type(mrmodel))
mrtrainer = MRTrainer.init_from_dict(mrmodel, train_dataloader, test_dataloader, wandblogger, hyper)
mrtrainer.train(hyper['device'])

Model:  <class 'networks.mrnet.MNet'>


[34m[1mwandb[0m: Currently logged in as: [33mhallpaz[0m ([33msiren-song[0m). Use [1m`wandb login --relogin`[0m to force relogin


DATA SIZE torch.Size([1, 32, 32])
[Logger] All inference done in 28.460923671722412s on cuda


0,1
D0 loss,█▃▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
D1 loss,█▇▇▅▄▄▃▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Total loss,█▃▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
D0 loss,0.00095
D1 loss,0.58389
Total loss,0.00101


DATA SIZE torch.Size([1, 64, 64])
[Logger] All inference done in 29.949442148208618s on cuda


0,1
D0 loss,█▅▄▄▃▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
D1 loss,█▆▄▃▂▁▁▁▂▂▃▃▄▄▄▅▅▅▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇██████
Total loss,█▅▄▄▃▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
D0 loss,0.00055
D1 loss,0.64577
Total loss,0.00061


DATA SIZE torch.Size([1, 128, 128])
[Logger] All inference done in 30.4000825881958s on cuda


0,1
D0 loss,█▆▅▄▄▄▃▃▃▃▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
D1 loss,▃▁▁▁▁▂▂▃▃▄▄▄▅▅▅▆▆▆▆▆▇▇▇▇▇▇▇▇▇███████████
Total loss,█▆▅▄▄▄▃▃▃▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
D0 loss,0.00068
D1 loss,1.74579
Total loss,0.00086


DATA SIZE torch.Size([1, 256, 256])
[Logger] All inference done in 30.685975074768066s on cuda


0,1
D0 loss,█▆▅▅▄▄▃▃▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
D1 loss,▂▁▁▂▃▄▅▅▆▆▇▇▇▇▇▇▇███████████████████████
Total loss,█▆▅▄▄▃▃▃▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
D0 loss,0.0009
D1 loss,4.298
Total loss,0.00133


DATA SIZE torch.Size([1, 512, 512])
[Logger] All inference done in 30.182589769363403s on cuda


0,1
D0 loss,██▆▆▅▄▄▃▃▃▃▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
D1 loss,▂▁▁▂▂▃▅▅▆▆▆▇▇▇▇▇▇▇▇█████████████████████
Total loss,█▇▆▅▅▄▄▃▃▃▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
D0 loss,0.00123
D1 loss,8.83678
Total loss,0.00212


DATA SIZE torch.Size([1, 1024, 1024])
[Logger] All inference done in 31.63190770149231s on cuda
File  MGL1pic0._6-6_w256F_hf512384_MEp404_hl1_r1024_pr2.pth


0,1
D0 loss,█▇▆▅▄▄▃▃▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
D1 loss,▃▁▂▃▄▅▆▆▇▇▇▇▇▇▇█████████████████████████
Total loss,█▆▅▅▄▃▃▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
D0 loss,0.00202
D1 loss,17.51847
Total loss,0.00377


Total model parameters =  367416
Training finished after 5444 epochs
