In [1]:
import torch
import torch.nn as nn
import scipy.ndimage

In [2]:
class HDC_Block(nn.Module):
  def __init__(self, channels):
    super().__init__()
    # my assumption was that if you wanted a convolution with 3x3x1, you couldn't have it be 3d and specify the kernel size like that
    # however, looking at the code implementation, you can do that
    self.one_one_one1 = nn.Conv3d(channels, channels, kernel_size=1, stride=1)
    self.three_three_one = nn.Conv3d(8, 8, kernel_size=(3,3,1), padding=(1,1,0))
    self.one_three_three = nn.Conv3d(channels, channels, kernel_size=(1,3,3), padding=(0,1,1))
  def forward(self, x):
    x1 = self.one_one_one1(x)
    print(x1.shape)

    # [2, 32, 64, 64, 64]

    print("channel groups")
    channel_group1 = x1[:, 0:8, :, :, :] # one modality
    print(channel_group1.shape)

    channel_group2 = x1[:, 8:16, :, :, :]
    print(channel_group2.shape)

    channel_group3 = x1[:, 16:24, :, :, :]
    print(channel_group3.shape)

    channel_group4 = x1[:, 24:32, :, :, :]
    print(channel_group4.shape)

    x2 = self.three_three_one(channel_group2)
    print(x2.shape)
    x3 = self.three_three_one(channel_group3+x2)
    print(x3.shape)
    x4 = self.three_three_one(channel_group4+x3)
    print(x4.shape)

    end = torch.cat([channel_group1, x2, x3, x4], dim=1)
    print(end.shape)

    x5 = self.one_one_one1(end)
    print(x5.shape)

    out = self.one_three_three(x5)
    print(out.shape)
    return out

In [3]:
x = torch.rand(size=(2, 32, 64, 64, 64), dtype=torch.float32)
print(x.shape)

model = HDC_Block(32)
print(model)
print()

out = model(x)

torch.Size([2, 32, 64, 64, 64])
HDC_Block(
  (one_one_one1): Conv3d(32, 32, kernel_size=(1, 1, 1), stride=(1, 1, 1))
  (three_three_one): Conv3d(8, 8, kernel_size=(3, 3, 1), stride=(1, 1, 1), padding=(1, 1, 0))
  (one_three_three): Conv3d(32, 32, kernel_size=(1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1))
)

torch.Size([2, 32, 64, 64, 64])
channel groups
torch.Size([2, 8, 64, 64, 64])
torch.Size([2, 8, 64, 64, 64])
torch.Size([2, 8, 64, 64, 64])
torch.Size([2, 8, 64, 64, 64])
torch.Size([2, 8, 64, 64, 64])
torch.Size([2, 8, 64, 64, 64])
torch.Size([2, 8, 64, 64, 64])
torch.Size([2, 32, 64, 64, 64])
torch.Size([2, 32, 64, 64, 64])
torch.Size([2, 32, 64, 64, 64])


In [4]:
class HDC_Net(nn.Module):
  def __init__(self, x):
    super().__init__()
    # self.pds = torch.nn.functional.interpolate(x)
    self.conv1 = nn.Conv3d(in_channels=4, out_channels=32, kernel_size=3, padding=1, stride=1)
    self.downsample = nn.Conv3d(32, 32, kernel_size=2, stride=2)
    self.HDC = HDC_Block(32)
    self.upsample = nn.ConvTranspose3d(32, 32, kernel_size=2, stride=2)
    self.upinterpolate = nn.Upsample(scale_factor=2, mode='trilinear')
    self.conv2 = nn.Conv3d(in_channels=32, out_channels=3, kernel_size=1, stride=1)
    self.softmax = nn.Softmax(dim=1)
  def forward(self, x):
    print(x.shape)
    nimages, channels, width, height, depth = x.shape
    print(nimages, channels, width, height, depth)
    print(x.type)
    # x1 = torch.tensor(scipy.ndimage.zoom(x, [1, 8.0, 0.5, 0.5, 0.5])) # using this function took about 2 minutes and for many images, it's not reasonable
    x1 = torch.nn.functional.interpolate(x, scale_factor=[0.5, 0.5, 0.5]) # PDS - interpolate only looks at dim 2,3,4... (doesn't regard for channel and number of images)
    # instead, I used torch.nn.functional.interpolate to interpolate the spatial dimensions, but for the channels I used a 1x1x1 conv
    # because they did want to avoid using 3x3x3 conv and it will work same
    print(x1.shape)
    x1 = self.conv1(x1)
    print(x1.shape)

    x2 = self.HDC(x1)
    print(x2.shape)
    print()
    x3 = self.downsample(x2)
    print(x3.shape)

    x4 = self.HDC(x3)
    print(x4.shape)
    print()
    x5 = self.downsample(x4)
    print(x5.shape)

    x6 = self.HDC(x5)
    print(x6.shape)
    print()
    x7 = self.downsample(x6)
    print(x7.shape)

    x8 = self.HDC(x7)
    print(x8.shape)
    print()

    print("decoder time")

    x9 = self.upsample(x8)
    print(x9.shape)
    x10 = torch.add(x9, x6)
    print(x10.shape)
    x11 = self.HDC(x10)
    print(x11.shape)

    x12 = self.upsample(x11)
    print(x12.shape)
    x13 = torch.add(x12, x4)
    print(x13.shape)
    x14 = self.HDC(x13)
    print(x14.shape)

    x15 = self.upsample(x14)
    print(x15.shape)
    x16 = torch.add(x15, x2)
    print(x16.shape)
    x17 = self.HDC(x16)
    print(x17.shape)

    print("\nupsampling\n") # by this they meant interpolation

    x18 = self.upinterpolate(x17)
    print(x18.shape)

    x19 = self.conv2(x18)
    print(x19.shape)

    out = self.softmax(x19)
    print(out.shape)

    return x19, out

In [5]:
x = torch.rand(size=(2, 4, 128, 128, 128), dtype=torch.float32)
print(x.shape)

model = HDC_Net(x)
print(model)
print()

out = model(x)
# print(out.shape)

torch.Size([2, 4, 128, 128, 128])
HDC_Net(
  (conv1): Conv3d(4, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
  (downsample): Conv3d(32, 32, kernel_size=(2, 2, 2), stride=(2, 2, 2))
  (HDC): HDC_Block(
    (one_one_one1): Conv3d(32, 32, kernel_size=(1, 1, 1), stride=(1, 1, 1))
    (three_three_one): Conv3d(8, 8, kernel_size=(3, 3, 1), stride=(1, 1, 1), padding=(1, 1, 0))
    (one_three_three): Conv3d(32, 32, kernel_size=(1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1))
  )
  (upsample): ConvTranspose3d(32, 32, kernel_size=(2, 2, 2), stride=(2, 2, 2))
  (upinterpolate): Upsample(scale_factor=2.0, mode=trilinear)
  (conv2): Conv3d(32, 3, kernel_size=(1, 1, 1), stride=(1, 1, 1))
  (softmax): Softmax(dim=1)
)

torch.Size([2, 4, 128, 128, 128])
2 4 128 128 128
<built-in method type of Tensor object at 0x7f6f1fa25ef0>
torch.Size([2, 4, 64, 64, 64])
torch.Size([2, 32, 64, 64, 64])
torch.Size([2, 32, 64, 64, 64])
channel groups
torch.Size([2, 8, 64, 64, 64])
torch.Size([2, 8, 6

In [6]:
output, probability = out

In [7]:
print(output.shape)
print(output)

torch.Size([2, 3, 128, 128, 128])
tensor([[[[[ 8.7843e-02,  8.7234e-02,  8.6016e-02,  ...,  8.2627e-02,
             7.3189e-02,  6.8470e-02],
           [ 9.2754e-02,  9.0761e-02,  8.6776e-02,  ...,  8.1573e-02,
             7.2789e-02,  6.8398e-02],
           [ 1.0257e-01,  9.7816e-02,  8.8297e-02,  ...,  7.9464e-02,
             7.1990e-02,  6.8253e-02],
           ...,
           [ 1.0986e-01,  1.0385e-01,  9.1842e-02,  ...,  8.0777e-02,
             7.3132e-02,  6.9309e-02],
           [ 1.1145e-01,  1.0424e-01,  8.9831e-02,  ...,  7.7969e-02,
             6.9042e-02,  6.4579e-02],
           [ 1.1225e-01,  1.0444e-01,  8.8826e-02,  ...,  7.6564e-02,
             6.6998e-02,  6.2215e-02]],

          [[ 8.8003e-02,  8.7243e-02,  8.5723e-02,  ...,  8.2546e-02,
             7.3142e-02,  6.8440e-02],
           [ 9.2851e-02,  9.0779e-02,  8.6636e-02,  ...,  8.1515e-02,
             7.2794e-02,  6.8434e-02],
           [ 1.0255e-01,  9.7852e-02,  8.8463e-02,  ...,  7.9453e-02,
      

In [8]:
print(probability.shape)
print(probability)

torch.Size([2, 3, 128, 128, 128])
tensor([[[[[0.3395, 0.3399, 0.3407,  ..., 0.3429, 0.3449, 0.3459],
           [0.3395, 0.3397, 0.3402,  ..., 0.3423, 0.3447, 0.3459],
           [0.3394, 0.3393, 0.3391,  ..., 0.3411, 0.3443, 0.3459],
           ...,
           [0.3415, 0.3411, 0.3403,  ..., 0.3416, 0.3441, 0.3453],
           [0.3460, 0.3449, 0.3429,  ..., 0.3426, 0.3433, 0.3436],
           [0.3482, 0.3469, 0.3442,  ..., 0.3431, 0.3429, 0.3428]],

          [[0.3394, 0.3398, 0.3406,  ..., 0.3429, 0.3449, 0.3459],
           [0.3394, 0.3396, 0.3401,  ..., 0.3423, 0.3447, 0.3459],
           [0.3393, 0.3392, 0.3391,  ..., 0.3410, 0.3443, 0.3460],
           ...,
           [0.3416, 0.3413, 0.3406,  ..., 0.3419, 0.3445, 0.3457],
           [0.3460, 0.3451, 0.3431,  ..., 0.3428, 0.3436, 0.3440],
           [0.3483, 0.3470, 0.3444,  ..., 0.3433, 0.3431, 0.3431]],

          [[0.3393, 0.3396, 0.3404,  ..., 0.3429, 0.3449, 0.3459],
           [0.3392, 0.3395, 0.3399,  ..., 0.3422, 0.3448, 0

In [9]:
import numpy as np

In [10]:
randomized_training_images = []
for i in range(2):
  newx = torch.rand(size=(1, 4, 128, 128, 128), dtype=torch.float32)
  randomized_training_images.append(newx)

In [11]:
print(len(randomized_training_images))
print(randomized_training_images[0].shape)
# print(randomized_training_images[0])

2
torch.Size([1, 4, 128, 128, 128])


In [12]:
randomized_training_segmentations = []
for i in range(2):
  newy = torch.rand(size=(1, 3, 128, 128, 128), dtype=torch.float32)
  randomized_training_segmentations.append(newy)

In [13]:
print(len(randomized_training_segmentations))
print(randomized_training_segmentations[0].shape)
# print(randomized_training_segmentations[1])

2
torch.Size([1, 3, 128, 128, 128])


In [14]:
randomized_training_data = list(zip(randomized_training_images, randomized_training_segmentations))

In [15]:
trainloader = torch.utils.data.DataLoader(dataset=randomized_training_data, batch_size=2, shuffle=True) # batch size should be 10

In [16]:
len(trainloader)

1

In [17]:
randomized_validation_images = []
for i in range(2):
  newy = torch.rand(size=(1, 4, 128, 128, 128), dtype=torch.float32)
  randomized_validation_images.append(newy)

In [18]:
print(len(randomized_validation_images))
print(randomized_validation_images[0].shape)

2
torch.Size([1, 4, 128, 128, 128])


In [19]:
randomized_validation_segmentations = []
for i in range(2):
  newy = torch.rand(size=(1, 3, 128, 128, 128), dtype=torch.float32)
  randomized_validation_segmentations.append(newy)

In [20]:
print(len(randomized_validation_segmentations))
print(randomized_validation_segmentations[0].shape)

2
torch.Size([1, 3, 128, 128, 128])


In [21]:
randomized_validation_data = list(zip(randomized_validation_images, randomized_validation_segmentations))

In [22]:
validationloader = torch.utils.data.DataLoader(dataset=randomized_validation_data, batch_size=2, shuffle=True) # batch size should be 10

In [23]:
len(validationloader)

1

In [24]:
randomized_testing_images = []
for i in range(2):
  newy = torch.rand(size=(1, 4, 128, 128, 128), dtype=torch.float32)
  randomized_testing_images.append(newy)

In [25]:
print(len(randomized_testing_images))
print(randomized_testing_images[0].shape)

2
torch.Size([1, 4, 128, 128, 128])


In [26]:
randomized_testing_segmentations = []
for i in range(2):
  newy = torch.rand(size=(1, 3, 128, 128, 128), dtype=torch.float32)
  randomized_testing_segmentations.append(newy)

In [27]:
print(len(randomized_testing_segmentations))
print(randomized_testing_segmentations[0].shape)

2
torch.Size([1, 3, 128, 128, 128])


In [40]:
randomized_testing_data = list(zip(randomized_testing_images, randomized_testing_segmentations))

In [41]:
testingloader = torch.utils.data.DataLoader(dataset=randomized_testing_data, batch_size=2, shuffle=True) # batch size should be 10

In [30]:
"""
Training + Validation:
multi-class soft Dice function as the loss function

Testing:
mean accuracy
dice coefficient
hausdorff implementation
"""

'\nTraining + Validation:\nmulti-class soft Dice function as the loss function\n\nTesting:\nmean accuracy\ndice coefficient\nhausdorff implementation\n'

In [31]:
# chose to make this a class because when you call dice loss in criterion, you don't have anything to input, but when u run the prediction through inside the training, then you have params
# also because most sources I saw used a class
class DiceLoss(nn.Module):
  def __init__(self):
    super().__init__()
    self.smooth = 1
  def forward(self, true, pred):
    # flatten to easily do it pixel by pixel
    true = true.view(-1)
    pred = pred.view(-1)
    numerator = 2*(true*pred).sum()
    denominator = true.sum() + pred.sum()
    dice_loss = 1 - (numerator + self.smooth) / (denominator + self.smooth)
    return dice_loss

In [32]:
import torch.optim

In [33]:
# epochs
epochs = 2 # should be 800
# loss
criterion = DiceLoss()
# optimizer
optimizer = torch.optim.Adam(params=model.parameters(), lr=10**-3, weight_decay=10**-5)

In [42]:
training_losses = []
validation_losses = []

for i in range(epochs):
  training_loss = 0
  validation_loss = 0
  print("training time")
  for images, segs in trainloader:
    optimizer.zero_grad()
    print(len(images), len(segs))
    print(images.shape)
    print(segs.shape)
    images = images.squeeze().clone().detach().requires_grad_(True)
    # segs = segs.long() - no
    segs = segs.squeeze().clone().detach().requires_grad_(True)
    print(images.shape)
    print(segs.shape)
    outputs, softmax_outputs = model(images)
    print(outputs.shape)
    print(softmax_outputs.shape)

    # arg_outputs = outputs.argmax(dim=1)
    # print(arg_outputs.shape)
    # print(arg_outputs)
    # print(segs.shape)
    print()
    loss = criterion(softmax_outputs.float(), segs)
    print(loss) # loss with random tensors will be really high because none of the tensors are related to each other
    
    # loss can be > 1 - https://ai.stackexchange.com/questions/24685/can-the-sparse-categorical-cross-entropy-be-greater-than-one, https://stats.stackexchange.com/questions/392681/cross-entropy-loss-max-value
    loss.backward()
    training_loss += loss.item()
    print()
  print("validation time")
  for images, segs in validationloader:
    optimizer.zero_grad()
    print(len(images), len(segs))
    print(images.shape)
    print(segs.shape)
    images = images.squeeze().clone().detach().requires_grad_(True)
    # segs = segs.long() - no
    segs = segs.squeeze().clone().detach().requires_grad_(True)
    print(images.shape)
    print(segs.shape)
    outputs, softmax_outputs = model(images)
    print(outputs.shape)
    print(softmax_outputs.shape)
    loss = criterion(softmax_outputs.float(), segs)
    print(loss) # loss with random tensors will be really high because none of the tensors are related to each other
    # loss can be > 1 - https://ai.stackexchange.com/questions/24685/can-the-sparse-categorical-cross-entropy-be-greater-than-one, https://stats.stackexchange.com/questions/392681/cross-entropy-loss-max-value
    loss.backward()
    validation_loss += loss.item()
  training_losses.append(training_loss/len(trainloader))
  validation_losses.append(validation_loss/len(validationloader))
  print("Epoch: {}/{}... Training Loss: {}... Validation Loss: {}...".format(i+1,epochs, training_losses[-1], validation_losses[-1]))
  if validation_loss < min(validation_losses):
    print("Validation loss has decreased...saving model")
    torch.save(model.state_dict(), "fcn.pth")
  print()

training time
2 2
torch.Size([2, 1, 4, 128, 128, 128])
torch.Size([2, 1, 3, 128, 128, 128])
torch.Size([2, 4, 128, 128, 128])
torch.Size([2, 3, 128, 128, 128])
torch.Size([2, 4, 128, 128, 128])
2 4 128 128 128
<built-in method type of Tensor object at 0x7f6f1d243f50>
torch.Size([2, 4, 64, 64, 64])
torch.Size([2, 32, 64, 64, 64])
torch.Size([2, 32, 64, 64, 64])
channel groups
torch.Size([2, 8, 64, 64, 64])
torch.Size([2, 8, 64, 64, 64])
torch.Size([2, 8, 64, 64, 64])
torch.Size([2, 8, 64, 64, 64])
torch.Size([2, 8, 64, 64, 64])
torch.Size([2, 8, 64, 64, 64])
torch.Size([2, 8, 64, 64, 64])
torch.Size([2, 32, 64, 64, 64])
torch.Size([2, 32, 64, 64, 64])
torch.Size([2, 32, 64, 64, 64])
torch.Size([2, 32, 64, 64, 64])

torch.Size([2, 32, 32, 32, 32])
torch.Size([2, 32, 32, 32, 32])
channel groups
torch.Size([2, 8, 32, 32, 32])
torch.Size([2, 8, 32, 32, 32])
torch.Size([2, 8, 32, 32, 32])
torch.Size([2, 8, 32, 32, 32])
torch.Size([2, 8, 32, 32, 32])
torch.Size([2, 8, 32, 32, 32])
torch.Size(

In [43]:
print(training_losses)
print(validation_losses)

[0.599981963634491, 0.5999820232391357]
[0.5999563932418823, 0.5999563932418823]


In [44]:
def dice_score(outputs, segmentations): # they find individual
  # outputs = torch.Size([2, 3, 128, 128, 128])
  # segmentations = torch.Size([2, 3, 128, 128, 128])
  # print(output.shape)
  # print(segmentations.shape)
  n_classes = segmentations.shape[1]
  region_scores = []
  for i in range(n_classes):
    outputs = outputs.view(-1)
    segmentations = segmentations.view(-1)
    numerator = 2*(outputs*segmentations).sum()
    denominator = outputs.sum() + segmentations.sum()
    dice = (numerator) / (denominator)
    region_scores.append(dice)
  return region_scores

In [45]:
"""
1 NECROTIC TUMOUR CORE (NCR — label 1) - index 0

2 GD-ENHANCING TUMOUR (ET — label 2) - index 1

3 PERITUMORAL EDEMATOUS/INVADED TISSUE (ED — label 3) - index 2
"""

'\n1 NECROTIC TUMOUR CORE (NCR — label 1) - index 0\n\n2 GD-ENHANCING TUMOUR (ET — label 2) - index 1\n\n3 PERITUMORAL EDEMATOUS/INVADED TISSUE (ED — label 3) - index 2\n'

In [47]:
for images, segs in testingloader:
    with torch.no_grad():
      print(len(images), len(segs))
      print(images.shape)
      print(segs.shape)
      images = images.squeeze().clone().detach().requires_grad_(True)
      # segs = segs.long() - no
      segs = segs.squeeze().clone().detach().requires_grad_(True)
      print(images.shape)
      print(segs.shape)
      outputs, softmax_outputs = model(images)
      print(outputs.shape)
      print(softmax_outputs.shape)
      print()

      print("......."*5)

      region_scores = dice_score(softmax_outputs, segs)

      print(len(region_scores))
      print(region_scores)
      print()

      print("1 NECROTIC TUMOUR CORE (NCR — label 1)")
      print(region_scores[0].item())
      print()

      print("2 GD-ENHANCING TUMOUR (ET — label 2)")
      print(region_scores[1].item())
      print()

      print("3 PERITUMORAL EDEMATOUS/INVADED TISSUE (ED — label 3)")
      print(region_scores[2].item())

2 2
torch.Size([2, 1, 4, 128, 128, 128])
torch.Size([2, 1, 3, 128, 128, 128])
torch.Size([2, 4, 128, 128, 128])
torch.Size([2, 3, 128, 128, 128])
torch.Size([2, 4, 128, 128, 128])
2 4 128 128 128
<built-in method type of Tensor object at 0x7f6f1d241230>
torch.Size([2, 4, 64, 64, 64])
torch.Size([2, 32, 64, 64, 64])
torch.Size([2, 32, 64, 64, 64])
channel groups
torch.Size([2, 8, 64, 64, 64])
torch.Size([2, 8, 64, 64, 64])
torch.Size([2, 8, 64, 64, 64])
torch.Size([2, 8, 64, 64, 64])
torch.Size([2, 8, 64, 64, 64])
torch.Size([2, 8, 64, 64, 64])
torch.Size([2, 8, 64, 64, 64])
torch.Size([2, 32, 64, 64, 64])
torch.Size([2, 32, 64, 64, 64])
torch.Size([2, 32, 64, 64, 64])
torch.Size([2, 32, 64, 64, 64])

torch.Size([2, 32, 32, 32, 32])
torch.Size([2, 32, 32, 32, 32])
channel groups
torch.Size([2, 8, 32, 32, 32])
torch.Size([2, 8, 32, 32, 32])
torch.Size([2, 8, 32, 32, 32])
torch.Size([2, 8, 32, 32, 32])
torch.Size([2, 8, 32, 32, 32])
torch.Size([2, 8, 32, 32, 32])
torch.Size([2, 8, 32, 32,