In [1]:
import torch
import os
from pathlib import Path
#from logs.locallogger import LocalLogger3D
from logs.wandblogger import WandBLogger3D
from training.trainer import MRTrainer
from datasets.signals import VolumeSignal
from datasets.utils import checker
from networks.mrnet import MRFactory
from datasets.pyramids import create_MR_structure
import yaml
from yaml.loader import SafeLoader
import matplotlib.pyplot as plt

In [2]:
os.environ["WANDB_NOTEBOOK_NAME"] = "train3d.ipynb"
BASE_DIR = Path('.').absolute().parents[0]
VOXEL_PATH = BASE_DIR.joinpath('vox')
MODEL_PATH = BASE_DIR.joinpath('models')

In [3]:
project_name = "dev-sandbox"
config_file = '../configs/config_3d_m_net.yml'
with open(config_file) as f:
    hyper = yaml.load(f, Loader=SafeLoader)
    if isinstance(hyper['batch_size'], str):
        hyper['batch_size'] = eval(hyper['batch_size'])
    print(hyper)

{'model': 'M', 'in_features': 3, 'hidden_layers': 1, 'hidden_features': [80, 160, 256], 'bias': True, 'max_stages': 3, 'period': 2, 'domain': [-1, 1], 'omega_0': [2, 4, 8], 'hidden_omega_0': [30, 30, 30], 'superposition_w0': False, 'sampling_scheme': 'regular', 'decimation': True, 'filter': 'gauss', 'attributes': ['d0'], 'loss_function': 'mse', 'opt_method': 'Adam', 'lr': 0.0005, 'loss_tol': 1e-12, 'diff_tol': 1e-09, 'max_epochs_per_stage': [100, 21, 21], 'batch_size': 65536, 'image_name': 'checker.npy', 'width': 128, 'height': 128, 'channels': 1, 'device': 'cuda', 'eval_device': 'cpu', 'save_format': 'general', 'visualize_grad': True}


In [4]:
dim = 128
vol = torch.from_numpy(checker(dim, 32))
base_signal = VolumeSignal(vol.view((1, dim, dim, dim)),
                           hyper['domain'],
                           batch_size=hyper['batch_size'])

train_dataloader = create_MR_structure(base_signal, hyper['max_stages'],
                                        hyper['filter'], hyper['decimation'])
test_dataloader = create_MR_structure(base_signal, hyper['max_stages'],
                                        hyper['filter'], False)

In [5]:
img_name = os.path.basename(hyper['image_name'])
wandblogger = WandBLogger3D(project_name,
                            f"{hyper['model']}{hyper['filter'][0].upper()}{img_name[0:5]}",
                            hyper,
                            BASE_DIR,
                            visualize_gt_grads=hyper.get('visualize_grad', False))
mrmodel = MRFactory.from_dict(hyper)
print("Model: ", type(mrmodel))
mrtrainer = MRTrainer.init_from_dict(mrmodel, train_dataloader, test_dataloader, wandblogger, hyper)
mrtrainer.train(hyper['device'])

Model:  <class 'networks.mrnet.MNet'>


[34m[1mwandb[0m: Currently logged in as: [33mhallpaz[0m ([33msiren-song[0m). Use [1m`wandb login --relogin`[0m to force relogin


[Logger] All inference done in 1.8087921142578125s on cpu


0,1
D0 loss,█▇▆▅▅▄▄▄▃▃▃▃▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
D0 loss,0.01895


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016666666666666666, max=1.0…

[Logger] All inference done in 3.8001890182495117s on cpu


0,1
D0 loss,█▇▆▄▃▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁

0,1
D0 loss,0.00343


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016666666666666666, max=1.0…

[Logger] All inference done in 8.655534744262695s on cpu
File  MGcheck_3-3_w8F_hf256_MEp21_hl1_128px.pth


0,1
D0 loss,█▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
D0 loss,0.00791


Total model parameters =  154272
Training finished after 142 epochs
