In [1]:
import pickle
import numpy as np

In [9]:
with open('donut_data.pkl', 'rb') as file:
    data_list = pickle.load(file)
    
with open('donut_data_coord_grid.pkl', 'rb') as file:
    coord_grid = pickle.load(file)

In [3]:
N_data = len(data_list)
# split data 
data_train, data_test, data_val = data_list[0:1000], data_list[1000:1100], data_list[1100:1200]

In [11]:
coord_grid

array([[[0.        , 0.        ],
        [0.        , 0.01538462],
        [0.        , 0.03076923],
        ...,
        [0.        , 0.96923077],
        [0.        , 0.98461538],
        [0.        , 1.        ]],

       [[0.01538462, 0.        ],
        [0.01538462, 0.01538462],
        [0.01538462, 0.03076923],
        ...,
        [0.01538462, 0.96923077],
        [0.01538462, 0.98461538],
        [0.01538462, 1.        ]],

       [[0.03076923, 0.        ],
        [0.03076923, 0.01538462],
        [0.03076923, 0.03076923],
        ...,
        [0.03076923, 0.96923077],
        [0.03076923, 0.98461538],
        [0.03076923, 1.        ]],

       ...,

       [[0.96923077, 0.        ],
        [0.96923077, 0.01538462],
        [0.96923077, 0.03076923],
        ...,
        [0.96923077, 0.96923077],
        [0.96923077, 0.98461538],
        [0.96923077, 1.        ]],

       [[0.98461538, 0.        ],
        [0.98461538, 0.01538462],
        [0.98461538, 0.03076923],
        .

In [5]:
# Based on https://github.com/EmilienDupont/coin
from math import sqrt

import einops
import torch
from torch import nn
import numpy as np

#from coral.utils.interpolate import knn_interpolate_custom, rescale_coordinate


class Sine(nn.Module):
    """Sine activation with scaling.

    Args:
        w0 (float): Omega_0 parameter from SIREN paper.
    """
    def __init__(self, w0=1.0):
        super().__init__()
        self.w0 = w0

    def forward(self, x):
        return torch.sin(self.w0 * x)


class SirenLayer(nn.Module):
    """Implements a single SIREN layer.

    Args:
        dim_in (int): Dimension of input.
        dim_out (int): Dimension of output.
        w0 (float):
        c (float): c value from SIREN paper used for weight initialization.
        is_first (bool): Whether this is first layer of model.
        is_last (bool): Whether this is last layer of model. If it is, no
            activation is applied and 0.5 is added to the output. Since we
            assume all training data lies in [0, 1], this allows for centering
            the output of the model.
        use_bias (bool): Whether to learn bias in linear layer.
        activation (torch.nn.Module): Activation function. If None, defaults to
            Sine activation.
    """

    def __init__(
        self,
        dim_in,
        dim_out,
        w0=30.0,
        c=6.0,
        is_first=False,
        is_last=False,
        use_bias=True,
        activation=None,
    ):
        super().__init__()
        self.dim_in = dim_in
        self.dim_out = dim_out
        self.is_first = is_first
        self.is_last = is_last

        self.linear = nn.Linear(dim_in, dim_out, bias=use_bias)

        # Initialize layers following SIREN paper
        w_std = (1 / dim_in) if self.is_first else (sqrt(c / dim_in) / w0)
        nn.init.uniform_(self.linear.weight, -w_std, w_std)
        if use_bias:
            nn.init.uniform_(self.linear.bias, -w_std, w_std)

        self.activation = Sine(w0) if activation is None else activation

    def forward(self, x):
        out = self.linear(x)
        if self.is_last:
            # We assume target data is in [0, 1], so adding 0.5 allows us to learn
            # zero-centered features
            out += 0
        else:
            out = self.activation(out)
        return out


class Siren(nn.Module):
    """SIREN model.

    Args:
        dim_in (int): Dimension of input.
        dim_hidden (int): Dimension of hidden layers.
        dim_out (int): Dimension of output.
        num_layers (int): Number of layers.
        w0 (float): Omega 0 from SIREN paper.
        w0_initial (float): Omega 0 for first layer.
        use_bias (bool): Whether to learn bias in linear layer.
    """

    def __init__(
        self,
        dim_in,
        dim_hidden,
        dim_out,
        num_layers,
        w0=30.0,
        w0_initial=30.0,
        use_bias=True,
    ):
        super().__init__()
        self.dim_in = dim_in
        self.dim_hidden = dim_hidden
        self.dim_out = dim_out
        self.num_layers = num_layers

        layers = []
        for ind in range(num_layers - 1):
            is_first = ind == 0
            layer_w0 = w0_initial if is_first else w0
            layer_dim_in = dim_in if is_first else dim_hidden

            layers.append(
                SirenLayer(
                    dim_in=layer_dim_in,
                    dim_out=dim_hidden,
                    w0=layer_w0,
                    use_bias=use_bias,
                    is_first=is_first,
                )
            )

        self.net = nn.Sequential(*layers)

        self.last_layer = SirenLayer(
            dim_in=dim_hidden, dim_out=dim_out, w0=w0, use_bias=use_bias, is_last=True
        )

    def forward(self, x):
        """Forward pass of SIREN model.

        Args:
            x (torch.Tensor): Tensor of shape (*, dim_in), where * means any
                number of dimensions.

        Returns:
            Tensor of shape (*, dim_out).
        """
        x = self.net(x)
        return self.last_layer(x)


class ModulatedSiren(Siren):
    """Modulated SIREN model.

    Args:
        dim_in (int): Dimension of input.
        dim_hidden (int): Dimension of hidden layers.
        dim_out (int): Dimension of output.
        num_layers (int): Number of layers.
        w0 (float): Omega 0 from SIREN paper.
        w0_initial (float): Omega 0 for first layer.
        use_bias (bool): Whether to learn bias in linear layer.
        modulate_scale (bool): Whether to modulate with scales.
        modulate_shift (bool): Whether to modulate with shifts.
        use_latent (bool): If true, use a latent vector which is mapped to
            modulations, otherwise use modulations directly.
        latent_dim (int): Dimension of latent vector.
        modulation_net_dim_hidden (int): Number of hidden dimensions of
            modulation network.
        modulation_net_num_layers (int): Number of layers in modulation network.
            If this is set to 1 will correspond to a linear layer.
    """

    def __init__(
        self,
        dim_in,
        dim_hidden,
        dim_out,
        num_layers,
        w0=30.0,
        w0_initial=30.0,
        use_bias=True,
        modulate_scale=False,
        modulate_shift=True,
        use_latent=False,
        latent_dim=64,
        modulation_net_dim_hidden=64,
        modulation_net_num_layers=1,
        mu=0,
        sigma=1,
        last_activation=None,
    ):
        super().__init__(
            dim_in,
            dim_hidden,
            dim_out,
            num_layers,
            w0,
            w0_initial,
            use_bias,
        )
        # Must modulate at least one of scale and shift
        assert modulate_scale or modulate_shift

        self.modulate_scale = modulate_scale
        self.modulate_shift = modulate_shift
        self.w0 = w0
        self.w0_initial = w0_initial
        self.mu = mu
        self.sigma = sigma
        self.last_activation = (
            nn.Identity() if last_activation is None else last_activation
        )

        # We modulate features at every *hidden* layer of the base network and
        # therefore have dim_hidden * (num_layers - 1) modulations, since the
        # last layer is not modulated
        num_modulations = dim_hidden * (num_layers - 1)
        if self.modulate_scale and self.modulate_shift:
            # If we modulate both scale and shift, we have twice the number of
            # modulations at every layer and feature
            num_modulations *= 2

        if use_latent:
            self.modulation_net = LatentToModulation(
                latent_dim,
                num_modulations,
                modulation_net_dim_hidden,
                modulation_net_num_layers,
            )
        else:
            self.modulation_net = Bias(num_modulations)

        # Initialize scales to 1 and shifts to 0 (i.e. the identity)
        if not use_latent:
            if self.modulate_shift and self.modulate_scale:
                self.modulation_net.bias.data = torch.cat(
                    (
                        torch.ones(num_modulations // 2),
                        torch.zeros(num_modulations // 2),
                    ),
                    dim=0,
                )
            elif self.modulate_scale:
                self.modulation_net.bias.data = torch.ones(num_modulations)
            else:
                self.modulation_net.bias.data = torch.zeros(num_modulations)

        self.num_modulations = num_modulations

    def modulated_forward(self, x, latent):
        """Forward pass of modulated SIREN model.

        Args:
            x (torch.Tensor): Shape (batch_size, *, dim_in), where * refers to
                any spatial dimensions, e.g. (height, width), (height * width,)
                or (depth, height, width) etc.
            latent (torch.Tensor): Shape (batch_size, latent_dim). If
                use_latent=False, then latent_dim = num_modulations.

        Returns:
            Output features of shape (batch_size, *, dim_out).
        """
        # Extract batch_size and spatial dims of x, so we can reshape output
        x_shape = x.shape[:-1]
        # Flatten all spatial dimensions, i.e. shape
        # (batch_size, *, dim_in) -> (batch_size, num_points, dim_in)

        x = x.view(x.shape[0], -1, x.shape[-1])

        # Shape (batch_size, num_modulations)
        modulations = self.modulation_net(latent)

        # Split modulations into shifts and scales and apply them to hidden
        # features.
        mid_idx = (
            self.num_modulations // 2
            if (self.modulate_scale and self.modulate_shift)
            else 0
        )
        idx = 0
        for module in self.net:
            if self.modulate_scale:
                # Shape (batch_size, 1, dim_hidden). Note that we add 1 so
                # modulations remain zero centered
                scale = modulations[:, idx: idx +
                                    self.dim_hidden].unsqueeze(1) + 1.0
            else:
                scale = 1.0

            if self.modulate_shift:
                # Shape (batch_size, 1, dim_hidden)
                shift = modulations[
                    :, mid_idx + idx: mid_idx + idx + self.dim_hidden
                ].unsqueeze(1)
            else:
                shift = 0.0

            x = module.linear(x)
            x = scale * x + shift  # Broadcast scale and shift across num_points
            x = module.activation(x)  # (batch_size, num_points, dim_hidden)

            idx = idx + self.dim_hidden

        # Shape (batch_size, num_points, dim_out)
        out = self.last_activation(self.last_layer(x))
        out = out * self.sigma + self.mu
        # Reshape (batch_size, num_points, dim_out) -> (batch_size, *, dim_out)
        return out.view(*x_shape, out.shape[-1])


class LatentToModulation(nn.Module):
    """Maps a latent vector to a set of modulations.
    Args:
        latent_dim (int):
        num_modulations (int):
        dim_hidden (int):
        num_layers (int):
    """

    def __init__(
        self, latent_dim, num_modulations, dim_hidden, num_layers, activation=nn.SiLU
    ):
        super().__init__()
        self.latent_dim = latent_dim
        self.num_modulations = num_modulations
        self.dim_hidden = dim_hidden
        self.num_layers = num_layers
        self.activation = activation

        if num_layers == 1:
            self.net = nn.Linear(latent_dim, num_modulations)
        else:
            layers = [nn.Linear(latent_dim, dim_hidden), self.activation()]
            if num_layers > 2:
                for i in range(num_layers - 2):
                    layers += [nn.Linear(dim_hidden, dim_hidden),
                               self.activation()]
            layers += [nn.Linear(dim_hidden, num_modulations)]
            self.net = nn.Sequential(*layers)

    def forward(self, latent):
        return self.net(latent)


class Bias(nn.Module):
    def __init__(self, size):
        super().__init__()
        self.bias = nn.Parameter(torch.zeros(size), requires_grad=True)
        # Add latent_dim attribute for compatibility with LatentToModulation model
        self.latent_dim = size

    def forward(self, x):
        return x + self.bias


class ConvLatentToModulation(nn.Module):
    """Maps a latent vector to a set of modulations.

    Args:
        latent_dim (int):
        num_modulations (int):
        dim_hidden (int):
        num_layers (int):
    """

    def __init__(
        self, input_dim, grid_size, latent_dim, num_modulations, kernel_size=1
    ):
        super().__init__()
        self.latent_dim = latent_dim
        self.num_modulations = num_modulations
        self.grid_size = grid_size

        if input_dim == 2:
            self.net = nn.Conv2d(latent_dim, num_modulations, kernel_size)
        elif input_dim == 1:
            self.net = nn.Conv1d(latent_dim, num_modulations, kernel_size)
            # self.coords = shape2coordinates([grid_size, grid_size])

    def forward(self, latent):
        return self.net(latent)


class ConvModulatedSiren(Siren):
    """Modulated SIREN model.

    Args:
        dim_in (int): Dimension of input.
        dim_hidden (int): Dimension of hidden layers.
        dim_out (int): Dimension of output.
        num_layers (int): Number of layers.
        w0 (float): Omega 0 from SIREN paper.
        w0_initial (float): Omega 0 for first layer.
        use_bias (bool): Whether to learn bias in linear layer.
        modulate_scale (bool): Whether to modulate with scales.
        modulate_shift (bool): Whether to modulate with shifts.
        use_latent (bool): If true, use a latent vector which is mapped to
            modulations, otherwise use modulations directly.
        latent_dim (int): Dimension of latent vector.
        modulation_net_dim_hidden (int): Number of hidden dimensions of
            modulation network.
        modulation_net_num_layers (int): Number of layers in modulation network.
            If this is set to 1 will correspond to a linear layer.
    """

    def __init__(
        self,
        dim_in,
        dim_hidden,
        dim_out,
        num_layers,
        w0=30.0,
        w0_initial=30.0,
        use_bias=True,
        modulate_scale=False,
        modulate_shift=True,
        use_latent=False,
        latent_dim=64,
        mu=0,
        sigma=1,
        last_activation=None,
        grid_size=8,
        interpolation="bilinear",
        conv_kernel=3,
    ):
        super().__init__(
            dim_in,
            dim_hidden,
            dim_out,
            num_layers,
            w0,
            w0_initial,
            use_bias,
        )
        # Must modulate at least one of scale and shift
        assert modulate_scale or modulate_shift

        self.modulate_scale = modulate_scale
        self.modulate_shift = modulate_shift
        self.w0 = w0
        self.w0_initial = w0_initial
        self.mu = mu
        self.sigma = sigma
        self.grid_size = grid_size
        self.interpolation = interpolation
        self.last_activation = (
            nn.Identity() if last_activation is None else last_activation
        )

        # We modulate features at every *hidden* layer of the base network and
        # therefore have dim_hidden * (num_layers - 1) modulations, since the
        # last layer is not modulated
        num_modulations = dim_hidden * (num_layers - 1)
        if self.modulate_scale and self.modulate_shift:
            # If we modulate both scale and shift, we have twice the number of
            # modulations at every layer and feature
            num_modulations *= 2

        self.modulation_net = ConvLatentToModulation(
            dim_in, grid_size, latent_dim, num_modulations, conv_kernel
        )

        # Initialize scales to 1 and shifts to 0 (i.e. the identity)
        if not use_latent:
            if self.modulate_shift and self.modulate_scale:
                self.modulation_net.bias.data = torch.cat(
                    (
                        torch.ones(num_modulations // 2),
                        torch.zeros(num_modulations // 2),
                    ),
                    dim=0,
                )
            elif self.modulate_scale:
                self.modulation_net.bias.data = torch.ones(num_modulations)
            else:
                self.modulation_net.bias.data = torch.zeros(num_modulations)

        self.num_modulations = num_modulations

        # self.mode = "linear" if dim_in == 1 else "bilinear"

    def modulated_forward(self, x, latent):
        """Forward pass of modulated SIREN model.

        Args:
            x (torch.Tensor): Shape (batch_size, *, dim_in), where * refers to
                any spatial dimensions, e.g. (height, width), (height * width,)
                or (depth, height, width) etc.
            latent (torch.Tensor): Shape (batch_size, latent_dim). If
                use_latent=False, then latent_dim = num_modulations.

        Returns:
            Output features of shape (batch_size, *, dim_out).
        """
        # Extract batch_size and spatial dims of x, so we can reshape output
        x_shape = x.shape[:-1]
        # Flatten all spatial dimensions, i.e. shape
        # (batch_size, *, dim_in) -> (batch_size, num_points, dim_in)

        # Shape (batch_size, num_modulations)
        modulations = self.modulation_net(latent)
        # print('mod1', modulations.shape)
        # v1 modulations = torch.nn.functional.grid_sample(modulations, x, mode='bilinear', padding_mode='zeros', align_corners=None)
        # v2 modulations = grid_sample(modulations, x)
        # vtest
        modulations = torch.nn.functional.interpolate(
            modulations, x.shape[1:-1], mode=self.interpolation
        )
        # print('mod2', modulations.shape)

        x = x.view(x.shape[0], -1, x.shape[-1])
        # modulations = modulations.permute(0, 2, 3, 1)

        # place the channel at the end
        modulations = torch.movedim(modulations, 1, -1)
        modulations = modulations.view(x.shape[0], -1, self.num_modulations)
        # print('mod', modulations.shape)

        # Split modulations into shifts and scales and apply them to hidden
        # features.
        mid_idx = (
            self.num_modulations // 2
            if (self.modulate_scale and self.modulate_shift)
            else 0
        )
        idx = 0
        for module in self.net:
            if self.modulate_scale:
                # Shape (batch_size, 1, dim_hidden). Note that we add 1 so
                # modulations remain zero centered
                scale = modulations[:, :, idx: idx + self.dim_hidden] + 1.0
            else:
                scale = 1.0

            if self.modulate_shift:
                # Shape (batch_size, 1, dim_hidden)
                shift = modulations[
                    :, :, mid_idx + idx: mid_idx + idx + self.dim_hidden
                ]
            else:
                shift = 0.0

            x = module.linear(x)
            x = scale * x + shift  # Broadcast scale and shift across num_points
            x = module.activation(x)  # (batch_size, num_points, dim_hidden)

            idx = idx + self.dim_hidden

        # Shape (batch_size, num_points, dim_out)
        out = self.last_activation(self.last_layer(x))
        # print('out', out.shape)
        # Reshape (batch_size, num_points, dim_out) -> (batch_size, *, dim_out)
        return out.view(*x_shape, out.shape[-1])


class ConvModulatedSiren2(Siren):
    """Modulated SIREN model.

    Args:
        dim_in (int): Dimension of input.
        dim_hidden (int): Dimension of hidden layers.
        dim_out (int): Dimension of output.
        num_layers (int): Number of layers.
        w0 (float): Omega 0 from SIREN paper.
        w0_initial (float): Omega 0 for first layer.
        use_bias (bool): Whether to learn bias in linear layer.
        modulate_scale (bool): Whether to modulate with scales.
        modulate_shift (bool): Whether to modulate with shifts.
        use_latent (bool): If true, use a latent vector which is mapped to
            modulations, otherwise use modulations directly.
        latent_dim (int): Dimension of latent vector.
        modulation_net_dim_hidden (int): Number of hidden dimensions of
            modulation network.
        modulation_net_num_layers (int): Number of layers in modulation network.
            If this is set to 1 will correspond to a linear layer.
    """

    def __init__(
        self,
        dim_in,
        dim_hidden,
        dim_out,
        num_layers,
        w0=30.0,
        w0_initial=30.0,
        use_bias=True,
        modulate_scale=False,
        modulate_shift=True,
        use_latent=False,
        latent_dim=64,
        mu=0,
        sigma=1,
        last_activation=None,
        grid_size=8,
        interpolation="bilinear",
        conv_kernel=3,
        rescale_coordinate=False,
    ):
        super().__init__(
            dim_in,
            dim_hidden,
            dim_out,
            num_layers,
            w0,
            w0_initial,
            use_bias,
        )
        # Must modulate at least one of scale and shift
        assert modulate_scale or modulate_shift

        self.modulate_scale = modulate_scale
        self.modulate_shift = modulate_shift
        self.w0 = w0
        self.w0_initial = w0_initial
        self.mu = mu
        self.sigma = sigma
        self.grid_size = grid_size
        self.interpolation = interpolation
        self.last_activation = (
            nn.Identity() if last_activation is None else last_activation
        )
        self.rescale_coordinate = rescale_coordinate

        # We modulate features at every *hidden* layer of the base network and
        # therefore have dim_hidden * (num_layers - 1) modulations, since the
        # last layer is not modulated
        num_modulations = dim_hidden * (num_layers - 1)
        if self.modulate_scale and self.modulate_shift:
            # If we modulate both scale and shift, we have twice the number of
            # modulations at every layer and feature
            num_modulations *= 2

        self.modulation_net = ConvLatentToModulation(
            dim_in, grid_size, latent_dim, num_modulations, conv_kernel
        )

        # Initialize scales to 1 and shifts to 0 (i.e. the identity)
        if not use_latent:
            if self.modulate_shift and self.modulate_scale:
                self.modulation_net.bias.data = torch.cat(
                    (
                        torch.ones(num_modulations // 2),
                        torch.zeros(num_modulations // 2),
                    ),
                    dim=0,
                )
            elif self.modulate_scale:
                self.modulation_net.bias.data = torch.ones(num_modulations)
            else:
                self.modulation_net.bias.data = torch.zeros(num_modulations)

        self.num_modulations = num_modulations

        # create modulation positions
        crds = []
        for i in range(dim_in):
            crds.append(torch.linspace(0.0, 1.0, grid_size))
        self.modulation_grid = torch.stack(
            torch.meshgrid(*crds, indexing="ij"), dim=-1)

    def modulated_forward(self, x, latent):
        """Forward pass of modulated SIREN model.

        Args:
            x (torch.Tensor): Shape (batch_size, *, dim_in), where * refers to
                any spatial dimensions, e.g. (height, width), (height * width,)
                or (depth, height, width) etc.
            latent (torch.Tensor): Shape (batch_size, latent_dim). If
                use_latent=False, then latent_dim = num_modulations.

        Returns:
            Output features of shape (batch_size, *, dim_out).
        """
        # Extract batch_size and spatial dims of x, so we can reshape output
        x_shape = x.shape[:-1]
        x = x.view(x.shape[0], -1, x.shape[-1])
        batch_size, num_points = x.shape[0], x.shape[1]
        # Flatten all spatial dimensions, i.e. shape
        # (batch_size, *, dim_in) -> (batch_size, num_points, dim_in)

        # Shape (batch_size, num_modulations)
        modulations = self.modulation_net(latent)

        # print('toto', modulations.shape, latent.shape)
        modulations = torch.movedim(modulations, 1, -1)

        modulations = modulations.view(
            modulations.shape[0], -1, modulations.shape[-1])

        mod_grid = self.modulation_grid.view(-1, x.shape[-1])
        mod_grid = (
            mod_grid.unsqueeze(0)
            .repeat(x_shape[0], *(1,) * self.modulation_grid.ndim)
            .cuda()
        )

        x_batch = (
            torch.arange(modulations.shape[0])
            .view(-1, 1)
            .repeat(1, modulations.shape[1])
            .flatten()
            .cuda()
        )
        y_batch = (
            torch.arange(x.shape[0]).view(-1, 1).repeat(1,
                                                        x.shape[1]).flatten().cuda()
        )

        # print('tototot')
        # print(einops.rearrange(modulations, 'b d c -> (b d) c').shape,
        #      einops.rearrange(mod_grid, 'b d c -> (b d) c').shape,
        #      x.view(-1, x.shape[-1]).shape,
        #      x_batch.shape,
        #      y_batch.shape)

        modulations = knn_interpolate_custom(
            einops.rearrange(modulations, "b d c -> (b d) c"),
            einops.rearrange(mod_grid, "b d c -> (b d) c"),
            x.view(-1, x.shape[-1]),
            batch_x=x_batch,
            batch_y=y_batch,
        )
        if self.rescale_coordinate:
            x = rescale_coordinate(
                self.modulation_grid.cuda(), x.view(-1, x.shape[-1]))

        x = x.view(batch_size, -1, x.shape[-1])
        modulations = modulations.view(batch_size, -1, modulations.shape[-1])

        # Split modulations into shifts and scales and apply them to hidden
        # features.
        mid_idx = (
            self.num_modulations // 2
            if (self.modulate_scale and self.modulate_shift)
            else 0
        )
        idx = 0
        for module in self.net:
            if self.modulate_scale:
                # Shape (batch_size, 1, dim_hidden). Note that we add 1 so
                # modulations remain zero centered
                scale = modulations[:, :, idx: idx + self.dim_hidden] + 1.0
            else:
                scale = 1.0

            if self.modulate_shift:
                # Shape (batch_size, 1, dim_hidden)
                shift = modulations[
                    :, :, mid_idx + idx: mid_idx + idx + self.dim_hidden
                ]
            else:
                shift = 0.0

            x = module.linear(x)
            x = scale * x + shift  # Broadcast scale and shift across num_points
            x = module.activation(x)  # (batch_size, num_points, dim_hidden)

            idx = idx + self.dim_hidden

        # Shape (batch_size, num_points, dim_out)
        out = self.last_activation(self.last_layer(x))
        # print('out', out.shape)
        # Reshape (batch_size, num_points, dim_out) -> (batch_size, *, dim_out)
        return out.view(*x_shape, out.shape[-1])


class SSN(nn.Module):
    """Simple Sine model.

    Args:
        dim_in (int): Dimension of input.
        dim_hidden (int): Dimension of hidden layers.
        dim_out (int): Dimension of output.
        num_layers (int): Number of layers.
        w0 (float): Omega 0 from SIREN paper.
        w0_initial (float): Omega 0 for first layer.
        use_bias (bool): Whether to learn bias in linear layer.
    """

    def __init__(
        self,
        dim_in,
        dim_hidden,
        dim_out,
        num_layers,
        w0=30.0,
    ):
        super().__init__()
        self.dim_in = dim_in
        self.dim_hidden = dim_hidden
        self.dim_out = dim_out
        self.num_layers = num_layers
        self.w0 = w0

        layers = [nn.Linear(dim_in, dim_hidden)]
        for j in range(self.num_layers-1):
            layers.append(nn.Linear(dim_hidden, dim_hidden))
        layers.append(nn.Linear(dim_hidden, dim_out))
        self.layers = nn.ModuleList(layers)
        
        self.init_weights()
                          
    def init_weights(self):
        for j, layer in enumerate(self.layers):
            if j == 0:
                nn.init.normal(layer.weight, 0,  np.sqrt(2) / np.sqrt(layer.weight.shape[1]))
                #nn.init.uniform(layer.weight, -self.w0 / layer.weight.shape[1], self.w0 / layer.weight.shape[1])
            else:
                nn.init.normal(layer.weight, 0, np.sqrt(2) / np.sqrt(layer.weight.shape[1]))
            print(layer.weight.std())
        
    def forward(self, x):
        for j, layer in enumerate(self.layers[:-1]):
            if j == 0:
                x = torch.sin(self.w0*layer(x))
            else:
                x = torch.sin(layer(x))
        out = self.layers[-1](x)
        return out 



class ModulatedSSN(nn.Module):
    """Modulated Sine model.

    Args:
        dim_in (int): Dimension of input.
        dim_hidden (int): Dimension of hidden layers.
        dim_out (int): Dimension of output.
        num_layers (int): Number of layers.
        w0 (float): Omega 0 from SIREN paper.
        w0_initial (float): Omega 0 for first layer.
        use_bias (bool): Whether to learn bias in linear layer.
    """

    def __init__(
        self,
        dim_in,
        dim_hidden,
        dim_out,
        num_layers,
        w0=30.0,
        latent_dim=64,
        modulation_net_dim_hidden=64,
        modulation_net_num_layers=1,
    ):
        super().__init__()
        self.dim_in = dim_in
        self.dim_hidden = dim_hidden
        self.dim_out = dim_out
        self.num_layers = num_layers
        self.w0 = w0
        self.latent_dim = latent_dim    

        layers = [nn.Linear(dim_in, dim_hidden)]
        for j in range(self.num_layers-1):
            layers.append(nn.Linear(dim_hidden, dim_hidden))
        layers.append(nn.Linear(dim_hidden, dim_out))
        self.layers = nn.ModuleList(layers)
        self.num_modulations = dim_hidden * len(self.layers[:-1])    

        self.modulation_net = LatentToModulation(
                latent_dim,
                self.num_modulations,
                modulation_net_dim_hidden,
                modulation_net_num_layers,
            )
        self.init_weights()
                          
    def init_weights(self):
        for j, layer in enumerate(self.layers):
            if j == 0:
                nn.init.normal(layer.weight, 0,  np.sqrt(2) / np.sqrt(layer.weight.shape[1]))
                #nn.init.uniform(layer.weight, -self.w0 / layer.weight.shape[1], self.w0 / layer.weight.shape[1])
            else:
                nn.init.normal(layer.weight, 0, np.sqrt(2) / np.sqrt(layer.weight.shape[1]))
            print(layer.weight.std())
        
    def modulated_forward(self, x, z):
        x_shape = x.shape[:-1]
        x = x.view(x.shape[0], -1, x.shape[-1])
        modulations = self.modulation_net(z)
        modulations = modulations.reshape(-1, self.latent_dim, len(self.layers[:-1])).unsqueeze(1)

        for j, layer in enumerate(self.layers[:-1]):
            if j == 0:
                x = torch.sin(self.w0*layer(x) + modulations[..., j])
            else:
                x = torch.sin(layer(x) + modulations[..., j])
        out = self.layers[-1](x)

        return out.view(*x_shape, out.shape[-1])

# GPT implementation of encoding

In [6]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset

class FunctionDataset(Dataset):
    """
    Dataset to handle sampled input functions.
    Args:
        samples (torch.Tensor): Input coordinates of shape (num_samples, dim_in).
        values (torch.Tensor): Corresponding function values of shape (num_samples, dim_out).
    """
    def __init__(self, samples, values):
        super().__init__()
        self.samples = samples
        self.values = values

    def __len__(self):
        return self.samples.shape[0]

    def __getitem__(self, idx):
        return self.samples[idx], self.values[idx]

In [52]:
u_sample = data_train[0][0]

samples = []
values = []

for i in range(u_sample.shape[0]):
    for j in range(u_sample.shape[1]):
        samples.append(coord_grid[i, j])
        values.append(u_sample[i, j])
        
samples_tensor = torch.tensor(samples, dtype=torch.float)
values_tensor = torch.tensor(values, dtype=torch.float).unsqueeze(-1)

In [53]:
samples_tensor.shape, values_tensor.shape

(torch.Size([4356, 2]), torch.Size([4356, 1]))

In [27]:
class HyperNetwork(nn.Module):
    """
    Hyper-network that maps a latent code to modulation parameters.
    Args:
        latent_dim (int): Dimension of the latent code.
        num_modulations (int): Number of modulation parameters.
        hidden_dim (int): Dimension of hidden layers in the hyper-network.
        num_layers (int): Number of layers in the hyper-network.
    """
    def __init__(self, latent_dim, num_modulations, hidden_dim, num_layers):
        super().__init__()
        layers = [nn.Linear(latent_dim, hidden_dim), nn.ReLU()]
        for _ in range(num_layers - 2):
            layers += [nn.Linear(hidden_dim, hidden_dim), nn.ReLU()]
        layers.append(nn.Linear(hidden_dim, num_modulations))
        self.net = nn.Sequential(*layers)

    def forward(self, z):
        return self.net(z)


class CORALEncoderHyper(nn.Module):
    """
    CORAL Encoder with hyper-network for auto-decoding.
    Args:
        input_dim (int): Dimension of the input space.
        hidden_dim (int): Dimension of hidden layers in the SIREN.
        latent_dim (int): Dimension of the latent code.
        output_dim (int): Dimension of the output space (function values).
        num_layers (int): Number of SIREN layers.
        w0 (float): Omega_0 value for sine activation.
        lr (float): Learning rate for optimization.
        hyper_hidden_dim (int): Hidden dimension in the hyper-network.
        hyper_layers (int): Number of layers in the hyper-network.
    """
    def __init__(self, input_dim, hidden_dim, latent_dim, output_dim, num_layers, w0=30.0, lr=1e-3,
                 hyper_hidden_dim=64, hyper_layers=3):
        super().__init__()
        self.latent_dim = latent_dim
        self.siren = Siren(
            dim_in=input_dim,
            dim_hidden=hidden_dim,
            dim_out=output_dim,
            num_layers=num_layers,
            w0=w0
        )
        self.hyper_net = HyperNetwork(
            latent_dim=latent_dim,
            num_modulations=hidden_dim * (num_layers - 1),  # Number of modulations for hidden layers
            hidden_dim=hyper_hidden_dim,
            num_layers=hyper_layers
        )
        self.latent_code = nn.Parameter(torch.zeros(latent_dim), requires_grad=True)
        self.lr = lr

    def forward(self, x):
        """
        Forward pass through the INR modulated by parameters from the hyper-network.
        Args:
            x (torch.Tensor): Input coordinates of shape (batch_size, input_dim).
        Returns:
            torch.Tensor: Reconstructed function values of shape (batch_size, output_dim).
        """
        modulations = self.hyper_net(self.latent_code)
        return self.siren_with_modulation(x, modulations)

    def siren_with_modulation(self, x, modulations):
        """
        Modulate the SIREN with the hyper-network's output.
        Args:
            x (torch.Tensor): Input coordinates.
            modulations (torch.Tensor): Modulation parameters from the hyper-network.
        Returns:
            torch.Tensor: Output of the SIREN.
        """
        idx = 0
        for layer in self.siren.net:
            x = layer.linear(x)
            x = x * modulations[idx: idx + layer.dim_out] + 1.0  # Scale modulation
            x = layer.activation(x)
            idx += layer.dim_out
        return self.siren.last_layer(x)

    def encode(self, dataset, num_epochs=500, batch_size=128):
        """
        Optimize the latent code to encode the input function.
        Args:
            dataset (Dataset): Dataset containing input coordinates and function values.
            num_epochs (int): Number of optimization steps.
            batch_size (int): Batch size for training.
        Returns:
            torch.Tensor: Optimized latent code representing the input function.
        """
        dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
        optimizer = optim.Adam([self.latent_code], lr=self.lr)
        loss_fn = nn.MSELoss()

        for epoch in range(num_epochs):
            total_loss = 0
            for coords, values in dataloader:
                optimizer.zero_grad()
                preds = self.forward(coords)
                loss = loss_fn(preds, values)
                loss.backward()
                optimizer.step()
                total_loss += loss.item()

            if (epoch + 1) % 50 == 0:
                print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {total_loss:.4f}")

        return self.latent_code.detach()

In [54]:
# Define input function (e.g., a simple 2D sine wave)
#coords = torch.rand(1000, 2)  # Input coordinates in [0, 1]^2
#values = torch.sin(2 * torch.pi * coords[:, 0]) * torch.cos(2 * torch.pi * coords[:, 1])  # Function values
#values = values.unsqueeze(-1)  # Reshape to (num_samples, output_dim)

# Create dataset
dataset = FunctionDataset(samples_tensor, values_tensor)

# Initialize CORAL encoder with hyper-network
coral_encoder = CORALEncoderHyper(
    input_dim=2, hidden_dim=64, latent_dim=32, output_dim=1, num_layers=4, w0=30.0, lr=1e-3
)

# Optimize latent code
latent_code = coral_encoder.encode(dataset)

# Print the learned latent code
print("Optimized Latent Code:", latent_code)

Epoch 50/500, Loss: 631.8710
Epoch 100/500, Loss: 667.1519
Epoch 150/500, Loss: 630.5511
Epoch 200/500, Loss: 629.5670
Epoch 250/500, Loss: 630.9309
Epoch 300/500, Loss: 626.2505
Epoch 350/500, Loss: 628.1308
Epoch 400/500, Loss: 625.9124
Epoch 450/500, Loss: 625.8239
Epoch 500/500, Loss: 631.9530
Optimized Latent Code: tensor([  7.9867,   6.7610,  -6.0202,   3.7552,  -9.0773,   3.3845,   2.6099,
         -0.0196,   7.4714,   0.3657,   2.3088,   9.7401,   6.3893,   7.2908,
          4.9538,  -1.6567,  -4.6801,  -6.8735,   6.2920,   4.6968,  -3.1676,
          7.4630,   4.2138,  -4.1387,  -0.4714,  -4.6254, -11.6401,  -1.9965,
          4.1924,  -0.9238,   2.9957,  -2.4327])


In [46]:
samples_tensor.shape, values_tensor.shape

(torch.Size([4356, 2]), torch.Size([4356, 1]))

In [49]:
coords.shape, values.shape

(torch.Size([1000, 2]), torch.Size([1000, 1]))

In [55]:
class CORALTrainer:
    def __init__(self, model, latent_dim, learning_rate=1e-3, lambda_=1e-2):
        """
        Initialize the CORAL Trainer.
        Args:
            model (ModulatedSSN): The shared INR model.
            latent_dim (int): Dimension of the latent code.
            learning_rate (float): Learning rate for outer loop updates.
            lambda_ (float): Scaling factor for outer updates.
        """
        self.model = model
        self.latent_dim = latent_dim
        self.learning_rate = learning_rate
        self.lambda_ = lambda_
        self.optimizer = optim.Adam(self.model.parameters(), lr=self.learning_rate)

    def train(self, dataset, num_epochs=100, batch_size=16, inner_steps=5, inner_lr=1e-2):
        """
        Train the model following Algorithm 1.
        Args:
            dataset (Dataset): Dataset containing input functions.
            num_epochs (int): Number of training epochs.
            batch_size (int): Number of functions per batch.
            inner_steps (int): Number of inner-loop optimization steps.
            inner_lr (float): Learning rate for the inner loop.
        """
        dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
        loss_fn = nn.MSELoss()

        for epoch in range(num_epochs):
            for batch in dataloader:
                coords, values = batch

                # Initialize latent codes for all functions in the batch
                batch_latent_codes = torch.zeros(len(values), self.latent_dim, requires_grad=True)
                inner_optimizer = optim.SGD([batch_latent_codes], lr=inner_lr)

                # Inner Encoding Loop: Optimize latent codes
                for _ in range(inner_steps):
                    inner_optimizer.zero_grad()
                    total_loss = 0
                    for i, (coord, value) in enumerate(zip(coords, values)):
                        # Predict function values using the INR and latent code
                        pred = self.model.modulated_forward(coord, batch_latent_codes[i])
                        loss = loss_fn(pred, value)
                        loss.backward()
                        total_loss += loss.item()
                    inner_optimizer.step()

                # Outer Loop: Update shared parameters
                self.optimizer.zero_grad()
                batch_loss = 0
                for i, (coord, value) in enumerate(zip(coords, values)):
                    pred = self.model.modulated_forward(coord, batch_latent_codes[i])
                    loss = loss_fn(pred, value)
                    batch_loss += loss
                batch_loss.backward()
                self.optimizer.step()

            print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {batch_loss.item():.4f}")


In [57]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import numpy as np

# Define the dataset
class SineDataset(Dataset):
    def __init__(self, num_functions=100, num_points=50, noise_level=0.1):
        """
        Initialize the sine wave dataset.
        Args:
            num_functions (int): Number of sine wave functions in the dataset.
            num_points (int): Number of points per sine wave.
            noise_level (float): Standard deviation of noise added to the data.
        """
        self.num_functions = num_functions
        self.num_points = num_points
        self.noise_level = noise_level
        self.data = []
        for _ in range(num_functions):
            # Random amplitude and phase for sine waves
            amplitude = np.random.uniform(0.5, 2.0)
            phase = np.random.uniform(0, np.pi)
            x = np.linspace(-np.pi, np.pi, num_points)
            y = amplitude * np.sin(x + phase) + np.random.normal(0, noise_level, num_points)
            self.data.append((x, y))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        x, y = self.data[idx]
        return torch.tensor(x, dtype=torch.float32).unsqueeze(-1), torch.tensor(y, dtype=torch.float32)

# Define the ModulatedSSN model
class ModulatedSSN(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, latent_dim):
        """
        Initialize the modulated SSN model.
        Args:
            input_dim (int): Dimensionality of the input.
            hidden_dim (int): Number of hidden units in the MLP.
            output_dim (int): Dimensionality of the output.
            latent_dim (int): Dimensionality of the latent code.
        """
        super(ModulatedSSN, self).__init__()
        self.latent_to_modulation = nn.Linear(latent_dim, hidden_dim)
        self.mlp = nn.Sequential(
            nn.Linear(input_dim + hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )

    def modulated_forward(self, coords, latent_code):
        """
        Forward pass with modulation.
        Args:
            coords (torch.Tensor): Input coordinates (e.g., x values) of shape [B, D].
            latent_code (torch.Tensor): Latent code for modulation of shape [L].
        """
        modulation = self.latent_to_modulation(latent_code)  # Shape: [hidden_dim]
        modulation = modulation.unsqueeze(0).expand(coords.shape[0], -1)  # Shape: [B, hidden_dim]
        modulated_input = torch.cat([coords, modulation], dim=-1)  # Shape: [B, D + hidden_dim]
        return self.mlp(modulated_input)


# Hyperparameters
input_dim = 1
hidden_dim = 128
output_dim = 1
latent_dim = 16
num_epochs = 100
batch_size = 16
inner_steps = 5
inner_lr = 1e-2
outer_lr = 1e-3

# Create the dataset and model
dataset = SineDataset(num_functions=100, num_points=50)
model = ModulatedSSN(input_dim, hidden_dim, output_dim, latent_dim)

# Create the trainer
trainer = CORALTrainer(model=model, latent_dim=latent_dim, learning_rate=outer_lr)

# Train the model
trainer.train(
    dataset=dataset,
    num_epochs=num_epochs,
    batch_size=batch_size,
    inner_steps=inner_steps,
    inner_lr=inner_lr
)


  return F.mse_loss(input, target, reduction=self.reduction)


Epoch 1/100, Loss: 2.9854
Epoch 2/100, Loss: 4.2023
Epoch 3/100, Loss: 4.7112
Epoch 4/100, Loss: 3.9826
Epoch 5/100, Loss: 2.8686
Epoch 6/100, Loss: 1.7769
Epoch 7/100, Loss: 2.2375
Epoch 8/100, Loss: 3.6624
Epoch 9/100, Loss: 4.1444
Epoch 10/100, Loss: 4.4604
Epoch 11/100, Loss: 4.4054
Epoch 12/100, Loss: 4.0973
Epoch 13/100, Loss: 4.6462
Epoch 14/100, Loss: 4.4411
Epoch 15/100, Loss: 4.2909
Epoch 16/100, Loss: 1.4430
Epoch 17/100, Loss: 5.3307
Epoch 18/100, Loss: 2.2648
Epoch 19/100, Loss: 5.7765
Epoch 20/100, Loss: 4.2309
Epoch 21/100, Loss: 6.2891
Epoch 22/100, Loss: 3.7459
Epoch 23/100, Loss: 3.9010
Epoch 24/100, Loss: 2.9079
Epoch 25/100, Loss: 1.5306
Epoch 26/100, Loss: 3.1234
Epoch 27/100, Loss: 2.7318
Epoch 28/100, Loss: 6.1208
Epoch 29/100, Loss: 3.7955
Epoch 30/100, Loss: 5.9055
Epoch 31/100, Loss: 3.2382
Epoch 32/100, Loss: 3.8503
Epoch 33/100, Loss: 3.1335
Epoch 34/100, Loss: 3.3741
Epoch 35/100, Loss: 3.2778
Epoch 36/100, Loss: 3.5445
Epoch 37/100, Loss: 3.5104
Epoch 38/1

In [59]:
trainer

<__main__.CORALTrainer at 0x20bd7ab7790>