In [307]:
import torch
import torch.nn as nn
import scipy.ndimage

In [308]:
def adapt_to_conv2d(ts, n):
  # from [x, 64, 64, 64]
  ts = ts.unsqueeze(dim=0)
  ts = torch.nn.functional.interpolate(ts, scale_factor=[1/n,1,1])
  ts = ts.squeeze(dim=0)
  return ts

def adapt_from_conv2d(ts, n):
  # from [x, 1, 64, 64]
  ts = ts.unsqueeze(dim=0)
  ts = torch.nn.functional.interpolate(ts, scale_factor=[n,1,1])
  # ts = ts.squeeze(dim=0)
  return ts

In [309]:
class HDC_Block(nn.Module):
  def __init__(self, channels):
    super().__init__()
    self.one_one_one1 = nn.Conv3d(channels, channels, kernel_size=1, stride=1)
    # self.three_three_one = nn.Conv3d(8, 8, kernel_size=[3,3,1], stride=1)
    self.three_three_one = nn.Conv2d(1, 1, kernel_size=3, stride=1, padding=1) # doesn't work with 2d convs because of dimensions
    self.one_one_one2 = nn.Conv3d(4, channels, kernel_size=1, stride=1)
  def forward(self, x):
    slices = x.shape[2]
    print(slices)
    print(x.shape)

    x1 = self.one_one_one1(x)
    print(x1.shape)

    channel_group1 = x1[:, 0, :, :, :]
    channel_group1 = adapt_to_conv2d(channel_group1, slices)
    print(channel_group1.shape)

    channel_group2 = x1[:, 1, :, :, :]
    channel_group2 = adapt_to_conv2d(channel_group2, slices)
    print(channel_group2.shape)

    channel_group3 = x1[:, 2, :, :, :]
    channel_group3 = adapt_to_conv2d(channel_group3, slices)
    print(channel_group3.shape)

    channel_group4 = x1[:, 3, :, :, :]
    channel_group4 = adapt_to_conv2d(channel_group4, slices)
    print(channel_group4.shape)

    print()

    final_group2 = self.three_three_one(channel_group2)
    print(final_group2.shape)
    
    print()

    group2_group3 = torch.cat([final_group2, channel_group3], dim=0)
    print(group2_group3.shape)
    final_group23 = self.three_three_one(group2_group3)
    print(final_group23.shape)

    print()

    group2_group3_group4 = torch.cat([final_group23, channel_group4], dim=0)
    print(group2_group3_group4.shape)
    final_group234 = self.three_three_one(group2_group3_group4)
    print(final_group234.shape)

    print()

    final_group1234 = torch.cat([final_group234, channel_group1], dim=0)
    print(final_group1234.shape)

    final_group1234 = adapt_from_conv2d(final_group1234, slices)
    print(final_group1234.shape)

    x2 = self.one_one_one2(final_group1234)
    print(x2.shape)

    x3 = x + x2

    return x3

In [310]:
x = torch.rand(size=(1, 32, 32, 32, 32), dtype=torch.float32)
print(x.shape)

model = HDC_Block(32)
print(model)
print()

out = model(x)
# print(out.shape)

torch.Size([1, 32, 32, 32, 32])
HDC_Block(
  (one_one_one1): Conv3d(32, 32, kernel_size=(1, 1, 1), stride=(1, 1, 1))
  (three_three_one): Conv2d(1, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (one_one_one2): Conv3d(4, 32, kernel_size=(1, 1, 1), stride=(1, 1, 1))
)

32
torch.Size([1, 32, 32, 32, 32])
torch.Size([1, 32, 32, 32, 32])
torch.Size([1, 1, 32, 32])
torch.Size([1, 1, 32, 32])
torch.Size([1, 1, 32, 32])
torch.Size([1, 1, 32, 32])

torch.Size([1, 1, 32, 32])

torch.Size([2, 1, 32, 32])
torch.Size([2, 1, 32, 32])

torch.Size([3, 1, 32, 32])
torch.Size([3, 1, 32, 32])

torch.Size([4, 1, 32, 32])
torch.Size([1, 4, 32, 32, 32])
torch.Size([1, 32, 32, 32, 32])


In [311]:
class HDC_Net(nn.Module):
  def __init__(self, x):
    super().__init__()
    # self.pds = torch.nn.functional.interpolate(x)
    self.conv1 = nn.Conv3d(in_channels=32, out_channels=32, kernel_size=3, padding=1, stride=1)
    self.HDC = HDC_Block(32)
    self.maxpool = nn.MaxPool3d(kernel_size=2, stride=2)
    self.softmax = nn.Softmax()
  def forward(self, x):
    print(x.shape)
    nimages, channels, width, height, depth = x.shape
    print(nimages, channels, width, height, depth)
    print(x.type)
    # x1 = torch.tensor(scipy.ndimage.zoom(x, [1, 8.0, 0.5, 0.5, 0.5])) # using this function took about 2 minutes and for many images, it's not reasonable
    x1 = torch.nn.functional.interpolate(x, scale_factor=0.5) # interpolate only looks at dim 2,3,4... (doesn't regard for channel and number of images)
    # instead, I used torch.nn.functional.interpolate to interpolate the spatial dimensions, but for the channels I used a 1x1x1 conv
    # because they did want to avoid using 3x3x3 conv and it will work same
    print(x1.shape)
    x1 = nn.Conv3d(in_channels=4, out_channels=32, kernel_size=1, stride=1)(x1)
    print(x1.shape)
    x2 = self.conv1(x1)
    print(x2.shape)

    print()

    x3 = self.HDC(x2)
    x4 = self.maxpool(x3)
    print(x4.shape)

    print()
    
    x5 = self.HDC(x4)
    x6 = self.maxpool(x5)
    print(x6.shape)

    print()

    x7 = self.HDC(x6)
    x8 = self.maxpool(x7)
    print(x8.shape)

    print()

    x9 = self.HDC(x8)
    print(x9.shape)

    print()
    
    x10 = nn.ConvTranspose3d(32, 32, kernel_size=2, stride=2)(x9)
    print(x10.shape)
    x11 = torch.add(x10, x7)
    print(x11.shape)
    x12 = self.HDC(x11)
    print(x12.shape)

    x13 = nn.ConvTranspose3d(32, 32, kernel_size=2, stride=2)(x12)
    print(x13.shape)
    x14 = torch.add(x13, x5)
    print(x14.shape)
    x15 = self.HDC(x14)
    print(x15.shape)

    x16 = nn.ConvTranspose3d(32, 32, kernel_size=2, stride=2)(x15)
    print(x16.shape)
    x17 = torch.add(x16, x3)
    print(x17.shape)
    x18 = self.HDC(x17)
    print(x18.shape)

    print()
    
    x19 = nn.ConvTranspose3d(32, 3, kernel_size=2, stride=2)(x18)
    print(x19.shape)

    prob = self.softmax(x19)

    return x19, prob

In [312]:
x = torch.rand(size=(1, 4, 128, 128, 128), dtype=torch.float32)
print(x.shape)

model = HDC_Net(x)
print(model)
print()

out = model(x)
# print(out.shape)

torch.Size([1, 4, 128, 128, 128])
HDC_Net(
  (conv1): Conv3d(32, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
  (HDC): HDC_Block(
    (one_one_one1): Conv3d(32, 32, kernel_size=(1, 1, 1), stride=(1, 1, 1))
    (three_three_one): Conv2d(1, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (one_one_one2): Conv3d(4, 32, kernel_size=(1, 1, 1), stride=(1, 1, 1))
  )
  (maxpool): MaxPool3d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (softmax): Softmax(dim=None)
)

torch.Size([1, 4, 128, 128, 128])
1 4 128 128 128
<built-in method type of Tensor object at 0x7f96b170c590>
torch.Size([1, 4, 64, 64, 64])
torch.Size([1, 32, 64, 64, 64])
torch.Size([1, 32, 64, 64, 64])

64
torch.Size([1, 32, 64, 64, 64])
torch.Size([1, 32, 64, 64, 64])
torch.Size([1, 1, 64, 64])
torch.Size([1, 1, 64, 64])
torch.Size([1, 1, 64, 64])
torch.Size([1, 1, 64, 64])

torch.Size([1, 1, 64, 64])

torch.Size([2, 1, 64, 64])
torch.Size([2, 1, 64, 64])

torch.Size([3, 1, 64, 64



In [313]:
output, probability = out

In [314]:
print(output.shape)
print(output)

torch.Size([1, 3, 128, 128, 128])
tensor([[[[[ 1.3352e-02,  5.1008e-01, -1.2334e-01,  ...,  3.5770e-01,
            -1.8187e-02,  5.3047e-01],
           [-2.1438e-01, -1.6631e-01, -3.2354e-01,  ..., -2.4701e-01,
            -3.4320e-01, -1.0304e-01],
           [-1.8454e-01,  6.6412e-01,  3.9195e-02,  ...,  4.4311e-01,
             1.6506e-01,  4.8279e-01],
           ...,
           [-2.8958e-01, -8.6609e-02, -4.3220e-01,  ..., -1.0983e-01,
            -3.5034e-01,  2.6012e-02],
           [-1.7226e-01,  5.1136e-01,  1.0627e-01,  ...,  3.8220e-01,
             1.7234e-01,  5.2318e-01],
           [ 2.1654e-01, -1.7558e-01,  1.4851e-01,  ..., -2.5490e-01,
             1.4170e-01,  2.2263e-01]],

          [[ 5.3952e-01, -7.9148e-01,  7.1376e-01,  ..., -7.3147e-01,
             7.1752e-01, -6.5223e-01],
           [ 4.2476e-01,  6.2628e-01,  3.4775e-01,  ...,  5.7942e-01,
             2.5304e-01,  4.2800e-01],
           [ 3.1814e-01, -7.5857e-01,  5.5796e-01,  ..., -5.6664e-01,
      

In [315]:
print(probability.shape)
print(probability)

torch.Size([1, 3, 128, 128, 128])
tensor([[[[[0.3418, 0.4801, 0.2617,  ..., 0.4282, 0.2996, 0.5291],
           [0.2868, 0.3414, 0.2207,  ..., 0.3287, 0.2332, 0.3469],
           [0.2786, 0.5644, 0.3153,  ..., 0.4754, 0.3426, 0.4657],
           ...,
           [0.2654, 0.3801, 0.2095,  ..., 0.3836, 0.2351, 0.3635],
           [0.2907, 0.5267, 0.3187,  ..., 0.4719, 0.3546, 0.5120],
           [0.3133, 0.3347, 0.3340,  ..., 0.3228, 0.3546, 0.4417]],

          [[0.4481, 0.1254, 0.6304,  ..., 0.1580, 0.6123, 0.1194],
           [0.3129, 0.5299, 0.2863,  ..., 0.5430, 0.2773, 0.5255],
           [0.4896, 0.1196, 0.5835,  ..., 0.1695, 0.6004, 0.1242],
           ...,
           [0.2993, 0.4987, 0.2839,  ..., 0.5473, 0.3156, 0.5637],
           [0.4678, 0.1573, 0.6047,  ..., 0.1428, 0.5577, 0.1305],
           [0.2907, 0.4869, 0.3106,  ..., 0.5016, 0.3290, 0.5093]],

          [[0.3001, 0.5381, 0.2739,  ..., 0.4983, 0.3043, 0.4293],
           [0.2361, 0.3641, 0.3018,  ..., 0.3233, 0.2914, 0