In [9]:
# ------ 1. 检查torchvision版本 ------
import torch
import torchvision

print("torch version:", torch.__version__)
print("torchvision version:", torchvision.__version__)

torch version: 2.6.0+cpu
torchvision version: 0.21.0+cpu


In [10]:
# ====== 2. 导入依赖 ======
import torch
from torch.utils.data import DataLoader
import pytorch_lightning as pl

from datasets import get_dataset
from models.autoencoder import PreTrained_AutoEncoder


In [11]:
import os
import numpy as np
from PIL import Image

def create_toy_mri_dataset(root_dir="./data/MRI", image_size=(256, 256)):
    """
    创建一个toy MRI数据集，包含4个患者，每个患者有T1和T2加权图像
    """
    
    splits = ["train", "val", "test"]
    
    for split in splits:
        split_dir = os.path.join(root_dir, split)
        os.makedirs(split_dir, exist_ok=True)
        
        # 为每个split创建4个患者
        for i in range(1, 5):
            patient_dir = os.path.join(split_dir, f"Patient{i:03d}")
            os.makedirs(patient_dir, exist_ok=True)
            
            # 生成模拟的MRI图像
            # T1: 通常白质亮，灰质暗，CSF暗
            # T2: 通常白质暗，灰质亮，CSF亮
            
            np.random.seed(42 + i)  # 为了可重现性
            
            # 创建基础解剖结构
            h, w = image_size
            center_x, center_y = w//2, h//2
            
            # 创建一个简单的脑部轮廓
            x, y = np.meshgrid(np.arange(w), np.arange(h))
            
            # 椭圆形脑部轮廓
            brain_mask = ((x - center_x)**2 / (w//3)**2 + 
                         (y - center_y)**2 / (h//3)**2) < 1
            
            # 内部结构（简化的脑室和灰白质）
            ventricle_mask = ((x - center_x)**2 / (w//8)**2 + 
                             (y - center_y)**2 / (h//8)**2) < 1
            
            # T1加权图像 (白质亮，灰质中等，脑室暗)
            t1_image = np.zeros((h, w), dtype=np.float32)
            t1_image[brain_mask] = 0.6 + 0.2 * np.random.randn(np.sum(brain_mask))  # 白质
            t1_image[ventricle_mask] = 0.2 + 0.1 * np.random.randn(np.sum(ventricle_mask))  # 脑室
            
            # 添加一些"灰质"区域
            gray_matter = brain_mask & ~ventricle_mask
            gray_indices = np.random.choice(np.sum(gray_matter), size=np.sum(gray_matter)//3, replace=False)
            gray_coords = np.where(gray_matter)
            selected_coords = (gray_coords[0][gray_indices], gray_coords[1][gray_indices])
            t1_image[selected_coords] = 0.4 + 0.1 * np.random.randn(len(gray_indices))
            
            # T2加权图像 (白质暗，灰质亮，脑室亮)
            t2_image = np.zeros((h, w), dtype=np.float32)
            t2_image[brain_mask] = 0.3 + 0.2 * np.random.randn(np.sum(brain_mask))  # 白质
            t2_image[ventricle_mask] = 0.8 + 0.1 * np.random.randn(np.sum(ventricle_mask))  # 脑室
            t2_image[selected_coords] = 0.7 + 0.1 * np.random.randn(len(gray_indices))  # 灰质
            
            # 确保值在[0,1]范围内
            t1_image = np.clip(t1_image, 0, 1)
            t2_image = np.clip(t2_image, 0, 1)
            
            # 转换为0-255的uint8并保存
            t1_uint8 = (t1_image * 255).astype(np.uint8)
            t2_uint8 = (t2_image * 255).astype(np.uint8)
            
            # 保存图像
            Image.fromarray(t1_uint8, mode='L').save(
                os.path.join(patient_dir, "T1.png")
            )
            Image.fromarray(t2_uint8, mode='L').save(
                os.path.join(patient_dir, "T2.png")
            )
            
            print(f"Created {split}/Patient{i:03d}: T1.png, T2.png")
    
    print(f"Toy dataset created at {root_dir}")
    print("Directory structure:")
    for root, dirs, files in os.walk(root_dir):
        level = root.replace(root_dir, '').count(os.sep)
        indent = ' ' * 2 * level
        print(f"{indent}{os.path.basename(root)}/")
        subindent = ' ' * 2 * (level + 1)
        for file in files:
            print(f"{subindent}{file}")

if __name__ == "__main__":
    create_toy_mri_dataset()

Created train/Patient001: T1.png, T2.png
Created train/Patient002: T1.png, T2.png
Created train/Patient003: T1.png, T2.png
Created train/Patient004: T1.png, T2.png
Created val/Patient001: T1.png, T2.png
Created val/Patient002: T1.png, T2.png
Created val/Patient003: T1.png, T2.png
Created val/Patient004: T1.png, T2.png
Created test/Patient001: T1.png, T2.png
Created test/Patient002: T1.png, T2.png
Created test/Patient003: T1.png, T2.png
Created test/Patient004: T1.png, T2.png
Toy dataset created at ./data/MRI
Directory structure:
MRI/
  test/
    Patient001/
      T1.png
      T2.png
      .ipynb_checkpoints/
        T1-checkpoint.png
        T2-checkpoint.png
    Patient002/
      T1.png
      T2.png
      .ipynb_checkpoints/
        T1-checkpoint.png
    Patient003/
      T1.png
      T2.png
    Patient004/
      T1.png
      T2.png
  train/
    Patient001/
      T1.png
      T2.png
    Patient002/
      T1.png
      T2.png
    Patient003/
      T1.png
      T2.png
    Patient004/
   

In [12]:
from datamodules.mri_datamodule import MRIDataModule

# ====== 3. 构建小数据集 ======
dataset = get_dataset(
    dataset_name="mri_contrastive",
    root="./data/MRI",     # 你的数据根目录
    split="train",
    modalities=("T1", "T2"),
    fixed_pair=True        # 强制配对 (T1, T2)
)

# 从MRIDataModule获取静态方法
collate_fn = MRIDataModule._collate_fn

# 注意：参数名是 collate_fn，不是 _collate_fn
loader = DataLoader(dataset, batch_size=2, shuffle=True, collate_fn=collate_fn)

# 检查一下
x_anchor, x_pos = next(iter(loader))
print("x_anchor shape:", x_anchor.shape)  # (B, 1, 256, 256)
print("x_pos shape:", x_pos.shape)        # (B, 1, 256, 256)


x_anchor shape: torch.Size([2, 1, 256, 256])
x_pos shape: torch.Size([2, 1, 256, 256])


In [13]:
# ====== 4. 初始化模型 ======
model = PreTrained_AutoEncoder(
    lr=1e-3,
    lambda_rec=1.0,   # 先只用重建损失
    lambda_nce=0.0,   # NCE 先关掉，避免复杂
    proj_dim=128,
    temperature=0.1,
)

In [14]:
from pytorch_lightning.callbacks import TQDMProgressBar
# ====== 4. Trainer 训练 ======
trainer = pl.Trainer(
    max_epochs=3,        # 小规模调试
    accelerator="cpu",   # 本地 CPU 就好
    devices=1,
    log_every_n_steps=1
)

trainer.fit(model, loader, loader)  # train/val 都用同一个 loader 先跑通

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

  | Name      | Type           | Params | Mode 
-----------------------------------------------------
0 | encoder   | Shared_Encoder | 388 K  | train
1 | decoder   | Shared_Decoder | 174 K  | train
2 | projector | Sequential     | 49.4 K | train
-----------------------------------------------------
613 K     Trainable params
0         Non-trainable params
613 K     Total params
2.453     Total estimated model params size (MB)
40        Modules in train mode
0         Modules in eval mode


Sanity Checking: |                                                                               | 0/? [00:00<…

Training: |                                                                                      | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

`Trainer.fit` stopped: `max_epochs=3` reached.


In [15]:
# ====== 5. Debug 验证输出 ======
# 随便取一条数据跑 forward
x_anchor, x_pos = dataset[0]
x_anchor = x_anchor.unsqueeze(0)  # (1,1,256,256)
with torch.no_grad():
    recon = model(x_anchor)
print("Input:", x_anchor.shape, "-> Recon:", recon.shape)

Input: torch.Size([1, 1, 256, 256]) -> Recon: torch.Size([1, 1, 256, 256])
