In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class STpconv3D(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=(3,3,3), stride=(1,1,1),
                 use_bias=True, activation=None):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size,)*3
        self.stride = stride if isinstance(stride, tuple) else (stride,)*3
        self.use_bias = use_bias

        # Image kernel
        self.weight = nn.Parameter(torch.empty(out_channels, in_channels, *self.kernel_size))
        nn.init.xavier_uniform_(self.weight)

        if use_bias:
            self.bias = nn.Parameter(torch.zeros(out_channels))
        else:
            self.bias = None

        # Activation
        if activation == "leaky":
            self.activation = nn.LeakyReLU(0.1, inplace=True)
        elif activation == "relu":
            self.activation = nn.ReLU(inplace=True)
        elif activation == "sigmoid":
            self.activation = nn.Sigmoid()
        elif activation is None:
            self.activation = None
        else:
            raise ValueError(f"Unsupported activation: {activation}")

        # For mask convolution (all ones kernel)
        self.register_buffer("mask_kernel", torch.ones(out_channels, in_channels, *self.kernel_size))

        # Padding
        self.padding = tuple((k-1)//2 for k in self.kernel_size)

    def forward(self, img, mask):
        # Convolve mask (to count valid pixels per output location)
        with torch.no_grad():
            mask_out = F.conv3d(mask, self.mask_kernel, stride=self.stride, padding=self.padding)
            mask_out = torch.clamp(mask_out, 0, 1)

        # Convolve image * mask
        img_out = F.conv3d(img * mask, self.weight, bias=None, stride=self.stride, padding=self.padding)

        # Normalize by ratio
        kernel_volume = self.kernel_size[0] * self.kernel_size[1] * self.kernel_size[2]
        mask_ratio = kernel_volume / (mask_out + 1e-8)
        mask_ratio = mask_ratio * mask_out  # zero where mask is zero
        img_out = img_out * mask_ratio

        if self.bias is not None:
            img_out = img_out + self.bias.view(1, -1, 1, 1, 1)

        if self.activation is not None:
            img_out = self.activation(img_out)

        return img_out, mask_out


In [2]:
class STpconvUNet3D(nn.Module):
    def __init__(self, in_channels=1, out_channels=1,
                 n_conv_layers=4, n_filters=None,
                 kernel_sizes=None, strides=None,
                 n_conv_per_block=1):
        super().__init__()

        self.n_conv_layers = n_conv_layers
        self.n_conv_per_block = n_conv_per_block

        if n_filters is None:
            n_filters = [min(2**(i+5), 256) for i in range(n_conv_layers)]  # [32,64,...]

        if kernel_sizes is None:
            kernel_sizes = [(3,3,3)] * n_conv_layers

        if strides is None:
            strides = [(2,2,2)] * n_conv_layers

        # Encoder
        self.encoders = nn.ModuleList()
        prev_ch = in_channels
        for i in range(n_conv_layers):
            block = nn.ModuleList()
            for j in range(n_conv_per_block-1):
                block.append(STpconv3D(prev_ch, n_filters[i], kernel_size=kernel_sizes[i],
                                       stride=(1,1,1), activation="leaky"))
                prev_ch = n_filters[i]
            block.append(STpconv3D(prev_ch, n_filters[i], kernel_size=kernel_sizes[i],
                                   stride=strides[i], activation="leaky"))
            self.encoders.append(block)
            prev_ch = n_filters[i]

        # Decoder
        self.decoders = nn.ModuleList()
        for i in reversed(range(n_conv_layers)):
            block = nn.ModuleList()
            nf = out_channels if i == 0 else n_filters[i-1]
            block.append(STpconv3D(prev_ch + (in_channels if i==0 else n_filters[i-1]),
                                   nf, kernel_size=(3,3,3), stride=(1,1,1), activation="leaky"))
            prev_ch = nf
            self.decoders.append(block)

        # Final 1x1x1 conv
        self.final = nn.Conv3d(prev_ch, out_channels, kernel_size=1)

    def forward(self, img, mask):
        # Encoder forward
        skips = []
        skip_masks = []
        x, m = img, mask
        for block in self.encoders:
            for layer in block:
                x, m = layer(x, m)
            skips.append(x)
            skip_masks.append(m)

        # Decoder forward
        for i, block in enumerate(self.decoders):
            x = F.interpolate(x, scale_factor=2, mode="trilinear", align_corners=False)
            m = F.interpolate(m, scale_factor=2, mode="nearest")

            if i < self.n_conv_layers:
                e = skips[self.n_conv_layers-1-i-1] if (self.n_conv_layers-1-i-1)>=0 else img
                em = skip_masks[self.n_conv_layers-1-i-1] if (self.n_conv_layers-1-i-1)>=0 else mask
                x = torch.cat([x, e], dim=1)
                m = torch.cat([m, em], dim=1)

            for layer in block:
                x, m = layer(x, m)

        return self.final(x)
