In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint
from torch.jit import Final
from timm.layers import (
    Mlp,
    Format,
    PatchDropout,
    LayerNorm2d,
    RotaryEmbeddingCat,
    to_2tuple,
    nchw_to
)

# Standard library
from functools import partial
from typing import Callable, Optional, Tuple

In [2]:
""" Position Embedding Utilities

Hacked together by / Copyright 2022 Ross Wightman
"""
import logging
import math
from typing import List, Tuple, Optional, Union


_logger = logging.getLogger(__name__)


def resample_abs_pos_embed(
        posemb,
        new_size: List[int],
        old_size: Optional[List[int]] = None,
        num_prefix_tokens: int = 1,
        interpolation: str = 'bicubic',
        antialias: bool = True,
        verbose: bool = False,
):
    # sort out sizes, assume square if old size not provided
    num_pos_tokens = posemb.shape[1]
    num_new_tokens = new_size[0] * new_size[1] + num_prefix_tokens
    if num_new_tokens == num_pos_tokens and new_size[0] == new_size[1]:
        return posemb

    if old_size is None:
        hw = int(math.sqrt(num_pos_tokens - num_prefix_tokens))
        old_size = hw, hw

    if num_prefix_tokens:
        posemb_prefix, posemb = posemb[:, :num_prefix_tokens], posemb[:, num_prefix_tokens:]
    else:
        posemb_prefix, posemb = None, posemb

    # do the interpolation
    embed_dim = posemb.shape[-1]
    orig_dtype = posemb.dtype
    posemb = posemb.float()  # interpolate needs float32
    posemb = posemb.reshape(1, old_size[0], old_size[1], -1).permute(0, 3, 1, 2)
    posemb = F.interpolate(posemb, size=new_size, mode=interpolation, antialias=antialias)
    posemb = posemb.permute(0, 2, 3, 1).reshape(1, -1, embed_dim)
    posemb = posemb.to(orig_dtype)

    # add back extra (class, etc) prefix tokens
    if posemb_prefix is not None:
        posemb = torch.cat([posemb_prefix, posemb], dim=1)

    if not torch.jit.is_scripting() and verbose:
        _logger.info(f'Resized position embedding: {old_size} to {new_size}.')

    return posemb


In [3]:
try:
    from torch import _assert
except ImportError:
    def _assert(condition: bool, message: str):
        assert condition, message

In [4]:
from typing import Optional, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint
from torch.jit import Final

from timm.layers import (
    Mlp,
    DropPath,
    apply_rot_embed_cat,
    use_fused_attn,
)

class PatchEmbed(nn.Module):
    """ 2D Image to Patch Embedding
    """
    output_fmt: Format
    dynamic_img_pad: torch.jit.Final[bool]

    def __init__(
            self,
            img_size: Optional[int] = 224,
            patch_size: int = 16,
            in_chans: int = 3,
            embed_dim: int = 768,
            norm_layer: Optional[Callable] = None,
            flatten: bool = True,
            output_fmt: Optional[str] = None,
            bias: bool = True,
            strict_img_size: bool = True,
            dynamic_img_pad: bool = False,
    ):
        super().__init__()
        self.patch_size = to_2tuple(patch_size)
        if img_size is not None:
            self.img_size = to_2tuple(img_size)
            self.grid_size = tuple([s // p for s, p in zip(self.img_size, self.patch_size)])
            self.num_patches = self.grid_size[0] * self.grid_size[1]
        else:
            self.img_size = None
            self.grid_size = None
            self.num_patches = None

        if output_fmt is not None:
            self.flatten = False
            self.output_fmt = Format(output_fmt)
        else:
            # flatten spatial dim and transpose to channels last, kept for bwd compat
            self.flatten = flatten
            self.output_fmt = Format.NCHW
        self.strict_img_size = strict_img_size
        self.dynamic_img_pad = dynamic_img_pad

        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
        self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()

    def forward(self, x):
        B, C, H, W = x.shape
        if self.img_size is not None:
            if self.strict_img_size:
                _assert(H == self.img_size[0], f"Input height ({H}) doesn't match model ({self.img_size[0]}).")
                _assert(W == self.img_size[1], f"Input width ({W}) doesn't match model ({self.img_size[1]}).")
            elif not self.dynamic_img_pad:
                _assert(
                    H % self.patch_size[0] == 0,
                    f"Input height ({H}) should be divisible by patch size ({self.patch_size[0]})."
                )
                _assert(
                    W % self.patch_size[1] == 0,
                    f"Input width ({W}) should be divisible by patch size ({self.patch_size[1]})."
                )
        if self.dynamic_img_pad:
            pad_h = (self.patch_size[0] - H % self.patch_size[0]) % self.patch_size[0]
            pad_w = (self.patch_size[1] - W % self.patch_size[1]) % self.patch_size[1]
            x = F.pad(x, (0, pad_w, 0, pad_h))
        x = self.proj(x)
        if self.flatten:
            x = x.flatten(2).transpose(1, 2)  # NCHW -> NLC
        elif self.output_fmt != Format.NCHW:
            x = nchw_to(x, self.output_fmt)
        x = self.norm(x)
        return x

def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
    """
    Get relative positional embeddings according to the relative positions of
        query and key sizes.
    Args:
        q_size (int): size of query q.
        k_size (int): size of key k.
        rel_pos (Tensor): relative position embeddings (L, C).

    Returns:
        Extracted positional embeddings according to relative positions.
    """
    max_rel_dist = int(2 * max(q_size, k_size) - 1)
    # Interpolate rel pos if needed.
    if rel_pos.shape[0] != max_rel_dist:
        # Interpolate rel pos.
        rel_pos_resized = F.interpolate(
            rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
            size=max_rel_dist,
            mode="linear",
        )
        rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
    else:
        rel_pos_resized = rel_pos

    # Scale the coords with short length if shapes for q and k are different.
    q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
    k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
    relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)

    return rel_pos_resized[relative_coords.long()]


def get_decomposed_rel_pos_bias(
    q: torch.Tensor,
    rel_pos_h: torch.Tensor,
    rel_pos_w: torch.Tensor,
    q_size: Tuple[int, int],
    k_size: Tuple[int, int],
) -> torch.Tensor:
    """
    Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
    https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py
    Args:
        q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
        rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
        rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
        q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
        k_size (Tuple): spatial sequence size of key k with (k_h, k_w).

    Returns:
        bias (Tensor): attention bias to add to attention map
    """
    q_h, q_w = q_size
    k_h, k_w = k_size
    Rh = get_rel_pos(q_h, k_h, rel_pos_h)
    Rw = get_rel_pos(q_w, k_w, rel_pos_w)

    B, _, dim = q.shape
    r_q = q.reshape(B, q_h, q_w, dim)
    rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
    rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)

    attn_bias = rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
    return attn_bias.reshape(-1, q_h * q_w, k_h * k_w)


class Attention(nn.Module):
    fused_attn: Final[bool]

    def __init__(
        self,
        dim,
        num_heads=8,
        qkv_bias=True,
        qk_norm=False,
        attn_drop=0.0,
        proj_drop=0.0,
        norm_layer=nn.LayerNorm,
        use_rel_pos: bool = False,
        input_size: Optional[Tuple[int, int]] = None,
        rope: Optional[nn.Module] = None,
    ):
        super().__init__()
        assert dim % num_heads == 0, "dim should be divisible by num_heads"
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = self.head_dim**-0.5
        self.fused_attn = use_fused_attn()

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
        self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)
        self.use_rel_pos = use_rel_pos
        if self.use_rel_pos:
            assert rope is None
            assert (
                input_size is not None
            ), "Input size must be provided if using relative positional encoding."
            # initialize relative positional embeddings
            self.rel_pos_h = nn.Parameter(
                torch.zeros(2 * input_size[0] - 1, self.head_dim)
            )
            self.rel_pos_w = nn.Parameter(
                torch.zeros(2 * input_size[1] - 1, self.head_dim)
            )
        self.rope = rope

    def forward(self, x):
        B, H, W, _ = x.shape
        N = H * W
        x = x.reshape(B, N, -1)
        qkv = self.qkv(x).view(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
        # qkv with shape (3, B, nHead, H * W, C)
        q, k, v = qkv.reshape(3, B * self.num_heads, N, -1).unbind(0)
        # q, k, v with shape (B * nHead, H * W, C)
        q, k = self.q_norm(q), self.k_norm(k)

        if self.use_rel_pos:
            attn_bias = get_decomposed_rel_pos_bias(
                q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W)
            )
        else:
            attn_bias = None
            if self.rope is not None:
                rope = self.rope.get_embed()
                q = apply_rot_embed_cat(q, rope).type_as(v)
                k = apply_rot_embed_cat(k, rope).type_as(v)

        if self.fused_attn:
            x = torch.nn.functional.scaled_dot_product_attention(
                q,
                k,
                v,
                attn_mask=attn_bias,
                dropout_p=self.attn_drop.p if self.training else 0.0,
            )
        else:
            q = q * self.scale
            attn = q @ k.transpose(-2, -1)
            if attn_bias is not None:
                attn = attn + attn_bias
            attn = attn.softmax(dim=-1)
            attn = self.attn_drop(attn)
            x = attn @ v

        x = x.view(B, self.num_heads, N, -1).transpose(1, 2).reshape(B, N, -1)
        x = self.proj(x)
        x = x.view(B, H, W, -1)
        return x


class LayerScale(nn.Module):
    def __init__(self, dim, init_values=1e-5, inplace=False):
        super().__init__()
        self.inplace = inplace
        self.gamma = nn.Parameter(init_values * torch.ones(dim))

    def forward(self, x):
        return x.mul_(self.gamma) if self.inplace else x * self.gamma


class Block(nn.Module):

    def __init__(
        self,
        dim,
        num_heads,
        mlp_ratio=4.0,
        qkv_bias=True,
        qk_norm=False,
        proj_drop=0.0,
        attn_drop=0.0,
        init_values=None,
        drop_path=0.0,
        act_layer=nn.GELU,
        norm_layer=nn.LayerNorm,
        mlp_layer=Mlp,
        use_rel_pos=False,
        window_size=0,
        input_size=None,
        rope=None,
    ):
        super().__init__()
        self.window_size = window_size
        self.norm1 = norm_layer(dim)
        self.attn = Attention(
            dim,
            num_heads=num_heads,
            qkv_bias=qkv_bias,
            qk_norm=qk_norm,
            attn_drop=attn_drop,
            proj_drop=proj_drop,
            norm_layer=norm_layer,
            use_rel_pos=use_rel_pos,
            input_size=input_size if window_size == 0 else (window_size, window_size),
            rope=rope,
        )
        self.ls1 = (
            LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
        )
        self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()

        self.norm2 = norm_layer(dim)
        self.mlp = mlp_layer(
            in_features=dim,
            hidden_features=int(dim * mlp_ratio),
            act_layer=act_layer,
            drop=proj_drop,
        )
        self.ls2 = (
            LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
        )
        self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()

    def forward(self, x):
        B, H, W, _ = x.shape

        shortcut = x
        x = self.norm1(x)
        # Window partition
        pad_hw: Optional[Tuple[int, int]] = None
        if self.window_size > 0:
            x, pad_hw = window_partition(x, self.window_size)

        x = self.drop_path1(self.ls1(self.attn(x)))

        # Reverse window partition
        if self.window_size > 0:
            x = window_unpartition(x, self.window_size, (H, W), pad_hw)

        x = shortcut + x

        x = x.reshape(B, H * W, -1)  # MLP is faster for N, L, C tensor
        x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
        x = x.reshape(B, H, W, -1)

        return x


def window_partition(
    x: torch.Tensor, window_size: int
) -> Tuple[torch.Tensor, Tuple[int, int]]:
    """
    Partition into non-overlapping windows with padding if needed.
    Args:
        x (tensor): input tokens with [B, H, W, C].
        window_size (int): window size.

    Returns:
        windows: windows after partition with [B * num_windows, window_size, window_size, C].
        (Hp, Wp): padded height and width before partition
    """
    B, H, W, C = x.shape

    pad_h = (window_size - H % window_size) % window_size
    pad_w = (window_size - W % window_size) % window_size
    x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
    Hp, Wp = H + pad_h, W + pad_w

    x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
    windows = (
        x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
    )
    return windows, (Hp, Wp)


def window_unpartition(
    windows: torch.Tensor,
    window_size: int,
    hw: Tuple[int, int],
    pad_hw: Optional[Tuple[int, int]] = None,
) -> torch.Tensor:
    """
    Window unpartition into original sequences and removing padding.
    Args:
        windows (tensor): input tokens with [B * num_windows, window_size, window_size, C].
        window_size (int): window size.
        pad_hw (Tuple): padded height and width (Hp, Wp).
        hw (Tuple): original height and width (H, W) before padding.

    Returns:
        x: unpartitioned sequences with [B, H, W, C].
    """
    Hp, Wp = pad_hw if pad_hw is not None else hw
    H, W = hw
    B = windows.shape[0] // (Hp * Wp // window_size // window_size)
    x = windows.view(
        B, Hp // window_size, Wp // window_size, window_size, window_size, -1
    )
    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
    x = x[:, :H, :W, :].contiguous()
    return x


def checkpoint_seq(
        functions,
        x,
        every=1,
        flatten=False,
        skip_last=False,
        preserve_rng_state=True
):
    r"""A helper function for checkpointing sequential models.

    Sequential models execute a list of modules/functions in order
    (sequentially). Therefore, we can divide such a sequence into segments
    and checkpoint each segment. All segments except run in :func:`torch.no_grad`
    manner, i.e., not storing the intermediate activations. The inputs of each
    checkpointed segment will be saved for re-running the segment in the backward pass.

    See :func:`~torch.utils.checkpoint.checkpoint` on how checkpointing works.

    .. warning::
        Checkpointing currently only supports :func:`torch.autograd.backward`
        and only if its `inputs` argument is not passed. :func:`torch.autograd.grad`
        is not supported.

    .. warning:
        At least one of the inputs needs to have :code:`requires_grad=True` if
        grads are needed for model inputs, otherwise the checkpointed part of the
        model won't have gradients.

    Args:
        functions: A :class:`torch.nn.Sequential` or the list of modules or functions to run sequentially.
        x: A Tensor that is input to :attr:`functions`
        every: checkpoint every-n functions (default: 1)
        flatten (bool): flatten nn.Sequential of nn.Sequentials
        skip_last (bool): skip checkpointing the last function in the sequence if True
        preserve_rng_state (bool, optional, default=True):  Omit stashing and restoring
            the RNG state during each checkpoint.

    Returns:
        Output of running :attr:`functions` sequentially on :attr:`*inputs`

    Example:
        >>> model = nn.Sequential(...)
        >>> input_var = checkpoint_seq(model, input_var, every=2)
    """
    def run_function(start, end, functions):
        def forward(_x):
            for j in range(start, end + 1):
                _x = functions[j](_x)
            return _x
        return forward

    if isinstance(functions, torch.nn.Sequential):
        functions = functions.children()
    if flatten:
        functions = chain.from_iterable(functions)
    if not isinstance(functions, (tuple, list)):
        functions = tuple(functions)

    num_checkpointed = len(functions)
    if skip_last:
        num_checkpointed -= 1
    end = -1
    for start in range(0, num_checkpointed, every):
        end = min(start + every - 1, num_checkpointed - 1)
        x = checkpoint(run_function(start, end, functions), x, preserve_rng_state=preserve_rng_state)
    if skip_last:
        return run_function(end + 1, len(functions) - 1, functions)(x)
    return x

In [6]:
img_size = (64, 128)
in_channels = 3
out_channels = 3
history=1
patch_size=2
drop_path=0.1
drop_rate=0.1
embed_dim=128
depth=4
decoder_depth=1
num_heads=4
mlp_ratio=4
pre_norm = False


qkv_bias: bool = True
qk_norm: bool = False
init_values: Optional[float] = None
pos_drop_rate: float = 0.1
patch_drop_rate: float = 0.0
proj_drop_rate: float = 0.1
attn_drop_rate: float = 0.1
drop_path_rate: float = 0.1
embed_layer:Callable = partial(
                PatchEmbed, output_fmt=Format.NHWC, strict_img_size=False)
norm_layer: Optional[Callable] = nn.LayerNorm
act_layer: Optional[Callable] = nn.GELU
block_fn: Callable = Block
mlp_layer: Callable = Mlp
use_abs_pos: bool = True
use_rel_pos: bool = False
use_rope: bool = False
window_size: int = 14
global_attn_indexes: Tuple[int, ...] = ()
neck_chans: int = 0
ref_feat_shape: Optional[Tuple[Tuple[int, int], Tuple[int, int]]] = None

In [7]:
img_size = img_size
in_channels = in_channels * history
out_channels = out_channels
patch_size = patch_size
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
act_layer = act_layer or nn.GELU

# num_features for consistency with other models
num_features = embed_dim = embed_dim
grad_checkpointing = False

patch_embed = embed_layer(
    img_size=img_size,
    patch_size=patch_size,
    in_chans=in_channels,
    embed_dim=embed_dim,
    bias=not pre_norm,  # disable bias if pre-norm is used
)
grid_size = patch_embed.grid_size

if use_abs_pos:
    # Initialize absolute positional embedding with pretrain image size.
    pos_embed = nn.Parameter(torch.zeros(1, grid_size[0], grid_size[1], embed_dim))
else:
    pos_embed = None
pos_drop = nn.Dropout(p=pos_drop_rate)
if patch_drop_rate > 0:
    patch_drop = PatchDropout(
        patch_drop_rate,
        num_prefix_tokens=0,
    )
else:
    patch_drop = nn.Identity()
norm_pre = norm_layer(embed_dim) if pre_norm else nn.Identity()

if use_rope:
    assert (
        not use_rel_pos
    ), "ROPE and relative pos embeddings should not be enabled at same time"
    if ref_feat_shape is not None:
        assert len(ref_feat_shape) == 2
        ref_feat_shape_global = to_2tuple(ref_feat_shape[0])
        ref_feat_shape_window = to_2tuple(ref_feat_shape[1])
    else:
        ref_feat_shape_global = ref_feat_shape_window = None
    rope_global = RotaryEmbeddingCat(
        embed_dim // num_heads,
        in_pixels=False,
        feat_shape=grid_size,
        ref_feat_shape=ref_feat_shape_global,
    )
    rope_window = RotaryEmbeddingCat(
        embed_dim // num_heads,
        in_pixels=False,
        feat_shape=to_2tuple(window_size),
        ref_feat_shape=ref_feat_shape_window,
    )
else:
    rope_global = None
    rope_window = None

# stochastic depth decay rule
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
blocks = nn.Sequential(
    *[
        block_fn(
            dim=embed_dim,
            num_heads=num_heads,
            mlp_ratio=mlp_ratio,
            qkv_bias=qkv_bias,
            qk_norm=qk_norm,
            init_values=init_values,
            proj_drop=proj_drop_rate,
            attn_drop=attn_drop_rate,
            drop_path=dpr[i],
            norm_layer=norm_layer,
            act_layer=act_layer,
            mlp_layer=mlp_layer,
            use_rel_pos=use_rel_pos,
            window_size=window_size if i not in global_attn_indexes else 0,
            input_size=grid_size,
            rope=(
                rope_window
                if i not in global_attn_indexes
                else rope_global
            ),
        )
        for i in range(depth)
    ]
)
if neck_chans:
    neck = nn.Sequential(
        nn.Conv2d(
            embed_dim,
            neck_chans,
            kernel_size=1,
            bias=False,
        ),
        LayerNorm2d(neck_chans),
        nn.Conv2d(
            neck_chans,
            neck_chans,
            kernel_size=3,
            padding=1,
            bias=False,
        ),
        LayerNorm2d(neck_chans),
    )
    num_features = neck_chans
else:
    neck = LayerNorm2d(embed_dim)
    neck_chans = embed_dim

head = nn.ModuleList()
for _ in range(decoder_depth):
    head.append(nn.Linear(neck_chans, neck_chans))
    head.append(nn.GELU())
head.append(nn.Linear(neck_chans, out_channels * patch_size**2))
head = nn.Sequential(*head)

In [8]:
def unpatchify(x: torch.Tensor):
        """
        x: (B, Hp, Wp, V * patch_size**2) [1, 32, 64, 12]
        return imgs: (B, V, H, W)
        """
        p = patch_size
        c = out_channels
        h = img_size[0] // p
        w = img_size[1] // p
        assert h * w == x.shape[1] * x.shape[2]
        x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
        x = torch.einsum("nhwpqc->nchpwq", x)
        imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p))
        return imgs

def forward_encoder(self, x: torch.Tensor):
    # x.shape = [B,C,H,W]
    x = self.patch_embed(x)
    # x.shape = [B,num_patches,embed_dim]
    x = x + pos_embed
    x = self.pos_drop(x)
    x = self.patch_drop(x)
    x = self.norm_pre(x)
    if self.grad_checkpointing and not torch.jit.is_scripting():
        x = checkpoint_seq(self.blocks, x)
    else:
        x = self.blocks(x)
    x = self.neck(x.permute(0, 3, 1, 2))
    # x.shape = [B,num_patches,embed_dim]
    return x

def forward(self, x):
    if len(x.shape) == 5:  # x.shape = [B,T,in_channels,H,W]
        x = x.flatten(1, 2)
    # x.shape = [B,T*in_channels,H,W]
    x = self.forward_encoder(x)
    # x.shape = [B,num_patches,embed_dim]
    x = self.head(x)
    # x.shape = [B,num_patches,embed_dim]
    preds = self.unpatchify(x)
    # preds.shape = [B,out_channels,H,W]
    return preds

In [9]:
x = torch.randn(1, 3, 64, 128)

In [10]:
if len(x.shape) == 5:  # x.shape = [B,T,C,H,W]
    x = x.flatten(1, 2)

In [11]:
x.shape

torch.Size([1, 3, 64, 128])

In [12]:
x = patch_embed(x)
x.shape

torch.Size([1, 32, 64, 128])

In [13]:
pos_embed.shape

torch.Size([1, 32, 64, 128])

In [14]:
x = x + pos_embed

In [15]:
x = pos_drop(x)
print(x.shape)
x = patch_drop(x)
print(x.shape)
x = norm_pre(x)
print(x.shape)

torch.Size([1, 32, 64, 128])
torch.Size([1, 32, 64, 128])
torch.Size([1, 32, 64, 128])


In [16]:
grid_size

(32, 64)

In [17]:
x = blocks(x)
print(x.shape)

torch.Size([1, 32, 64, 128])


In [18]:
x = neck(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)

In [19]:
x.shape

torch.Size([1, 32, 64, 128])

In [20]:
head

Sequential(
  (0): Linear(in_features=128, out_features=128, bias=True)
  (1): GELU(approximate='none')
  (2): Linear(in_features=128, out_features=12, bias=True)
)

In [21]:
x = head(x)
x.shape

torch.Size([1, 32, 64, 12])

In [22]:
preds = unpatchify(x)
print(preds.shape)

torch.Size([1, 3, 64, 128])
