In [1]:
import torch
import os
import skimage
from PIL import Image
from torchvision.transforms import Resize, Compose, ToTensor, Normalize
from pathlib import Path
from torch.utils.data import DataLoader, Dataset
from training.wandblogger import WandBLogger2D
from training.trainer import MRTrainer
from datasets.imagesignal import ImageSignal
from networks.mrnet import MRFactory
from datasets.imageutils import gaussian_pyramid2D, gaussian_tower2D, box_kernel2D

In [2]:
os.environ["WANDB_NOTEBOOK_NAME"] = "m-net.ipynb"
BASE_DIR = Path('.').absolute().parents[0]
IMAGE_PATH = BASE_DIR.joinpath('img')

In [3]:
project_name = "lena_baseX"
# hyperparameters
hyper = {'omega_0':[4, 8, 16, 32, 64, 128] ,
         'in_features': 2,
         'hidden_features': [48, 64, 64, 128, 128, 256],
         'hidden_layers': 1,
         'superposition_w0': False,
         'hidden_omega_0': [30, 40, 50, 60, 70, 80],

         'sampling_scheme': 'uniform',
         'multiresolution': 'pyramid',
         
         'max_epochs_per_stage': [500, 450, 350, 250, 150, 100],
         'opt_method': 'Adam',
         'loss_function': 'd0_MSE',
         'lr': 1e-4,
         'loss_tol': 1e-14,
         'diff_tol': 0.00000000001,
         'batch_size': 513*513,
         
         'image_name': 'lena513.png',
         'width': 513,
         'height': 513,
         'channels': 1,

         'stage': 1,
         'max_stages': 6,
         'model': 'M',
         'useattributes': True,
         'device': 'cuda' if torch.cuda.is_available() else 'cpu',
         'eval_device': 'cuda' if torch.cuda.is_available() else 'cpu',
         
        #  'extrapolate': [-3, 3], 
        #  'save_format': 'general',
}

In [4]:
BATCH_SIZE = hyper['batch_size']
kernel = box_kernel2D(5)

In [5]:

base_signal = ImageSignal.init_fromfile(
                    os.path.join(IMAGE_PATH, hyper['image_name']),
                    useattributes=hyper.get('useattributes', False))
hyper['width'], hyper['height'] = base_signal.dimensions()
if hyper['multiresolution'] == 'capacity':
    train_dataloader = DataLoader(base_signal, batch_size=BATCH_SIZE, shuffle=True, pin_memory=True, num_workers=0)
    test_dataloader = DataLoader(base_signal, batch_size=BATCH_SIZE, pin_memory=True, num_workers=0)
else:
    pyramid = gaussian_pyramid2D(base_signal, hyper['max_stages'], kernel)
    tower = gaussian_tower2D(base_signal, hyper['max_stages'], kernel)
    trainsource = pyramid if hyper['multiresolution'] == 'pyramid' else tower
    train_dataloader = [DataLoader(signal, shuffle=True, batch_size=BATCH_SIZE) 
                        for signal in trainsource]
    test_dataloader = [DataLoader(signal, batch_size=BATCH_SIZE) 
                        for signal in tower]

In [6]:
wandblogger = WandBLogger2D(project_name,
                            f"{hyper['model']}{hyper['multiresolution'][0].upper()}{hyper['image_name'][0:4]}_",
                            hyper,
                            BASE_DIR)
mrmodel = MRFactory.from_dict(hyper)
print("Model: ", type(mrmodel))
mrtrainer = MRTrainer.init_from_dict(mrmodel, train_dataloader, test_dataloader, wandblogger, hyper)
mrtrainer.train(hyper['device'])

Model:  <class 'networks.mrnet.MNet'>


[34m[1mwandb[0m: Currently logged in as: [33mlvelho[0m ([33msiren-song[0m). Use [1m`wandb login --relogin`[0m to force relogin


0,1
D0 loss,█▆▅▅▄▃▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
D0 loss,0.00159


0,1
D0 loss,█▆▅▃▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
D0 loss,0.0


0,1
D0 loss,█▆▅▃▃▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
D0 loss,0.0


0,1
D0 loss,█▅▄▄▃▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
D0 loss,0.0


0,1
D0 loss,█▄▄▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
D0 loss,7e-05


0,1
D0 loss,█▂▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
D0 loss,0.00034


Training finished after 1800 epochs
