In [None]:
import torch
import os
from pathlib import Path
#from logs.locallogger import LocalLogger3D
from logs.wandblogger import WandBLogger3D 
from training.trainer import MRTrainer
from datasets.signals import VolumeSignal, Procedural3DSignal
from networks.mrnet import MRFactory
from datasets.pyramids import create_MR_structure
import yaml
from yaml.loader import SafeLoader
import matplotlib.pyplot as plt
from datasets.procedural import voronoi_texture, marble_texture

In [None]:
os.environ["WANDB_NOTEBOOK_NAME"] = "train3d.ipynb"
BASE_DIR = Path('.').absolute().parents[0]
VOXEL_PATH = BASE_DIR.joinpath('vox')
MODEL_PATH = BASE_DIR.joinpath('models')

In [None]:
project_name = "dev-sandbox"
config_file = '../configs/config_3d_m_net.yml'
with open(config_file) as f:
    hyper = yaml.load(f, Loader=SafeLoader)
    if isinstance(hyper['batch_size'], str):
        hyper['batch_size'] = eval(hyper['batch_size'])
    if hyper['channels'] == 0:
        hyper['channels'] = hyper['out_features']
    print(hyper)
filepath = os.path.join(VOXEL_PATH, hyper['filename'])
torch.manual_seed(777)

In [None]:
# vol = torch.from_numpy(checker(dim, 32))
# base_signal = VolumeSignal(vol.view((1, dim, dim, dim)),
#                            hyper['domain'],
#                            batch_size=hyper['batch_size'])
# base_signal = VolumeSignal.init_fromfile(filepath,
#                                          hyper['domain'],
#                                          channels=hyper['channels'],
#                                          batch_size=hyper['batch_size'])

# train_dataloader = create_MR_structure(base_signal, hyper['max_stages'],
#                                         hyper['filter'], hyper['decimation'])
# test_dataloader = create_MR_structure(base_signal, hyper['max_stages'],
#                                         hyper['filter'], False)
dim = hyper['width']
# proc = voronoi_texture(4)
proc = marble_texture(2/dim)
base_signal = Procedural3DSignal(
    proc,
    (dim, dim, dim),
    channels=hyper['channels'],
    domain=hyper['domain'],
    batch_size=hyper['batch_size'],
    color_space=hyper['color_space']
)
train_dataloader = [base_signal]
test_dataloader = [base_signal]

In [None]:
filename = os.path.basename(hyper['filename'])
wandblogger = WandBLogger3D(project_name,
                            f"{hyper['model']}{hyper['filter'][0].upper()}{filename[0:5]}",
                            hyper,
                            BASE_DIR)
mrmodel = MRFactory.from_dict(hyper)
print("Model: ", type(mrmodel))
mrtrainer = MRTrainer.init_from_dict(mrmodel, train_dataloader, test_dataloader, wandblogger, hyper)
mrtrainer.train(hyper['device'])