In [None]:
import torch
import os
import skimage
from PIL import Image
from torchvision.transforms import Resize, Compose, ToTensor, Normalize
from pathlib import Path
from torch.utils.data import DataLoader, Dataset
from logging_module.wandblogger import WandBLogger2D
from training.trainer import MRTrainer
from datasets.imagesignal import ImageSignal
from networks.mrnet import MRFactory
from networks.siren import Siren
from datasets.pyramids import create_MR_structure
import yaml
from yaml.loader import SafeLoader
from training.loss import gradient
from datasets.sampling import make2Dcoords
import matplotlib.pyplot as plt


#### MR-Net Config & Data Source

In [None]:
os.environ["WANDB_NOTEBOOK_NAME"] = "s-net.ipynb"
BASE_DIR = Path('.').absolute().parents[0]
IMAGE_PATH = BASE_DIR.joinpath('img')

In [None]:
project_name = "test_sgrad"

with open('../configs/config_base_l_net.yml') as f:
    hyper = yaml.load(f, Loader=SafeLoader)
    print(hyper)


In [None]:
base_signal = ImageSignal.init_fromfile(
                    os.path.join(IMAGE_PATH, hyper['image_name']),
                    useattributes=hyper.get('useattributes', False),
                    batch_pixels_perc=1,
                    width=hyper['width'],height= hyper['height'])
hyper['width'], hyper['height'] = base_signal.dimensions()


#### SIREN Data Source

In [None]:
def get_mgrid(sidelen, dim=2):
    tensors = tuple(dim * [torch.linspace(-1, 1, steps=sidelen)])
    mgrid = torch.stack(torch.meshgrid(*tensors, indexing='ij'), dim=-1)
    mgrid = mgrid.reshape(-1, dim)
    return mgrid

In [None]:
def get_cameraman_tensor(sidelength):
    img = Image.fromarray(skimage.data.camera())        
    transform = Compose([
        Resize(sidelength),
        ToTensor()
    ])
    img = transform(img)
    return img

In [None]:
import scipy.ndimage
    
class PoissonEqn(Dataset):
    def __init__(self, sidelength):
        super().__init__()
        img = get_cameraman_tensor(sidelength)
        
        # Compute gradient and laplacian       
        grads_x = scipy.ndimage.sobel(img.numpy(), axis=1).squeeze(0)[..., None]
        grads_y = scipy.ndimage.sobel(img.numpy(), axis=2).squeeze(0)[..., None]
        grads_x, grads_y = torch.from_numpy(grads_x), torch.from_numpy(grads_y)
                
        self.grads = torch.stack((grads_x, grads_y), dim=-1).view(-1, 2)
        
        self.pixels = img.permute(1, 2, 0).view(-1, 1)
        self.coords = get_mgrid(sidelength, 2)

    def __len__(self):
        return 1

    def __getitem__(self, idx):
        return {'coords': self.coords}, {'d0':self.pixels, 'd1':self.grads}

In [None]:
cameraman = PoissonEqn(128)

### Select Data Source

In [None]:
# base_signal or cameraman
dataloader = DataLoader(cameraman, batch_size=1, pin_memory=True, num_workers=0)

### SIREN Net & Training

In [None]:

poisson_siren = Siren(in_features=2, out_features=1, first_omega_0=30, hidden_omega_0= 30, hidden_features=256, hidden_layers=3, outermost_linear=True)
poisson_siren.cuda()

In [None]:
def gradients_mse(model_output, coords, gt_gradients):
    # compute gradients on the model
    gradients = gradient(model_output, coords)
    gt_grads = gt_gradients.view(1,-1,2)
    # compare them with the ground-truth
    gradients_loss = torch.mean((gradients - gt_grads).pow(2).sum(-1))
    return gradients_loss

In [None]:
total_steps = 100
steps_til_summary = 10

optim = torch.optim.Adam(lr=1e-4, params=poisson_siren.parameters())

model_input, gt = next(iter(dataloader))
gt = {key: value.cuda() for key, value in gt.items()}
model_input = model_input['coords'].cuda()

for step in range(total_steps):

    model_output, coords = poisson_siren(model_input)
    train_loss = gradients_mse(model_output, coords, gt['d1'])

    if not step % steps_til_summary:
        print("Step %d, Total loss %0.6f" % (step, train_loss))

        img_grad = gradient(model_output, coords)

        fig, axes = plt.subplots(1, 2, figsize=(18, 6))

        axes[0].imshow(model_output.cpu().view(128,128).detach().numpy())
        axes[1].imshow(img_grad.cpu().norm(dim=-1).view(128,128).detach().numpy())
        plt.show()
        
    optim.zero_grad()
    train_loss.backward()
    optim.step()