In [3]:

import os
import fnmatch
import numpy as np
import torch
from tqdm import tqdm
import random

from torch.utils.data import DataLoader

from dual_network import Dual3DCNN6 as Dual
# from Dataset_json import PXAI_Dataset
from decayLR import DecayLR

from utilities import create_list_from_master_json, read_json_file, split_data

In [6]:
def list_patient_folders(data_path):
    """
    List all directories in the base_directory.
    Each directory represents a patient.
    """
    try:
        patient_folders = [name for name in os.listdir(data_path)
                           if os.path.isdir(os.path.join(data_path, name))]
        return patient_folders
    except FileNotFoundError:
        print(f"Directory {data_path} was not found.")
        return []

# Specify the directory where the patient folders are located
data_path = '/home/shahpouriz/Data/DBP_newDATA/DBP/nrrd/proton'

# Get the list of patient folders
patient_list = list_patient_folders(data_path)


In [8]:
# Split data
total_patients = len(patient_list)
train_ratio = 0.7
val_ratio = 0.2
test_ratio = 0.1
seed_random = random.randint(0,100)
train_data, val_data, test_data = split_data(patient_list, train_ratio, val_ratio, test_ratio, seed=seed_random)
print(len(train_data))
print(len(val_data))
print(len(test_data))

18
5
3


In [14]:
import torch
from torch.utils.data import Dataset
import SimpleITK as sitk
import numpy as np
import glob

class PXAI_Dataset(Dataset):
    def __init__(self, pCTs_path, rCTs_path, augment=False):
        """
        Initialize the dataset with lists of paths for planning CTs and repeated CTs.
        """
        self.pCT_paths = pCTs_path
        self.rCT_paths = rCTs_path
        self.augment = augment

    def __len__(self):
        return len(self.pCT_paths)

    def __getitem__(self, idx):
        # Since rCT_paths[idx] could be a list of paths, adjust processing accordingly
        pCT_path = self.pCT_paths[idx]
        rCT_paths = self.rCT_paths[idx]

        pCT = self.process_scan(pCT_path)
        rCTs = [self.process_scan(path) for path in rCT_paths]
        reg = np.array([])  # Placeholder for registration data, adjust as needed

        # Assuming rCTs is a list of rCT images, adjust the return value as per your model's requirement
        return pCT, rCTs, reg
        

    
    
    def process_scan(self, path):
        volume = self.read_nrrd_file(path)
        if not self.dose:
            volume = self.normalize(volume)
        return volume
    
    def read_nrrd_file(self, filepath):
        sitk_img = sitk.ReadImage(filepath)
        img = sitk.GetArrayFromImage(sitk_img).astype(np.float32)
        return img
    
    def normalize(self, volume, min=-1000, max=1000):
        volume[ volume < min ] = min
        volume[ volume > max ] = max
        volume = (volume - min) / (max - min)
        volume = volume.astype("float32")
        return volume
    
    def read_registration(self,path):
        with open(path) as f:
            lines = f.readlines()
        params = lines[3].split(' ')
        params = np.array(params[10::],dtype=np.float32)
        return params
    


def prepare_data(data_dir, patient_ids):
    """
    Scan through the patient folders to find pCT and rCT files.
    """
    pct_paths = []
    rct_paths = []
    
    for patient_id in patient_ids:
        patient_folder = os.path.join(data_dir, patient_id)
        # Find planning CT (pCT) and repeated CTs (rCT)
        planning_ct = glob.glob(os.path.join(patient_folder, 'rtdose_pCT*'))
        repeated_cts = glob.glob(os.path.join(patient_folder, 'rtdose_rCT*'))
        
        if planning_ct:
            pct_paths.extend(planning_ct)
            rct_paths.extend(repeated_cts)
        
    return pct_paths, rct_paths
        

In [15]:
# Set parameters
starting_epoch = 0
decay_epoch = 150
final_epoch = 30
learning_rate = 0.0001
batchsize = 5
device_num = 1
lambda_reg = 0.000001

# Condition for saving list
save_list = False


exception_list = ['']

# Create lists
pct_train = []
rct_train = []
reg_train = []

pct_val = []
rct_val = []
reg_val = []

pct_test = []
rct_test = []
reg_test = []

# Prepare training, validation, and testing datasets
pct_train, rct_train = prepare_data(data_path, train_data)
pct_val, rct_val = prepare_data(data_path, val_data)
pct_test, rct_test = prepare_data(data_path, test_data)

# Initialize datasets
dataset_train = PXAI_Dataset(pct_train, rct_train)
dataset_val = PXAI_Dataset(pct_val, rct_val)
dataset_test = PXAI_Dataset(pct_test, rct_test)

# Initialize DataLoader instances
dataloader_train = DataLoader(dataset_train, batch_size=5, shuffle=True, num_workers=1)
dataloader_val = DataLoader(dataset_val, batch_size=5, shuffle=False, num_workers=1)
dataloader_test = DataLoader(dataset_test, batch_size=5, shuffle=False, num_workers=1)  # If needed

print("Training, validation, and testing datasets are ready.")

Training, validation, and testing datasets are ready.


In [17]:

# Build model
print('Initializing model...')
model = Dual(width=128, height=128, depth=128)
device = torch.device(f"cuda:{device_num}" if torch.cuda.is_available() else "cpu")
model.to(device)

# Define loss
print('Defining loss...')
mae_loss = torch.nn.L1Loss()
mse_loss = torch.nn.MSELoss()

# Define optimizer
print('Defining optimizer...')
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, betas=(0.5, 0.999))

# Define scheduler
print('Defining scheduler...')
lr_lambda = DecayLR(epochs=final_epoch, offset=0, decay_epochs=decay_epoch).step
lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)


Initializing model...
Defining loss...
Defining optimizer...
Defining scheduler...


AssertionError: Decay must start before the training session ends!

In [16]:

# Training
torch.backends.cudnn.benchmark = True

for epoch in range(starting_epoch, final_epoch):
    progress_bar = tqdm(enumerate(dataloader_train), total=len(dataloader_train))
    mae_list = []
    for i, data in progress_bar:
        pCT, rCT, reg = data
        pCT = pCT.unsqueeze(1).to(device)
        rCT = rCT.unsqueeze(1).to(device)
        reg = reg.to(device)
                
        output = model(pCT, rCT)
        # modified L1 loss
        # loss_output = mae_loss(output, reg)
        loss_output = mse_loss(output, reg)
        
        # L1 Regularization
        l1_reg = torch.tensor(0., requires_grad=True).to(device)
        for name, param in model.named_parameters():
            l1_reg = l1_reg + torch.norm(param, 1)
        
        loss_output += lambda_reg * l1_reg
        
        optimizer.zero_grad(set_to_none=True)
        loss_output.backward()
        optimizer.step()
        
        mae_list.append(loss_output.item())
        mean_mae = np.mean(mae_list)
        progress_bar.set_description(f'Epoch: {epoch}/{final_epoch}, Batch: {i}/{len(dataloader_train)}, Loss_avg: {mean_mae.item()}', refresh=True)       
    
    progress_bar_val = tqdm(enumerate(dataloader_val), total=len(dataloader_val))
    val_loss = []
    for j, data2 in progress_bar_val:
        pCT_val, rCT_val, reg_val = data2
        pCT_val = pCT_val.unsqueeze(1).to(device)
        rCT_val = rCT_val.unsqueeze(1).to(device)
        reg_val = reg_val.to(device)
        
        output_val = model(pCT_val, rCT_val)
        loss_output_val = mae_loss(output_val, reg_val)
        # loss_output_val = mse_loss(output_val, reg_val)
        
        val_loss.append(loss_output_val.item())
        mean_val_loss = np.mean(val_loss)
        progress_bar_val.set_description(f'Epoch: {epoch}/{final_epoch}, Batch: {j}/{len(dataloader_val)}, Loss_avg: {mean_val_loss.item()}', refresh=True)
    
    lr_scheduler.step(mean_val_loss)
    
    # Save model
    if not os.path.isdir(save_dir):
        os.makedirs(save_dir)
    
    current_valid_mae = val_loss[-1]    
    if current_valid_mae <= best_mae and epoch > 0:
        best_mae = current_valid_mae
        torch.save(model.state_dict(),f'{save_dir}/model_weights_dose_{epoch+1}_{fname_comment}.pt')
    with open(loss_file, 'a') as f: #a-append
        f.write(f'Epoch: {epoch+1}/{final_epoch}, Loss: {mean_mae}, Val: {mean_val_loss}\n')
            

# Train patients: 31
# ['DBP_HN011', 'DBP_HN005', 'DBP_HN012', 'DBP_HN028', 'DBP_HN041', 'DBP_HN020', 'DBP_HN043', 'DBP_HN027', 'DBP_HN006', 'DBP_HN036', 'DBP_HN024', 'DBP_HN013', 'DBP_HN032', 'DBP_HN023', 'DBP_HN014', 'DBP_HN026', 'DBP_HN039', 'DBP_HN040', 'DBP_HN035', 'DBP_HN015', 'DBP_HN025', 'DBP_HN022', 'DBP_HN031', 'DBP_HN042', 'DBP_HN033', 'DBP_HN021', 'DBP_HN018', 'DBP_HN045', 'DBP_HN037', 'DBP_HN034', 'DBP_HN002']
# Valid patients: 8
# ['DBP_HN004', 'DBP_HN029', 'DBP_HN007', 'DBP_HN008', 'DBP_HN010', 'DBP_HN016', 'DBP_HN017', 'DBP_HN019']
# Test patients: 3
# ['DBP_HN003', 'DBP_HN009', 'DBP_HN044']


Initializing model...
Defining loss...
Defining optimizer...


  return torch._C._cuda_getDeviceCount() > 0


Defining scheduler...


AssertionError: Decay must start before the training session ends!