In [2]:
import torch
import cv2
import numpy as np
import matplotlib.pyplot as plt
import random
import os
import math
import torch.nn as nn

In [3]:
# Path to the chest-ct-segmentation dataset folder
data_dir = "dataset\Lucchi++"
train_images_dir = os.path.join(data_dir, "Train_In")
train_masks_dir = os.path.join(data_dir, "Train_Out")
test_images_dir = os.path.join(data_dir, "Test_In")
test_masks_dir = os.path.join(data_dir, "Test_Out")
PATCH_SIZE = 128

i = 0
# Prepare the training data, Append image and corresponding mask paths
train_data = []
for image_file in os.listdir(train_images_dir):
    image_path = os.path.join(train_images_dir, image_file)
    mask_path = os.path.join(train_masks_dir, f"{i}.png")
    i += 1
    train_data.append(
    { 
        "image" : image_path, 
        "annotation" : mask_path
    })

i = 0
# Prepare the test data, Append image and corresponding mask paths
test_data = []
for image_file in os.listdir(test_images_dir):
    image_path = os.path.join(test_images_dir, image_file)
    mask_path = os.path.join(test_masks_dir, f"{i}.png")
    i += 1
    test_data.append(
    { 
        "image" : image_path, 
        "annotation" : mask_path
    })
print(train_data)

[{'image': 'dataset\\Lucchi++\\Train_In\\mask0000.png', 'annotation': 'dataset\\Lucchi++\\Train_Out\\0.png'}, {'image': 'dataset\\Lucchi++\\Train_In\\mask0001.png', 'annotation': 'dataset\\Lucchi++\\Train_Out\\1.png'}, {'image': 'dataset\\Lucchi++\\Train_In\\mask0002.png', 'annotation': 'dataset\\Lucchi++\\Train_Out\\2.png'}, {'image': 'dataset\\Lucchi++\\Train_In\\mask0003.png', 'annotation': 'dataset\\Lucchi++\\Train_Out\\3.png'}, {'image': 'dataset\\Lucchi++\\Train_In\\mask0004.png', 'annotation': 'dataset\\Lucchi++\\Train_Out\\4.png'}, {'image': 'dataset\\Lucchi++\\Train_In\\mask0005.png', 'annotation': 'dataset\\Lucchi++\\Train_Out\\5.png'}, {'image': 'dataset\\Lucchi++\\Train_In\\mask0006.png', 'annotation': 'dataset\\Lucchi++\\Train_Out\\6.png'}, {'image': 'dataset\\Lucchi++\\Train_In\\mask0007.png', 'annotation': 'dataset\\Lucchi++\\Train_Out\\7.png'}, {'image': 'dataset\\Lucchi++\\Train_In\\mask0008.png', 'annotation': 'dataset\\Lucchi++\\Train_Out\\8.png'}, {'image': 'dataset

In [4]:
class SegmentationModel(nn.Module):
    def __init__(self, num_classes):
        super(SegmentationModel, self).__init__()
        
        # Encoder (downsampling)
        self.encoder = nn.Sequential(
            self.conv_block(3, 32, stride=2),
            self.conv_block(32, 64, stride=2),
            self.conv_block(64, 128, stride=2),
            self.conv_block(128, 256, stride=2)
        )
        
        # Decoder (upsampling)
        self.decoder = nn.Sequential(
            self.upconv_block(256, 128),
            self.upconv_block(128, 64),
            self.upconv_block(64, 32),
            self.upconv_block(32, 32)
        )
        
        # Final classification layer
        self.final = nn.Conv2d(32, num_classes, kernel_size=3, padding=1)
        
    def conv_block(self, in_channels, out_channels, stride=1):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
    
    def upconv_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
    
    def forward(self, x):
        # Encoder
        features = []
        for encoder_layer in self.encoder:
            x = encoder_layer(x)
            features.append(x)
        
        # Decoder
        for i, decoder_layer in enumerate(self.decoder):
            x = decoder_layer(x)
            if i < len(self.decoder) - 1:
                x = x + features[-i-2]  # Skip connection
        
        # Final classification
        x = self.final(x)
        return x

In [5]:
# Custom Dataset class
class SegmentationDataset(torch.utils.data.Dataset):
    def __init__(self, data_list, transform=None):
        self.data_list = data_list
        self.transform = transform

    def __len__(self):
        return len(self.data_list)

    def __getitem__(self, idx):
        # Load image and mask
        image = cv2.imread(self.data_list[idx]["image"])
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        mask = cv2.imread(self.data_list[idx]["annotation"], cv2.IMREAD_GRAYSCALE)

        # Resize image and mask to 1024x768
        image = cv2.resize(image, (1024, 768))
        mask = cv2.resize(mask, (1024, 768))

        # Split image and mask into 64x64 patches
        patch_size = PATCH_SIZE
        image_patches = []
        mask_patches = []

        for i in range(0, 1024, patch_size):
            for j in range(0, 768, patch_size):
                image_patch = image[i:i+patch_size, j:j+patch_size]
                mask_patch = mask[i:i+patch_size, j:j+patch_size]
                if image_patch.size > 0 and mask_patch.size > 0:
                    # Transpose image patch to (C, H, W) format
                    image_patch = image_patch.transpose(2, 0, 1)
                    image_patches.append(image_patch)
                    mask_patches.append(mask_patch)

        # Convert lists to numpy arrays and normalize
        image_patches = np.array(image_patches, dtype=np.float32) / 255.0
        mask_patches = np.array(mask_patches, dtype=np.float32) / 255.0

        # Convert to tensors
        image_patches = torch.from_numpy(image_patches)
        mask_patches = torch.from_numpy(mask_patches).unsqueeze(1)

        return image_patches, mask_patches

def calculate_metrics(pred_mask, true_mask, threshold=0.5):
    # Convert predictions to binary
    pred_mask = (pred_mask > threshold).float()

    # Calculate intersection and union
    intersection = (pred_mask * true_mask).sum()
    union = pred_mask.sum() + true_mask.sum() - intersection

    # Calculate IoU
    iou = (intersection + 1e-7) / (union + 1e-7)

    # Calculate Dice coefficient
    dice = (2. * intersection + 1e-7) / (pred_mask.sum() + true_mask.sum() + 1e-7)

    return iou.item(), dice.item()

# Training function
def train_model(model, train_loader, val_loader, num_epochs=50, device="cuda"):
    criterion = nn.BCEWithLogitsLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

    best_val_loss = float('inf')
    best_val_iou = 0.0

    for epoch in range(num_epochs):
        # Training phase
        model.train()
        train_loss = 0
        train_iou = 0
        train_dice = 0
        num_batches = 0

        for images, masks in train_loader:
            images = images.to(device)
            masks = masks.to(device)

            images = images.view(-1, 3, PATCH_SIZE, PATCH_SIZE)
            masks = masks.view(-1, 1, PATCH_SIZE, PATCH_SIZE)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, masks)
            loss.backward()
            optimizer.step()

            # Calculate metrics
            batch_iou, batch_dice = calculate_metrics(torch.sigmoid(outputs), masks)

            train_loss += loss.item()
            train_iou += batch_iou
            train_dice += batch_dice
            num_batches += 1

        avg_train_loss = train_loss / num_batches
        avg_train_iou = train_iou / num_batches
        avg_train_dice = train_dice / num_batches

        # Validation phase
        model.eval()
        val_loss = 0
        val_iou = 0
        val_dice = 0
        num_val_batches = 0

        with torch.no_grad():
            for images, masks in val_loader:
                images = images.to(device)
                masks = masks.to(device)

                images = images.view(-1, 3, PATCH_SIZE, PATCH_SIZE)
                masks = masks.view(-1, 1, PATCH_SIZE, PATCH_SIZE)

                outputs = model(images)
                loss = criterion(outputs, masks)

                # Calculate metrics
                batch_iou, batch_dice = calculate_metrics(torch.sigmoid(outputs), masks)

                val_loss += loss.item()
                val_iou += batch_iou
                val_dice += batch_dice
                num_val_batches += 1

        avg_val_loss = val_loss / num_val_batches
        avg_val_iou = val_iou / num_val_batches
        avg_val_dice = val_dice / num_val_batches

        print(f'Epoch {epoch+1}/{num_epochs}:')
        print(f'Training - Loss: {avg_train_loss:.4f}, IoU: {avg_train_iou:.4f}, Dice: {avg_train_dice:.4f}')
        print(f'Validation - Loss: {avg_val_loss:.4f}, IoU: {avg_val_iou:.4f}, Dice: {avg_val_dice:.4f}')

        # Save best model based on IoU
        if avg_val_iou > best_val_iou:
            best_val_iou = avg_val_iou
            torch.save(model.state_dict(), 'best_model_iou.pth')

        # Also save based on loss if needed
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save(model.state_dict(), 'best_model_loss.pth')

        print('-' * 60)

# Inference function
def predict(model, image_path, device="cuda"):
    model.eval()

    # Load and preprocess image
    image = cv2.imread(image_path)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    image = cv2.resize(image, (1024, 768))

    # Split image into 64x64 patches
    patch_size = PATCH_SIZE
    patches_list = []

    for i in range(0, 768, patch_size):  # 修改遍历顺序
        for j in range(0, 1024, patch_size):
            # 提取patch
            patch = image[i:i+patch_size, j:j+patch_size]

            # 确保patch大小一致
            if patch.shape[0] != patch_size or patch.shape[1] != patch_size:
                patch = cv2.resize(patch, (patch_size, patch_size))

            # 标准化并转换通道顺序
            patch = patch / 255.0
            patch = patch.transpose(2, 0, 1)  # (H,W,C) -> (C,H,W)
            patches_list.append(patch)

    # 转换为numpy数组，确保形状正确
    patches_array = np.stack(patches_list)  # (N, C, H, W)

    # 转换为tensor并移至设备
    patches_tensor = torch.from_numpy(patches_array).float().to(device)

    # 运行推理
    with torch.no_grad():
        outputs = model(patches_tensor)
        pred_masks = torch.sigmoid(outputs) > 0.5

    # 重建完整mask
    full_mask = np.zeros((768, 1024))
    patch_idx = 0

    for i in range(0, 768, patch_size):
        for j in range(0, 1024, patch_size):
            mask_patch = pred_masks[patch_idx, 0].cpu().numpy()

            # 处理边界情况
            h = min(patch_size, 768-i)
            w = min(patch_size, 1024-j)

            if mask_patch.shape != (h, w):
                mask_patch = cv2.resize(mask_patch, (w, h))

            full_mask[i:i+h, j:j+w] = mask_patch
            patch_idx += 1

    return full_mask

In [6]:
# Create datasets
train_dataset = SegmentationDataset(train_data)
test_dataset = SegmentationDataset(test_data)

# Create data loaders
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=4, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=4)

# Initialize model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# model = SegmentationModel(num_classes=1).to(device)

# # Train model
# train_model(model, train_loader, test_loader, num_epochs=50, device=device)


In [7]:
# https://huggingface.co/CompVis/ldm-super-resolution-4x-openimages1
# Example inference
test_image_path = test_data[0]["image"]
# pred_mask = predict(model, test_image_path)

# # Visualize results
# plt.figure(figsize=(12, 4))
# plt.subplot(131)
# plt.imshow(cv2.imread(test_image_path))
# plt.title('Original Image')
# plt.subplot(132)
# plt.imshow(cv2.imread(test_data[0]["annotation"], cv2.IMREAD_GRAYSCALE))
# plt.title('Ground Truth')
# plt.subplot(133)
# plt.imshow(pred_mask)
# plt.title('Prediction')
# plt.show()



In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class TimestepEmbedding(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(in_channels, out_channels),
            nn.SiLU(),
            nn.Linear(out_channels, out_channels)
        )

    def forward(self, x):
        # 确保输入是浮点类型
        x = x.float()
        return self.layers(x)

class ResBlock(nn.Module):
    def __init__(self, in_channels, out_channels, time_channels, dropout=0.0):
        super().__init__()
        # 确保group数量能够整除通道数
        groups = min(32, in_channels) # 动态调整groups数量

        self.in_layers = nn.Sequential(
            nn.GroupNorm(groups, in_channels),
            nn.SiLU(),
            nn.Conv2d(in_channels, out_channels, 3, padding=1)
        )

        self.emb_layers = nn.Sequential(
            nn.SiLU(),
            nn.Linear(time_channels, out_channels)
        )

        self.out_layers = nn.Sequential(
            nn.GroupNorm(min(32, out_channels), out_channels), # 这里也要调整
            nn.SiLU(),
            nn.Dropout(p=dropout),
            nn.Conv2d(out_channels, out_channels, 3, padding=1)
        )

        if in_channels != out_channels:
            self.skip_connection = nn.Conv2d(in_channels, out_channels, 1)
        else:
            self.skip_connection = nn.Identity()

    def forward(self, x, emb):
        h = self.in_layers(x)
        emb_out = self.emb_layers(emb)[:, :, None, None]
        h = h + emb_out
        h = self.out_layers(h)
        return h + self.skip_connection(x)

class MemoryEfficientCrossAttention(nn.Module):
    def __init__(self, dim, context_dim=None, heads=8, dropout=0.0):
        super().__init__()
        inner_dim = dim
        context_dim = context_dim if context_dim is not None else dim

        self.heads = heads
        self.scale = (dim // heads) ** -0.5

        self.to_q = nn.Linear(dim, inner_dim, bias=False)
        self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
        self.to_v = nn.Linear(context_dim, inner_dim, bias=False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        )

    def forward(self, x, context=None):
        q = self.to_q(x)
        context = context if context is not None else x
        k = self.to_k(context)
        v = self.to_v(context)

        return self.to_out(F.scaled_dot_product_attention(q, k, v))

class BasicTransformerBlock(nn.Module):
    def __init__(self, dim, context_dim=None):
        super().__init__()
        self.attn1 = MemoryEfficientCrossAttention(dim)
        self.ff = nn.Sequential(
            nn.Linear(dim, dim * 8),
            nn.GELU(),
            nn.Linear(dim * 8, dim)
        )
        self.attn2 = MemoryEfficientCrossAttention(dim, context_dim)
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)
        self.norm3 = nn.LayerNorm(dim)

    def forward(self, x, context=None):
        x = self.attn1(self.norm1(x)) + x
        x = self.attn2(self.norm2(x), context) + x
        x = self.ff(self.norm3(x)) + x
        return x

class SpatialTransformer(nn.Module):
    def __init__(self, in_channels, n_heads, d_head, context_dim=None):
        super().__init__()
        self.norm = nn.GroupNorm(32, in_channels)
        inner_dim = n_heads * d_head
        self.proj_in = nn.Linear(in_channels, inner_dim)
        self.transformer_blocks = nn.ModuleList([
            BasicTransformerBlock(inner_dim, context_dim)
            for _ in range(1)
        ])
        self.proj_out = nn.Linear(inner_dim, in_channels)

    def forward(self, x, context=None):
        b, c, h, w = x.shape
        x_in = x
        x = self.norm(x)
        x = x.permute(0, 2, 3, 1).reshape(b, h*w, c)
        x = self.proj_in(x)

        for block in self.transformer_blocks:
            x = block(x, context)

        x = self.proj_out(x)
        x = x.reshape(b, h, w, c).permute(0, 3, 1, 2)
        return x + x_in

class UNetModel(nn.Module):
    def __init__(
        self,
        in_channels=4,
        model_channels=320,
        out_channels=4,
        time_embed_dim=1280,
        context_dim=2048
    ):
        super().__init__()

        # Time embedding
        self.time_embed = TimestepEmbedding(model_channels, time_embed_dim)

        # Input blocks
        self.input_blocks = nn.ModuleList([
            nn.Conv2d(in_channels, model_channels, 3, padding=1)
        ])

        # Add remaining input blocks
        current_channels = model_channels
        channel_multipliers = [1, 2, 4]  # 控制通道数增长

        for i in range(len(channel_multipliers)):
            out_channels = model_channels * channel_multipliers[i]
            for _ in range(2):
                layers = [ResBlock(current_channels, out_channels, time_embed_dim)]
                if i % 3 == 2:
                    layers.append(nn.Conv2d(out_channels, out_channels, 3, stride=2, padding=1))
                if i > 0:
                    layers.append(SpatialTransformer(out_channels, 8, 64, context_dim))
                self.input_blocks.append(nn.Sequential(*layers))
                current_channels = out_channels

        # Middle block
        self.middle_block = nn.Sequential(
            ResBlock(current_channels, current_channels, time_embed_dim),
            SpatialTransformer(current_channels, 8, 64, context_dim),
            ResBlock(current_channels, current_channels, time_embed_dim)
        )

        # Output blocks
        self.output_blocks = nn.ModuleList([])
        for i in range(len(channel_multipliers)):
            out_channels = model_channels * channel_multipliers[-(i+1)]
            for _ in range(3):
                layers = [ResBlock(current_channels, out_channels, time_embed_dim)]
                if i > 0:
                    layers.append(SpatialTransformer(out_channels, 8, 64, context_dim))
                if i < len(channel_multipliers)-1:
                    layers.append(nn.ConvTranspose2d(out_channels, out_channels, 4, 2, 1))
                self.output_blocks.append(nn.Sequential(*layers))
                current_channels = out_channels

        # Final output
        self.out = nn.Sequential(
            nn.GroupNorm(min(32, model_channels), model_channels),
            nn.SiLU(),
            nn.Conv2d(model_channels, out_channels, 3, padding=1)
        )

    def forward(self, x, timesteps, context):
        # 确保timesteps是正确的形状和类型
        if len(timesteps.shape) == 1:  # (batch_size,)
            timesteps = timesteps.unsqueeze(1)  # (batch_size, 1)
        timesteps = timesteps.float()

        # Time embedding
        temb = self.time_embed(timesteps)

        # Input blocks
        hs = []
        h = x
        for module in self.input_blocks:
            h = module(h) if not isinstance(module[0], ResBlock) else module[0](h, temb)
            hs.append(h)

        # Middle block
        h = self.middle_block[0](h, temb)
        h = self.middle_block[1](h)
        h = self.middle_block[2](h, temb)

        # Output blocks
        for module in self.output_blocks:
            h = torch.cat([h, hs.pop()], dim=1)
            h = module[0](h, temb)
            if len(module) > 1:
                if isinstance(module[1], SpatialTransformer):
                    h = module[1](h)
                    if len(module) > 2:
                        h = module[2](h)
                else:
                    h = module[1](h)

        return self.out(h)



In [14]:
class PartialTrainUNet(UNetModel):
    def __init__(self):
        super().__init__(in_channels=3, out_channels=1)  # 假设输入是RGB图像,输出是单通道mask

        # 冻结前半部分参数
        for param in self.input_blocks.parameters():
            param.requires_grad = False
        for param in self.middle_block.parameters():
            param.requires_grad = False

    def forward(self, x):
        # Time embedding (我们不使用,所以传入零张量)
        temb = torch.zeros(x.shape[0], 1280).to(x.device)  # 使用固定的 time_embed_dim

        # Input blocks
        hs = []
        h = x
        for module in self.input_blocks:
            if isinstance(module, nn.Sequential):
                if isinstance(module[0], ResBlock):
                    h = module[0](h, temb)
                    for layer in module[1:]:
                        h = layer(h)
                else:
                    h = module(h)
            else:
                h = module(h)
            hs.append(h)

        # Middle block
        if isinstance(self.middle_block, nn.Sequential):
            for layer in self.middle_block:
                if isinstance(layer, ResBlock):
                    h = layer(h, temb)
                else:
                    h = layer(h)
        else:
            h = self.middle_block(h)

        # Output blocks (这部分参数是可训练的)
        for module in self.output_blocks:
            h = torch.cat([h, hs.pop()], dim=1)
            if isinstance(module, nn.Sequential):
                if isinstance(module[0], ResBlock):
                    h = module[0](h, temb)
                    for layer in module[1:]:
                        if isinstance(layer, SpatialTransformer):
                            h = layer(h)
                        else:
                            h = layer(h)
                else:
                    h = module(h)
            else:
                h = module(h)

        return self.out(h)


In [16]:
import torch
from safetensors import safe_open

# 创建模型实例
model = UNetModel(in_channels=3, out_channels=1)  # 根据需要修改输入和输出通道

# 加载权重
model_path = "output_sdxl_unet.safetensors"

with safe_open(model_path, framework="pt") as f:
    # 这里我们假设文件中保存的状态字典为模型的状态字典
    state_dict = f.get_tensors()
    model.load_state_dict(state_dict)

# 将模型设置为评估模式
model.eval()

# 如果需要，移动到 GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)


AttributeError: 'builtins.safe_open' object has no attribute 'get_tensors'

In [15]:
# Initialize model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = PartialTrainUNet().to(device)

# Train model
train_model(model, train_loader, test_loader, num_epochs=50, device=device)


KeyboardInterrupt: 