In [1]:
from typing import Callable, List, Optional, Tuple, Union
import torch

In [2]:
""" Layer/Module Helpers

Hacked together by / Copyright 2020 Ross Wightman
"""
from itertools import repeat
import collections.abc


# From PyTorch internals
def _ntuple(n):
    def parse(x):
        if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
            return tuple(x)
        return tuple(repeat(x, n))
    return parse


to_1tuple = _ntuple(1)
to_2tuple = _ntuple(2)
to_3tuple = _ntuple(3)
to_4tuple = _ntuple(4)
to_ntuple = _ntuple

In [3]:
if hasattr(torch.nn.functional, 'scaled_dot_product_attention'):
    ATTENTION_MODE = 'flash'
else:
    try:
        import xformers
        import xformers.ops
        ATTENTION_MODE = 'xformers'
    except:
        ATTENTION_MODE = 'math'
print(f'attention mode is {ATTENTION_MODE}')

attention mode is math


In [4]:
try:
    from torch import _assert
except ImportError:
    def _assert(condition: bool, message: str):
        assert condition, message


def _float_to_int(x: float) -> int:
    """
    Symbolic tracing helper to substitute for inbuilt `int`.
    Hint: Inbuilt `int` can't accept an argument of type `Proxy`
    """
    return int(x)

In [5]:
from enum import Enum
from typing import Union

import torch


class Format(str, Enum):
    NCHW = 'NCHW'
    NHWC = 'NHWC'
    NCL = 'NCL'
    NLC = 'NLC'


FormatT = Union[str, Format]


def get_spatial_dim(fmt: FormatT):
    fmt = Format(fmt)
    if fmt is Format.NLC:
        dim = (1,)
    elif fmt is Format.NCL:
        dim = (2,)
    elif fmt is Format.NHWC:
        dim = (1, 2)
    else:
        dim = (2, 3)
    return dim


def get_channel_dim(fmt: FormatT):
    fmt = Format(fmt)
    if fmt is Format.NHWC:
        dim = 3
    elif fmt is Format.NLC:
        dim = 2
    else:
        dim = 1
    return dim


def nchw_to(x: torch.Tensor, fmt: Format):
    if fmt == Format.NHWC:
        x = x.permute(0, 2, 3, 1)
    elif fmt == Format.NLC:
        x = x.flatten(2).transpose(1, 2)
    elif fmt == Format.NCL:
        x = x.flatten(2)
    return x


def nhwc_to(x: torch.Tensor, fmt: Format):
    if fmt == Format.NCHW:
        x = x.permute(0, 3, 1, 2)
    elif fmt == Format.NLC:
        x = x.flatten(1, 2)
    elif fmt == Format.NCL:
        x = x.flatten(1, 2).transpose(1, 2)
    return x

In [6]:
# from functools import partial

# import torch
# import torch.nn as nn
# import torch.utils.checkpoint
# import numpy as np
# import einops
# import math
# class DropPath(nn.Module):
#     """Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks).
#     """
#     def __init__(self, drop_prob=None):
#         super(DropPath, self).__init__()
#         self.drop_prob = drop_prob

#     def forward(self, x):
#         return drop_path(x, self.drop_prob, self.training)


# class Mlp(nn.Module):
#     def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
#         super().__init__()
#         out_features = out_features or in_features
#         hidden_features = hidden_features or in_features
#         self.fc1 = nn.Linear(in_features, hidden_features)
#         self.act = act_layer()
#         self.fc2 = nn.Linear(hidden_features, out_features)
#         self.drop = nn.Dropout(drop)

#     def forward(self, x):
#         x = self.fc1(x)
#         x = self.act(x)
#         x = self.drop(x)
#         x = self.fc2(x)
#         x = self.drop(x)
#         return x
# class Attention(nn.Module):
#     def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
#         super().__init__()
#         self.num_heads = num_heads
#         head_dim = dim // num_heads
#         self.scale = qk_scale or head_dim ** -0.5

#         self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
#         self.attn_drop = nn.Dropout(attn_drop)
#         self.proj = nn.Linear(dim, dim)
#         self.proj_drop = nn.Dropout(proj_drop)

#     def forward(self, x):
#         B, L, C = x.shape

#         qkv = self.qkv(x)
#         if ATTENTION_MODE == 'flash':
#             qkv = einops.rearrange(qkv, 'B L (K H D) -> K B H L D', K=3, H=self.num_heads).float()
#             q, k, v = qkv[0], qkv[1], qkv[2]  # B H L D
#             x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
#             x = einops.rearrange(x, 'B H L D -> B L (H D)')
#         elif ATTENTION_MODE == 'xformers':
#             qkv = einops.rearrange(qkv, 'B L (K H D) -> K B L H D', K=3, H=self.num_heads)
#             q, k, v = qkv[0], qkv[1], qkv[2]  # B L H D
#             x = xformers.ops.memory_efficient_attention(q, k, v)
#             x = einops.rearrange(x, 'B L H D -> B L (H D)', H=self.num_heads)
#         elif ATTENTION_MODE == 'math':
#             qkv = einops.rearrange(qkv, 'B L (K H D) -> K B H L D', K=3, H=self.num_heads)
#             q, k, v = qkv[0], qkv[1], qkv[2]  # B H L D
#             attn = (q @ k.transpose(-2, -1)) * self.scale
#             attn = attn.softmax(dim=-1)
#             attn = self.attn_drop(attn)
#             x = (attn @ v).transpose(1, 2).reshape(B, L, C)
#         else:
#             raise NotImplemented

#         x = self.proj(x)
#         x = self.proj_drop(x)
#         return x

# class Block(nn.Module):

#     def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None,
#                  act_layer=nn.GELU, norm_layer=nn.LayerNorm, skip=False, use_checkpoint=False):
#         super().__init__()
#         self.norm1 = norm_layer(dim)
#         self.attn = Attention(
#             dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale)
#         self.norm2 = norm_layer(dim)
#         mlp_hidden_dim = int(dim * mlp_ratio)
#         self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer)
#         self.skip_linear = nn.Linear(2 * dim, dim) if skip else None
#         self.use_checkpoint = use_checkpoint

#     def forward(self, x, skip=None):
#         if self.use_checkpoint:
#             return torch.utils.checkpoint.checkpoint(self._forward, x, skip)
#         else:
#             return self._forward(x, skip)

#     def _forward(self, x, skip=None):
#         if self.skip_linear is not None:
#             x = self.skip_linear(torch.cat([x, skip], dim=-1))
#         x = x + self.attn(self.norm1(x))
#         x = x + self.mlp(self.norm2(x))
#         return x


# class PatchEmbed(nn.Module):
#     """ 2D Image to Patch Embedding
#     """
 
#     def __init__(
#             self,
#             img_size: Optional[int] = 224,
#             patch_size: int = 16,
#             in_chans: int = 3,
#             embed_dim: int = 768,
#             norm_layer: Optional[Callable] = None,
#             flatten: bool = True,
#             output_fmt: Optional[str] = None,
#             bias: bool = True,
#             strict_img_size: bool = True,
#     ):
#         super().__init__()
#         self.patch_size = to_2tuple(patch_size)
#         if img_size is not None:
#             self.img_size = to_2tuple(img_size)
#             self.grid_size = tuple([s // p for s, p in zip(self.img_size, self.patch_size)])
#             self.num_patches = self.grid_size[0] * self.grid_size[1]
#         else:
#             self.img_size = None
#             self.grid_size = None
#             self.num_patches = None

#         if output_fmt is not None:
#             self.flatten = False
#             self.output_fmt = Format(output_fmt)
#         else:
#             # flatten spatial dim and transpose to channels last, kept for bwd compat
#             self.flatten = flatten
#             self.output_fmt = Format.NCHW
#         self.strict_img_size = strict_img_size

#         self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
#         self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()

#     def forward(self, x):
#         B, C, H, W = x.shape
#         if self.img_size is not None:
#             if self.strict_img_size:
#                 _assert(H == self.img_size[0], f"Input height ({H}) doesn't match model ({self.img_size[0]}).")
#                 _assert(W == self.img_size[1], f"Input width ({W}) doesn't match model ({self.img_size[1]}).")
#             else:
#                 _assert(
#                     H % self.patch_size[0] == 0,
#                     f"Input height ({H}) should be divisible by patch size ({self.patch_size[0]})."
#                 )
#                 _assert(
#                     W % self.patch_size[1] == 0,
#                     f"Input width ({W}) should be divisible by patch size ({self.patch_size[1]})."
#                 )

#         x = self.proj(x)
#         if self.flatten:
#             x = x.flatten(2).transpose(1, 2)  # NCHW -> NLC
#         elif self.output_fmt != Format.NCHW:
#             x = nchw_to(x, self.output_fmt)
#         x = self.norm(x)
#         return x
# def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
#     """
#     grid_size: int of the grid height and width
#     return:
#     pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
#     """
#     grid_h = np.arange(grid_size, dtype=np.float32)
#     grid_w = np.arange(grid_size, dtype=np.float32)
#     grid = np.meshgrid(grid_w, grid_h)  # here w goes first
#     grid = np.stack(grid, axis=0)

#     grid = grid.reshape([2, 1, grid_size, grid_size])
#     pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
#     if cls_token:
#         pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
#     return pos_embed


# def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
#     assert embed_dim % 2 == 0

#     # use half of dimensions to encode grid_h
#     emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])  # (H*W, D/2)
#     emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])  # (H*W, D/2)

#     emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
#     return emb


# def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
#     """
#     embed_dim: output dimension for each position
#     pos: a list of positions to be encoded: size (M,)
#     out: (M, D)
#     """
#     assert embed_dim % 2 == 0
#     omega = np.arange(embed_dim // 2, dtype=float)
#     omega /= embed_dim / 2.
#     omega = 1. / 10000**omega  # (D/2,)

#     pos = pos.reshape(-1)  # (M,)
#     out = np.einsum('m,d->md', pos, omega)  # (M, D/2), outer product

#     emb_sin = np.sin(out) # (M, D/2)
#     emb_cos = np.cos(out) # (M, D/2)

#     emb = np.concatenate([emb_sin, emb_cos], axis=1)  # (M, D)
#     return emb

In [7]:
def interpolate_pos_embed(model, checkpoint_model):
    if 'pos_embed' in checkpoint_model:
        pos_embed_checkpoint = checkpoint_model['pos_embed']
        embedding_size = pos_embed_checkpoint.shape[-1]
        num_patches = model.patch_embed.num_patches
        num_extra_tokens = model.pos_embed.shape[-2] - num_patches
        # height (== width) for the checkpoint position embedding
        orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
        # height (== width) for the new position embedding
        new_size = int(num_patches ** 0.5)
        # class_token and dist_token are kept unchanged
        if orig_size != new_size:
            print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
            extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
            # only the position tokens are interpolated
            pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
            pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
            pos_tokens = torch.nn.functional.interpolate(
                pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
            pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
            new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
            checkpoint_model['pos_embed'] = new_pos_embed
def timestep_embedding(timesteps, dim, max_period=10000):
    """
    Create sinusoidal timestep embeddings.

    :param timesteps: a 1-D Tensor of N indices, one per batch element.
                      These may be fractional.
    :param dim: the dimension of the output.
    :param max_period: controls the minimum frequency of the embeddings.
    :return: an [N x dim] Tensor of positional embeddings.
    """
    half = dim // 2
    freqs = torch.exp(
        -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
    ).to(device=timesteps.device)
    args = timesteps[:, None].float() * freqs[None]
    embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
    if dim % 2:
        embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
    return embedding

In [8]:
# class MaskedAutoencoderViT(nn.Module):
#     """ Masked Autoencoder with VisionTransformer backbone
#     """
#     def __init__(self, img_size=224, patch_size=16, in_chans=3,
#                  embed_dim=1024, depth=24, num_heads=16,
#                  qkv_bias=True, qk_scale=None,
#                  decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
#                  mlp_ratio=4., mlp_time_embed=False, norm_layer=nn.LayerNorm, 
#                  num_classes=None, norm_pix_loss=False,
#                  use_checkpoint=False, conv=True, skip=False):
#         super().__init__()

#         # --------------------------------------------------------------------------
#         # MAE encoder specifics
#         self.embed_dim = embed_dim
#         self.num_classes = num_classes
#         self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
#         num_patches = self.patch_embed.num_patches
#         self.time_embed = nn.Sequential(
#             nn.Linear(embed_dim, 4 * embed_dim),
#             nn.SiLU(),
#             nn.Linear(4 * embed_dim, embed_dim),
#         ) if mlp_time_embed else nn.Identity()
#         self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
#         if self.num_classes is not None:
#             self.label_emb = nn.Embedding(self.num_classes, embed_dim)
#             self.extras = 2
#         else:
#             self.extras = 1
        
#         self.pos_embed = nn.Parameter(torch.zeros(1, self.extras + num_patches, embed_dim), requires_grad=False)  # fixed sin-cos embedding

#         self.in_blocks = nn.ModuleList([
#             Block(
#                 dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
#                 norm_layer=norm_layer, use_checkpoint=use_checkpoint)
#             for _ in range(depth // 2)])

#         self.norm = norm_layer(embed_dim)
#         self.mid_block = nn.Sequential(
#             Block(
#                 dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
#                 norm_layer=norm_layer, use_checkpoint=use_checkpoint)
#         )
#         # --------------------------------------------------------------------------

#         # --------------------------------------------------------------------------
#         # MAE decoder specifics
#         self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True)

#         self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))

#         self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim), requires_grad=False)  # fixed sin-cos embedding

#         self.out_blocks = nn.ModuleList([
#             Block(
#                 dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
#                 norm_layer=norm_layer, skip=skip, use_checkpoint=use_checkpoint)
#             for _ in range(depth // 2)])

#         self.decoder_norm = norm_layer(decoder_embed_dim)
#         self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size**2 * in_chans, bias=True) # decoder to patch
#         # --------------------------------------------------------------------------

#         self.norm_pix_loss = norm_pix_loss

#         self.initialize_weights()

#     def initialize_weights(self):
#         # initialization
#         # initialize (and freeze) pos_embed by sin-cos embedding
#         pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True)
#         self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))

#         decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True)
#         self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0))

#         # initialize patch_embed like nn.Linear (instead of nn.Conv2d)
#         w = self.patch_embed.proj.weight.data
#         torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))

#         # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
#         torch.nn.init.normal_(self.cls_token, std=.02)
#         torch.nn.init.normal_(self.mask_token, std=.02)

#         # initialize nn.Linear and nn.LayerNorm
#         self.apply(self._init_weights)

#     def _init_weights(self, m):
#         if isinstance(m, nn.Linear):
#             # we use xavier_uniform following official JAX ViT:
#             torch.nn.init.xavier_uniform_(m.weight)
#             if isinstance(m, nn.Linear) and m.bias is not None:
#                 nn.init.constant_(m.bias, 0)
#         elif isinstance(m, nn.LayerNorm):
#             nn.init.constant_(m.bias, 0)
#             nn.init.constant_(m.weight, 1.0)

#     def patchify(self, imgs):
#         """
#         imgs: (N, 3, H, W)
#         x: (N, L, patch_size**2 *3)
#         """
#         p = self.patch_embed.patch_size[0]
#         assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0

#         h = w = imgs.shape[2] // p
#         x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p))
#         x = torch.einsum('nchpwq->nhwpqc', x)
#         x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3))
#         return x

#     def unpatchify(self, x):
#         """
#         x: (N, L, patch_size**2 *3)
#         imgs: (N, 3, H, W)
#         """
#         p = self.patch_embed.patch_size[0]
#         h = w = int(x.shape[1]**.5)
#         assert h * w == x.shape[1]
        
#         x = x.reshape(shape=(x.shape[0], h, w, p, p, 3))
#         x = torch.einsum('nhwpqc->nchpwq', x)
#         imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p))
#         return imgs

#     def random_masking(self, x, mask_ratio):
#         """
#         Perform per-sample random masking by per-sample shuffling.
#         Per-sample shuffling is done by argsort random noise.
#         x: [N, L, D], sequence
#         """
#         N, L, D = x.shape  # batch, length, dim
#         len_keep = int(L * (1 - mask_ratio))
        
#         noise = torch.rand(N, L, device=x.device)  # noise in [0, 1]
        
#         # sort noise for each sample
#         ids_shuffle = torch.argsort(noise, dim=1)  # ascend: small is keep, large is remove
#         ids_restore = torch.argsort(ids_shuffle, dim=1)

#         # keep the first subset
#         ids_keep = ids_shuffle[:, :len_keep]
#         x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))

#         # generate the binary mask: 0 is keep, 1 is remove
#         mask = torch.ones([N, L], device=x.device)
#         mask[:, :len_keep] = 0
#         # unshuffle to get the binary mask
#         mask = torch.gather(mask, dim=1, index=ids_restore)

#         return x_masked, mask, ids_restore

#     def forward_encoder(self, x, timesteps, mask_ratio, y=None):
#         # embed patches
#         print(x.shape)
#         x = self.patch_embed(x)
#         print(x.shape)
#         # add pos embed w/o cls token
#         B, L, D = x.shape
# #         print(x)
#         time_token = self.time_embed(timestep_embedding(timesteps, self.embed_dim))
#         time_token = time_token.unsqueeze(dim=1)
#         x = torch.cat((time_token, x), dim=1)
#         if y is not None:
#             label_emb = self.label_emb(y)
#             label_emb = label_emb.unsqueeze(dim=1)
#             x = torch.cat((label_emb, x), dim=1)
# #             print("Yess")
#         x = x + self.pos_embed
#         print(x.shape)
#         # masking: length -> length * mask_ratio
#         x, mask, ids_restore = self.random_masking(x, mask_ratio)
#         print(x.shape)
#         print("mask", mask.shape)
#         skips = []
#         for blk in self.in_blocks:
#             x = blk(x)
#             skips.append(x)
# #         bf = x
#         x = self.mid_block(x)
#         x = self.norm(x)
#         print(x.shape)
#         return x, mask, ids_restore

#     def forward_decoder(self, x, ids_restore):
#         # embed tokens
#         print(x.shape)
#         x = self.decoder_embed(x)
#         print(x.shape)
#         # append mask tokens to sequence
#         mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)
#         print(mask_tokens.shape)
#         x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1)  # no cls token
#         print(x_.shape)
#         x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]))  # unshuffle
#         print(x_.shape)
#         x = torch.cat([x[:, :1, :], x_], dim=1)  # append cls token
#         print(x.shape)

#         # add pos embed
#         x = x + self.decoder_pos_embed
#         print(x.shape)
#         # apply Transformer blocks
#         for blk in self.out_blocks:
#             x = blk(x)
#         x = self.decoder_norm(x)
#         print(x.shape)
#         # predictor projection
#         x = self.decoder_pred(x)
#         print(x.shape)
#         # remove cls token
#         x = x[:, 1:, :]
#         print(x.shape)
#         return x

#     def forward_loss(self, imgs, pred, mask):
#         """
#         imgs: [N, 3, H, W]
#         pred: [N, L, p*p*3]
#         mask: [N, L], 0 is keep, 1 is remove, 
#         """
#         target = self.patchify(imgs)
#         if self.norm_pix_loss:
#             mean = target.mean(dim=-1, keepdim=True)
#             var = target.var(dim=-1, keepdim=True)
#             target = (target - mean) / (var + 1.e-6)**.5

#         loss = (pred - target) ** 2
#         loss = loss.mean(dim=-1)  # [N, L], mean loss per patch

#         loss = (loss * mask).sum() / mask.sum()  # mean loss on removed patches
#         return loss

#     def forward(self, imgs, timesteps, mask_ratio=0.75, y=None):
#         latent, mask, ids_restore = self.forward_encoder(imgs, timesteps, mask_ratio, y=None)
#         pred = self.forward_decoder(latent, ids_restore)  # [N, L, p*p*3]
#         loss = self.forward_loss(imgs, pred, mask)
#         return loss, pred, mask


# def mae_vit_base_patch16_dec512d8b(**kwargs):
#     model = MaskedAutoencoderViT(
#         patch_size=16, embed_dim=768, depth=12, num_heads=12,
#         decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
#         mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
#     return model


# def mae_vit_large_patch16_dec512d8b(**kwargs):
#     model = MaskedAutoencoderViT(
#         patch_size=16, embed_dim=1024, depth=24, num_heads=16,
#         decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
#         mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
#     return model


# def mae_vit_huge_patch14_dec512d8b(**kwargs):
#     model = MaskedAutoencoderViT(
#         patch_size=14, embed_dim=1280, depth=32, num_heads=16,
#         decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
#         mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
#     return model


# # set recommended archs
# mae_vit_base_patch16 = mae_vit_base_patch16_dec512d8b  # decoder: 512 dim, 8 blocks
# mae_vit_large_patch16 = mae_vit_large_patch16_dec512d8b  # decoder: 512 dim, 8 blocks
# mae_vit_huge_patch14 = mae_vit_huge_patch14_dec512d8b  # decoder: 512 dim, 8 blocks

In [10]:
# model = mae_vit_base_patch16(img_size = 256)

In [11]:
x_test = torch.rand(1, 3, 32, 32)
t_batch = torch.tensor([1], device=x_test.device).repeat(x_test.shape[0])

In [12]:
t_batch = torch.tensor([1], device=x_test.device).repeat(x_test.shape[0])

In [13]:
# y_test = model(x_test, t_batch, 0.5)

In [None]:
mprint()

In [16]:
import maskdit

In [17]:
model = maskdit.DiT_B_2()

In [20]:
y_test = model(x_test, t_batch, None)

In [22]:
y_test['x'].shape

torch.Size([1, 4, 32, 32])