In [11]:
#@title siren.py

# Based on https://github.com/EmilienDupont/coin
from math import sqrt

import einops
import torch
from torch import nn
import numpy as np

import gc

class Sine(nn.Module):
    """Sine activation with scaling.

    Args:
        w0 (float): Omega_0 parameter from SIREN paper.
    """

    def __init__(self, w0=1.0):
        super().__init__()
        self.w0 = w0

    def forward(self, x):
        return torch.sin(self.w0 * x)


class SirenLayer(nn.Module):
    """Implements a single SIREN layer.

    Args:
        dim_in (int): Dimension of input.
        dim_out (int): Dimension of output.
        w0 (float):
        c (float): c value from SIREN paper used for weight initialization.
        is_first (bool): Whether this is first layer of model.
        is_last (bool): Whether this is last layer of model. If it is, no
            activation is applied and 0.5 is added to the output. Since we
            assume all training data lies in [0, 1], this allows for centering
            the output of the model.
        use_bias (bool): Whether to learn bias in linear layer.
        activation (torch.nn.Module): Activation function. If None, defaults to
            Sine activation.
    """

    def __init__(
        self,
        dim_in,
        dim_out,
        w0=30.0,
        c=6.0,
        is_first=False,
        is_last=False,
        use_bias=True,
        activation=None,
    ):
        super().__init__()
        self.dim_in = dim_in
        self.dim_out = dim_out
        self.is_first = is_first
        self.is_last = is_last

        self.linear = nn.Linear(dim_in, dim_out, bias=use_bias)

        # Initialize layers following SIREN paper
        w_std = (1 / dim_in) if self.is_first else (sqrt(c / dim_in) / w0)
        nn.init.uniform_(self.linear.weight, -w_std, w_std)
        if use_bias:
            nn.init.uniform_(self.linear.bias, -w_std, w_std)

        self.activation = Sine(w0) if activation is None else activation

    def forward(self, x):
        out = self.linear(x)
        if self.is_last:
            # We assume target data is in [0, 1], so adding 0.5 allows us to learn
            # zero-centered features
            out += 0
        else:
            out = self.activation(out)
        return out


class Siren(nn.Module):
    """SIREN model.

    Args:
        dim_in (int): Dimension of input.
        dim_hidden (int): Dimension of hidden layers.
        dim_out (int): Dimension of output.
        num_layers (int): Number of layers.
        w0 (float): Omega 0 from SIREN paper.
        w0_initial (float): Omega 0 for first layer.
        use_bias (bool): Whether to learn bias in linear layer.
    """

    def __init__(
        self,
        dim_in,
        dim_hidden,
        dim_out,
        num_layers,
        w0=30.0,
        w0_initial=30.0,
        use_bias=True,
    ):
        super().__init__()
        self.dim_in = dim_in
        self.dim_hidden = dim_hidden
        self.dim_out = dim_out
        self.num_layers = num_layers

        layers = []
        for ind in range(num_layers - 1):
            is_first = ind == 0
            layer_w0 = w0_initial if is_first else w0
            layer_dim_in = dim_in if is_first else dim_hidden

            layers.append(
                SirenLayer(
                    dim_in=layer_dim_in,
                    dim_out=dim_hidden,
                    w0=layer_w0,
                    use_bias=use_bias,
                    is_first=is_first,
                )
            )

        self.net = nn.Sequential(*layers)

        self.last_layer = SirenLayer(
            dim_in=dim_hidden, dim_out=dim_out, w0=w0, use_bias=use_bias, is_last=True
        )

    def forward(self, x):
        """Forward pass of SIREN model.

        Args:
            x (torch.Tensor): Tensor of shape (*, dim_in), where * means any
                number of dimensions.

        Returns:
            Tensor of shape (*, dim_out).
        """
        x = self.net(x)
        return self.last_layer(x)


class ModulatedSiren(Siren):
    """Modulated SIREN model.

    Args:
        dim_in (int): Dimension of input.
        dim_hidden (int): Dimension of hidden layers.
        dim_out (int): Dimension of output.
        num_layers (int): Number of layers.
        w0 (float): Omega 0 from SIREN paper.
        w0_initial (float): Omega 0 for first layer.
        use_bias (bool): Whether to learn bias in linear layer.
        modulate_scale (bool): Whether to modulate with scales.
        modulate_shift (bool): Whether to modulate with shifts.
        use_latent (bool): If true, use a latent vector which is mapped to
            modulations, otherwise use modulations directly.
        latent_dim (int): Dimension of latent vector.
        modulation_net_dim_hidden (int): Number of hidden dimensions of
            modulation network.
        modulation_net_num_layers (int): Number of layers in modulation network.
            If this is set to 1 will correspond to a linear layer.
    """

    def __init__(
        self,
        dim_in,
        dim_hidden,
        dim_out,
        num_layers,
        w0=30.0,
        w0_initial=30.0,
        use_bias=True,
        modulate_scale=False,
        modulate_shift=True,
        use_latent=True,
        latent_dim=64,
        modulation_net_dim_hidden=64,
        modulation_net_num_layers=1,
        mu=0,
        sigma=1,
        last_activation=None,
    ):
        super().__init__(
            dim_in,
            dim_hidden,
            dim_out,
            num_layers,
            w0,
            w0_initial,
            use_bias,
        )
        # Must modulate at least one of scale and shift
        assert modulate_scale or modulate_shift

        self.modulate_scale = modulate_scale
        self.modulate_shift = modulate_shift
        self.w0 = w0
        self.w0_initial = w0_initial
        self.mu = mu
        self.sigma = sigma
        self.last_activation = (
            nn.Identity() if last_activation is None else last_activation
        )

        # We modulate features at every *hidden* layer of the base network and
        # therefore have dim_hidden * (num_layers - 1) modulations, since the
        # last layer is not modulated
        num_modulations = dim_hidden * (num_layers - 1)
        if self.modulate_scale and self.modulate_shift:
            # If we modulate both scale and shift, we have twice the number of
            # modulations at every layer and feature
            num_modulations *= 2

        if use_latent:
            self.modulation_net = LatentToModulation(
                latent_dim,
                num_modulations,
                modulation_net_dim_hidden,
                modulation_net_num_layers,
            )

        self.num_modulations = num_modulations

    def modulated_forward(self, x, latent):
        """Forward pass of modulated SIREN model.

        Args:
            x (torch.Tensor): Shape (batch_size, *, dim_in), where * refers to
                any spatial dimensions, e.g. (height, width), (height * width,)
                or (depth, height, width) etc.
            latent (torch.Tensor): Shape (batch_size, latent_dim). If
                use_latent=False, then latent_dim = num_modulations.

        Returns:
            Output features of shape (batch_size, *, dim_out).
        """
        # Extract batch_size and spatial dims of x, so we can reshape output
        x_shape = x.shape[:-1]
        # Flatten all spatial dimensions, i.e. shape
        # (batch_size, *, dim_in) -> (batch_size, num_points, dim_in)

        x = x.view(x.shape[0], -1, x.shape[-1])

        # Shape (batch_size, num_modulations)
        modulations = self.modulation_net(latent)

        # Split modulations into shifts and scales and apply them to hidden
        # features.
        mid_idx = (
            self.num_modulations // 2
            if (self.modulate_scale and self.modulate_shift)
            else 0
        )
        idx = 0
        for module in self.net:
            if self.modulate_scale:
                # Shape (batch_size, 1, dim_hidden). Note that we add 1 so
                # modulations remain zero centered
                scale = modulations[:, idx: idx +
                                    self.dim_hidden].unsqueeze(1) + 1.0
            else:
                scale = 1.0

            if self.modulate_shift:
                # Shape (batch_size, 1, dim_hidden)
                shift = modulations[
                    :, mid_idx + idx: mid_idx + idx + self.dim_hidden
                ].unsqueeze(1)
            else:
                shift = 0.0

            x = module.linear(x)
            x = scale * x + shift  # Broadcast scale and shift across num_points
            x = module.activation(x)  # (batch_size, num_points, dim_hidden)

            idx = idx + self.dim_hidden

        # Shape (batch_size, num_points, dim_out)
        out = self.last_activation(self.last_layer(x))
        out = out * self.sigma + self.mu
        # Reshape (batch_size, num_points, dim_out) -> (batch_size, *, dim_out)
        return out.view(*x_shape, out.shape[-1])

# hypernetwork mapping latent code to modulation
class LatentToModulation(nn.Module):
    """Maps a latent vector to a set of modulations.
    Args:
        latent_dim (int):
        num_modulations (int):
        dim_hidden (int):
        num_layers (int):
    """

    def __init__(
        self, latent_dim, num_modulations, dim_hidden, num_layers, activation=nn.SiLU
    ):
        super().__init__()
        self.latent_dim = latent_dim
        self.num_modulations = num_modulations
        self.dim_hidden = dim_hidden
        self.num_layers = num_layers
        self.activation = activation

        if num_layers == 1:
            self.net = nn.Linear(latent_dim, num_modulations)
        else:
            layers = [nn.Linear(latent_dim, dim_hidden), self.activation()]
            if num_layers > 2:
                for i in range(num_layers - 2):
                    layers += [nn.Linear(dim_hidden, dim_hidden),
                               self.activation()]
            layers += [nn.Linear(dim_hidden, num_modulations)]
            self.net = nn.Sequential(*layers)

    def forward(self, latent):
        return self.net(latent)


class Bias(nn.Module):
    def __init__(self, size):
        super().__init__()
        self.bias = nn.Parameter(torch.zeros(size), requires_grad=True)
        # Add latent_dim attribute for compatibility with LatentToModulation model
        self.latent_dim = size

    def forward(self, x):
        return x + self.bias

In [2]:
#@title coral.losses

import torch

mse_fn = torch.nn.MSELoss()
per_element_mse_fn = torch.nn.MSELoss(reduction="none")
per_element_bce_fn = torch.nn.BCEWithLogitsLoss(reduction="none")


def per_element_multi_scale_fn(
    model_output,
    gt,
    loss_name="mse",
    last_element=False,
):
    if loss_name == "mse":
        loss_fn = per_element_mse_fn
    elif loss_name == "bce":
        loss_fn = per_element_bce_fn

    N = gt.shape[0]
    gt = gt.reshape(N, -1)
    loss = [loss_fn(out.reshape(N, -1), gt) for out in model_output]

    loss = torch.stack(loss)

    return loss


def batch_multi_scale_fn(model_output, gt, loss_name="mse", use_resized=False):
    # if use_resized:
    #    loss = [(out - gt_img)**2 for out, gt_img in zip(model_output['model_out']['output'], gt['img'])]
    # else:
    per_element_multi_scale_mse = per_element_multi_scale_fn(
        model_output, gt, loss_name=loss_name
    )
    # Shape (batch_size,)
    return per_element_multi_scale_mse.view(gt.shape[0], -1).mean(dim=1)


def per_element_nll_fn(x, y):
    num_examples = x.size()[0]

    negative_log_likelihood = -(y * torch.log(x) + (1 - y) * torch.log(1 - x))

    return negative_log_likelihood


def per_element_rel_mse_fn(x, y, reduction=True):
    num_examples = x.size()[0]

    diff_norms = torch.norm(
        x.reshape(num_examples, -1) - y.reshape(num_examples, -1), 2, 1
    )
    y_norms = torch.norm(y.reshape(num_examples, -1), 2, 1)

    return diff_norms / y_norms


def batch_mse_rel_fn(x1, x2):
    """Computes MSE between two batches of signals while preserving the batch
    dimension (per batch element MSE).
    Args:
        x1 (torch.Tensor): Shape (batch_size, *).
        x2 (torch.Tensor): Shape (batch_size, *).
    Returns:
        MSE tensor of shape (batch_size,).
    """
    # Shape (batch_size, *)
    # per_element_mse = per_element_mse_fn(x1, x2)
    per_element_mse = per_element_rel_mse_fn(x1, x2)
    # Shape (batch_size,)
    return per_element_mse.view(x1.shape[0], -1).mean(dim=1)


def batch_mse_fn(x1, x2):
    """Computes MSE between two batches of signals while preserving the batch
    dimension (per batch element MSE).
    Args:
        x1 (torch.Tensor): Shape (batch_size, *).
        x2 (torch.Tensor): Shape (batch_size, *).
    Returns:
        MSE tensor of shape (batch_size,).
    """
    # Shape (batch_size, *)
    per_element_mse = per_element_mse_fn(x1, x2)
    # Shape (batch_size,)
    return per_element_mse.view(x1.shape[0], -1).mean(dim=1)


def batch_nll_fn(x1, x2):
    per_element_nll = per_element_nll_fn(x1, x2)
    return per_element_nll.view(x1.shape[0], -1).mean(dim=1)


def mse2psnr(mse):
    """Computes PSNR from MSE, assuming the MSE was calculated between signals
    lying in [0, 1].
    Args:
        mse (torch.Tensor or float):
    """
    return -10.0 * torch.log10(mse)


def psnr_fn(x1, x2):
    """Computes PSNR between signals x1 and x2. Note that the values of x1 and
    x2 are assumed to lie in [0, 1].
    Args:
        x1 (torch.Tensor): Shape (*).
        x2 (torch.Tensor): Shape (*).
    """
    return mse2psnr(mse_fn(x1, x2))


# from fno


# loss function with rel/abs Lp loss
class LpLoss(object):
    def __init__(self, d=2, p=2, size_average=True, reduction=True):
        super(LpLoss, self).__init__()

        # Dimension and Lp-norm type are postive
        assert d > 0 and p > 0

        self.d = d
        self.p = p
        self.reduction = reduction
        self.size_average = size_average

    def abs(self, x, y):
        num_examples = x.size()[0]

        # Assume uniform mesh
        h = 1.0 / (x.size()[1] - 1.0)

        all_norms = (h ** (self.d / self.p)) * torch.norm(
            x.view(num_examples, -1) - y.view(num_examples, -1), self.p, 1
        )

        if self.reduction:
            if self.size_average:
                return torch.mean(all_norms)
            else:
                return torch.sum(all_norms)

        return all_norms

    def rel(self, x, y):
        num_examples = x.size()[0]

        diff_norms = torch.norm(
            x.reshape(num_examples, -1) - y.reshape(num_examples, -1), self.p, 1
        )
        y_norms = torch.norm(y.reshape(num_examples, -1), self.p, 1)

        if self.reduction:
            if self.size_average:
                return torch.mean(diff_norms / y_norms)
            else:
                return torch.sum(diff_norms / y_norms)

        return diff_norms / y_norms

    def __call__(self, x, y):
        return self.rel(x, y)


In [3]:
#@title metalearning.py

from collections import OrderedDict
from functools import partial

import einops
import torch
import torch.nn as nn
import torch.utils.checkpoint as cp
from torch import autograd
from torch.nn.parallel import DistributedDataParallel as DDP

# adapted from https://github.com/EmilienDupont/coinpp/blob/main/coinpp/metalearning.py

def inner_loop(
    func_rep,
    modulations,
    coordinates,
    features,
    inner_steps,
    inner_lr,
    is_train=False,
    gradient_checkpointing=False,
    loss_type="mse",
):
    """Performs inner loop, i.e. fits modulations such that the function
    representation can match the target features.

    Args:
        func_rep (models.ModulatedSiren):
        modulations (torch.Tensor): Shape (batch_size, latent_dim). -> here the modulations ARE the latents
        coordinates (torch.Tensor): Coordinates at which function representation
            should be evaluated. Shape (batch_size, *, coordinate_dim).
        features (torch.Tensor): Target features for model to match. Shape
            (batch_size, *, feature_dim).
        inner_steps (int): Number of inner loop steps to take.
        inner_lr (float): Learning rate for inner loop.
        is_train (bool):
        gradient_checkpointing (bool): If True uses gradient checkpointing. This
            can massively reduce memory consumption.
    """
    fitted_modulations = modulations


    for step in range(inner_steps):
        if gradient_checkpointing:
            fitted_modulations = cp.checkpoint(
                inner_loop_step,
                func_rep,
                fitted_modulations,
                coordinates,
                features,
                torch.as_tensor(inner_lr),
                torch.as_tensor(is_train),
                torch.as_tensor(gradient_checkpointing),
                loss_type,
            )
        else:
            fitted_modulations = inner_loop_step(
                func_rep,
                fitted_modulations,
                coordinates,
                features,
                inner_lr,
                is_train,
                gradient_checkpointing,
                loss_type,
            )
    return fitted_modulations


def inner_loop_step(
    func_rep,
    modulations,
    coordinates,
    features,
    inner_lr,
    is_train=False,
    gradient_checkpointing=False,
    loss_type="mse",
):
    """Performs a single inner loop step."""
    detach = not torch.is_grad_enabled() and gradient_checkpointing
    batch_size = len(features)
    if loss_type == "mse":
        element_loss_fn = per_element_mse_fn
    elif loss_type == "bce":
        element_loss_fn = per_element_nll_fn
    elif "multiscale" in loss_type:
        loss_name = loss_type.split("-")[1]
        element_loss_fn = partial(
            per_element_multi_scale_fn,
            loss_name=loss_name,
            last_element=False,
        )

    N, C = features.shape[0], features.shape[-1]

    with torch.enable_grad():
        # Note we multiply by batch size here to undo the averaging across batch
        # elements from the MSE function. Indeed, each set of modulations is fit
        # independently and the size of the gradient should not depend on how
        # many elements are in the batch

        # dimension mismatch occurs here
        features_recon = func_rep.modulated_forward(coordinates, modulations)


        loss = element_loss_fn(features_recon, features).mean() * batch_size

        # If we are training, we should create graph since we will need this to
        # compute second order gradients in the MAML outer loop
        grad = torch.autograd.grad(
            loss,
            modulations,
            create_graph=is_train and not detach,
        )[0]
        # if clip_grad_value is not None:
        #    nn.utils.clip_grad_value_(grad, clip_grad_value)
    # Perform single gradient descent step ON THE LATENTS
    return modulations - inner_lr * grad


def outer_step(
    func_rep,
    coordinates,
    features,
    inner_steps,
    inner_lr,
    is_train=False,
    return_reconstructions=False,
    gradient_checkpointing=False,
    loss_type="mse",
    modulations=0,
    use_rel_loss=False,
):
    """

    Args:
        coordinates (torch.Tensor): Shape (batch_size, *, coordinate_dim). Note this
            _must_ have a batch dimension.
        features (torch.Tensor): Shape (batch_size, *, feature_dim). Note this _must_
            have a batch dimension.
    """

    if loss_type == "mse":
        loss_fn = batch_mse_fn
    elif loss_type == "bce":
        loss_fn = batch_nll_fn
    elif "multiscale" in loss_type:
        loss_name = loss_type.split("-")[1]
        loss_fn = partial(batch_multi_scale_fn, loss_name=loss_name)

    func_rep.zero_grad()
    batch_size = len(coordinates)
    if isinstance(func_rep, DDP):
        func_rep = func_rep.module

    modulations = modulations.requires_grad_()

    feat = features.clone()
    coords = coordinates.clone()

    # Run inner loop. HERE MODULATIONS IS LATENT Z
    modulations = inner_loop(
        func_rep,
        modulations,
        coords,
        feat,
        inner_steps,
        inner_lr,
        is_train,
        gradient_checkpointing,
        loss_type,
    )


    with torch.set_grad_enabled(is_train):
        features_recon = func_rep.modulated_forward(coordinates, modulations)
        per_example_loss = loss_fn(features_recon, features)  # features
        loss = per_example_loss.mean()

    outputs = {
        "loss": loss,
        "psnr": mse2psnr(per_example_loss).mean().item(),
        "modulations": modulations,
    }

    if return_reconstructions:
        outputs["reconstructions"] = (
            features_recon[-1] if "multiscale" in loss_type else features_recon
        )

    if use_rel_loss:
        rel_loss = (
            batch_mse_rel_fn(features_recon[-1], features).mean()
            if "multiscale" in loss_type
            else batch_mse_rel_fn(features_recon, features).mean()
        )
        outputs["rel_loss"] = rel_loss

    return outputs

In [4]:
#@title training logic

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from tqdm import tqdm


def train_inr(
    inr,
    train_loader,
    test_loader,
    latent_dim,
    inner_steps,
    inner_lr,
    lr_code,
    meta_lr_code,
    weight_decay_code,
    lr_inr,
    gamma_step,
    ntrain,
    ntest,
    epochs,
    saved_checkpoint,
    checkpoint,
    RESULTS_DIR,
    run_name,
    device,
):
    """
    Train an INR model using the meta-learning framework with inner and outer loops.

    Args:
        inr (nn.Module): The modulated SIREN model.
        train_loader (DataLoader): DataLoader for training data.
        test_loader (DataLoader): DataLoader for testing data.
        latent_dim (int): Dimension of latent codes.
        inner_steps (int): Number of inner-loop steps.
        inner_lr (float): Learning rate for inner-loop optimization.
        lr_code (float): Initial learning rate for alpha.
        meta_lr_code (float): Learning rate for meta-updates to alpha.
        weight_decay_code (float): Weight decay for alpha optimization.
        lr_inr (float): Learning rate for INR parameters.
        gamma_step (float): Learning rate decay factor for scheduler.
        ntrain (int): Number of training samples.
        ntest (int): Number of testing samples.
        epochs (int): Number of epochs.
        saved_checkpoint (bool): Whether to load from a saved checkpoint.
        checkpoint (dict): The saved checkpoint data.
        RESULTS_DIR (str): Directory to save results.
        run_name (str): Name for the saved model and logs.
        device (torch.device): Device for training (e.g., 'cuda' or 'cpu').

    Returns:
        None
    """
    # Initialize alpha parameter
    alpha = nn.Parameter(torch.Tensor([lr_code]).to(device))

    # Set optimizer
    optimizer = optim.AdamW(
        [
            {"params": inr.parameters(), "lr": lr_inr},
            {"params": alpha, "lr": meta_lr_code, "weight_decay": weight_decay_code},
        ],
        lr=lr_inr,
        weight_decay=0,
    )

    # Load from checkpoint if available
    if saved_checkpoint:
        inr.load_state_dict(checkpoint['inr'])
        optimizer.load_state_dict(checkpoint['optimizer_inr'])
        epoch_start = checkpoint['epoch']
        alpha = checkpoint['alpha']
        best_loss = checkpoint['loss']
        cfg = checkpoint['cfg']
        print("epoch_start, alpha, best_loss", epoch_start, alpha.item(), best_loss)
        print("cfg : ", cfg)
    else:
        epoch_start = 0
        best_loss = np.inf

    # Initialize scheduler
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        mode="min",
        factor=gamma_step,
        patience=500,
        threshold=0.01,
        threshold_mode="rel",
        cooldown=0,
        min_lr=1e-5,
        eps=1e-08,
        verbose=True,
    )

    # Training loop
    for step in range(epoch_start, epochs):

        rel_train_mse = 0
        rel_test_mse = 0
        fit_train_mse = 0
        fit_test_mse = 0
        use_rel_loss = step % 10 == 0
        step_show = step % 1 == 0
        step_show_last = step == epochs - 1

        for substep, (images, modulations, coords, idx) in enumerate(train_loader):
            inr.train()
            images = images.to(device)
            modulations = modulations.to(device)
            coords = coords.to(device)
            n_samples = images.shape[0]


            outputs = outer_step(
                inr,
                coords,
                images,
                inner_steps,
                alpha,
                is_train=True,
                return_reconstructions=False,
                gradient_checkpointing=False,
                use_rel_loss=use_rel_loss,
                loss_type="mse",
                modulations=torch.zeros_like(modulations),
            )


            optimizer.zero_grad()
            outputs["loss"].backward(create_graph=False)
            nn.utils.clip_grad_value_(inr.parameters(), clip_value=1.0)
            optimizer.step()
            loss = outputs["loss"].cpu().detach()
            fit_train_mse += loss.item() * n_samples

            # MLP regression
            if use_rel_loss:
                rel_train_mse += outputs["rel_loss"].item() * n_samples

        train_loss = fit_train_mse / ntrain

        scheduler.step(train_loss)

        if use_rel_loss:
            rel_train_loss = rel_train_mse / ntrain

        # evaluating on test
        if True in (step_show, step_show_last):
            for images, modulations, coords, idx in test_loader:
                inr.eval()
                images = images.to(device)
                modulations = modulations.to(device)
                coords = coords.to(device)
                n_samples = images.shape[0]

                outputs = outer_step(
                    inr,
                    coords,
                    images,
                    inner_steps,
                    alpha,
                    is_train=False,
                    return_reconstructions=False,
                    gradient_checkpointing=False,
                    use_rel_loss=use_rel_loss,
                    loss_type="mse",
                    modulations=torch.zeros_like(modulations),
                )

                loss = outputs["loss"]
                fit_test_mse += loss.item() * n_samples

                if use_rel_loss:
                    rel_test_mse += outputs["rel_loss"].item() * n_samples

            test_loss = fit_test_mse / ntest

            if use_rel_loss:
                rel_test_loss = rel_test_mse / ntest

            print(f'epoch = {step}; train_loss = {train_loss}; test_loss={test_loss}')

        # Save the best model
        if train_loss < best_loss:
            best_loss = train_loss
            print('save model')
            torch.save(
                {
                    "epoch": step,
                    "inr": inr,
                    "loss": best_loss,
                    "alpha": alpha,
                },
                f"{RESULTS_DIR}/{run_name}.pt",
            )

In [5]:
import torch
from torch.utils.data import DataLoader, Dataset

# Define a synthetic dataset
class SyntheticDataset(Dataset):
    def __init__(self, num_samples, num_points, input_dim=2, output_dim=1, latent_dim=16):
        self.num_samples = num_samples
        self.num_points = num_points
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.latent_dim = latent_dim

        # Generate random coordinates (x, y, ...)
        self.coordinates = torch.rand(num_samples, num_points, input_dim)
        # Generate corresponding outputs (e.g., a sine wave)
        self.images = torch.sin(2 * torch.pi * self.coordinates.sum(dim=-1, keepdim=True))
        # Initialize modulations with the correct dimensions
        self.modulations = torch.zeros(num_samples, self.latent_dim)

        print(self.coordinates.shape, self.images.shape, self.modulations.shape)

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        return self.images[idx], self.modulations[idx], self.coordinates[idx], torch.tensor([idx])
    
    
class FunctionDataset(Dataset):
    def __init__(self, coordinates, features, latent_dim):
        self.coordinates = coordinates
        self.features = features
        self.latent_dim = latent_dim

        num_samples = self.coordinates.shape[0]
        self.modulations = torch.zeros(num_samples, self.latent_dim)

    def __len__(self):
        return self.coordinates.shape[0]

    def __getitem__(self, idx):
        return self.features[idx], self.modulations[idx], self.coordinates[idx], torch.tensor([idx])

In [6]:
# now train an MLP from latent_a to latent_u

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from tqdm import tqdm


class LatentCodeMLP(nn.Module):
    def __init__(self, latent_dim, hidden_dim=128, num_layers=3):
        """
        A basic MLP to map one latent code to another.

        Args:
            latent_dim (int): Dimension of the input and output latent codes.
            hidden_dim (int): Dimension of the hidden layers.
            num_layers (int): Number of layers in the MLP (including input and output).
        """
        super(LatentCodeMLP, self).__init__()

        layers = []
        for i in range(num_layers):
            if i == 0:
                layers.append(nn.Linear(latent_dim, hidden_dim))
            elif i == num_layers - 1:
                layers.append(nn.Linear(hidden_dim, latent_dim))
            else:
                layers.append(nn.Linear(hidden_dim, hidden_dim))
            
            if i != num_layers - 1:
                layers.append(nn.SiLU())

        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)


def train_latent_code_mlp(model, dataset_a_train, dataset_u_train, dataset_a_test, dataset_u_test, latent_dim, epochs, lr, batch_size, device):
    """
    Train the MLP to map latent codes from one dataset to another.

    Args:
        model (nn.Module): The MLP model.
        dataset_a_train (torch.Tensor): Training latent codes from dataset `a` (shape: [num_train_samples, latent_dim]).
        dataset_u_train (torch.Tensor): Training latent codes from dataset `u` (shape: [num_train_samples, latent_dim]).
        dataset_a_test (torch.Tensor): Testing latent codes from dataset `a` (shape: [num_test_samples, latent_dim]).
        dataset_u_test (torch.Tensor): Testing latent codes from dataset `u` (shape: [num_test_samples, latent_dim]).
        latent_dim (int): Dimension of the latent codes.
        epochs (int): Number of training epochs.
        lr (float): Learning rate.
        batch_size (int): Batch size for training.
        device (torch.device): Device to use for training (CPU or GPU).

    Returns:
        nn.Module: Trained MLP model.
        list: Training losses per epoch.
        list: Testing losses per epoch.
    """
    # Create DataLoaders
    train_dataset = torch.utils.data.TensorDataset(dataset_a_train, dataset_u_train)
    test_dataset = torch.utils.data.TensorDataset(dataset_a_test, dataset_u_test)

    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=False)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    model = model.to(device)
    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)

    train_losses = []
    test_losses = []

    for epoch in range(epochs):
        # Training phase
        model.train()
        train_loss = 0.0

        for batch_a, batch_u in train_loader:
            batch_a, batch_u = batch_a.to(device), batch_u.to(device)

            optimizer.zero_grad()
            predictions = model(batch_a)
            loss = criterion(predictions, batch_u)
            loss.backward()
            optimizer.step()

            train_loss += loss.item() * batch_a.size(0)

        train_loss /= len(train_loader.dataset)
        train_losses.append(train_loss)

        # Testing phase
        model.eval()
        test_loss = 0.0

        with torch.no_grad():
            for batch_a, batch_u in test_loader:
                batch_a, batch_u = batch_a.to(device), batch_u.to(device)

                predictions = model(batch_a)
                loss = criterion(predictions, batch_u)

                test_loss += loss.item() * batch_a.size(0)

        test_loss /= len(test_loader.dataset)
        test_losses.append(test_loss)

        print(f"Epoch {epoch + 1}/{epochs}, Train Loss: {train_loss:.6f}, Test Loss: {test_loss:.6f}")

    return model, train_losses, test_losses

In [7]:
def ExtractEncoding(model_path, dataset):

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # EXTRACT THE CODES FROM INPUT INR

    best_model = torch.load(model_path)

    inr = best_model['inr']
    alpha = best_model['alpha']
    inner_steps = 3
    

    train_loader = DataLoader(dataset, batch_size=16, shuffle=False)

    encodings = []

    for substep, (images, modulations, coords, idx) in enumerate(train_loader):
        inr.train()
        images = images.to(device)
        modulations = modulations.to(device)
        coords = coords.to(device)
        n_samples = images.shape[0]

        outputs = outer_step(
          inr,
          coords,
          images,
          inner_steps,
          alpha,
          is_train=True,
          return_reconstructions=False,
          gradient_checkpointing=False,
          use_rel_loss=False,
          loss_type="mse",
          modulations=torch.zeros_like(modulations),
              )

        encodings.append(outputs['modulations'])

    # concatenate all the encodings
    encodings_tensor = torch.concatenate(encodings, axis=0)

    # for memory purposes
    del best_model, inr

    gc.collect()

    return encodings_tensor.detach()

In [10]:
def mask_generator(N, shape='box'):

    # initialise
    mask = np.zeros((N, N))

    if shape=='box':
        mask[1:-1, 1:-1] = 1

    if shape=='torus1':
        for i in range(N):
            for j in range(N):
                if (i-N//2)**2 + (j-N//2)**2 < (N/2.2)**2:
                    mask[i,j] = 1
                if (i-N//2)**2 + (j-N//2)**2 < (N/6)**2:
                    mask[i,j] = 0

    if shape=='torus2':
        for i in range(N):
            for j in range(N):
                if (i-N//2)**2 + (j-N//2)**2 < (N/2.2)**2:
                    mask[i,j] = 1
                if (i-N//4)**2 + (j-N//2)**2 < (N/8)**2:
                    mask[i,j] = 0
                if (i-3*N//4)**2 + (j-N//2)**2 < (N/8)**2:
                    mask[i,j] = 0

    if shape=='triangle':
        # Calculate the height of the triangle
        triangle_height = N-5

        # Center of the array
        cx = N // 2

        # Loop through the rows of the triangle
        for y in range(triangle_height):
            # Calculate the width of the triangle at the current row
            half_width = y // 2

            # Define the start and end indices for the row
            start_x = cx - half_width
            end_x = cx + half_width + 1

            # Fill the row with 1s
            mask[y, start_x:end_x] = 1

    if shape=='bean':

        num_points = 50
        a, b, c = 1.2, 0,-0.2
        circle_radius = 11.5

        # Generate normalized x values for the quadratic curve
        x_vals = np.linspace(-0.6, 0.6, num_points)
        y_vals = a * x_vals**2 + b * x_vals + c

        # Scale x and y values to pixel coordinates
        x_pixel = ((x_vals + 1) * (N // 2)).astype(int)
        y_pixel = ((1 - y_vals) * (N // 2)).astype(int)  # Invert y-axis for array indexing

        # Draw circles along the curve
        for x, y in zip(x_pixel, y_pixel):
            # Create a grid of indices for the circle
            y_grid, x_grid = np.ogrid[:N, :N]
            distance = (x_grid - x)**2 + (y_grid - y)**2
            mask[distance <= circle_radius**2] = 1


    return mask