In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# ====== 一个简单的 Shared Encoder （toy 版本，随便卷积一下） ======
class SharedEncoder(nn.Module):
    def __init__(self, in_ch=1, out_ch=256):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(in_ch, 32, 3, padding=1, stride=2),
            nn.ReLU(),
            nn.Conv2d(32, 64, 3, padding=1, stride=2),
            nn.ReLU(),
            nn.Conv2d(64, out_ch, 3, padding=1, stride=2),
            nn.ReLU(),
        )
    def forward(self, x):
        return self.encoder(x)  # (B, out_ch, H/8, W/8)

# ====== 导入你写的 UNetFlow ======
from flow_matching_unet import UNetFlow

# ====== Flow Matching 框架最小示例 ======
class FlowMatchingWrapper(nn.Module):
    def __init__(self, in_channels=256, cond_dim=256):
        super().__init__()
        self.unet = UNetFlow(
            in_channels=in_channels,
            base_channels=64,
            time_emb_dim=128,
            depth=4,
            use_cross_attn=True,
            cond_dim=cond_dim
        )

    def forward(self, x_t, cond_feats, t, cond_mask=None):
        # v_pred: (B, C, H, W)
        return self.unet(x_t, cond_feats, t, cond_mask)


# ====== 测试 pipeline ======
if __name__ == "__main__":
    B, C, H, W = 2, 1, 128, 128   # 两个样本，单通道 MRI slice
    N, C_cond = 3, 256           # 条件模态数 3 (T1/T2/FLAIR)，embedding dim = 256

    # 输入原始 MRI 图像
    img = torch.randn(B, C, H, W)  # toy MRI slice
    encoder = SharedEncoder(in_ch=1, out_ch=256)
    x_0 = encoder(img)             # (B, 256, H/8, W/8)

    # 构造 flow matching 输入
    t = torch.randint(0, 1000, (B,))          # 时间步
    noise = torch.randn_like(x_0)
    x_t = x_0 + noise * 0.1                   # 加噪声的 latent

    cond_feats = torch.randn(B, N, C_cond)    # 条件模态 embedding
    cond_mask = torch.tensor([[1,1,1],[1,0,1]])  # 第二个样本缺一个模态

    # 调用模型
    model = FlowMatchingWrapper(in_channels=256, cond_dim=C_cond)
    v_pred = model(x_t, cond_feats, t, cond_mask)  # (B, 256, H/8, W/8)

    # 构造 target 向量场（这里只是 toy）
    v_target = noise

    # 损失
    loss = F.mse_loss(v_pred, v_target)
    print("v_pred shape:", v_pred.shape)
    print("Loss:", loss.item())


v_pred shape: torch.Size([2, 256, 16, 16])
Loss: 1.1204676628112793


In [None]:
"""
本地CPU测试Demo - FusionTransformer + FlowMatchingUNet
测试两个模块的集成和功能
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from typing import Optional, Dict, Any
import time
import os
import sys

# 检查文件是否存在
required_files = ['fusion_transformer.py', 'flow_matching_unet.py']
for file in required_files:
    if not os.path.exists(file):
        raise FileNotFoundError(f"Required file '{file}' not found in current directory!")

# 导入自定义模块
try:
    from fusion_transformer import FusionTransformer
    from flow_matching_unet import UNetFlow
    print("✅ Successfully imported FusionTransformer and UNetFlow")
except ImportError as e:
    print(f"❌ Import error: {e}")
    sys.exit(1)

class MRIFlowMatchingDemo:
    """MRI Flow Matching 完整测试Demo"""
    
    def __init__(self, device='cpu'):
        self.device = torch.device(device)
        self.setup_test_config()
        self.create_models()
        
    def setup_test_config(self):
        """设置测试配置"""
        # 模拟MRI数据配置
        self.batch_size = 2
        self.img_height, self.img_width = 64, 64  # 小尺寸便于CPU测试
        self.in_channels = 256  # encoder输出通道数
        self.n_modalities = 4   # T1, T2, FLAIR, DWI
        
        # FusionTransformer配置
        self.C_tok = 256        # token特征维度
        self.C_cond = 128       # 条件特征维度
        
        # UNetFlow配置
        self.base_channels = 64
        self.time_emb_dim = 128
        self.depth = 3          # UNet深度，CPU友好
        
        # 选择器配置
        self.selector_config = {
            "mode": "auto",
            "ratio_shallow": 0.6,
            "ratio_deep": 0.3,
            "temp": 1.0
        }
        
    def create_models(self):
        """创建模型实例"""
        print("\n🔧 Creating models...")
        
        # 创建FusionTransformer
        self.fusion_transformer = FusionTransformer(
            C_tok=self.C_tok,
            C_cond=self.C_cond,
            n_heads=4,
            n_layers=2,
            dropout=0.05,
            max_N=8
        ).to(self.device)
        
        # 创建UNetFlow
        self.unet_flow = UNetFlow(
            in_channels=self.in_channels,
            base_channels=self.base_channels,
            time_emb_dim=self.time_emb_dim,
            depth=self.depth,
            use_cross_attn=True,
            cond_dim=self.C_cond,
            kv_dim=None,  # 默认与dim相同
            selector_config=self.selector_config
        ).to(self.device)
        
        print(f"✅ Models created on device: {self.device}")
        self.print_model_info()
        
    def print_model_info(self):
        """打印模型信息"""
        fusion_params = sum(p.numel() for p in self.fusion_transformer.parameters())
        unet_params = sum(p.numel() for p in self.unet_flow.parameters())
        
        print(f"📊 FusionTransformer parameters: {fusion_params:,}")
        print(f"📊 UNetFlow parameters: {unet_params:,}")
        print(f"📊 Total parameters: {fusion_params + unet_params:,}")
        
    def generate_test_data(self):
        """生成测试数据"""
        print("\n📁 Generating test data...")
        
        # 模拟来自shared encoder的token特征
        # 每个模态压缩成一个token向量
        tokens = torch.randn(
            self.batch_size, 
            self.n_modalities, 
            self.C_tok, 
            device=self.device
        )
        
        # 条件掩码：模拟某些模态缺失的情况
        cond_mask = torch.ones(self.batch_size, self.n_modalities, device=self.device)
        # 模拟第一个样本缺失最后一个模态
        cond_mask[0, -1] = 0
        
        # 模拟输入特征图 (来自shared encoder)
        x_t = torch.randn(
            self.batch_size, 
            self.in_channels, 
            self.img_height, 
            self.img_width,
            device=self.device
        )
        
        # 时间步
        t = torch.rand(self.batch_size, device=self.device) * 1000
        
        return {
            'tokens': tokens,
            'cond_mask': cond_mask,
            'x_t': x_t,
            't': t
        }
        
    def test_fusion_transformer(self, test_data):
        """测试FusionTransformer"""
        print("\n🧪 Testing FusionTransformer...")
        
        tokens = test_data['tokens']
        cond_mask = test_data['cond_mask']
        
        print(f"Input tokens shape: {tokens.shape}")
        print(f"Input cond_mask shape: {cond_mask.shape}")
        print(f"Cond_mask content:\n{cond_mask}")
        
        # 前向传播
        start_time = time.time()
        with torch.no_grad():
            cond_feats = self.fusion_transformer(tokens, cond_mask)
        end_time = time.time()
        
        if cond_feats is not None:
            print(f"✅ Output cond_feats shape: {cond_feats.shape}")
            print(f"⏱️  Forward time: {end_time - start_time:.4f}s")
            print(f"📊 Output stats - Mean: {cond_feats.mean().item():.4f}, Std: {cond_feats.std().item():.4f}")
            
            # 验证输出维度
            expected_shape = (self.batch_size, self.n_modalities, self.C_cond)
            assert cond_feats.shape == expected_shape, f"Shape mismatch: {cond_feats.shape} != {expected_shape}"
            print("✅ Shape validation passed")
        else:
            print("❌ FusionTransformer returned None (disabled)")
            
        return cond_feats
        
    def test_unet_flow(self, test_data, cond_feats):
        """测试UNetFlow"""
        print("\n🧪 Testing UNetFlow...")
        
        x_t = test_data['x_t']
        t = test_data['t']
        cond_mask = test_data['cond_mask']
        
        print(f"Input x_t shape: {x_t.shape}")
        print(f"Input t shape: {t.shape}")
        print(f"Input cond_feats shape: {cond_feats.shape if cond_feats is not None else None}")
        
        # 前向传播
        start_time = time.time()
        with torch.no_grad():
            velocity = self.unet_flow(x_t, cond_feats, t, cond_mask)
        end_time = time.time()
        
        print(f"✅ Output velocity shape: {velocity.shape}")
        print(f"⏱️  Forward time: {end_time - start_time:.4f}s")
        print(f"📊 Output stats - Mean: {velocity.mean().item():.4f}, Std: {velocity.std().item():.4f}")
        
        # 验证输出维度
        assert velocity.shape == x_t.shape, f"Shape mismatch: {velocity.shape} != {x_t.shape}"
        print("✅ Shape validation passed")
        
        return velocity
        
    def test_different_selector_modes(self, test_data, cond_feats):
        """测试不同的选择器模式"""
        print("\n🧪 Testing different selector modes...")
        
        modes = ["all", "auto", "topk_q", "topk_central", "weighted"]
        
        for mode in modes:
            print(f"\n--- Testing selector mode: {mode} ---")
            
            # 更新选择器配置
            selector_config = {"mode": mode, "k": 2, "temp": 1.0}
            self.unet_flow.set_selector_config(selector_config)
            
            try:
                start_time = time.time()
                with torch.no_grad():
                    velocity = self.unet_flow(
                        test_data['x_t'], 
                        cond_feats, 
                        test_data['t'], 
                        test_data['cond_mask']
                    )
                end_time = time.time()
                
                print(f"✅ Mode {mode}: Success, time: {end_time - start_time:.4f}s")
                print(f"   Output stats - Mean: {velocity.mean().item():.4f}, Std: {velocity.std().item():.4f}")
                
            except Exception as e:
                print(f"❌ Mode {mode}: Failed with error: {e}")
                
    def test_missing_modalities(self):
        """测试缺失模态的情况"""
        print("\n🧪 Testing missing modalities scenarios...")
        
        scenarios = [
            ("All modalities", torch.ones(self.batch_size, self.n_modalities)),
            ("Missing last modality", torch.tensor([[1,1,1,0], [1,1,1,1]])),
            ("Missing multiple", torch.tensor([[1,0,1,0], [1,1,0,1]])),
            ("Only one modality", torch.tensor([[1,0,0,0], [0,1,0,0]]))
        ]
        
        for scenario_name, mask in scenarios:
            print(f"\n--- {scenario_name} ---")
            print(f"Mask: {mask}")
            
            # 生成对应的测试数据
            tokens = torch.randn(self.batch_size, self.n_modalities, self.C_tok, device=self.device)
            x_t = torch.randn(self.batch_size, self.in_channels, self.img_height, self.img_width, device=self.device)
            t = torch.rand(self.batch_size, device=self.device) * 1000
            cond_mask = mask.to(self.device)
            
            try:
                with torch.no_grad():
                    cond_feats = self.fusion_transformer(tokens, cond_mask)
                    velocity = self.unet_flow(x_t, cond_feats, t, cond_mask)
                
                print(f"✅ {scenario_name}: Success")
                print(f"   Velocity stats - Mean: {velocity.mean().item():.4f}, Std: {velocity.std().item():.4f}")
                
            except Exception as e:
                print(f"❌ {scenario_name}: Failed with error: {e}")
                
    def test_gradient_flow(self):
        """测试梯度流"""
        print("\n🧪 Testing gradient flow...")
        
        # 启用梯度计算
        self.fusion_transformer.train()
        self.unet_flow.train()
        
        # 生成测试数据
        test_data = self.generate_test_data()
        
        # 需要梯度的输入
        tokens = test_data['tokens'].requires_grad_(True)
        x_t = test_data['x_t'].requires_grad_(True)
        
        # 前向传播
        cond_feats = self.fusion_transformer(tokens, test_data['cond_mask'])
        velocity = self.unet_flow(x_t, cond_feats, test_data['t'], test_data['cond_mask'])
        
        # 计算损失并反向传播
        loss = velocity.mean()
        loss.backward()
        
        # 检查梯度
        fusion_grads = [p.grad for p in self.fusion_transformer.parameters() if p.grad is not None]
        unet_grads = [p.grad for p in self.unet_flow.parameters() if p.grad is not None]
        
        print(f"✅ FusionTransformer gradients: {len(fusion_grads)} parameters have gradients")
        print(f"✅ UNetFlow gradients: {len(unet_grads)} parameters have gradients")
        
        if fusion_grads:
            fusion_grad_norm = torch.sqrt(sum(torch.sum(g**2) for g in fusion_grads))
            print(f"📊 FusionTransformer gradient norm: {fusion_grad_norm.item():.6f}")
            
        if unet_grads:
            unet_grad_norm = torch.sqrt(sum(torch.sum(g**2) for g in unet_grads))
            print(f"📊 UNetFlow gradient norm: {unet_grad_norm.item():.6f}")
            
        # 恢复eval模式
        self.fusion_transformer.eval()
        self.unet_flow.eval()
        
    def run_all_tests(self):
        """运行所有测试"""
        print("🚀 Starting comprehensive tests...")
        print(f"Device: {self.device}")
        print(f"PyTorch version: {torch.__version__}")
        
        try:
            # 1. 基础功能测试
            test_data = self.generate_test_data()
            cond_feats = self.test_fusion_transformer(test_data)
            velocity = self.test_unet_flow(test_data, cond_feats)
            
            # 2. 选择器模式测试
            self.test_different_selector_modes(test_data, cond_feats)
            
            # 3. 缺失模态测试
            self.test_missing_modalities()
            
            # 4. 梯度流测试
            self.test_gradient_flow()
            
            print("\n🎉 All tests completed successfully!")
            
        except Exception as e:
            print(f"\n❌ Test failed with error: {e}")
            import traceback
            traceback.print_exc()

def main():
    """主函数"""
    print("=" * 60)
    print("MRI Flow Matching - Local CPU Demo")
    print("=" * 60)
    
    # 设置随机种子
    torch.manual_seed(42)
    np.random.seed(42)
    
    # 创建并运行测试
    demo = MRIFlowMatchingDemo(device='cpu')
    demo.run_all_tests()

if __name__ == "__main__":
    main()