In [7]:
import numpy as np

ex_path = '/data2/songwei/Data/processed/Dataset010_IXI_LargeIA_ALL/nnUNetPlans_3d_fullres/IXI046-Guys-0824-MRA.npy'
ex_data = np.load(ex_path)
ex_data.shape

AttributeError: 'NpzFile' object has no attribute 'shape'

In [5]:
ex_label_path = '/data2/songwei/Data/processed/Dataset010_IXI_LargeIA_ALL/nnUNetPlans_3d_fullres/IXI046-Guys-0824-MRA_seg.npy'
ex_label = np.load(ex_label_path)
ex_label.shape

(1, 100, 512, 432)

In [6]:
ex_data.max(), ex_data.min(), ex_label.max(), ex_label.min()

(21.870064, -0.7824818, 5, -1)

In [None]:
import torch
import torch.nn as nn

class DoubleConv(nn.Module):
    """(Conv3D -> IN -> LeakyReLU) * 2"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv3d(in_channels, out_channels, 1, 1, 0),
            nn.InstanceNorm3d(out_channels),
            nn.LeakyReLU(inplace=True),
            nn.Conv3d(out_channels, out_channels, 3, 1, 1, groups=8),
            nn.InstanceNorm3d(out_channels),
            nn.LeakyReLU(inplace=True),
            nn.Conv3d(out_channels, out_channels, 1, 1, 0),
            nn.InstanceNorm3d(out_channels),
            nn.LeakyReLU(inplace=True),
        )
        self.residual = in_channels == out_channels

    def forward(self, x):
        return self.double_conv(x) + x if self.residual else self.double_conv(x)


class MyNet(nn.Module):
    def __init__(self, in_channels, n_classes, n_channels, depth=4, head_channels=16):
        super().__init__()
        self.in_channels = in_channels
        self.n_classes = n_classes
        self.n_channels = n_channels
        self.depth = depth # 添加网络深度参数
        self.head_channels = head_channels

        self.conv = DoubleConv(in_channels, n_channels)
        self.encoders = nn.ModuleList() # 使用 ModuleList 存储编码器层
        self.decoders = nn.ModuleList() # 使用 ModuleList 存储解码器层

        # 创建编码器层
        for i in range(self.depth):
            in_channels = n_channels * 2**i
            out_channels = n_channels * 2**(i+1)
            self.encoders.append(Down(in_channels, out_channels))

        # 创建解码器层
        for i in range(self.depth-1, -1, -1):
            in_channels = n_channels * 2**(i+2)
            out_channels = n_channels * 2**i
            self.decoders.append(Up(in_channels, out_channels))

        self.params_list = [
            n_channels * head_channels * 2**(self.depth),
            head_channels * n_classes,
            head_channels,
            n_classes,
        ]

        self.GAP = nn.AdaptiveAvgPool3d(1)
        self.controller = nn.Conv3d(n_channels * 2**self.depth + 2, sum(self.params_list), 1)

    def encoding_task(self, task_id):
        N = task_id.shape[0]
        task_encoding = torch.zeros(size=(N, 2))
        for i in range(N):
            task_encoding[i, task_id[i]] = 1
        return task_encoding.cuda()

    def forward(self, x, task_id):
        x_enc = [self.conv(x)] # 存储编码器输出

        # 编码过程
        for encoder in self.encoders:
            x_enc.append(encoder(x_enc[-1]))

        # 解码过程
        x_dec = x_enc[-1]
        for i, decoder in enumerate(self.decoders):
            x_dec = decoder(x_dec, x_enc[-(i+2)])

        task_embeddings = self.encoding_task(task_id)
        task_embeddings = task_embeddings.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
        x_feat = self.GAP(x_enc[-1])
        x_cond = torch.cat([x_feat, task_embeddings], dim=1)
        params = self.controller(x_cond)
        params = params.squeeze(-1).squeeze(-1).squeeze(-1)
        params_split = torch.split_with_sizes(params, self.params_list, dim=1)
        N, _, D, H, W = x_dec.shape
        head_feat = x_dec.view(1, -1, D, H, W)
        head_feat = F.leaky_relu(
            F.conv3d(
                head_feat,
                params_split[0].reshape(N * self.head_channels, -1, 1, 1, 1),
                bias=params_split[2].reshape(N * self.head_channels),
                stride=1,
                padding=0,
                groups=N,
            )
        )
        logits = F.conv3d(
            head_feat,
            params_split[1].reshape(self.n_classes * N, -1, 1, 1, 1),
            bias=params_split[3].reshape(self.n_classes * N),
            stride=1,
            padding=0,
            groups=N,
        )
        logits = logits.reshape(N, -1, D, H, W)
        return logits