In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
import plotly.graph_objects as go

from unetr import UNETR

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

depth, height, width = 3, 3, 3 # Generating random 3D data for CT, PTV, OAR, and a target dose
depth, height, width = 128, 128, 128 # Generating random 3D data for CT, PTV, OAR, and a target dose

stride1 = 4
# Second layer: Reduce depth from 16 to 8
stride2 = 2

class SimpleConv3D(nn.Module):
    def __init__(self):
        super(SimpleConv3D, self).__init__()
        self.conv1 = nn.Conv3d(4, 16, kernel_size=3, padding=1, stride=(2, 1, 1))  # Reduce depth from 64 to 32
        self.conv2 = nn.Conv3d(16, 32, kernel_size=3, padding=1, stride=(2, 1, 1))  # Reduce depth from 32 to 16
        self.conv3 = nn.Conv3d(32, 1, kernel_size=3, padding=1, stride=(2, 1, 1))  # Reduce depth from 16 to 8
    
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = self.conv3(x)
        return x


# Generating random data
ct = torch.rand(1, 1, depth, height, width, device=device)  # 1 batch, 1 channel, depth, height, width
ptv = torch.rand(1, 1, depth, height, width, device=device)
oar = torch.rand(1, 1, depth, height, width, device=device)
target_dose = torch.rand(1, 1, depth, height, width, device=device)  # This is what the model will learn to predict

# model = SimpleConv3D().to(device)

model = UNETR(input_dim=4, output_dim=1).to(device)


Using device: cuda


In [None]:
def visualize_volume(masked_pred, iteration):
    masked_pred_np = masked_pred[0, 0, :, :, :].cpu().detach().numpy()
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')
    ax.voxels(masked_pred_np, facecolors='blue', edgecolor='k')
    plt.title(f'3D View of the Predicted Volume - Iteration {iteration}')
    plt.show()

def train():
    input_target_dose = torch.zeros(1, 1, depth, height, width, device=device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    combined_input = torch.cat([ct, ptv, oar, input_target_dose], dim=1)
    
    iteration = 0

    for z in range(0, width, 8):
        output = model(combined_input)
        print(output.shape)
        loss = F.mse_loss(output, target_dose[:, :, :, :, z:z+8])

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        print(f'Iteration ({z}), Loss: {loss.item()}')
        
        input_target_dose[:, :, :, :, z:z+8] = output.detach()
        combined_input[:, 3, :, :, z:z+8] = output.detach()
        
        visualize_volume(input_target_dose, iteration)
        iteration += 1
            # print(output.shape)
            # loss = F.mse_loss(output, target_dose[:, :, x, y, z])
            # loss.backward()
            # optimizer.step()
            # optimizer.zero_grad()
            
            # # Print the loss
            # print(f'Iteration ({x},{y},{z}), Loss: {loss.item()}')

            # # Update the input_target_dose with the predicted value
            # input_target_dose[:, :, x, y, z] = output.detach()
            # combined_input[:, 3, x, y, z] = output.detach()
            
            # # Visualize the input_target_dose to show its gradual filling
            # visualize_volume(input_target_dose, iteration)
            # iteration += 1

train()

In [None]:
from tqdm import tqdm

def train_single_epoch(model, data_loader, optimizer, criterion, teacher_forcing=False):
    model.train()
    total_loss = 0
    pbar = tqdm(data_loader, desc="Train", leave=False)
    
    for batch in pbar:
        
        target = batch["dose"].unsqueeze(1)
        input_target_dose = torch.zeros_like(target, device=device)
        
        # ensure features/dose are in the correct shape
        # (batch_size, channels, height, width, depth)
        features = batch["features"].transpose(1, -1)
        
        combined_input = torch.cat([features, input_target_dose], dim=1)
        
        for x in range(0, 128, 8):
        
            optimizer.zero_grad()

            outputs = model(combined_input)
            
            loss = criterion(outputs, target[:, :, :, :, x:x+8])
            
            if teacher_forcing:
                combined_input[:, 3, :, :, x:x+8] = target[:, :, :, :, x:x+8]
            else:
                combined_input[:, 3, :, :, x:x+8] = outputs.detach()
            
            loss.backward()

            optimizer.step()
            total_loss += loss.item()

    return total_loss / len(data_loader)

In [None]:

# ct = torch.rand(1, 1, depth, height, width)  # 1 batch, 1 channel, depth, height, width
# ptv = torch.rand(1, 1, depth, height, width)
# oar = torch.rand(1, 1, depth, height, width)
# target_dose = torch.rand(1, 1, depth, height, width)  # This is what the model will learn to predict

# model = SimpleConv3D()
# def train():
#     # Initialize the input tensor with zeros
#     input_target_dose = torch.zeros(1, 1, depth, height, width) 
#     combined_input = torch.cat([ct, ptv, oar, input_target_dose], dim=1)  # Concatenate along the channel dimension

#     optimizer = torch.optim.Adam(model.parameters(), lr=0.01)   
#     optimizer.zero_grad()
    
#     for x in range(depth):
#         for y in range(height):
#             for z in range(width):
#                 combined_input = torch.cat([ct, ptv, oar, input_target_dose], dim=1)
#                 output = model(combined_input)
#                 loss = F.mse_loss(output, target_dose[:, :, x, y, z])
#                 loss.backward()
#                 optimizer.step()
#                 optimizer.zero_grad()
                
#                 input_target_dose[:, :, x, y, z] = target_dose[:, :, x, y, z]
                
#     return loss                

# train()
