In [None]:
class PatchEmbed(nn.Module):
    def __init__(self,img_size,patch_size=16,in_chans=3,embed=768):
        super.__init__()
        self.img_size = img_size
        self.patch_size = 16
        self.num_pathces = (img_size//patch_size)**2
        self.proj = nn.Conv2d(in_chans,embed,kernel_size=patch_size,stride=patch_size)
    def forward(self,x):
        x = self.proj(x)
        x = x.flatten(2) #
        x = x.transpose(1,2) #
        return x

In [None]:
class attention(nn.Module):
    def __init__(self,dim,num_heads=8,qkv_bias=False,attn_drop=0.,proj_drop=0.):
        super.__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5

        self.qkv = nn.Linear(dim,dim*3,bias=qkv_bias)
        self.att_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim,dim)
        self.proj_dropout = proj_drop
    def forward(self,x):
        B , N , C = x.shape
        qkv = self.qkv(x).reshape(B,N,3,self.num_heads,C//self.num_heads).permute(2,0,3,1,4)
        q , k ,v = qkv[0],qkv[1],qkv[2]

        attn = (q@k.transpose(-2,-1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.att_drop(attn)

        x = (attn@v).transpose(1,2).reshape(B,N,C)
        x = self.proj(x)
        x = self.proj_dropout(x)

        return x


In [None]:
class transformerblock(nn.Module):
    def __init__(self,dim,num_heads,mlp_ratio=4.,qkv_bias=False,drop=0.,attn_droup=0.,act_layer=nn.GELU):
        super.__initn__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = attention(dim,num_heads,qkv_bias,attn_droup,drop)
        self.norm2 = nn.LayerNorm(dim)
        mlp_hidden_dim = int(dim*mlp_ratio)
        self.mlp = nn.Sequential(
        nn.Linear(dim,mlp_hidden_dim),act_layer,nn.Dropout(drop),nn.Linear(mlp_hidden_dim,dim),nn.Dropout(drop))
    def forward(self,x):
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x

In [None]:
class visiontransformer(nn.Module):
    def __init__(self,img_size=192,patch_size=16,in_channels=3,num_classes=3,embed_dim=768,dept=12,num_heads=12,mlp_ratio=4.,qkv_bias=True,drop_rate=0.,attn_drop_rate=0.):
        super().__init__()
        self.patch_embed = PatchEmbed(img_size=img_size,patch_size=patch_size,in_chans=in_channels,embed=embed_dim)
        num_patches = self.patch_embed.num_pathces
        self.cls_token = nn.Parameter(torch.zeros(1,1,embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1,num_patches+1,embed_dim))
        self.pos_drop = nn.Dropout(p=drop_rate)

        self.blocks = nn.ModuleList([transformerblock(dim=embed_dim,num_heads=num_heads,mlp_ratio=mlp_ratio,
                                                     qkv_bias=qkv_bias,drop=drop_rate,attn_droup=attn_drop_rate) for i in range(dept)])

        self.Norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim,num_classes)

        nn.init.trunc_normal_(self.pos_embed,std=.02)
        nn.init.trunc_normal_(self.cls_token,std=.02)
        self.apply(self._init_weights)
    def _init_weights(self,m):
        if isinstance(m,nn.Linear):
            nn.init.trunc_normal_(m.weights,std=.02)
            if isinstance(m,nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias,0)
        elif isinstance(m,nn.LayerNorm):
            nn.init.constant_(m.bias,0)
            nn.init.constant_(m.weights,1)
    def forward(self,x):

        B = x.shape[0]
        x = self.patch_embed(x) #output shape will be (batch_size,number of pathces,embeding)


        cls_token = self.cls_token.expand(B,-1,-1) # output shape will be (32,1,embeding)
        x = torch.cat((cls_token,x),dim=1)  # output shape will be (batch_size,number of pathces+1,embeding)

        x = x + self.pos_embed
        x = self.pos_drop(x)

        for block in self.blocks:
            x = block(x)
        x = self.Norm(x)
        x = x[:,0]
        x = self.head(x)
        return x

