# BAD


In [1]:
import os
import pandas as pd
import numpy as np
import SimpleITK as sitk
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torch.nn.functional as F

# Step 1: Load DICOM files

def load_dicom_series(folder_path):
    """
    Load a DICOM series from a folder path and return it as a 3D numpy array.
    """
    reader = sitk.ImageSeriesReader()
    dicom_files = reader.GetGDCMSeriesFileNames(folder_path)
    if len(dicom_files) == 0:
        raise ValueError(f"No DICOM files found in the directory: {folder_path}")
    reader.SetFileNames(dicom_files)
    image = reader.Execute()  # SimpleITK image
    array = sitk.GetArrayFromImage(image)  # Convert to numpy array
    return array

# Step 2: Dataset definition

class MultiModalDataset(Dataset):
    def __init__(self, data_dir, labels_csv, transform=None):
        """
        Args:
            data_dir (str): Path to the root directory containing patient folders.
            labels_csv (str): Path to the CSV file containing patient IDs and labels.
            transform (callable, optional): Optional transform to be applied on a sample.
        """
        self.data_dir = data_dir
        self.labels = pd.read_csv(labels_csv)
        self.transform = transform
        self.samples = self._generate_samples()

    def _generate_samples(self):
        """
        Generate a list of samples considering multiple scans per patient.
        Each sample corresponds to a specific scan session.
        """
        samples = []
        for _, row in self.labels.iterrows():
            patient_id = row["PatientID"]
            label = row["Label"]
            patient_folder = os.path.join(self.data_dir, patient_id)
            if os.path.exists(patient_folder):
                for scan_session in os.listdir(patient_folder):
                    session_path = os.path.join(patient_folder, scan_session)
                    if os.path.isdir(session_path):
                        samples.append({"session_path": session_path, "label": label})
        return samples

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        sample = self.samples[idx]
        session_path = sample["session_path"]
        label = sample["label"]

        # Load multimodal data
        try:
            ct_path = os.path.join(session_path, 'CT/CTAC')
            pet_path = os.path.join(session_path, 'PT/PET')
            mr_path = os.path.join(session_path, 'MR')

            ct_data = load_dicom_series(ct_path) if os.path.exists(ct_path) else np.zeros((64, 64, 64))
            pet_data = load_dicom_series(pet_path) if os.path.exists(pet_path) else np.zeros((64, 64, 64))
            mr_data = None
            if os.path.exists(mr_path):
                mr_data = [load_dicom_series(os.path.join(mr_path, subfolder))
                           for subfolder in os.listdir(mr_path)
                           if os.path.isdir(os.path.join(mr_path, subfolder))]

            # Preprocess data (e.g., resize, normalize)
            ct_data = self.preprocess(ct_data)
            pet_data = self.preprocess(pet_data)
            mr_data = [self.preprocess(mr) for mr in mr_data] if mr_data else [np.zeros((64, 64, 64))]

            # Combine modalities (e.g., stack along a new dimension)
            combined_data = np.stack([ct_data, pet_data] + mr_data[:1], axis=0)  # Shape: (modalities, depth, height, width)

            if self.transform:
                combined_data = self.transform(torch.tensor(combined_data, dtype=torch.float32))

            return combined_data, torch.tensor(label, dtype=torch.long)

        except ValueError as e:
            print(f"Error loading data for session: {session_path} - {e}")
            return None, None

    @staticmethod
    def preprocess(image):
        """
        Preprocess the image: normalize and resize to (64, 64, 64).
        """
        if image is None or np.all(image == 0):
            return np.zeros((64, 64, 64))
        image = (image - np.min(image)) / (np.max(image) - np.min(image) + 1e-8)  # Normalize
        from skimage.transform import resize
        return resize(image, (64, 64, 64), mode='constant', anti_aliasing=True)

# Step 3: Define the multimodal model

class MultiModalNet(nn.Module):
    def __init__(self):
        super(MultiModalNet, self).__init__()
        self.conv1 = nn.Conv3d(3, 16, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv3d(16, 32, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool3d(2, 2)
        self.fc1 = nn.Linear(32 * 16 * 16 * 16, 128)  # Updated to match flattened size
        self.fc2 = nn.Linear(128, 2)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(x.size(0), -1)  # Flatten dynamically based on batch size
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Step 4: Training loop

def train_model(model, dataloader, criterion, optimizer, num_epochs=20, save_path="best_model.pth"):
    best_loss = float('inf')
    for epoch in range(num_epochs):
        model.train()
        epoch_loss = 0.0
        for inputs, labels in dataloader:
            if inputs is None or labels is None:
                continue

            inputs, labels = inputs.to('cuda'), labels.to('cuda')

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()

        avg_loss = epoch_loss / len(dataloader)
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {avg_loss}")

        # Save the model if it has the best loss so far
        if avg_loss < best_loss:
            best_loss = avg_loss
            torch.save(model.state_dict(), save_path)
            print(f"Model saved with loss {best_loss}")

# Step 5: Main script

data_dir = "../datasets/PyDownloader/QIN-BREAST/"
labels_csv = "./QIN_labels.csv"
transform = transforms.Compose([
    transforms.Normalize(mean=[0.5], std=[0.5])
])

dataset = MultiModalDataset(data_dir, labels_csv, transform=transform)
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)

model = MultiModalNet().to('cuda')
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

train_model(model, dataloader, criterion, optimizer, num_epochs=20, save_path="best_model.pth")


KeyboardInterrupt: 

5ch data training

fill 0 when channel doesn't exist

CT
PET
MR_DWI
MR_T1
MR_dynamic

In [4]:
import os
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset
import SimpleITK as sitk
from skimage.transform import resize

def load_dicom_series(folder_path):
    """
    读取DICOM序列并返回 (D, H, W) 的 numpy array, 若失败则返回None
    """
    if not folder_path or not os.path.isdir(folder_path):
        return None

    reader = sitk.ImageSeriesReader()
    dicom_files = reader.GetGDCMSeriesFileNames(folder_path)
    if len(dicom_files) == 0:
        return None

    reader.SetFileNames(dicom_files)
    image = reader.Execute()  # SimpleITK image
    array = sitk.GetArrayFromImage(image)  # (D, H, W)
    return array

def normalize_and_resize(image_np, output_shape=(64, 64, 64)):
    """
    将 (D, H, W) 归一化到[0,1], 并 resize 到 output_shape.
    若 image_np=None 或全0, 返回全0占位.
    """
    if image_np is None or np.all(image_np == 0):
        return np.zeros(output_shape, dtype=np.float32)

    min_val, max_val = np.min(image_np), np.max(image_np)
    if max_val - min_val < 1e-8:
        image_np = np.zeros_like(image_np, dtype=np.float32)
    else:
        image_np = (image_np - min_val) / (max_val - min_val + 1e-8)

    image_resized = resize(image_np, output_shape, mode='constant', anti_aliasing=True)
    return image_resized.astype(np.float32)

class MultiModal3DDataset(Dataset):
    """
    5 通道: [CT, PET, MR_DWI, MR_T1, MR_DYNAMIC]
    """
    def __init__(self, csv_path, transform=None, output_shape=(64,64,64)):
        self.df = pd.read_csv(csv_path)
        self.transform = transform
        self.output_shape = output_shape

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        label = row["label"]

        # 1) CT
        ct_path = row["ct_path"] if isinstance(row["ct_path"], str) and row["ct_path"] else None
        ct_data = load_dicom_series(ct_path)
        ct_data = normalize_and_resize(ct_data, self.output_shape)

        # 2) PET
        pet_path = row["pet_path"] if isinstance(row["pet_path"], str) and row["pet_path"] else None
        pet_data = load_dicom_series(pet_path)
        pet_data = normalize_and_resize(pet_data, self.output_shape)

        # 3) MR_DWI
        mr_dwi_str = row["mr_dwi"] if isinstance(row["mr_dwi"], str) else ""
        mr_dwi_data = np.zeros(self.output_shape, dtype=np.float32)
        if mr_dwi_str:
            dwi_list = mr_dwi_str.split(";")
            if len(dwi_list) > 0:  # 这里只示范取第一个文件夹
                dwi_arr = load_dicom_series(dwi_list[0])
                mr_dwi_data = normalize_and_resize(dwi_arr, self.output_shape)

        # 4) MR_T1
        mr_t1_str = row["mr_t1"] if isinstance(row["mr_t1"], str) else ""
        mr_t1_data = np.zeros(self.output_shape, dtype=np.float32)
        if mr_t1_str:
            t1_list = mr_t1_str.split(";")
            if len(t1_list) > 0:
                t1_arr = load_dicom_series(t1_list[0])
                mr_t1_data = normalize_and_resize(t1_arr, self.output_shape)

        # 5) MR_dynamic
        mr_dynamic_str = row["mr_dynamic"] if isinstance(row["mr_dynamic"], str) else ""
        mr_dynamic_data = np.zeros(self.output_shape, dtype=np.float32)
        if mr_dynamic_str:
            dyn_list = mr_dynamic_str.split(";")
            if len(dyn_list) > 0:
                dyn_arr = load_dicom_series(dyn_list[0])
                mr_dynamic_data = normalize_and_resize(dyn_arr, self.output_shape)

        # 拼接 5 通道: (C=5, D, H, W)
        combined = np.stack([ct_data, pet_data,
                             mr_dwi_data, mr_t1_data, mr_dynamic_data], axis=0)

        combined_tensor = torch.tensor(combined, dtype=torch.float32)
        label_tensor = torch.tensor(label, dtype=torch.long)

        if self.transform:
            combined_tensor = self.transform(combined_tensor)

        return combined_tensor, label_tensor


In [7]:
import os
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset
import SimpleITK as sitk
from skimage.transform import resize

def load_dicom_series(folder_path):
    """
    读取DICOM序列并返回 (D, H, W) 的 numpy array, 若失败则返回None
    """
    if not folder_path or not os.path.isdir(folder_path):
        return None

    reader = sitk.ImageSeriesReader()
    dicom_files = reader.GetGDCMSeriesFileNames(folder_path)
    if len(dicom_files) == 0:
        return None

    reader.SetFileNames(dicom_files)
    image = reader.Execute()  # SimpleITK image
    array = sitk.GetArrayFromImage(image)  # (D, H, W)
    return array

def normalize_and_resize(image_np, output_shape=(64, 64, 64)):
    """
    将 (D, H, W) 归一化到[0,1], 并 resize 到 output_shape.
    若 image_np=None 或全0, 返回全0占位.
    """
    if image_np is None or np.all(image_np == 0):
        return np.zeros(output_shape, dtype=np.float32)

    min_val, max_val = np.min(image_np), np.max(image_np)
    if max_val - min_val < 1e-8:
        image_np = np.zeros_like(image_np, dtype=np.float32)
    else:
        image_np = (image_np - min_val) / (max_val - min_val + 1e-8)

    image_resized = resize(image_np, output_shape, mode='constant', anti_aliasing=True)
    return image_resized.astype(np.float32)

class MultiModal3DDataset(Dataset):
    """
    5 通道: [CT, PET, MR_DWI, MR_T1, MR_DYNAMIC]
    """
    def __init__(self, csv_path, transform=None, output_shape=(64,64,64)):
        self.df = pd.read_csv(csv_path)
        self.transform = transform
        self.output_shape = output_shape

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        label = row["label"]

        # 1) CT
        ct_path = row["ct_path"] if isinstance(row["ct_path"], str) and row["ct_path"] else None
        ct_data = load_dicom_series(ct_path)
        ct_data = normalize_and_resize(ct_data, self.output_shape)

        # 2) PET
        pet_path = row["pet_path"] if isinstance(row["pet_path"], str) and row["pet_path"] else None
        pet_data = load_dicom_series(pet_path)
        pet_data = normalize_and_resize(pet_data, self.output_shape)

        # 3) MR_DWI
        mr_dwi_str = row["mr_dwi"] if isinstance(row["mr_dwi"], str) else ""
        mr_dwi_data = np.zeros(self.output_shape, dtype=np.float32)
        if mr_dwi_str:
            dwi_list = mr_dwi_str.split(";")
            if len(dwi_list) > 0:  # 这里只示范取第一个文件夹
                dwi_arr = load_dicom_series(dwi_list[0])
                mr_dwi_data = normalize_and_resize(dwi_arr, self.output_shape)

        # 4) MR_T1
        mr_t1_str = row["mr_t1"] if isinstance(row["mr_t1"], str) else ""
        mr_t1_data = np.zeros(self.output_shape, dtype=np.float32)
        if mr_t1_str:
            t1_list = mr_t1_str.split(";")
            if len(t1_list) > 0:
                t1_arr = load_dicom_series(t1_list[0])
                mr_t1_data = normalize_and_resize(t1_arr, self.output_shape)

        # 5) MR_dynamic
        mr_dynamic_str = row["mr_dynamic"] if isinstance(row["mr_dynamic"], str) else ""
        mr_dynamic_data = np.zeros(self.output_shape, dtype=np.float32)
        if mr_dynamic_str:
            dyn_list = mr_dynamic_str.split(";")
            if len(dyn_list) > 0:
                dyn_arr = load_dicom_series(dyn_list[0])
                mr_dynamic_data = normalize_and_resize(dyn_arr, self.output_shape)

        # 拼接 5 通道: (C=5, D, H, W)
        combined = np.stack([ct_data, pet_data,
                             mr_dwi_data, mr_t1_data, mr_dynamic_data], axis=0)

        combined_tensor = torch.tensor(combined, dtype=torch.float32)
        label_tensor = torch.tensor(label, dtype=torch.long)

        if self.transform:
            combined_tensor = self.transform(combined_tensor)

        return combined_tensor, label_tensor


NameError: name 'Simple3DCNN_5ch' is not defined