In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
from transformer_rnn import TransformerRNN

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [29]:
class ConvRNNCell(nn.Module):
    def __init__(self, input_channels, hidden_channels, kernel_size):
        super(ConvRNNCell, self).__init__()
        self.hidden_channels = hidden_channels
        padding = kernel_size // 2  # same padding
        self.conv = nn.Conv2d(in_channels=input_channels + hidden_channels,
                              out_channels=hidden_channels,
                              kernel_size=kernel_size,
                              padding=padding)

    def forward(self, x, hidden_state):
        combined = torch.cat([x, hidden_state], dim=1)  # concatenate along channel axis
        hidden_state = torch.tanh(self.conv(combined))
        return hidden_state

In [30]:
class ConvRNN(nn.Module):
    def __init__(self, input_channels, hidden_channels, kernel_size, num_layers):
        super(ConvRNN, self).__init__()
        self.num_layers = num_layers
        self.hidden_channels = hidden_channels
        self.cells = nn.ModuleList([
            ConvRNNCell(input_channels if i == 0 else hidden_channels, hidden_channels, kernel_size)
            for i in range(num_layers)
        ])

    def forward(self, x, h_0=None):
        batch_size, seq_len, channels, height, width = x.size()
        
        # Initialize hidden states if not provided
        if h_0 is None:
            h_0 = [torch.zeros(batch_size, self.hidden_channels, height, width).to(x.device)
                   for _ in range(self.num_layers)]
        
        hidden_states = h_0
        outputs = []

        for t in range(seq_len):
            input_t = x[:, t]
            for i, cell in enumerate(self.cells):
                hidden_states[i] = cell(input_t, hidden_states[i])
                input_t = hidden_states[i]
            outputs.append(hidden_states[-1])
        
        outputs = torch.stack(outputs, dim=1)
        return outputs, hidden_states

In [61]:
# %load_ext autoreload
# %autoreload 2
import transformer_rnn

import importlib
importlib.reload(transformer_rnn)

depth, height, width = 128, 128, 128 # Generating random 3D data for CT, PTV, OAR, and a target dose

# Generating random data
ct = torch.rand(1, 1, depth, height, width, device=device)  # 1 batch, 1 channel, depth, height, width
ptv = torch.rand(1, 1, depth, height, width, device=device)
oar = torch.rand(1, 1, depth, height, width, device=device)
target_dose = torch.rand(1, 1, depth, height, width, device=device)  # This is what the model will learn to predict

unetr = transformer_rnn.TransformerRNN(input_dim=3, output_dim=1).to(device)
combined_input = torch.cat((ct, ptv, oar), dim=1).to(device)

In [62]:
output = unetr(combined_input)
print(output.shape)

x shape:  torch.Size([1, 128, 256, 12, 16])
h0 shape:  torch.Size([1, 96, 12, 16])
output shape:  torch.Size([1, 128, 96, 12, 16])
output shape:  torch.Size([1, 3, 128, 128, 48])
output shape:  torch.Size([1, 3, 128, 128, 128])
torch.Size([1, 3, 128, 128, 128])


In [32]:
B = 1  # Batch size
T = 10  # Sequence length (number of time steps)
C_in = 3  # Number of input channels
H = 12  # Height
W = 16  # Width

x = torch.randn(B, T, C_in, H, W)  # Example input tensor

# Define ConvRNN parameters
input_channels = C_in
hidden_channels = 96
kernel_size = 3
num_layers = 1

# Initialize the ConvRNN model
model = ConvRNN(input_channels, hidden_channels, kernel_size, num_layers)

# Initialize hidden state with the correct shape
h_0 = [torch.zeros(B, hidden_channels, H, W)]

print("Input shape: ", x.shape)
print("Hidden state shape: ", [h.shape for h in h_0])

# Forward pass
outputs, hidden_states = model(x, h_0)
print("Output shape: ", outputs.shape)
print("Hidden state shape: ", [h.shape for h in hidden_states])

Input shape:  torch.Size([1, 10, 3, 12, 16])
Hidden state shape:  [torch.Size([1, 96, 12, 16])]
Output shape:  torch.Size([1, 10, 96, 12, 16])
Hidden state shape:  [torch.Size([1, 96, 12, 16])]


In [None]:
# feature_map, z3, z6, z9 = unetr(combined_input)

# # concatentate z3, z6, z9
# z = torch.cat((z3, z6, z9), dim=2) # shape: [1, 96, 12, 4, 4])

# # flatten
# h = z.view(1, 1, -1)
# print(f"h shape: {h.shape}")

# step_size_input = 1
# seq_len = combined_input.shape[2] // step_size_input
# print(f"Sequence length: {seq_len}")

# x = [combined_input[:, :, i:i+step_size_input, ...] for i in range(0, combined_input.shape[2], step_size_input)]
# x = torch.stack(x, dim=1) # we now have batch, seq_len, channel, 1, height, width
# x = x.squeeze(3)
# print(f"x shape: {x.shape}")
# batch, seq_len, channel, height, width = x.shape
# # # flatten the last 3 dims
# # x = x.view(batch, seq_len, channel * step_size_input * height * width)

# # print(f"encoded states shape: {x.shape}")
# input_channels = channel
# hidden_channels = 5
# hidden_size = h.shape[2]
# num_layers = 1
# kernel_size = 3

# conv_rnn = ConvRNN(input_channels, hidden_channels, kernel_size, num_layers).to(device)

# # h with shape (batch_size, self.hidden_channels, height, width)

# outputs, hidden_states = conv_rnn(x)
# print(outputs.shape)
# print(hidden_states[0].shape)



# from torch.nn import RNN

# rnn = RNN(input_size, hidden_size, num_layers, batch_first=True).to(device)

# output, h_n = rnn(x, h)

# print(f"RNN output shape: {output.shape}")
# print(f"RNN h_n shape: {h_n.shape}")

h shape: torch.Size([1, 1, 2304])
Sequence length: 128
x shape: torch.Size([1, 128, 3, 128, 128])
torch.Size([1, 128, 5, 128, 128])
torch.Size([1, 5, 128, 128])


In [None]:
feature_map, z3, z6, z9 = unetr(combined_input)

# concatentate z3, z6, z9
z = torch.cat((z3, z6, z9), dim=2) # shape: [1, 96, 12, 4, 4])

# flatten
h = z.view(1, 1, -1)
print(f"h shape: {h.shape}")

step_size_input = 8
seq_len = combined_input.shape[2] // step_size_input
print(f"Sequence length: {seq_len}")

x = [combined_input[:, :, i:i+step_size_input, ...] for i in range(0, combined_input.shape[2], step_size_input)]
x = torch.stack(x, dim=1) # we now have batch, seq_len, channel, step_size_input, height, width

batch, seq_len, channel, step_size_input, height, width = x.shape

# flatten the last 3 dims
x = x.view(batch, seq_len, channel * step_size_input * height * width)
print(f"encoded states shape: {x.shape}")


input_size = x.shape[2]
# hidden_size = z.shape[1]
hidden_size = h.shape[2]
num_layers = 1

from torch.nn import RNN

rnn = RNN(input_size, hidden_size, num_layers, batch_first=True).to(device)

output, h_n = rnn(x, h)

print(f"RNN output shape: {output.shape}")
print(f"RNN h_n shape: {h_n.shape}")

  return F.conv3d(


h shape: torch.Size([1, 1, 2304])
Sequence length: 16
encoded states shape: torch.Size([1, 16, 393216])
RNN output shape: torch.Size([1, 16, 2304])
RNN h_n shape: torch.Size([1, 1, 2304])


In [None]:
# class RNN(nn.Module):
#     def __init__(self, input_size, hidden_size, num_layers, output_size):
#         super(RNN, self).__init__()
#         self.hidden_size = hidden_size
#         self.num_layers = num_layers
#         self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
#         self.fc = nn.Linear(hidden_size, output_size)
        
#     def forward(self, x):
#         # Set initial hidden and cell states 
#         h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device) 
#         c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)
        
#         # Forward propagate LSTM
#         out, _ = self.lstm(x, (h0, c0))  # out: tensor of shape (batch_size, seq_length, hidden_size)
        
#         # Decode the hidden state of the last time step
#         out = self.fc(out[:, -1, :])
#         return out
    
    
# import torch rnn


In [None]:
class ConvolutionalMGRU(nn.Module):


    def __init__(self):
        super().__init__()

        #self.SourceAxisDistance = config.source_axis
        #self.res_y = config.res_axis
        #self.NH = config.n_hidden

        self.SourceAxisDistance = 100.0
        self.res_y = 0.25
        self.NH = 64

        n_kernel = 3
        self.NY = 6
        self.NX = 4

        self.convXR = nn.Conv2d(self.NX, self.NH, n_kernel, padding='same', padding_mode='replicate')
        self.convXZ = nn.Conv2d(self.NX, self.NH, n_kernel, padding='same', padding_mode='replicate')
        self.convXN = nn.Conv2d(self.NX, self.NH, n_kernel, padding='same', padding_mode='replicate')

        self.linYR = nn.Linear(self.NY, self.NH)
        self.linYZ = nn.Linear(self.NY, self.NH)
        self.linYN = nn.Linear(self.NY, self.NH)

        self.convHR = nn.Conv2d(self.NH, self.NH, n_kernel, padding='same', padding_mode='replicate')
        self.convHZ = nn.Conv2d(self.NH, self.NH, n_kernel, padding='same', padding_mode='replicate')
        self.convHN = nn.Conv2d(self.NH, self.NH, n_kernel, padding='same', padding_mode='replicate')

        self.linearOut = nn.Linear(self.NH, 8)

        self.sigmoid = nn.Sigmoid()

    def forward(self, gantry, Fluence, Grid, Slices_CT, ymin):
        dimB = Grid.size(0)
        dimX = Grid.size(2)
        dimY = len(Slices_CT)
        dimZ = Grid.size(1)
        #device0 = Segment.device
        #collimatorAngle = Segment[:, 0]
        #gantryAngle = Segment[:, 2:4]
        device0 = "cuda:0"
        nb=2
        #collimatorAngle = torch.full((nb,), 20.0, device=device0, dtype=torch.float64)
        #gantryAngle = torch.full((nb,), gantry, device=device0, dtype=torch.float64)
        collimatorAngle = torch.full((nb,), 20.0, device=device0)
        gantryAngle = torch.full((nb,), gantry, device=device0)

        #print(Grid.permute(0, 3, 1, 2).shape)   -> ([4, 2, 23, 23])
        print(Fluence.unsqueeze(1).shape)

        XI_ = torch.cat((Grid.permute(0, 3, 1, 2), Fluence.unsqueeze(1)), dim=1)
        '''
        YI_ = torch.stack(
            (collimatorAngle.deg2rad().sin(), collimatorAngle.deg2rad().cos(),
             (gantryAngle[:, 1] - gantryAngle[:, 0]).deg2rad().sin(),
             (gantryAngle[:, 1] - gantryAngle[:, 0]).deg2rad().cos()
             ), dim=-1)
        '''

        YI_ = torch.stack(
            (collimatorAngle.deg2rad().sin(), collimatorAngle.deg2rad().cos(),
             (gantryAngle).deg2rad().sin(),
             (gantryAngle).deg2rad().cos()
             ), dim=-1)

        h0 = torch.zeros((1, 1, 1, 1), dtype=torch.float32, device=device0).expand(dimB, self.NH, dimZ, dimX)

        out = []
        for y in range(dimY):
            ### conic effect
            ycoor = (ymin + y) * self.res_y
            xz_fac = (self.SourceAxisDistance + ycoor) / self.SourceAxisDistance
            ### conic effect

            XI = torch.cat((Slices_CT[y].unsqueeze(1), XI_), dim=1)
            YI = torch.cat(
                (YI_, torch.tensor((ycoor, xz_fac), dtype=torch.float32, device=device0).unsqueeze(0).expand(dimB, -1)),
                dim=-1)

            r = self.sigmoid(self.convXR(XI) + self.linYR(YI).unsqueeze(-1).unsqueeze(-1) + self.convHR(h0))
            z = self.sigmoid(self.convXZ(XI) + self.linYZ(YI).unsqueeze(-1).unsqueeze(-1) + self.convHZ(h0))
            n = torch.tanh(self.convXN(XI) + self.linYN(YI).unsqueeze(-1).unsqueeze(-1) + r * self.convHN(h0))

            ht = (1 - z) * n + z * h0

            out.append(ht)

            h0 = ht

        return self.linearOut(torch.stack(out, dim=3).permute(0, 2, 3, 4, 1))