In [None]:
import torch
from torch import nn

from einops import repeat
from einops.layers.torch import Rearrange

from vit_pytorch.vit import Transformer, Attention, FeedForward, PreNorm

class SiT(nn.Module):
    def __init__(self, *,
                        dim, 
                        depth,
                        heads,
                        mlp_dim,
                        pool = 'cls', 
                        num_patches = 20,
                        num_classes= 1,
                        num_channels =4,
                        num_vertices = 2145,
                        dim_head = 64,
                        dropout = 0.,
                        emb_dropout = 0.
                        ):

        super().__init__()

        assert pool in {'cls', 'mean', 'max', 'sum'}, 'pool type must be either cls (cls token) or mean (mean pooling)'

        patch_dim = num_channels * num_vertices

        # inputs has size = b * c * n * v
        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c n v  -> b n (v c)'),
            nn.Linear(patch_dim, dim),
        )

        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.dropout = nn.Dropout(emb_dropout)

        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)

        self.pool = pool
        self.to_latent = nn.Identity()
        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )

    def forward(self, img):
        x = self.to_patch_embedding(img)
        b, n, _ = x.shape

        cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
        
        x = torch.cat((cls_tokens, x), dim=1)
        x += self.pos_embedding[:, :(n + 1)]
        x = self.dropout(x)

        x = self.transformer(x)

        x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]

        x = self.to_latent(x)

        return self.mlp_head(x)

In [None]:
class self_attention_layer_swin(nn.Module):
    """The self-attention layer on icosahedron discretized sphere based on
    2-ring filter
    
    Parameters:
            in_feats (int) - - input features/channels
            out_feats (int) - - output features/channels
            num_heads (int) - - Number of attention heads
            qkv_bias （bool） - - If True, add a learnable bias to query, key, value. Default: True
            qk_scale (float) - - Override default qk scale of head_dim ** -0.5 if set
            neigh_orders (ndarray) - - The indices of vertices used for patch partitioning
    Input: 
        B x in_feats x N tensor 
    Return:
        B x out_feats x N tensor
    """  
    def __init__(self, in_feats, out_feats, neigh_orders, neigh_orders_2=None, head_dim=8,
        qkv_bias=True, qk_scale=None, sep_process=True, drop_rate=None):
        super(self_attention_layer_swin, self).__init__()

        self.in_feats = in_feats
        self.out_feats = out_feats

        self.top = 16
        self.down = 19

        self.neigh_orders_top = neigh_orders['top'].reshape((-1, self.top))
        self.neigh_orders_down = neigh_orders['down'].reshape((-1, self.down))
        self.reverse_matrix = neigh_orders['reverse']
        self.cnt_matrix = nn.parameter.Parameter(torch.from_numpy((1 / neigh_orders['count']).astype(np.float32)), requires_grad=False)

        self.padding = nn.ZeroPad2d((0, 0, 0, 1))
        
        self.num_heads = in_feats // head_dim
        self.scale = qk_scale or head_dim ** -0.5
        self.sep_process = sep_process

        if drop_rate:
            self.drop_rate = drop_rate
        else:
            self.drop_rate = 0.0


        self.qkv = nn.Linear(in_feats, in_feats * 3, bias=qkv_bias)
        self.proj = nn.Sequential(
            nn.Linear(in_feats, out_feats),
            nn.Dropout(p=self.drop_rate, inplace=True)
            )
        self.residual = nn.Linear(in_feats, out_feats)
        
        mlp_ratio = 2.00
        self.mlp = nn.Sequential(
            nn.Linear(out_feats, int(out_feats * mlp_ratio)),
            nn.Dropout(p=self.drop_rate, inplace=True),
            nn.Linear(int(out_feats * mlp_ratio), out_feats),
            nn.Dropout(p=self.drop_rate, inplace=True)
        )
        self.norm = nn.BatchNorm1d(out_feats, momentum=0.15, affine=True, track_running_stats=False)
        self.softmax = nn.Softmax(dim=-1)
        
    def forward(self, x):
        x = torch.Tensor.permute(x, (0, 2, 1))
        res = self.residual(x)

        B, N, C = x.shape  # batch size x number of vertices x channel
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]

        q = q * self.scale
        if self.sep_process:
            attn_5adj = torch.einsum("binjc,binkc->binjk", q[:, :, self.neigh_orders_top], k[:, :, self.neigh_orders_top])
            attn_5adj = self.softmax(attn_5adj)

            v_5adj = v[:, :, self.neigh_orders_top]
            x_5adj = torch.einsum("binjk,binkc->binjc", attn_5adj, v_5adj)

            attn_6adj = torch.einsum("binjc,binkc->binjk", q[:, :, self.neigh_orders_down], k[:, :, self.neigh_orders_down])
            attn_6adj = self.softmax(attn_6adj)

            v_6adj = v[:, :, self.neigh_orders_down]
            x_6adj = torch.einsum("binjk,binkc->binjc", attn_6adj, v_6adj)

            x = torch.cat((x_5adj.reshape(B, self.num_heads, -1, C // self.num_heads), x_6adj.reshape(B, self.num_heads, -1, C // self.num_heads)), dim=2)
            x = self.padding(x)
            x = x[:, :, self.reverse_matrix, :].permute((0, 1, 3, 4, 2)) * self.cnt_matrix
            x = torch.Tensor.sum(x.permute((0, 1, 4, 2, 3)), dim=3)

        x = x.permute(0, 2, 1, 3).reshape(B, N, -1)

        out_features = self.proj(x) + res
        res2 = self.mlp(out_features)
        out_features = torch.Tensor.permute(out_features, (0, 2, 1))
        res2 = torch.Tensor.permute(res2, (0, 2, 1))
        out_features = out_features + self.norm(res2)

        return out_features