In [None]:
import torch
import os
import skimage
from PIL import Image
from torchvision.transforms import Resize, Compose, ToTensor, Normalize
from pathlib import Path
from torch.utils.data import DataLoader, Dataset
from logs.wandblogger import WandBLogger2D
from training.trainer import MRTrainer
from datasets.imagesignal import ImageSignal
from networks.mrnet import MRFactory
from datasets.sampler import make2Dcoords
import yaml
from yaml.loader import SafeLoader
import matplotlib.pyplot as plt

In [None]:
os.environ["WANDB_NOTEBOOK_NAME"] = "eval-net.ipynb"
BASE_DIR = Path('.').absolute().parents[0]
IMAGE_PATH = BASE_DIR.joinpath('img')
MODEL_PATH = BASE_DIR.joinpath('models')

In [None]:
project_name = "test_eval"
with open('../configs/config_base_m_net.yml') as f:
    hyper = yaml.load(f, Loader=SafeLoader)


In [None]:
mrmodel = MRFactory.load_state_dict(
    os.path.join(MODEL_PATH,'LSlena.pth'))

In [None]:
print("Model: ", type(mrmodel))

In [None]:
for p in mrmodel.parameters():
    print("p: ", p.shape, " = ", p.numel())
total_params = sum(p.numel() for p in mrmodel.parameters()) - mrmodel.n_stages()
print("TOTAL = ", total_params)
print("MODEL TOTAL = ", mrmodel.total_parameters())

In [None]:
output = mrmodel(make2Dcoords(256,256))
model_out = torch.clamp(output['model_out'], 0.0, 1.0)

plt.imshow(model_out.cpu().view(256,256).detach().numpy())
