In [None]:
!pip install SimpleITK
!pip install pydicom

In [None]:
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import torchvision
import sys
import numpy as np
import os
import csv
import SimpleITK as sitk
sys.path.append('/content/drive/MyDrive/PROJECT/RMSim_SeqX2Y/SeqX2Y-main')

## **Dataset Dataloader**

In [None]:
# CTデータセットclassを作成（CTDataset）：
class CTDataset(Dataset):
    # def __init__(self, file_paths, targets, transform=None):
    def __init__(self, file_paths, transform=None):
        self.file_paths = file_paths
        # self.targets = targets
        self.transform = transform

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
    # def __getitem__(self):
        file_path = self.file_paths[idx]

        # CT画像データの読み込む
        image = sitk.ReadImage(file_path)
        image_array = sitk.GetArrayFromImage(image)

        # 画像データの前処理
        if self.transform:
            image_array = self.transform(image_array)

        # 目標idを取得
        # target = self.targets[idx]

        # 画像データとタグを返す
        # return image_array, target
        return image_array

In [None]:
# Initializing the transform for the dataset
transform = torchvision.transforms.Compose([
	torchvision.transforms.ToTensor(),  # 画像をテンソルに変換
	torchvision.transforms.Normalize((0.5), (0.5))  # 画像データを正規化
])

In [None]:
# フォルダパスの定義、datasetとdataloaderのインスタンス化
folder_path = "/content/drive/MyDrive/DataSet/4DCT_dicom/T00"
file_paths = []

# フォルダ内のファイルをトラバースしてfile_pathに保存
for file_name in sorted(os.listdir(folder_path)):
    file_path = os.path.join(folder_path, file_name)
    file_paths.append(file_path)

# datasetのインスタンス化
dataset = CTDataset(file_paths)

# dataloaderのインスタンス化
dataloader = DataLoader(dataset, batch_size=16, shuffle=True)

## **seq2seq_4DCT_voxelmorph**

In [None]:
# seq2seq_4DCT_voxelmorph.py
import torch
import torch.nn as nn

from models.ConvLSTMCell3d import ConvLSTMCell
from layers import SpatialTransformer
from models.unet_utils import *

class EncoderDecoderConvLSTM(nn.Module):
    def __init__(self, nf, in_chan, size1, size2, size3):
        super(EncoderDecoderConvLSTM, self).__init__()

        """ ARCHITECTURE

        # Encoder (ConvLSTM)
        # Encoder Vector (final hidden state of encoder)
        # Decoder (ConvLSTM) - takes Encoder Vector as input
        # Decoder (3D CNN) - produces regression predictions for our model

        """
        # BxCx1xDxWxH

        self.encoder1_conv = nn.Conv3d(in_channels=in_chan,
                        out_channels=nf,
                        kernel_size=(3, 3, 3),
                        padding=(1, 1, 1))

        self.down1 = nn.MaxPool3d(kernel_size=2, stride=2)

        self.ConvLSTM3d1 = ConvLSTMCell(input_dim=nf,
                        hidden_dim=nf,
                        kernel_size=(3,3,3),
                        bias=True)
        self.ConvLSTM3d2 = ConvLSTMCell(input_dim=nf,
                        hidden_dim=nf,
                        kernel_size=(3, 3, 3),
                        bias=True)
        self.ConvLSTM3d3 = ConvLSTMCell(input_dim=nf,
                        hidden_dim=nf,
                        kernel_size=(3, 3, 3),
                        bias=True)
        self.ConvLSTM3d4 = ConvLSTMCell(input_dim=nf,
                        hidden_dim=nf,
                        kernel_size=(3, 3, 3),
                        bias=True)

        self.up1 = nn.Upsample(scale_factor=2, mode='trilinear')

        self.out = ConvOut(nf)

        self.transformer = SpatialTransformer((size1, size2, size3))


    def autoencoder(self, x, seq_len, rpm_x, rpm_y, future_step, h_t4, c_t4, h_t5, c_t5, h_t6, c_t6, h_t7, c_t7):
        latent = []
        out = []
        # encoder
        e1 = []
        e2 = []
        e3 = []

        for t in range(seq_len):
            # print(rpm_x.shape, rpm_y.shape)
            h_t1 = self.encoder1_conv(x[:,t,...])
            down1 = self.down1(h_t1)

            h_t4, c_t4 = self.ConvLSTM3d1(input_tensor=down1,
                                   cur_state=[h_t4,c_t4])
            h_t5, c_t5 = self.ConvLSTM3d2(input_tensor = h_t4,
                                   cur_state = [h_t5,c_t5])
            # h_t5 = torch.mul(h_t5,torch.squeeze(rpm_x[0,1]))
            h_t5 = torch.mul(h_t5,torch.squeeze(rpm_x[0,t-1]))
            # torch.squeeze(rpm_x[0, t-1])：这部分代码先使用 torch.squeeze() 函数将 rpm_x 张量中的大小为1的维度压缩（去掉），
            # 然后通过索引 [0, t-1] 获取 rpm_x 张量的特定元素。其中，t 是一个整数变量，表示一个时间步的索引。
            # torch.squeeze(rpm_x[0,t-1])：Remove the "t-1" position size 1, if not 1 else don't remove.
            # simple multiplication between rpm and feature

            encoder_vector = h_t5


        for t in range(future_step):

            h_t6, c_t6 = self.ConvLSTM3d3(input_tensor=encoder_vector,
                                   cur_state=[h_t6, c_t6])
            h_t7, c_t7 = self.ConvLSTM3d4(input_tensor=h_t6,
                                   cur_state=[h_t7, c_t7])
            h_t7 = torch.mul(h_t7, torch.squeeze(rpm_y[0,t]))
            # Simple multiplication between rpm and later phase features
            encoder_vector = h_t7
            latent += [h_t7]

        latent = torch.stack(latent,1)
        latent = latent.permute(0,2,1,3,4,5)
        timestep = latent.shape[2]

        output_img = []
        output_dvf = []
        # spatial transformer = transformer
        for i in range(timestep):
            output_ts = self.up1(latent[:,:,i,...])
            dvf = self.out(output_ts)
            warped_img = self.transformer(x[:,0,...],dvf)
            output_img += [warped_img]
            output_dvf += [dvf]

        output_img = torch.stack(output_img,1)
        output_dvf = torch.stack(output_dvf,1)
        output_img = output_img.permute(0,2,1,3,4,5)
        output_dvf = output_dvf.permute(0,2,1,3,4,5)

        return output_img, output_dvf


    def forward(self, x, rpm_x, rpm_y, future_seq=0, hidden_state=None):

        """
        Parameters
        ----------
        input_tensor:
            5-D Tensor of shape (b, t, c, h, w)        #   batch, time, channel, height, width
        """

        # find size of different input dimensions
        b, seq_len, _, d, h, w = x.size()

        # initialize hidden states
        # 当使用//运算符进行整数除法时，结果将会是一个整数，向下取整到最接近的整数值。这与普通的除法运算符/不同，后者执行的是浮点数除法，结果可以包含小数部分
        h_t4, c_t4 = self.ConvLSTM3d1.init_hidden(batch_size=b, image_size=(int(d // 2), int(h // 2), int(w // 2)))
        h_t5, c_t5 = self.ConvLSTM3d2.init_hidden(batch_size=b, image_size=(int(d // 2), int(h // 2), int(w // 2)))
        h_t6, c_t6 = self.ConvLSTM3d3.init_hidden(batch_size=b, image_size=(int(d // 2), int(h // 2), int(w // 2)))
        h_t7, c_t7 = self.ConvLSTM3d4.init_hidden(batch_size=b, image_size=(int(d // 2), int(h // 2), int(w // 2)))

        # autoencoder forward
        # outputs = self.autoencoder(x, seq_len, future_seq, h_t1, c_t1, h_t2, c_t2, h_t3, c_t3, m_t3, h_t4, c_t4, m_t4,
        # h_t5, c_t5, m_t5, h_t6, c_t6, m_t6, h_t7, c_t7, h_t8, c_t8)
        outputs = self.autoencoder(x, seq_len, rpm_x, rpm_y, future_seq, h_t4, c_t4, h_t5, c_t5, h_t6, c_t6, h_t7, c_t7)

        return outputs

# Instantiating the model and hyperparameters
model = EncoderDecoderConvLSTM(nf=96, in_chan=1, size1=128, size2=128, size3=128)
# model = ConvLSTMCell(input_dim=1, hidden_dim=96, kernel_size=(3,3), bias=False)

## **Loss function**

In [None]:
# MSE lossを使う
criterion = torch.nn.MSELoss()
num_epochs = 1
# Adam optimizerを使う
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

## **Reading RPM**

In [None]:
# Reading RPM #
with open('/content/drive/MyDrive/PROJECT/RMSim_SeqX2Y/SeqX2Y-main/rpm_max.csv', 'r') as f:
    data = list(csv.reader(f, delimiter=","))

RPM = np.array(data)
RPM = np.float32(RPM)
test_RPM = RPM

## **Training**

In [None]:
model.train() # モデルをトレーニングモードに設定
# トレーニングの循環部分
for epoch in range(num_epochs):
    for batch in dataloader:
        # Randomly choose RPM
        # patient = batch
        rpm = np.int(np.random.randint(0, 20, 1))
        # rpm = np.int(np.random.randint(1, 10, 1))
        # print("Patient index:", patient,"RPM index:",rpm )
        test_rpm_ = test_RPM[rpm,:]
        test_x_rpm = test_RPM[rpm,:1]
        test_x_rpm = np.expand_dims(test_x_rpm,0)
        test_y_rpm = test_RPM[rpm,0:]
        test_y_rpm = np.expand_dims(test_y_rpm,0)
        test_x_rpm_tensor = torch.Tensor(test_x_rpm)
        test_y_rpm_tensor = torch.Tensor(test_y_rpm)
#___________________________________________________________________________________

        bbatch=batch.unsqueeze(dim=0) # shape [1, 9 ,1 ,512 ,512]
        bbatch=bbatch.unsqueeze(dim=2) # shape [1, 9, 1, 1, 512, 512]

        # (b, t, _, c, h, w)
        fake = torch.randn([1,3,3,3,512,512], dtype=torch.float)

        # 勾配ゼロクリア
        optimizer.zero_grad()
        # 入力データを取得
        # inputs, targets = batch[0].to(device), batch[1].to(device)
        inputs = bbatch.float()
        # 順伝播 forward propagation
        outputs, DVF = model(fake, rpm_x=test_x_rpm_tensor, rpm_y=test_y_rpm_tensor, future_seq=9, hidden_state=None)
        # lossを計算
        loss = criterion(outputs, inputs)
        # 逆伝播 back propagation
        loss.backward()
        # パラメータの更新
        optimizer.step()
        # 現在のBatchのlossをprint
        print(f"Epoch: {epoch+1}, Batch Loss: {loss.item()}")
# モデルを保存する
torch.save(model, "/content/drive/MyDrive/PROJECT/RMSim_SeqX2Y/My_Train_Model/model.pth")