In [19]:
import torch
from torch import nn, einsum
from einops import rearrange
from monai.networks.blocks import TransformerBlock
import torch
import torch.nn as nn
from monai.networks.blocks import ChannelSELayer
from loguru import logger
import monai
import torch.nn.functional as F
# Paper: DS-TransUNet: Dual Swin Transformer U-Net for Medical Image Segmentation
# Paper URL: https://arxiv.org/abs/2106.06716


# class PreNorm(nn.Module):
#     def __init__(self, dim, fn):
#         super().__init__()
#         self.norm = nn.LayerNorm(dim)
#         self.fn = fn
        
#     def forward(self, x, **kwargs):
#         return self.fn(self.norm(x), **kwargs)


# class FeedForward(nn.Module):
#     def __init__(self, dim, hidden_dim, dropout=0.):
#         super().__init__()
#         self.net = nn.Sequential(
#             nn.Linear(dim, hidden_dim),
#             nn.GELU(),
#             nn.Dropout(dropout),
#             nn.Linear(hidden_dim, dim),
#             nn.Dropout(dropout)
#         )
        
#     def forward(self, x):
#         return self.net(x)


# class Attention(nn.Module):
#     def __init__(self, dim, heads=8, dim_head=64, dropout=0.):
#         super().__init__()
#         inner_dim = dim_head * heads
#         project_out = not (heads == 1 and dim_head == dim)

#         self.heads = heads
#         self.scale = dim_head ** -0.5
#         self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
#         self.to_out = nn.Sequential(
#             nn.Linear(inner_dim, dim),
#             nn.Dropout(dropout)
#         ) if project_out else nn.Identity()

#     def forward(self, x):
#         b, n, _ = x.shape
#         h = self.heads
#         qkv = self.to_qkv(x).chunk(3, dim=-1)
#         q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), qkv)

#         dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
#         attn = dots.softmax(dim=-1)
#         out = einsum('b h i j, b h j d -> b h i d', attn, v)
#         out = rearrange(out, 'b h n d -> b n (h d)')
#         return self.to_out(out)


# class Transformer(nn.Module):
#     def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.):
#         super().__init__()
#         self.layers = nn.ModuleList([])
#         for _ in range(depth):
#             self.layers.append(nn.ModuleList([
#                 PreNorm(dim, Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout)),
#                 PreNorm(dim, FeedForward(dim, mlp_dim, dropout=dropout))
#             ]))

#     def forward(self, x):
#         for attn, ff in self.layers:
#             x = attn(x) + x
#             x = ff(x) + x
#         return x


# class CrossAttention(nn.Module):
#     def __init__(self, dim, heads=8, dim_head=64, dropout=0.):
#         super().__init__()
#         inner_dim = dim_head * heads
#         project_out = not (heads == 1 and dim_head == dim)

#         self.heads = heads
#         self.scale = dim_head ** -0.5
#         self.to_k = nn.Linear(dim, inner_dim, bias=False)
#         self.to_v = nn.Linear(dim, inner_dim, bias=False)
#         self.to_q = nn.Linear(dim, inner_dim, bias=False)
#         self.to_out = nn.Sequential(
#             nn.Linear(inner_dim, dim),
#             nn.Dropout(dropout)
#         ) if project_out else nn.Identity()

#     def forward(self, x_qkv):
#         b, n, _ = x_qkv.shape
#         h = self.heads

#         k = self.to_k(x_qkv)
#         k = rearrange(k, 'b n (h d) -> b h n d', h=h)

#         v = self.to_v(x_qkv)
#         v = rearrange(v, 'b n (h d) -> b h n d', h=h)

#         q = self.to_q(x_qkv[:, 0].unsqueeze(1))
#         q = rearrange(q, 'b n (h d) -> b h n d', h=h)

#         dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
#         attn = dots.softmax(dim=-1)
#         out = einsum('b h i j, b h j d -> b h i d', attn, v)
#         out = rearrange(out, 'b h n d -> b n (h d)')
#         return self.to_out(out)


class TIF(nn.Module):
    def __init__(self, in_channels_pathology, in_channels_anatomy):
        super().__init__()
        self.transformer_pathology = TransformerBlock(
            hidden_size=in_channels_pathology,
            num_heads=8,
            mlp_dim=in_channels_pathology * 4,
            qkv_bias=False,
            with_cross_attention=True,
            use_flash_attention=True,
            dropout_rate=0.1
        )
        self.transformer_anatomy = TransformerBlock(
            hidden_size=in_channels_anatomy,
            num_heads=8,
            mlp_dim=in_channels_anatomy * 4,
            qkv_bias=False,
            with_cross_attention=True,
            use_flash_attention=True,
            dropout_rate=0.1
        )
        self.norm_pathology = nn.LayerNorm(in_channels_pathology)
        self.norm_anatomy = nn.LayerNorm(in_channels_anatomy)
        self.avgpool = nn.AdaptiveAvgPool1d(1)
        self.linear_pathology = nn.Linear(in_channels_pathology, in_channels_anatomy)
        self.linear_anatomy = nn.Linear(in_channels_anatomy, in_channels_pathology)

    def forward(self, pathology, anatomy):
        b, c_p, d, h, w = pathology.shape
        # Reshape to [B, C, D*H*W] then permute to [B, D*H*W, C]
        pathology = pathology.reshape(b, c_p, -1).permute(0, 2, 1)
        
        _, c_a, _, _, _ = anatomy.shape
        anatomy = anatomy.reshape(b, c_a, -1).permute(0, 2, 1)
        print(f"pathology.shape: {pathology.shape}")
        print(f"anatomy.shape: {anatomy.shape}")
        # Get token embeddings
        p_token = torch.flatten(self.avgpool(self.norm_anatomy(pathology).transpose(1, 2)), 1)
        a_token = torch.flatten(self.avgpool(self.norm_pathology(anatomy).transpose(1, 2)), 1)
        print(f"p_token.shape: {p_token.shape}")
        print(f"a_token.shape: {a_token.shape}")
        p_token = self.linear_anatomy(p_token).unsqueeze(1)
        a_token = self.linear_pathology(a_token).unsqueeze(1)
        
        # Cross attention
        anatomy = self.transformer_anatomy(torch.cat([p_token, anatomy], dim=1))[:, 1:, :]
        pathology = self.transformer_pathology(torch.cat([a_token, pathology], dim=1))[:, 1:, :]
        
        # Reshape back to 3D
        pathology = pathology.permute(0, 2, 1).reshape(b, c_p, d, h, w)
        anatomy = anatomy.permute(0, 2, 1).reshape(b, c_a, d, h, w)
        return pathology + anatomy


if __name__ == '__main__':
    model = TIF(in_channels_pathology=1024, in_channels_anatomy=1024)
    input1 = torch.randn(1, 1024, 3,3,3)  # Example: small-scale feature image
    input2 = torch.randn(1, 1024, 3,3,3)  # Example: large-scale feature image
    
    # Forward pass to get output
    output = model(input1, input2)

    # Print input and output shapes
    print(input1.size())
    print(input2.size())
    print(output.size())

pathology.shape: torch.Size([1, 27, 1024])
anatomy.shape: torch.Size([1, 27, 1024])
p_token.shape: torch.Size([1, 1024])
a_token.shape: torch.Size([1, 1024])
torch.Size([1, 1024, 3, 3, 3])
torch.Size([1, 1024, 3, 3, 3])
torch.Size([1, 1024, 3, 3, 3])
