In [1]:
import torch
from torch import nn

In [5]:
class PeriodicPadding2D(nn.Module):
    def __init__(self, pad_width, **kwargs):
        super().__init__(**kwargs)
        self.pad_width = pad_width

    def forward(self, inputs, **kwargs):
        if self.pad_width == 0:
            return inputs
        inputs_padded = torch.cat(
            (
                inputs[:, :, :, -self.pad_width :],
                inputs,
                inputs[:, :, :, : self.pad_width],
            ),
            dim=-1,
        )
        # Zero padding in the lat direction
        inputs_padded = nn.functional.pad(
            inputs_padded, (0, 0, self.pad_width, self.pad_width)
        )
        return inputs_padded


class PeriodicConv2D(nn.Module):
    def __init__(
        self, in_channels, out_channels, kernel_size, stride=1, padding=0, **kwargs
    ):
        super().__init__(**kwargs)
        self.padding = PeriodicPadding2D(padding)
        self.conv = nn.Conv2d(
            in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=0
        )

    def forward(self, inputs):
        return self.conv(self.padding(inputs))

class ResidualBlock(nn.Module):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        activation: str = "leaky",
        norm: bool = False,
        dropout: float = 0.1,
        n_groups: int = 1,
    ):
        super().__init__()
        if activation == "gelu":
            self.activation = nn.GELU()
        elif activation == "relu":
            self.activation = nn.ReLU()
        elif activation == "silu":
            self.activation = nn.SiLU()
        elif activation == "leaky":
            self.activation = nn.LeakyReLU(0.3)
        else:
            raise NotImplementedError(f"Activation {activation} not implemented")

        self.conv1 = PeriodicConv2D(in_channels, out_channels, kernel_size=3, padding=1)
        self.conv2 = PeriodicConv2D(
            out_channels, out_channels, kernel_size=3, padding=1
        )
        # If the number of input channels is not equal to the number of output channels we have to
        # project the shortcut connection
        if in_channels != out_channels:
            self.shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=(1, 1))
        else:
            self.shortcut = nn.Identity()

        if norm:
            self.norm1 = nn.BatchNorm2d(out_channels)
            self.norm2 = nn.BatchNorm2d(out_channels)
        else:
            self.norm1 = nn.Identity()
            self.norm2 = nn.Identity()

        self.drop = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor):
        # First convolution layer
        # h = self.drop(self.conv1(self.activation(self.norm1(x))))
        h = self.drop(self.norm1(self.activation(self.conv1(x))))
        # Second convolution layer
        # h = self.drop(self.conv2(self.activation(self.norm2(h))))
        h = self.drop(self.norm2(self.activation(self.conv2(h))))
        # Add the shortcut connection and return
        return h + self.shortcut(x)

In [15]:
in_channels = 3
out_channels = 3
history=1
hidden_channels=128
activation="leaky"
norm = True
dropout= 0.1
n_blocks= 4


In [28]:
image_proj = PeriodicConv2D(
    in_channels, hidden_channels, kernel_size=7, padding=3
)
blocks = nn.ModuleList(
    [
        ResidualBlock(
            hidden_channels,
            hidden_channels,
            activation=activation,
            norm=True,
            dropout=dropout,
        )
        for _ in range(n_blocks)
    ]
)

if norm:
    norm = nn.BatchNorm2d(hidden_channels)
else:
    norm = nn.Identity()
final = PeriodicConv2D(
    hidden_channels, out_channels, kernel_size=7, padding=3
)

In [20]:
x = torch.randn(1, 3, 64, 64)

In [21]:
if len(x.shape) == 5:  # x.shape = [B,T,C,H,W]
    x = x.flatten(1, 2)


In [22]:
x = image_proj(x)

In [23]:
x.shape

torch.Size([1, 128, 64, 64])

In [24]:
for block in blocks:
    x = block(x)
    print(x.shape)

torch.Size([1, 128, 64, 64])
torch.Size([1, 128, 64, 64])


In [29]:
yhat = final(nn.ReLU()(norm(x)))


In [31]:
yhat.shape

torch.Size([1, 3, 64, 64])