In [1]:
from models.segformer_simple import EfficientSelfAttention
import torch
import torch.nn as nn

class EfficientVit(nn.Module):
    def __init__(self, num_classes, image_size, patch_size, num_layers, num_heads, embed_dim, mlp_ratio, qkv_bias, qk_scale, drop_rate, attn_drop_rate, drop_path_rate):
        super(EfficientVit, self).__init__()
        self.patch_embed = nn.Conv2d(3, embed_dim, kernel_size=patch_size, stride=patch_size)
        self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.randn(1, (image_size // patch_size) ** 2 + 1, embed_dim))
        self.pos_drop = nn.Dropout(p=drop_rate)

        self.blocks = nn.ModuleList([
            EfficientSelfAttention(
                dim=embed_dim,
                num_heads=num_heads,
                qkv_bias=qkv_bias,
                qk_scale=qk_scale,
                attn_drop=attn_drop_rate,
                proj_drop=drop_rate
            )
            for _ in range(num_layers)
        ])

        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes)
    def forward(self,x):
        x = self.patch_embed(x)
        x = x.flatten(2).transpose(1, 2)
        x = torch.cat([self.cls_token.expand(x.shape[0], -1, -1), x], dim=1)
        x = x + self.pos_embed
        x = self.pos_drop(x)

        for block in self.blocks:
            x = block(x)
        x = self.norm(x)
        x = x[:, 0]
        x = self.head(x)
        return x
    
class SimpleVit(nn.Module):
    def __init__(self, num_classes, image_size, patch_size, num_layers, num_heads, embed_dim, mlp_ratio, qkv_bias, qk_scale, drop_rate, attn_drop_rate, drop_path_rate):
        super(SimpleVit, self).__init__()
        self.patch_embed = nn.Conv2d(3, embed_dim, kernel_size=patch_size, stride=patch_size)
        self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.randn(1, (image_size // patch_size) ** 2 + 1, embed_dim))
        self.pos_drop = nn.Dropout(p=drop_rate)

        self.blocks = nn.ModuleList([
            nn.MultiheadAttention(
                embed_dim=embed_dim,
                num_heads=num_heads,
                qkv_bias=qkv_bias,
                qk_scale=qk_scale,
                dropout=attn_drop_rate
            )
            for _ in range(num_layers)
        ])

        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes)
    def forward(self,x):
        x = self.patch_embed(x)
        x = x.flatten(2).transpose(1, 2)
        x = torch.cat([self.cls_token.expand(x.shape[0], -1, -1), x], dim=1)
        x = x + self.pos_embed
        x = self.pos_drop(x)

        for block in self.blocks:
            x = block(x)
        x = self.norm(x)
        x = x[:, 0]
        x = self.head(x)
        return x

In [None]:
effi_model = EfficientVit(10, 224, 16, 12, 12, 768, 4, True, 1, 0.1, 0.1, 0.1)
simple_model = SimpleVit(10, 224, 16, 12, 12, 768, 4, True, 1, 0.1, 0.1, 0.1)
input_tensor = torch.randn(1, 3, 224, 224)
print(effi_model(input_tensor).shape)
print(simple_model(input_tensor).shape)