In [1]:
import pandas as pd
import numpy as np
import random
import os
import SimpleITK as sitk
from tqdm import tqdm
from matplotlib import pyplot as plt
import torch
from torch.utils.data import Dataset, DataLoader
from torch.nn import functional as F
from torch import optim
import pydicom
import cv2
import random
import sys
from torch import nn

sys.path.append("/mnt/nas32/forGPU/jegal/Workspace/Work/0_CNN_total_Pytorch_new/")

In [2]:
selected_data_df = pd.read_csv("../d02_pp_nz_paap_splited.csv")
fvcf_normal_col_name = "FVC_norm"
fvcf_measured_col_name = "FVC_meas"

fevf_normal_col_name = "FEV_norm"
fevf_measured_col_name = "FEV_meas"

task_name = "d02"
get_recon = False

log_folder = f"../results/{task_name}"
if get_recon:
    log_folder = f"{log_folder}_recon"
    
log_csv_path = f"{log_folder}/log.csv"
log_plot_folder = f"{log_folder}/plots"
log_weight_folder = f"{log_folder}/weights"
os.makedirs(log_plot_folder, exist_ok=True)
os.makedirs(log_weight_folder, exist_ok=True)

num_epochs = 100
num_gpu = torch.cuda.device_count()
batch_size = 8 * num_gpu
print(f"batch_size: {batch_size}")
DEVICE = "cuda"


batch_size: 8


  selected_data_df = pd.read_csv("../d02_pp_nz_paap_splited.csv")


In [3]:
def load_dicom_array_with_rescale(dicom_path):
    """
    Load a DICOM file and return the rescaled pixel array.
    Applies Rescale Intercept and Rescale Slope if available.
    
    Parameters:
        dicom_path (str): Path to the DICOM file.
        
    Returns:
        np.ndarray: Rescaled pixel array.
    """
    dicom_data = pydicom.dcmread(dicom_path, force=True)
    pixel_array = dicom_data.pixel_array
    intercept = getattr(dicom_data, "RescaleIntercept", 0)
    slope = getattr(dicom_data, "RescaleSlope", 1)
    return pixel_array * slope + intercept

class PFTDataset(Dataset):
    def __init__(self, dcm_path_list, fvcf_array, fevf_array):
        assert len(dcm_path_list) == len(fvcf_array)
        assert len(fvcf_array) == len(fevf_array)
        
        self.dcm_path_list = dcm_path_list
        self.fvcf_array = fvcf_array
        self.fevf_array = fevf_array
        
    def __len__(self):
        return len(self.dcm_path_list)
    
    def __getitem__(self, idx):
        
        dcm_path = self.dcm_path_list[idx]
        fvcf_value = self.fvcf_array[idx]
        fevf_value = self.fevf_array[idx]
        # xray_array.shape = [1, 512, 512]
        xray_array = load_dicom_array_with_rescale(dcm_path)
        xray_array = cv2.resize(xray_array, (512, 512), interpolation=cv2.INTER_LINEAR)
        
        xray_array = (xray_array - xray_array.mean()) / xray_array.std()
        xray_array = torch.tensor(xray_array[None], dtype=torch.float32)
        pft_value = torch.tensor([fvcf_value, fevf_value], dtype=torch.float32)
        return xray_array, pft_value

In [4]:
def convert_dcm_path(dcm_path):
    dcm_path = dcm_path.replace("/workspace/nas216", "/mnt/nas216")
    return dcm_path

def get_phase_dcm_path_list(selected_data_df, phase_str):
    phase_row = selected_data_df["phase"] == phase_str
    phase_df = selected_data_df[phase_row]
    phase_dcm_path_list = list(phase_df["dcm_path"])
    phase_dcm_path_list = [convert_dcm_path(dcm_path) for dcm_path in phase_dcm_path_list]
    return phase_dcm_path_list

def get_pft_value_list(selected_data_df, phase_str):
    phase_row = selected_data_df["phase"] == phase_str
    phase_df = selected_data_df[phase_row]
    
    phase_fvcf_normal_col_array = np.array(phase_df[fvcf_normal_col_name])
    phase_fvcf_measured_col_array = np.array(phase_df[fvcf_measured_col_name])
    
    phase_fevf_normal_col_array = np.array(phase_df[fevf_normal_col_name])
    phase_fevf_measured_col_array = np.array(phase_df[fevf_measured_col_name])
    
    fvcf_array = (phase_fvcf_measured_col_array - phase_fvcf_normal_col_array) / phase_fvcf_normal_col_array
    fevf_array = (phase_fevf_measured_col_array - phase_fevf_normal_col_array) / phase_fevf_normal_col_array
    return fvcf_array, fevf_array

train_dcm_path_list = get_phase_dcm_path_list(selected_data_df, "train")
val_dcm_path_list = get_phase_dcm_path_list(selected_data_df, "val")
test_dcm_path_list = get_phase_dcm_path_list(selected_data_df, "test")

train_fvcf_array, train_fevf_array = get_pft_value_list(selected_data_df, "train")
val_fvcf_array, val_fevf_array = get_pft_value_list(selected_data_df, "val")
test_fvcf_array, test_fevf_array = get_pft_value_list(selected_data_df, "test")

train_dataset = PFTDataset(train_dcm_path_list, train_fvcf_array, train_fevf_array)
val_dataset = PFTDataset(val_dcm_path_list, val_fvcf_array, val_fevf_array)
test_dataset = PFTDataset(test_dcm_path_list, test_fvcf_array, test_fevf_array)

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, num_workers=8, shuffle=True, pin_memory=True)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, num_workers=8, shuffle=False, pin_memory=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, num_workers=8, shuffle=False, pin_memory=True)

In [5]:
from src.model.inception_resnet_v2.multi_task.multi_task_2d_v2 import InceptionResNetV2MultiTask2D
from src.model.inception_resnet_v2.diffusion.diff_ae.diffusion_layer import GroupNorm32
model = InceptionResNetV2MultiTask2D(input_shape=(1, 512, 512), class_channel=2, seg_channels=None, validity_shape=(1, 8, 8), inject_class_channel=None,
                                     block_size=16, include_cbam=False, decode_init_channel=None,
                                     norm="group", act="silu", dropout_proba=0.05,
                                     seg_act="softmax", class_act="tanh", recon_act=None, validity_act="sigmoid",
                                     get_seg=False, get_class=True, get_recon=get_recon, get_validity=False,
                                     use_class_head_simple=True, include_upsample=False,
                                     use_decode_simpleoutput=True, use_seg_conv_transpose=True,
                                     use_checkpoint=False).to(DEVICE)
with torch.no_grad():
    xray_array, pft_value = train_dataset[0]
    xray_array, pft_value = xray_array[None].to(DEVICE), pft_value[None].to(DEVICE)
    print(xray_array.shape, pft_value.shape)
    if get_recon:
        pred_pft_value, pred_xray_array = model(xray_array)
        print(pred_xray_array.shape, pred_pft_value.shape)
    else:
        pred_pft_value = model(xray_array)
        print(pred_pft_value.shape)
        
model = nn.DataParallel(model)

torch.Size([1, 1, 512, 512]) torch.Size([1, 2])
torch.Size([1, 2])


In [6]:
from src.model.train_util.logger import CSVLogger
from src.model.train_util.scheduler import OneCycleLR

epoch_col = ["epoch"]
train_col = ["train_loss", "pft_l1_loss"]
val_col = ["val_loss", "val_pft_l1_loss"]
if get_recon:
    train_col.append("recon_l1_loss")
    val_col.append("val_recon_l1_loss")

csv_logger = CSVLogger(log_csv_path, epoch_col + train_col + val_col)
loss_fn = F.l1_loss

optimizer = optim.AdamW(model.parameters(), lr=2e-5)
step_size = len(train_dataloader)  # 매 30 step마다

# StepLR 스케줄러 정의
scheduler_params = {
"step_size": step_size,
"first_epoch": 2,
"second_epoch": 68,
"total_epoch": num_epochs
}
scheduler = OneCycleLR(optimizer, **scheduler_params)

../results/d02/log.csv check exist...
../results/d02/log.csv exist.
../results/d02/log.csv has been deleted.


In [7]:
for epoch in range(1, num_epochs+1):
    model.train()
    train_pbar = tqdm(train_dataloader, total=len(train_dataloader))
    
    train_loss_list = []
    train_pft_loss_list = []
    train_recon_loss_list = []
    val_loss_list = []
    val_pft_loss_list = []
    val_recon_loss_list = []
    
    for xray_array, pft_value in train_pbar:
        
        optimizer.zero_grad()
        
        xray_array, pft_value = xray_array.to(DEVICE), pft_value.to(DEVICE)
        if get_recon:
            pred_pft_value, pred_xray_array = model(xray_array)
            recon_loss = loss_fn(pred_xray_array, xray_array)
        else:
            pred_pft_value = model(xray_array)
            recon_loss = torch.tensor(0)
        pft_loss = loss_fn(pred_pft_value, pft_value)
        
        loss = pft_loss * 0.9 + recon_loss * 0.1
        loss.backward()
        optimizer.step()
        scheduler.step()
        
        train_loss_list.append(loss.item())
        train_pft_loss_list.append(pft_loss.item())
        train_recon_loss_list.append(recon_loss.item())
        train_pbar.set_postfix({'loss_mean': f"{np.mean(train_loss_list):.4f}", 'loss_current': f"{loss.item():.4f}"})

    with torch.no_grad():
        for xray_array, pft_value in val_dataloader:
            xray_array, pft_value = xray_array.to(DEVICE), pft_value.to(DEVICE)
            if get_recon:
                pred_pft_value, pred_xray_array = model(xray_array)
                recon_loss = loss_fn(pred_xray_array, xray_array)
            else:
                pred_pft_value = model(xray_array)
                recon_loss = torch.tensor(0)
            pft_loss = loss_fn(pred_pft_value, pft_value)

            loss = pft_loss * 0.9 + recon_loss * 0.1

            val_loss_list.append(loss.item())
            val_pft_loss_list.append(pft_loss.item())
            val_recon_loss_list.append(recon_loss.item())
    
    ######################## Write Csv row #####################
    epoch_row_value_list = [f"{epoch}"]
    train_row_value_list = [f"{np.mean(train_loss_list):.3f}", 
                            f"{np.mean(train_pft_loss_list):.3f}",
                            f"{np.mean(train_recon_loss_list):.3f}"]
    val_row_value_list = [f"{np.mean(val_loss_list):.3f}", 
                            f"{np.mean(val_pft_loss_list):.3f}",
                            f"{np.mean(val_recon_loss_list):.3f}"]
    
    row_value_list = epoch_row_value_list + train_row_value_list + val_row_value_list
    csv_logger.writerow(row_value_list)
    
    ######################## Save model ########################
    torch.save({
    "model": diffusion_model.state_dict(),
    "optimizer": optimizer.state_dict(),
    },
    f"./{log_weight_folder}/{epoch:03d}.ckpt")
    ######################## Plot sample ########################
    with torch.no_grad():
        xray_array, pft_value = random.choice(val_dataset)
        xray_array, pft_value = xray_array[None].to(DEVICE), pft_value[None].to(DEVICE)
        if get_recon:
            pred_pft_value, pred_xray_array = model(xray_array)
        else:
            pred_pft_value = model(xray_array)
            
        pft_value_list = list(pft_value.cpu().numpy().round(3)[0])
        pred_pft_value_list = list(pred_pft_value.cpu().numpy().round(3)[0])
        _, ax = plt.imshow(1, 1, figsize=(8, 8))
        ax.imshow(xray_array, cmap="gray")
        ax.set_title(f"GT: {pft_value_list}, PRED: {pred_pft_value_list}")
        plt.tight_layout()
        plt.savefig(f"./{log_plot_folder}/polt_{epoch:03d}.png")
        plt.clf()        

  2%|██▊                                                                                                                           | 71/3252 [02:01<1:30:29,  1.71s/it, loss_mean=0.1300, loss_current=0.1537]


FileNotFoundError: Caught FileNotFoundError in DataLoader worker process 7.
Original Traceback (most recent call last):
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/_utils/worker.py", line 308, in _worker_loop
    data = fetcher.fetch(index)
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 51, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 51, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/tmp/ipykernel_1770/2366529230.py", line 36, in __getitem__
    xray_array = load_dicom_array_with_rescale(dcm_path)
  File "/tmp/ipykernel_1770/2366529230.py", line 12, in load_dicom_array_with_rescale
    dicom_data = pydicom.dcmread(dicom_path, force=True)
  File "/opt/conda/lib/python3.10/site-packages/pydicom/filereader.py", line 1042, in dcmread
    fp = open(fp, "rb")
FileNotFoundError: [Errno 2] No such file or directory: '/mnt/nas216/ds_pft_cxr/data/d02/CR09477/1/172_2023-0481_CR09477.1.1.dcm'


In [9]:
torch.tensor(0)

tensor(0)