In [3]:
import torch
from torch import nn


In [48]:
x = torch.tensor([1, 2, 3])
y = torch.tensor([4,5])
z= torch.tensor([6])

In [54]:
torch.stack(torch.meshgrid([x]), dim=-1).shape

torch.Size([3, 1])

In [27]:
def build_nD_grid(x_context, x_target, points_per_unit, grid_multiplier, num_dims):
    n_out = x_target.shape[1]

    x_mins = []
    x_maxs = []
    num_points = []
    x_grids = []
    for i in range(num_dims):
        d = i + 2 
        # Determine the grid on which to evaluate functional representation.
        x_min = min(torch.min(x_context[d]).cpu().numpy(),
                    torch.min(x_target[d]).cpu().numpy(), -2.) - 0.1
        x_max = max(torch.max(x_context[d]).cpu().numpy(),
                    torch.max(x_target[d]).cpu().numpy(), 2.) + 0.1
        n = int(to_multiple(points_per_unit * (x_max - x_min),
                                    grid_multiplier))
        # update the lists
        x_mins.append(x_min)
        x_maxs.append(x_max)
        num_points.append(n)
        
        # compute the x_grid
        x_grids.append(torch.linspace(x_min, x_max, num_points).to(x_context.device))

    x_grid = torch.cartesian_prod(*x_grids)
    x_grid = x_grid[None, :, :].repeat(x_context.shape[0], 1, 1)
    
    
    return x_grid, num_points

In [33]:
c = torch.tensor([[1., 2.],
                [-1, 1],
                [0,1]])

In [34]:
c.shape

torch.Size([3, 2])

In [37]:
torch.linalg.norm(c, dim = 1)

tensor([2.2361, 1.4142, 1.0000])

In [39]:
a = torch.rand(2,3,4,5)

In [40]:
a.shape

torch.Size([2, 3, 4, 5])

In [43]:
torch.linalg.norm(a, dim=-1).shape

torch.Size([2, 3, 4])

In [47]:
b = torch.tensor([1,2,3,4])
b_dims = [2,2]
b.view(*b_dims)

tensor([[1, 2],
        [3, 4]])

In [3]:
x = torch.randn(2, 3, 5)

In [5]:
order = [2,0,1]
x.permute(order).shape

torch.Size([5, 2, 3])

In [6]:
x.shape

torch.Size([2, 3, 5])

In [17]:
class DepthwiseSeparableConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, num_dims):
        if num_dims == 1:
            convf = nn.Conv1d
        elif num_dims == 2:
            convf = nn.Conv2d
        elif num_dims == 3:
            convf = nn.Conv3d
        else:
            raise ValueError('Number of dimensions > 3 not supported')

        assert kernel_size % 2 == 1
        kernel_size = [kernel_size] * num_dims
        padding = [k // 2 for k in kernel_size]

        super(DepthwiseSeparableConv, self).__init__()
        self.depthwise = convf(in_channels, 
                               in_channels, 
                               kernel_size=kernel_size, 
                               padding=padding, 
                               groups=in_channels)

        self.pointwise = convf(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        out = self.depthwise(x)
        out = self.pointwise(out)
        return out

In [21]:
conv_layer_1d = DepthwiseSeparableConv(3, 5, 5, 1)
conv_layer_2d = DepthwiseSeparableConv(3, 5, 5, 2)
conv_layer_3d = DepthwiseSeparableConv(3, 5, 5, 3)


In [24]:
x = torch.randn(100, 3, 64, 64, 32)

In [26]:
conv_layer_3d(x).shape

torch.Size([100, 5, 64, 64, 32])

In [38]:
class StandardDepthwiseSeparableCNN(nn.Module):
    def __init__(self, in_channels, out_channels, num_dims):
        # Default parameters
        latent_channels = 32
        kernel_size = 5
        num_layers = 12
        super().__init__()
        
        layers = [DepthwiseSeparableConv(in_channels, latent_channels, kernel_size, num_dims), nn.ReLU()]
        for i in range(num_layers - 2):
            layers.append(DepthwiseSeparableConv(latent_channels, latent_channels, kernel_size, num_dims))
            layers.append(nn.ReLU()) 
        layers.append(DepthwiseSeparableConv(latent_channels, out_channels, kernel_size, num_dims))

        self.conv_net = nn.Sequential(*layers)
        # init_sequential_weights(self.conv_net)

    def forward(self, x):
        return self.conv_net(x)

In [42]:
conv_net_1d = StandardDepthwiseSeparableCNN(3, 5, 1)
conv_net_2d = StandardDepthwiseSeparableCNN(3, 5, 2)
conv_net_3d = StandardDepthwiseSeparableCNN(3, 5, 3)

In [43]:
x = torch.randn(100, 3, 64)

In [44]:
conv_net_1d(x).shape

torch.Size([100, 5, 64])

In [46]:
x = torch.randn(100, 3, 64, 32)
conv_net_2d(x).shape

torch.Size([100, 5, 64, 32])

In [47]:
x = torch.randn(100, 3, 64, 32, 8)
conv_net_3d(x).shape

torch.Size([100, 5, 64, 32, 8])

In [60]:
x_context = torch.randn(100, 2, 3)
a, b, c = x_context.shape
x_context = x_context.view(a, b, *([1] * c), c)


In [58]:
[:, :, (None,) * x_context.shape[-1], :]

SyntaxError: invalid syntax (<ipython-input-58-162e408123ff>, line 1)

In [61]:
x_context.shape

torch.Size([100, 2, 1, 1, 1, 3])

In [62]:
nn_f = nn.Linear(4, 5)

In [63]:
x = torch.randn(100, 3, 64, 32, 4)


In [65]:
nn_f(x).shape

torch.Size([100, 3, 64, 32, 5])

In [70]:
def move_channel_idx(x, to_last, num_dims):
    if to_last:
        perm_idx = [0] + [i + 2 for i in range(num_dims)] + [1]
    else:
        perm_idx = [0, num_dims + 1] + [i + 1 for i in range(num_dims)]
    
    return x.permute(perm_idx)

In [73]:
x = torch.randn(100, 3, 32, 32, 32)
y = move_channel_idx(x, True, 3)
move_channel_idx(y, False, 3).shape

[0, 2, 3, 4, 1]
[0, 4, 1, 2, 3]


torch.Size([100, 3, 32, 32, 32])