In [3]:
import torch
import torch.nn as nn
import torchvision.transforms as T
import torch.nn.functional as F
import os
# 设置 TORCH_HOME
os.environ['TORCH_HOME'] = '/public_bme2/bme-dgshen/ZhaoyuQiu/.cache'


# 空洞空间金字塔池化模块 (ASPP)
class ASPP(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ASPP, self).__init__()
        self.atrous_block1 = nn.Conv2d(in_channels, out_channels, kernel_size=1)
        self.atrous_block6 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=6, dilation=6)
        self.atrous_block12 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=12, dilation=12)
        self.atrous_block18 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=18, dilation=18)
        self.global_avg_pool = nn.Sequential(
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Conv2d(in_channels, out_channels, kernel_size=1),
        )
        self.final_conv = nn.Conv2d(out_channels * 5, out_channels, kernel_size=1)

    def forward(self, x):
        size = x.shape[-2:]
        x1 = self.atrous_block1(x)
        x2 = self.atrous_block6(x)
        x3 = self.atrous_block12(x)
        x4 = self.atrous_block18(x)
        x5 = F.interpolate(self.global_avg_pool(x), size=size, mode='bilinear', align_corners=False)
        x = torch.cat([x1, x2, x3, x4, x5], dim=1)
        return self.final_conv(x)

# 改进后的分割模型
class ImprovedDinoSegmentationModel(nn.Module):
    def __init__(self, backbone, num_classes=4, patch_size=14, feat_dim=384):
        """
        基于 DINOv2 的改进分割模型
        """
        super(ImprovedDinoSegmentationModel, self).__init__()
        self.backbone = backbone
        self.num_classes = num_classes
        self.patch_size = patch_size
        self.feat_dim = feat_dim

        # 改进后的解码器
        self.decoder = nn.Sequential(
            nn.Conv2d(feat_dim, 256, kernel_size=3, padding=1),
            nn.ReLU(),
            ASPP(256, 128),  # ASPP 模块
            nn.Conv2d(128, num_classes, kernel_size=1)  # 最终输出类别数
        )

    def preprocess(self, x):
        """
        输入数据预处理
        """
        transform = T.Compose([
            T.Resize((self.patch_size * 32, self.patch_size * 32)),
            T.Normalize(mean=(0.5,), std=(0.5,)),  # 医学图像通常使用单通道归一化
        ])
        x = torch.cat([x] * 3, dim=1)  # 扩展为 3 通道
        return transform(x)

    def forward(self, x):
        """
        前向传播
        """
        print(x.shape)
        # 数据预处理
        x = self.preprocess(x)
        print(x.shape)

        # 提取特征
        features_dict = self.backbone.forward_features(x)
        features = features_dict['x_norm_patchtokens']

        # 重塑特征
        batch_size, num_patches, _ = features.shape
        patch_h = patch_w = int(num_patches ** 0.5)
        features = features.transpose(1, 2).reshape(batch_size, self.feat_dim, patch_h, patch_w)

        # 解码器
        output = self.decoder(features)
        output = F.interpolate(output, size=(512, 512), mode='bilinear', align_corners=False)
        return output

# 加载预训练的 DINOv2 模型
backbone = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14')
backbone.eval()

# 实例化分割模型
model = ImprovedDinoSegmentationModel(backbone=backbone, num_classes=4, patch_size=14, feat_dim=384)
model.eval()

# 创建测试数据
batch_size = 2
test_data = torch.randn(batch_size, 1, 512, 512)

# 推理
with torch.no_grad():
    output = model(test_data)

# 检查输出
print(f"Output shape: {output.shape}")  # 应为 (batch_size, num_classes, 512, 512)



KeyboardInterrupt: 