In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# ====== 一个简单的 Shared Encoder （toy 版本，随便卷积一下） ======
class SharedEncoder(nn.Module):
    def __init__(self, in_ch=1, out_ch=256):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(in_ch, 32, 3, padding=1, stride=2),
            nn.ReLU(),
            nn.Conv2d(32, 64, 3, padding=1, stride=2),
            nn.ReLU(),
            nn.Conv2d(64, out_ch, 3, padding=1, stride=2),
            nn.ReLU(),
        )
    def forward(self, x):
        return self.encoder(x)  # (B, out_ch, H/8, W/8)

# ====== 导入你写的 UNetFlow ======
from flow_matching_unet import UNetFlow

# ====== Flow Matching 框架最小示例 ======
class FlowMatchingWrapper(nn.Module):
    def __init__(self, in_channels=256, cond_dim=256):
        super().__init__()
        self.unet = UNetFlow(
            in_channels=in_channels,
            base_channels=64,
            time_emb_dim=128,
            depth=4,
            use_cross_attn=True,
            cond_dim=cond_dim
        )

    def forward(self, x_t, cond_feats, t, cond_mask=None):
        # v_pred: (B, C, H, W)
        return self.unet(x_t, cond_feats, t, cond_mask)


# ====== 测试 pipeline ======
if __name__ == "__main__":
    B, C, H, W = 2, 1, 128, 128   # 两个样本，单通道 MRI slice
    N, C_cond = 3, 256           # 条件模态数 3 (T1/T2/FLAIR)，embedding dim = 256

    # 输入原始 MRI 图像
    img = torch.randn(B, C, H, W)  # toy MRI slice
    encoder = SharedEncoder(in_ch=1, out_ch=256)
    x_0 = encoder(img)             # (B, 256, H/8, W/8)

    # 构造 flow matching 输入
    t = torch.randint(0, 1000, (B,))          # 时间步
    noise = torch.randn_like(x_0)
    x_t = x_0 + noise * 0.1                   # 加噪声的 latent

    cond_feats = torch.randn(B, N, C_cond)    # 条件模态 embedding
    cond_mask = torch.tensor([[1,1,1],[1,0,1]])  # 第二个样本缺一个模态

    # 调用模型
    model = FlowMatchingWrapper(in_channels=256, cond_dim=C_cond)
    v_pred = model(x_t, cond_feats, t, cond_mask)  # (B, 256, H/8, W/8)

    # 构造 target 向量场（这里只是 toy）
    v_target = noise

    # 损失
    loss = F.mse_loss(v_pred, v_target)
    print("v_pred shape:", v_pred.shape)
    print("Loss:", loss.item())


v_pred shape: torch.Size([2, 256, 16, 16])
Loss: 1.1204676628112793
