## Download pretrained model from huggingface


In [1]:
# !wget https://huggingface.co/facebook/sapiens-pretrain-0.3b/resolve/main/sapiens_0.3b_epoch_1600_clean.pth


## Define model


In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class PatchEmbedding(nn.Module):
    def __init__(self, img_size=1024, patch_size=16, in_chans=3, embed_dim=1024, dropout_rate=0.1):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.grid_size = img_size // patch_size
        self.num_patches = self.grid_size * self.grid_size
        self.projection = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, x):
        x = self.projection(x)  # (B, embed_dim, H/P, W/P)
        x = x.flatten(2)  # (B, embed_dim, N)
        x = x.transpose(1, 2)  # (B, N, embed_dim)
        x = self.dropout(x)
        return x

class Attention(nn.Module):
    def __init__(self, embed_dim, num_heads, dropout_rate=0.1):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.scale = self.head_dim ** -0.5

        self.qkv = nn.Linear(embed_dim, embed_dim * 3)
        self.proj = nn.Linear(embed_dim, embed_dim)
        self.attn_drop = nn.Dropout(dropout_rate)
        self.proj_drop = nn.Dropout(dropout_rate)

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

class FFN(nn.Module):
    def __init__(self, in_dim, hidden_dim, dropout_rate=0.1):
        super().__init__()
        self.layers = nn.ModuleList([
            nn.ModuleList([
                nn.Linear(in_dim, hidden_dim)  # This will be layers.0.0
            ]),
            nn.Linear(hidden_dim, in_dim)      # This will be layers.1
        ])
        self.act = nn.GELU()
        self.drop1 = nn.Dropout(dropout_rate)
        self.drop2 = nn.Dropout(dropout_rate)

    def forward(self, x):
        x = self.layers[0][0](x)  # Apply first linear layer (layers.0.0)
        x = self.act(x)
        x = self.drop1(x)
        x = self.layers[1](x)     # Apply second linear layer (layers.1)
        x = self.drop2(x)
        return x

class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, mlp_ratio=4.0, dropout_rate=0.1):
        super().__init__()
        self.ln1 = nn.LayerNorm(embed_dim)
        self.attn = Attention(embed_dim, num_heads, dropout_rate)
        self.ln2 = nn.LayerNorm(embed_dim)
        hidden_dim = int(embed_dim * mlp_ratio)
        self.ffn = FFN(embed_dim, hidden_dim, dropout_rate)
        self.drop_path = nn.Dropout(dropout_rate)

    def forward(self, x):
        x = x + self.drop_path(self.attn(self.ln1(x)))
        x = x + self.drop_path(self.ffn(self.ln2(x)))
        return x

class SapiensEncoder(nn.Module):
    def __init__(self, img_size=1024, patch_size=16, in_chans=3, embed_dim=1024, depth=24, num_heads=16, dropout_rate=0.1):
        super().__init__()
        self.patch_embed = PatchEmbedding(img_size, patch_size, in_chans, embed_dim, dropout_rate)
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, 1 + self.patch_embed.num_patches, embed_dim))
        self.pos_drop = nn.Dropout(dropout_rate)

        self.layers = nn.ModuleList([
            TransformerBlock(embed_dim, num_heads, 4.0, dropout_rate) for _ in range(depth)
        ])
        self.ln1 = nn.LayerNorm(embed_dim)

    def forward(self, x):
        B = x.shape[0]
        x = self.patch_embed(x)

        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        x = x + self.pos_embed
        x = self.pos_drop(x)

        for block in self.layers:
            x = block(x)

        x = self.ln1(x)
        return x

def vit_base_patch16_1024(dropout_rate=0.1):
    model = SapiensEncoder(
        img_size=1024,
        patch_size=16,
        in_chans=3,
        embed_dim=1024,
        depth=24,
        num_heads=16,
        dropout_rate=dropout_rate,
    )
    return model

# Example of setting a custom dropout rate
model = vit_base_patch16_1024(dropout_rate=0.2)  # Set dropout to 0.2 for the entire model

# To check all parameters and their shapes
# for name, param in model.named_parameters():
#     print(f"{name}: {param.shape}")


import torch
torch.manual_seed(42)
# Load the state dictionary
state_dict = torch.load("sapiens_0.3b_epoch_1600_clean.pth")
model.load_state_dict(state_dict)


device = "cuda"
model.to(device)
inputs = torch.ones(1, 3, 1024, 1024).to(device)
model.eval()
with torch.no_grad():
    outputs = model(inputs)
    print(outputs.shape)
    print(outputs)
    
    import torch
device = "cuda"
torch.cuda.empty_cache()
ts_model = torch.jit.load("sapiens_0.3b_epoch_1600_torchscript.pt2")
ts_model.to(device)
inputs = torch.ones(1, 3, 1024, 1024).to(device)
ts_output = ts_model(inputs)
ts_output

## Load pretrained Weight


In [12]:
import torch
torch.manual_seed(42)
# Load the state dictionary
state_dict = torch.load("sapiens_0.3b_epoch_1600_clean.pth")
model.load_state_dict(state_dict)

<All keys matched successfully>

In [13]:
device = "cuda"
model.to(device)
inputs = torch.ones(1, 3, 1024, 1024).to(device)
model.eval()
with torch.no_grad():
    outputs = model(inputs)
    print(outputs.shape)
    print(outputs)

torch.Size([1, 4097, 1024])


tensor([[[ 0.4483,  0.1892,  0.3767,  ..., -0.2551, -0.3712,  0.3022],
         [ 0.4364,  0.1943,  0.3770,  ..., -0.2410, -0.3435,  0.2985],
         [ 0.4363,  0.1940,  0.3770,  ..., -0.2411, -0.3436,  0.2986],
         ...,
         [ 0.4379,  0.1944,  0.3791,  ..., -0.2385, -0.3422,  0.2984],
         [ 0.4379,  0.1945,  0.3790,  ..., -0.2384, -0.3422,  0.2984],
         [ 0.4378,  0.1945,  0.3789,  ..., -0.2383, -0.3421,  0.2984]]],
       device='cuda:0')


In [14]:
# !wget https://huggingface.co/facebook/sapiens-pretrain-0.3b-torchscript/resolve/main/sapiens_0.3b_epoch_1600_torchscript.pt2

## Load the torchscript version of weights for efficient inference


In [3]:
import torch
device = "cuda"
torch.cuda.empty_cache()
ts_model = torch.jit.load("sapiens_0.3b_epoch_1600_torchscript.pt2")
ts_model.to(device)
inputs = torch.ones(1, 3, 1024, 1024).to(device)
ts_output = ts_model(inputs)
ts_output

(tensor([[[[ 1.2733e-01,  4.7221e-02,  5.2698e-02,  ..., -1.0587e-01,
            -1.0880e-01, -2.7345e-02],
           [ 2.4000e-01,  1.5818e-01,  1.4922e-01,  ...,  1.4420e-02,
             5.9127e-03,  5.7294e-02],
           [ 3.3432e-01,  2.4099e-01,  2.2941e-01,  ...,  1.0003e-01,
             9.6673e-02,  1.4691e-01],
           ...,
           [ 1.4163e-01,  3.4757e-02,  2.8229e-02,  ..., -3.0150e-02,
            -3.9315e-02,  2.2697e-02],
           [ 2.0953e-01,  1.0161e-01,  9.8985e-02,  ...,  2.3701e-02,
             1.6074e-02,  7.9371e-02],
           [ 3.8715e-01,  2.9326e-01,  2.7447e-01,  ...,  1.9486e-01,
             1.8994e-01,  2.3577e-01]],
 
          [[-8.6676e-02, -1.3280e-03,  4.7174e-02,  ...,  6.8181e-02,
             1.6999e-02, -1.7683e-02],
           [ 6.6773e-02,  1.2355e-01,  1.6407e-01,  ...,  2.0209e-01,
             1.4679e-01,  1.4856e-01],
           [ 1.3793e-01,  1.9407e-01,  2.3197e-01,  ...,  2.7402e-01,
             2.2278e-01,  2.1753e-01],
