In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops.layers.torch import Rearrange
from torch import Tensor

In [2]:
class ConvLayer(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride):
        super(ConvLayer, self).__init__()
        padding = (kernel_size - stride + 1) // 2
        self.conv = nn.Conv2d(
            in_channels, out_channels, kernel_size=kernel_size,
            stride=stride, padding=padding
        )
        self.norm = nn.GroupNorm(2, out_channels)
        self.act = nn.SiLU(inplace=True)

    def forward(self, x):
        x = self.conv(x)
        x = self.norm(x)
        x = self.act(x)
        return x


class ConvTransposeLayer(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride):
        super(ConvTransposeLayer, self).__init__()
        padding = (kernel_size - stride + 1) // 2
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels * 4, kernel_size=kernel_size,
                      stride=1, padding=padding),
            nn.PixelShuffle(2)
        )
        self.norm = nn.GroupNorm(2, out_channels)
        self.act = nn.SiLU(inplace=True)

    def forward(self, x):
        x = self.conv(x)
        x = self.norm(x)
        x = self.act(x)
        return x


class Conv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0):
        super(Conv, self).__init__()
        # Depth-wise convolution
        self.depthwise = nn.Conv2d(in_channels, in_channels, kernel_size=kernel_size, 
                                   stride=stride, padding=padding, groups=in_channels)
        # Point-wise convolution (1x1)
        self.pointwise = nn.Conv2d(in_channels, out_channels, kernel_size=1)
        
    def forward(self, x):
        x = self.depthwise(x)
        x = self.pointwise(x)
        return x

In [3]:
class SelfAttentionMemory(nn.Module):
    def __init__(self, input_dim: int, hidden_dim: int) -> None:
        super().__init__()

        # attention for hidden layer
        self.query_h = Conv(input_dim, hidden_dim, 1, padding="same")
        self.key_h = Conv(input_dim, hidden_dim, 1, padding="same")
        self.value_h = Conv(input_dim, input_dim, 1, padding="same")
        self.z_h = Conv(input_dim, input_dim, 1, padding="same")

        # attention for memory layer
        self.key_m = Conv(input_dim, hidden_dim, 1, padding="same")
        self.value_m = Conv(input_dim, input_dim, 1, padding="same")
        self.z_m = Conv(input_dim, input_dim, 1, padding="same")

        # weights of concated channels of h Zh and Zm.
        self.w_z = Conv(input_dim * 2, input_dim * 2, 1, padding="same")

        # weights of conated channels of Z and h.
        self.w = Conv(input_dim * 3, input_dim * 3, 1, padding="same")

        self.input_dim = input_dim
        self.hidden_dim = hidden_dim

    def forward(self, h, m):
        """
        Return:
            Tuple(torch.Tensor, torch.Tensor): new Hidden layer and new memory module.
        """
        batch_size, _, H, W = h.shape
        # hidden attention
        k_h = self.key_h(h)
        q_h = self.query_h(h)
        v_h = self.value_h(h)

        k_h = k_h.view(batch_size, self.hidden_dim, H * W)
        q_h = q_h.view(batch_size, self.hidden_dim, H * W).transpose(1, 2)
        v_h = v_h.view(batch_size, self.input_dim, H * W)

        attention_h = torch.softmax(torch.bmm(q_h, k_h), dim=-1)  # The shape is (batch_size, H*W, H*W)
        z_h = torch.matmul(attention_h, v_h.permute(0, 2, 1))
        z_h = z_h.transpose(1, 2).view(batch_size, self.input_dim, H, W)
        z_h = self.z_h(z_h)

        # memory attention
        k_m = self.key_m(m)
        v_m = self.value_m(m)

        k_m = k_m.view(batch_size, self.hidden_dim, H * W)
        v_m = v_m.view(batch_size, self.input_dim, H * W)

        attention_m = torch.softmax(torch.bmm(q_h, k_m), dim=-1)
        z_m = torch.matmul(attention_m, v_m.permute(0, 2, 1))
        z_m = z_m.transpose(1, 2).view(batch_size, self.input_dim, H, W)
        z_m = self.z_m(z_m)

        # channel concat of Zh and Zm.
        Z = torch.cat([z_h, z_m], dim=1)
        Z = self.w_z(Z)

        # channel concat of Z and h
        W = torch.cat([Z, h], dim=1)
        W = self.w(W)

        # mi_conv: Wm; zi * Z + Wm; hi * Ht + bm; i
        # mg_conv: Wm; zg * Z + Wm; hg * Ht + bm; g
        # mo_conv: Wm; zo * Z + Wm; ho * Ht + bm; o
        mi_conv, mg_conv, mo_conv = torch.chunk(W, chunks=3, dim=1)
        input_gate = torch.sigmoid(mi_conv)
        g = torch.tanh(mg_conv)
        new_M = (1 - input_gate) * m + input_gate * g
        output_gate = torch.sigmoid(mo_conv)
        new_H = output_gate * new_M

        return new_H, new_M


In [3]:

class PatchEmbedding(nn.Module):
    def __init__(self, in_channels = 5, patch_size = 16, emb_size = 128):
        self.patch_size = patch_size
        super().__init__()
        self.projection = nn.Sequential(
            # break-down the image in s1 x s2 patches and flat them
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=patch_size, p2=patch_size),
            nn.Linear(patch_size * patch_size * in_channels, emb_size)
        )

    def forward(self, x: Tensor) -> Tensor:
        x = self.projection(x)
        return x

class PatchUnembedding(nn.Module):
    def __init__(self, in_channels=5, patch_size=16, emb_size=128, img_size=160):
        super().__init__()
        self.patch_size = patch_size
        self.img_size = img_size
        self.emb_size = emb_size
        self.in_channels = in_channels
        self.reconstruction = nn.Sequential(
            nn.Linear(emb_size, patch_size * patch_size * in_channels),
            Rearrange('b (h w) (p1 p2 c) -> b c (h p1) (w p2)', 
                      p1=patch_size, p2=patch_size, h=img_size // patch_size, w=img_size // patch_size)
        )

    def forward(self, x: Tensor) -> Tensor:
        x = self.reconstruction(x)
        return x

class Attention(nn.Module):
    def __init__(self, dim, n_heads, dropout):
        super().__init__()
        self.n_heads = n_heads
        self.att = torch.nn.MultiheadAttention(embed_dim=dim,
                                               num_heads=n_heads,
                                               dropout=dropout)
        self.q = torch.nn.Linear(dim, dim)
        self.k = torch.nn.Linear(dim, dim)
        self.v = torch.nn.Linear(dim, dim)

    def forward(self, x, q=None):
        if q is None:
            q = self.q(x)
        k = self.k(x)
        v = self.v(x)
        attn_output, attn_output_weights = self.att(q, k, v)
        return attn_output, q


class SelfAttentionMemory(nn.Module):
    def __init__(self, image_size: int, in_channels: int, patch_size: int, emb_size: int) -> None:
        super().__init__()

        self.image_size = image_size
        self.input_dim = in_channels
        self.patch_size = patch_size
        self.emb_size = emb_size
        self.num_patches = (image_size // patch_size) ** 2

        self.patch_embed_h = PatchEmbedding(in_channels, patch_size, emb_size)
        self.patch_embed_m = PatchEmbedding(in_channels, patch_size, emb_size)

        self.pos_embedding_h = nn.Parameter(torch.randn(1, self.num_patches, emb_size))
        self.pos_embedding_m = nn.Parameter(torch.randn(1, self.num_patches, emb_size))

        self.attention_h = Attention(emb_size, 4, 0.1)
        self.attention_m = Attention(emb_size, 4, 0.1)

        self.patch_unembed_h = PatchUnembedding(in_channels, patch_size, emb_size, image_size)
        self.patch_unembed_m = PatchUnembedding(in_channels, patch_size, emb_size, image_size)


        # attention for hidden layer
        self.z_h = Conv(in_channels, in_channels, 1, padding="same")
        self.z_m = Conv(in_channels, in_channels, 1, padding="same")

        # weights of concated channels of h Zh and Zm.
        self.w_z = Conv(in_channels * 2, in_channels * 2, 1, padding="same")

        # weights of conated channels of Z and h.
        self.w = Conv(in_channels * 3, in_channels * 3, 1, padding="same")

    def forward(self, h, m):
        """
        Return:
            Tuple(torch.Tensor, torch.Tensor): new Hidden layer and new memory module.
        """

        z_h = self.patch_embed_h(h)
        z_m = self.patch_embed_m(m)

        z_h = z_h + self.pos_embedding_h
        z_m = z_m + self.pos_embedding_m

        z_h, q_h = self.attention_h(z_h)
        z_m, _ = self.attention_m(z_m, q_h)

        z_h = self.patch_unembed_h(z_h)
        z_m = self.patch_unembed_m(z_m)

        z_h = self.z_h(z_h)
        z_m = self.z_m(z_m)

        # channel concat of Zh and Zm.
        Z = torch.cat([z_h, z_m], dim=1)
        Z = self.w_z(Z)

        # channel concat of Z and h
        W = torch.cat([Z, h], dim=1)
        W = self.w(W)

        # mi_conv: Wm; zi * Z + Wm; hi * Ht + bm; i
        # mg_conv: Wm; zg * Z + Wm; hg * Ht + bm; g
        # mo_conv: Wm; zo * Z + Wm; ho * Ht + bm; o
        mi_conv, mg_conv, mo_conv = torch.chunk(W, chunks=3, dim=1)
        
        input_gate = torch.sigmoid(mi_conv)
        g = torch.tanh(mg_conv)
        new_M = (1 - input_gate) * m + input_gate * g
        
        output_gate = torch.sigmoid(mo_conv)
        new_H = output_gate * new_M

        return new_H, new_M


In [11]:
class Encoder(nn.Module):

    def __init__(self, in_channels, out_channels, kernel_size):
        super(Encoder, self).__init__()
        self.conv1 = ConvLayer(in_channels,  out_channels, kernel_size, stride=1)
        self.conv2 = ConvLayer(out_channels, out_channels, kernel_size, stride=2)
        self.conv3 = ConvLayer(out_channels, out_channels, kernel_size, stride=1)
        self.conv4 = ConvLayer(out_channels, out_channels, kernel_size, stride=2)
    
    def forward(self, x):  # BxT, 5, 160, 160
        enc1 = self.conv1(x)
        x = self.conv2(enc1)
        x = self.conv3(x)
        x = self.conv4(x)
        return x, enc1
    

class Decoder(nn.Module):

    def __init__(self, hid_channels, out_channels, kernel_size):
        super(Decoder, self).__init__()
        self.conv1 = ConvTransposeLayer(hid_channels, hid_channels, kernel_size, stride=1)
        self.conv2 = ConvLayer(hid_channels, hid_channels, kernel_size, stride=1)
        self.conv3 = ConvTransposeLayer(hid_channels, hid_channels, kernel_size, stride=1)
        self.conv4 = ConvLayer(hid_channels, hid_channels, kernel_size, stride=1)
        self.readout = nn.Conv2d(hid_channels, out_channels, kernel_size=1, stride=1)

    def forward(self, x, enc1):  # BxT, 5, 40, 40
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x + enc1)
        x = self.readout(x)
        return x


class ConvLSTMCell(nn.Module):
    def __init__(self, input_channels, hidden_channels, kernel_size, bias=True):
        super(ConvLSTMCell, self).__init__()

        self.input_channels = input_channels
        self.hidden_channels = hidden_channels

        self.kernel_size = kernel_size
        self.padding = kernel_size // 2
        self.bias = bias

        self.W_ci = nn.parameter.Parameter(
            torch.zeros(hidden_channels, 160, 160, dtype=torch.float)
        )
        self.W_co = nn.parameter.Parameter(
            torch.zeros(hidden_channels, 160, 160, dtype=torch.float)
        )
        self.W_cf = nn.parameter.Parameter(
            torch.zeros(hidden_channels, 160, 160, dtype=torch.float)
        )

        self.conv = nn.Conv2d(in_channels=self.input_channels + self.hidden_channels,
                              out_channels=4 * self.hidden_channels,
                              kernel_size=self.kernel_size,
                              padding=self.padding,
                              bias=self.bias)
        
        self.attention_memory = SelfAttentionMemory(image_size=160, in_channels=self.hidden_channels, patch_size=16, emb_size=128)

        self.sigmoid = nn.Sigmoid()
        self.tanh = nn.Tanh()

    def forward(self, x, hidden_state):
        h, c, m = hidden_state

        combined = torch.cat([x, h], dim=1)
        gates = self.conv(combined)
        ingate, forgetgate, cellgate, outgate  = torch.split(gates, self.hidden_channels, dim=1)
        
        ingate     = self.sigmoid(ingate + self.W_ci * c)
        forgetgate = self.sigmoid(forgetgate + self.W_cf * c)
        cellgate   = self.tanh(cellgate)
        # outgate    = self.sigmoid(outgate)

        c = c * forgetgate + ingate * cellgate
        
        outgate = self.sigmoid(outgate + self.W_co * c)
        h = outgate * self.tanh(c)
        h, m = self.attention_memory(h, m)
        return h, c, m


class ConvLSTM(nn.Module):
    def __init__(self, input_channels, hidden_channels, kernel_size):
        super(ConvLSTM, self).__init__()

        self.num_layers = len(hidden_channels)
        self.hidden_channels = hidden_channels

        # Create a list of ConvLSTM cells
        self.layers = nn.ModuleList()
        
        # Add the first layer
        self.layers.append(ConvLSTMCell(input_channels, hidden_channels[0], kernel_size[0]))
        
        # Add subsequent layers
        for i in range(1, self.num_layers):
            self.layers.append(ConvLSTMCell(hidden_channels[i-1], hidden_channels[i], kernel_size[i]))

        # Bottleneck layer
        self.conv  = nn.Conv2d(hidden_channels[-1], 1, kernel_size=1)

    def forward(self, x):
        # Assuming x is a sequence of frames: (batch_size, sequence_length, channels, height, width)
        batch_size, seq_len, _, height, width = x.size()
        
        # Initialize hidden and cell states for each layer
        hidden_states = []
        for i in range(self.num_layers):
            h = torch.zeros(batch_size, self.hidden_channels[i], height, width).to(x.device)
            c = torch.zeros(batch_size, self.hidden_channels[i], height, width).to(x.device)
            m = torch.zeros(batch_size, self.hidden_channels[i], height, width).to(x.device)
            hidden_states.append((h, c, m))

        for t in range(seq_len):
            x_t = x[:, t, :, :, :]
            for layer_idx in range(self.num_layers):
                h, c, m = hidden_states[layer_idx]
                h, c, m = self.layers[layer_idx](x_t, (h, c, m))
                hidden_states[layer_idx] = (h, c, m)
                x_t = h  # Output of the current layer is the input to the next layer

        return self.conv(h).unsqueeze(dim=1)


In [12]:
x = torch.randn((2, 48, 5, 160, 160))
model = ConvLSTM(5, [64, 32, 16], [5, 3, 3])
print(model(x).shape)  # (2, 1, 16, 160, 160)

torch.Size([2, 1, 1, 160, 160])


In [24]:
class TemporalAttention(nn.Module):
    """A Temporal Attention block for Temporal Attention Unit"""

    def __init__(self, d_model, kernel_size=21, attn_shortcut=True):
        super().__init__()

        self.proj_1 = nn.Conv2d(d_model, d_model, 1)         # 1x1 conv
        self.activation = nn.GELU()                          # GELU
        self.spatial_gating_unit = TemporalAttentionModule(d_model, kernel_size)
        self.proj_2 = nn.Conv2d(d_model, d_model, 1)         # 1x1 conv
        self.attn_shortcut = attn_shortcut

    def forward(self, x):
        if self.attn_shortcut:
            shortcut = x.clone()
        x = self.proj_1(x)
        x = self.activation(x)
        x = self.spatial_gating_unit(x)
        x = self.proj_2(x)
        if self.attn_shortcut:
            x = x + shortcut
        return x
    

class TemporalAttentionModule(nn.Module):
    """Large Kernel Attention for SimVP"""

    def __init__(self, dim, kernel_size, dilation=3, reduction=16):
        super().__init__()
        d_k = 2 * dilation - 1
        d_p = (d_k - 1) // 2
        dd_k = kernel_size // dilation + ((kernel_size // dilation) % 2 - 1)
        dd_p = (dilation * (dd_k - 1) // 2)

        self.conv0 = nn.Conv2d(dim, dim, d_k, padding=d_p, groups=dim)
        self.conv_spatial = nn.Conv2d(
            dim, dim, dd_k, stride=1, padding=dd_p, groups=dim, dilation=dilation)
        self.conv1 = nn.Conv2d(dim, dim, 1)

        self.reduction = max(dim // reduction, 4)
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(dim, dim // self.reduction, bias=False), # reduction
            nn.ReLU(True),
            nn.Linear(dim // self.reduction, dim, bias=False), # expansion
            nn.Sigmoid()
        )

    def forward(self, x):
        u = x.clone()
        attn = self.conv0(x)           # depth-wise conv
        attn = self.conv_spatial(attn) # depth-wise dilation convolution
        f_x = self.conv1(attn)         # 1x1 conv
        # append a se operation
        b, c, _, _ = x.size()
        se_atten = self.avg_pool(x).view(b, c)
        se_atten = self.fc(se_atten).view(b, c, 1, 1)
        return se_atten * f_x * u


In [25]:
x = torch.randn((8, 48*16, 40, 40))
model = TemporalAttention(48*16)
print(model(x).shape) 

torch.Size([8, 768, 40, 40])


In [33]:
class EncDecConvLSTM(nn.Module):
    def __init__(self, input_channels, hidden_channels, kernel_size):
        super(EncDecConvLSTM, self).__init__()
        self.encoder = Encoder(input_channels, hidden_channels, kernel_size)
        self.decoder = Decoder(hidden_channels, 1, kernel_size)
        # self.lstm = ConvLSTM(hidden_channels, [hidden_channels, hidden_channels, hidden_channels], [5, 3, 3])
        self.attn = TemporalAttention(hidden_channels*48)
        self.conv = nn.Conv2d(48, 1, kernel_size=1)

    def forward(self, x):
        B, T, C, H, W = x.size()
        x = x.view(B*T, C, H, W)

        z, enc1 = self.encoder(x)

        C_new, H_new, W_new = z.size(1), z.size(2), z.size(3)
        z = z.view(B, T, C_new, H_new, W_new) # 8, 48, 16, 40, 40

        # z = self.lstm(z)
        z = z.view(B, T*C_new, H_new, W_new) 
        z = self.attn(z)

        # T_new = z.size(1)
        z = z.view(B, T, C_new, H_new, W_new) # 8, 48, 16, 40, 40
        z = z.view(B*T, C_new, H_new, W_new)

        out = self.decoder(z, enc1)
        out = out.view(B, T, H, W)
        out = self.conv(out).unsqueeze(2) 
        return out

In [34]:
x = torch.randn((8, 48, 5, 160, 160))
model = EncDecConvLSTM(5, 16, 3)

y = model(x)
print(y.size())


torch.Size([8, 1, 1, 160, 160])
