In [41]:
import torch
import torch.nn as nn
import scipy.ndimage

In [89]:
class HDC_Block(nn.Module):
  def __init__(self, channels):
    super().__init__()
    # my assumption was that if you wanted a convolution with 3x3x1, you couldn't have it be 3d and specify the kernel size like that
    # however, looking at the code implementation, you can do that
    self.one_one_one1 = nn.Conv3d(channels, channels, kernel_size=1, stride=1)
    self.three_three_one = nn.Conv3d(8, 8, kernel_size=(3,3,1), padding=(1,1,0))
    self.one_three_three = nn.Conv3d(channels, channels, kernel_size=(1,3,3), padding=(0,1,1))
  def forward(self, x):
    x1 = self.one_one_one1(x)
    print(x1.shape)

    # [2, 32, 64, 64, 64]

    print("channel groups")
    channel_group1 = x1[:, 0:8, :, :, :] # one modality
    print(channel_group1.shape)

    channel_group2 = x1[:, 8:16, :, :, :]
    print(channel_group2.shape)

    channel_group3 = x1[:, 16:24, :, :, :]
    print(channel_group3.shape)

    channel_group4 = x1[:, 24:32, :, :, :]
    print(channel_group4.shape)

    x2 = self.three_three_one(channel_group2)
    print(x2.shape)
    x3 = self.three_three_one(channel_group3+x2)
    print(x3.shape)
    x4 = self.three_three_one(channel_group4+x3)
    print(x4.shape)

    end = torch.cat([channel_group1, x2, x3, x4], dim=1)
    print(end.shape)

    x5 = self.one_one_one1(end)
    print(x5.shape)

    out = self.one_three_three(x5)
    print(out.shape)
    return out

In [92]:
x = torch.rand(size=(2, 32, 64, 64, 64), dtype=torch.float32)
print(x.shape)

model = HDC_Block(32)
print(model)
print()

out = model(x)

torch.Size([2, 32, 64, 64, 64])
HDC_Block(
  (one_one_one1): Conv3d(32, 32, kernel_size=(1, 1, 1), stride=(1, 1, 1))
  (three_three_one): Conv3d(8, 8, kernel_size=(3, 3, 1), stride=(1, 1, 1), padding=(1, 1, 0))
  (one_three_three): Conv3d(32, 32, kernel_size=(1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1))
)

torch.Size([2, 32, 64, 64, 64])
channel groups
torch.Size([2, 8, 64, 64, 64])
torch.Size([2, 8, 64, 64, 64])
torch.Size([2, 8, 64, 64, 64])
torch.Size([2, 8, 64, 64, 64])
torch.Size([2, 8, 64, 64, 64])
torch.Size([2, 8, 64, 64, 64])
torch.Size([2, 8, 64, 64, 64])
torch.Size([2, 32, 64, 64, 64])
torch.Size([2, 32, 64, 64, 64])
torch.Size([2, 32, 64, 64, 64])


In [138]:
class HDC_Net(nn.Module):
  def __init__(self, x):
    super().__init__()
    # self.pds = torch.nn.functional.interpolate(x)
    self.conv1 = nn.Conv3d(in_channels=4, out_channels=32, kernel_size=3, padding=1, stride=1)
    self.downsample = nn.Conv3d(32, 32, kernel_size=2, stride=2)
    self.HDC = HDC_Block(32)
    self.upsample = nn.ConvTranspose3d(32, 32, kernel_size=2, stride=2)
    self.upinterpolate = nn.Upsample(scale_factor=2, mode='trilinear')
    self.conv2 = nn.Conv3d(in_channels=32, out_channels=4, kernel_size=1, stride=1)
    self.softmax = nn.Softmax(dim=1)
  def forward(self, x):
    print(x.shape)
    nimages, channels, width, height, depth = x.shape
    print(nimages, channels, width, height, depth)
    print(x.type)
    # x1 = torch.tensor(scipy.ndimage.zoom(x, [1, 8.0, 0.5, 0.5, 0.5])) # using this function took about 2 minutes and for many images, it's not reasonable
    x1 = torch.nn.functional.interpolate(x, scale_factor=[0.5, 0.5, 0.5]) # PDS - interpolate only looks at dim 2,3,4... (doesn't regard for channel and number of images)
    # instead, I used torch.nn.functional.interpolate to interpolate the spatial dimensions, but for the channels I used a 1x1x1 conv
    # because they did want to avoid using 3x3x3 conv and it will work same
    print(x1.shape)
    x1 = self.conv1(x1)
    print(x1.shape)

    x2 = self.HDC(x1)
    print(x2.shape)
    print()
    x3 = self.downsample(x2)
    print(x3.shape)

    x4 = self.HDC(x3)
    print(x4.shape)
    print()
    x5 = self.downsample(x4)
    print(x5.shape)

    x6 = self.HDC(x5)
    print(x6.shape)
    print()
    x7 = self.downsample(x6)
    print(x7.shape)

    x8 = self.HDC(x7)
    print(x8.shape)
    print()

    print("decoder time")

    x9 = self.upsample(x8)
    print(x9.shape)
    x10 = torch.add(x9, x6)
    print(x10.shape)
    x11 = self.HDC(x10)
    print(x11.shape)

    x12 = self.upsample(x11)
    print(x12.shape)
    x13 = torch.add(x12, x4)
    print(x13.shape)
    x14 = self.HDC(x13)
    print(x14.shape)

    x15 = self.upsample(x14)
    print(x15.shape)
    x16 = torch.add(x15, x2)
    print(x16.shape)
    x17 = self.HDC(x16)
    print(x17.shape)

    print("\nupsampling\n") # by this they meant interpolation

    x18 = self.upinterpolate(x17)
    print(x18.shape)

    x19 = self.conv2(x18)
    print(x19.shape)

    out = self.softmax(x19)
    print(out.shape)

    return x19, out

In [139]:
x = torch.rand(size=(2, 4, 128, 128, 128), dtype=torch.float32)
print(x.shape)

model = HDC_Net(x)
print(model)
print()

out = model(x)
# print(out.shape)

torch.Size([2, 4, 128, 128, 128])
HDC_Net(
  (conv1): Conv3d(4, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
  (downsample): Conv3d(32, 32, kernel_size=(2, 2, 2), stride=(2, 2, 2))
  (HDC): HDC_Block(
    (one_one_one1): Conv3d(32, 32, kernel_size=(1, 1, 1), stride=(1, 1, 1))
    (three_three_one): Conv3d(8, 8, kernel_size=(3, 3, 1), stride=(1, 1, 1), padding=(1, 1, 0))
    (one_three_three): Conv3d(32, 32, kernel_size=(1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1))
  )
  (upsample): ConvTranspose3d(32, 32, kernel_size=(2, 2, 2), stride=(2, 2, 2))
  (upinterpolate): Upsample(scale_factor=2.0, mode=trilinear)
  (conv2): Conv3d(32, 4, kernel_size=(1, 1, 1), stride=(1, 1, 1))
  (softmax): Softmax(dim=1)
)

torch.Size([2, 4, 128, 128, 128])
2 4 128 128 128
<built-in method type of Tensor object at 0x7fab71fcf230>
torch.Size([2, 4, 64, 64, 64])
torch.Size([2, 32, 64, 64, 64])
torch.Size([2, 32, 64, 64, 64])
channel groups
torch.Size([2, 8, 64, 64, 64])
torch.Size([2, 8, 6

In [140]:
output, probability = out

In [None]:
print(output.shape)
print(output)

In [None]:
print(probability.shape)
print(probability)

In [143]:
import numpy as np

In [144]:
randomized_training_images = []
for i in range(2):
  newx = torch.rand(size=(1, 4, 128, 128, 128), dtype=torch.float32)
  randomized_training_images.append(newx)

In [145]:
print(len(randomized_training_images))
print(randomized_training_images[0].shape)
# print(randomized_training_images[0])

2
torch.Size([1, 4, 128, 128, 128])


In [146]:
randomized_training_segmentations = []
for i in range(2):
  newy = torch.rand(size=(1, 4, 128, 128, 128), dtype=torch.float32)
  randomized_training_segmentations.append(newy)

In [147]:
print(len(randomized_training_segmentations))
print(randomized_training_segmentations[0].shape)
# print(randomized_training_segmentations[1])

2
torch.Size([1, 4, 128, 128, 128])


In [148]:
randomized_training_data = list(zip(randomized_training_images, randomized_training_segmentations))

In [173]:
trainloader = torch.utils.data.DataLoader(dataset=randomized_training_data, batch_size=2, shuffle=True) # batch size should be 10

In [174]:
len(trainloader)

1

In [175]:
randomized_validation_images = []
for i in range(2):
  newy = torch.rand(size=(1, 4, 128, 128, 128), dtype=torch.float32)
  randomized_validation_images.append(newy)

In [176]:
print(len(randomized_validation_images))
print(randomized_validation_images[0].shape)

2
torch.Size([1, 4, 128, 128, 128])


In [177]:
randomized_validation_segmentations = []
for i in range(2):
  newy = torch.rand(size=(1, 4, 128, 128, 128), dtype=torch.float32)
  randomized_validation_segmentations.append(newy)

In [178]:
print(len(randomized_validation_segmentations))
print(randomized_validation_segmentations[0].shape)

2
torch.Size([1, 4, 128, 128, 128])


In [179]:
randomized_validation_data = list(zip(randomized_validation_images, randomized_validation_segmentations))

In [180]:
validationloader = torch.utils.data.DataLoader(dataset=randomized_validation_data, batch_size=2, shuffle=True) # batch size should be 10

In [181]:
len(validationloader)

1

In [182]:
randomized_testing_images = []
for i in range(2):
  newy = torch.rand(size=(1, 4, 128, 128, 128), dtype=torch.float32)
  randomized_testing_images.append(newy)

In [183]:
print(len(randomized_testing_images))
print(randomized_testing_images[0].shape)

2
torch.Size([1, 4, 128, 128, 128])


In [184]:
randomized_testing_segmentations = []
for i in range(2):
  newy = torch.rand(size=(1, 4, 128, 128, 128), dtype=torch.float32)
  randomized_testing_segmentations.append(newy)

In [185]:
print(len(randomized_testing_segmentations))
print(randomized_testing_segmentations[0].shape)

2
torch.Size([1, 4, 128, 128, 128])


In [186]:
randomized_testing_data = list(zip(randomized_testing_segmentations, randomized_testing_segmentations))

In [187]:
testingloader = torch.utils.data.DataLoader(dataset=randomized_testing_data, batch_size=2, shuffle=True) # batch size should be 10

In [188]:
# def dice_loss (y_preds, y_outputs):
#   # talk more about this

In [189]:
# def multiclass_soft_dice_loss (y_preds, y_outputs):
#   # talk more about this

In [190]:
# def binary_soft_dice_loss (y_preds, y_outputs):
#   # there is both a binary and a multiclass formula

In [191]:
"""
Training + Validation:
multi-class soft Dice function as the loss function

Testing:
mean accuracy
dice coefficient
hausdorff implementation
"""

'\nTraining + Validation:\nmulti-class soft Dice function as the loss function\n\nTesting:\nmean accuracy\ndice coefficient\nhausdorff implementation\n'

In [201]:
# chose to make this a class because when you call dice loss in criterion, you don't have anything to input, but when u run the prediction through inside the training, then you have params
# also because most sources I saw used a class
class DiceLoss(nn.Module):
  def __init__(self):
    super().__init__()
    self.smooth = 1
  def forward(self, true, pred):
    # flatten to easily do it pixel by pixel
    true = true.view(-1)
    pred = pred.view(-1)
    numerator = 2*(true*pred).sum()
    denominator = true.sum() + pred.sum()
    dice_loss = 1 - (numerator + self.smooth) / (denominator + self.smooth)
    return dice_loss

In [202]:
import torch.optim

In [203]:
# epochs
epochs = 2 # should be 800
# loss
criterion = DiceLoss()
# optimizer
optimizer = torch.optim.Adam(params=model.parameters(), lr=10**-3, weight_decay=10**-5)

In [207]:
training_losses = []
validation_losses = []

for i in range(epochs):
  training_loss = 0
  validation_loss = 0
  print("training time")
  for images, segs in trainloader:
    optimizer.zero_grad()
    print(len(images), len(segs))
    print(images.shape)
    print(segs.shape)
    images = images.squeeze().clone().detach().requires_grad_(True)
    # segs = segs.long() - no
    segs = segs.squeeze().clone().detach().requires_grad_(True)
    print(images.shape)
    print(segs.shape)
    outputs, softmax_outputs = model(images)
    print(outputs.shape)
    print(softmax_outputs.shape)

    # arg_outputs = outputs.argmax(dim=1)
    # print(arg_outputs.shape)
    # print(arg_outputs)
    # print(segs.shape)
    print()
    loss = criterion(softmax_outputs.float(), segs)
    print(loss) # loss with random tensors will be really high because none of the tensors are related to each other
    
    # loss can be > 1 - https://ai.stackexchange.com/questions/24685/can-the-sparse-categorical-cross-entropy-be-greater-than-one, https://stats.stackexchange.com/questions/392681/cross-entropy-loss-max-value
    loss.backward()
    training_loss += loss.item()
    print()
  print("validation time")
  for images, segs in validationloader:
    optimizer.zero_grad()
    print(len(images), len(segs))
    print(images.shape)
    print(segs.shape)
    images = images.squeeze().clone().detach().requires_grad_(True)
    # segs = segs.long() - no
    segs = segs.squeeze().clone().detach().requires_grad_(True)
    print(images.shape)
    print(segs.shape)
    outputs, softmax_outputs = model(images)
    print(outputs.shape)
    print(softmax_outputs.shape)
    loss = criterion(softmax_outputs.float(), segs)
    print(loss) # loss with random tensors will be really high because none of the tensors are related to each other
    # loss can be > 1 - https://ai.stackexchange.com/questions/24685/can-the-sparse-categorical-cross-entropy-be-greater-than-one, https://stats.stackexchange.com/questions/392681/cross-entropy-loss-max-value
    loss.backward()
    validation_loss += loss.item()
  training_losses.append(training_loss/len(trainloader))
  validation_losses.append(validation_loss/len(validationloader))
  print("Epoch: {}/{}... Training Loss: {}... Validation Loss: {}...".format(i+1,epochs, training_losses[-1], validation_losses[-1]))
  if validation_loss < min(validation_losses):
    print("Validation loss has decreased...saving model")
    torch.save(model.state_dict(), "fcn.pth")
  print()

training time
2 2
torch.Size([2, 1, 4, 128, 128, 128])
torch.Size([2, 1, 4, 128, 128, 128])
torch.Size([2, 4, 128, 128, 128])
torch.Size([2, 4, 128, 128, 128])
torch.Size([2, 4, 128, 128, 128])
2 4 128 128 128
<built-in method type of Tensor object at 0x7fab71863d10>
torch.Size([2, 4, 64, 64, 64])
torch.Size([2, 32, 64, 64, 64])
torch.Size([2, 32, 64, 64, 64])
channel groups
torch.Size([2, 8, 64, 64, 64])
torch.Size([2, 8, 64, 64, 64])
torch.Size([2, 8, 64, 64, 64])
torch.Size([2, 8, 64, 64, 64])
torch.Size([2, 8, 64, 64, 64])
torch.Size([2, 8, 64, 64, 64])
torch.Size([2, 8, 64, 64, 64])
torch.Size([2, 32, 64, 64, 64])
torch.Size([2, 32, 64, 64, 64])
torch.Size([2, 32, 64, 64, 64])
torch.Size([2, 32, 64, 64, 64])

torch.Size([2, 32, 32, 32, 32])
torch.Size([2, 32, 32, 32, 32])
channel groups
torch.Size([2, 8, 32, 32, 32])
torch.Size([2, 8, 32, 32, 32])
torch.Size([2, 8, 32, 32, 32])
torch.Size([2, 8, 32, 32, 32])
torch.Size([2, 8, 32, 32, 32])
torch.Size([2, 8, 32, 32, 32])
torch.Size(

In [208]:
print(training_losses)
print(validation_losses)

[0.6666945219039917, 0.6666945219039917]
[0.666683554649353, 0.666683554649353]
