In [None]:
import os
from pathlib import Path
import torch
import yaml
from yaml.loader import SafeLoader
import json

In [None]:
try:
  import wandb
  import trimesh
except ModuleNotFoundError:
  !pip install wandb trimesh

In [None]:
json_path = input("path to credentials json")
f = open(json_path)
cred_data = json.load(f)
f.close()

In [None]:
!pip install git+https://{cred_data["access_token"]}@github.com/visgraf/mrnet.git@dev

In [None]:
from mrnet.logs.wandblogger import WandBLogger2D
from mrnet.training.trainer import MRTrainer
from mrnet.datasets.signals import ImageSignal
from mrnet.networks.mrnet import MRFactory
from mrnet.datasets.pyramids import create_MR_structure

In [None]:
os.environ["WANDB_NOTEBOOK_NAME"] = "mrnet_image_reconstruction.ipynb"
BASE_DIR = Path('.').absolute()
IMAGE_PATH = BASE_DIR.joinpath('img')
MODEL_PATH = BASE_DIR.joinpath('models')
torch.manual_seed(777)

#-- hyperparameters in configs --#
config_file = 'configs/image.yml'
with open(config_file) as f:
    hyper = yaml.load(f, Loader=SafeLoader)
    if isinstance(hyper['batch_size'], str):
        hyper['batch_size'] = eval(hyper['batch_size'])
    if hyper.get('channels', 0) == 0:
            hyper['channels'] = hyper['out_features']
    print(hyper)
imgpath = os.path.join(IMAGE_PATH, hyper['image_name'])
project_name = hyper.get('project_name', 'dev_sandbox')
hyper['device']

In [None]:
base_signal = ImageSignal.init_fromfile(
                    imgpath,
                    domain=hyper['domain'],
                    channels=hyper['channels'],
                    sampling_scheme=hyper['sampling_scheme'],
                    width=hyper['width'], height=hyper['height'],
                    attributes=hyper['attributes'],
                    batch_size=hyper['batch_size'],
                    color_space=hyper['color_space'])

train_dataset = create_MR_structure(base_signal,
                                       hyper['max_stages'],
                                       hyper['filter'],
                                       hyper['decimation'],
                                       hyper['pmode'])
test_dataset = create_MR_structure(base_signal,
                                      hyper['max_stages'],
                                      hyper['filter'],
                                      False,
                                      hyper['pmode'])

if hyper['width'] == 0:
    hyper['width'] = base_signal.shape[-1]
if hyper['height'] == 0:
    hyper['height'] = base_signal.shape[-1]

In [None]:
img_name = os.path.basename(hyper['image_name'])
mrmodel = MRFactory.from_dict(hyper)
print("Model: ", type(mrmodel))
wandblogger = WandBLogger2D(project_name,
                            f"{hyper['model']}{hyper['filter'][0].upper()}{img_name[0:5]}{hyper['color_space'][0]}",
                            hyper,
                            BASE_DIR)
mrtrainer = MRTrainer.init_from_dict(mrmodel,
                                     train_dataset,
                                     test_dataset,
                                     wandblogger,
                                     hyper)
mrtrainer.train(hyper['device'])