## Setup environment

In [1]:

from monai.utils import first, set_determinism
from monai.transforms import (EnsureChannelFirstd, Compose, CropForegroundd, LoadImaged, Orientationd, RandCropByPosNegLabeld, ScaleIntensityRanged, Spacingd)
from monai.networks.nets import UNet
from monai.networks.layers import Norm
from monai.inferers import sliding_window_inference
from monai.data import CacheDataset, DataLoader, Dataset
from monai.apps import download_and_extract
from monai.transforms import CenterSpatialCropd
from monai.transforms import Resized
import torch
import matplotlib.pyplot as plt
import os
import glob
import torch.nn as nn
import json
from datetime import datetime
from data_preparation2 import DataHandling 
from UNet_model import create_unet



In [9]:
data_dir = '/home/shahpouriz/Data/Practic/ASC-PET-001'
directory = '/home/shahpouriz/Data/Practic/LOG'
output_dir = '/home/shahpouriz/Data/Practic/OUT'

## Set dataset path

In [3]:
# Function to read JSON config file
def read_config(config_path):
    with open(config_path, 'r') as config_file:
        return json.load(config_file)


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")



print("Starting preparing data ...")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

config_path = "/home/shahpouriz/Data/Practic/training_params.json"
config = read_config(config_path)

data_prep = DataHandling(config)
loaders, val_files, test_files = data_prep.prepare_data(loaders_to_prepare=["test"])
test_loader = loaders.get("test")
model = create_unet().to(device)


  return torch._C._cuda_getDeviceCount() > 0


Starting preparing data ...
Loading data from: /home/shahpouriz/Data/Practic/ASC-PET-001
Total images loaded: 184


Loading dataset: 100%|██████████| 40/40 [00:51<00:00,  1.29s/it]


In [None]:
# import numpy as np
# check_ds = Dataset(data=test_files, transform=val_transforms)
# check_loader = DataLoader(check_ds, batch_size=1)
# check_data = first(check_loader)
# image, target = (check_data["image"][0][0], check_data["target"][0][0])
# print(f"image shape: {image.shape}, target shape: {target.shape}")
# # plot the slice [:, :, n]
# n = 105

# plt.figure("check", (12, 6))

# plt.subplot(1, 2, 1)
# plt.title("image")
# # Rotate the image slice and then display it
# rotated_image = np.rot90(image[:, n, :])
# plt.imshow(rotated_image, cmap="gist_yarg")

# plt.subplot(1, 2, 2)
# plt.title("target")
# # Rotate the target slice and then display it
# rotated_target = np.rot90(target[:, n, :])
# plt.imshow(rotated_target, cmap='gist_yarg')

# plt.show()


In [4]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")




In [5]:
from utils import find_last_best_model
log_filename = 'log_1_24_15_44.txt'
log_filepath = directory + '/'+ log_filename
bestmodel_filename, best_metric, best_epoch = find_last_best_model(log_filepath)
print(f"Last Best Model Saved as: {bestmodel_filename}, Best Metric: {best_metric}, Epoch: {best_epoch}")


Last Best Model Saved as: model_1_24_18_48.pth, Best Metric: 0.1938, Epoch: 230


In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Function to parse the loss values from the log file
def parse_loss_values(log_filepath):
    train_losses = []
    val_losses = []
    with open(log_filepath, 'r') as file:
        for line in file:
            if 'average loss:' in line:
                loss_value = float(line.split(': ')[-1])
                train_losses.append(loss_value)
            if 'Validation loss:' in line:
                val_loss_value = float(line.split(': ')[-1])
                val_losses.append(val_loss_value)
    return train_losses, val_losses


train_losses, val_losses = parse_loss_values(log_filepath)

max_epochs = len(train_losses)
val_interval = 2  # Update this if your validation interval is different

# Plotting
plt.figure(figsize=(14, 6))
plt.plot(range(1, max_epochs + 1), train_losses, label='Training Loss', color='blue', alpha=0.9)
plt.plot(range(2, max_epochs + 1, val_interval), val_losses, label='Validation Loss', color='orange', alpha=0.8)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Validation Losses')
plt.legend()
plt.xticks(np.arange(1, max_epochs + 1, 20))  # Adjust the x-axis ticks if needed
plt.show()


In [None]:
import torch


def load_model(directory, model_filename):
    model_path = os.path.join(directory, model_filename)
    if os.path.exists(model_path):
        print(f"Model file {model_filename} is loading.")
        # Load the model onto the CPU
        model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
    else:
        print(f"Model file {model_filename} not found.")





def find_model_info(log_filepath, model_filename):
    best_metric = None
    epoch = None
    with open(log_filepath, 'r') as file:
        for line in file:
            if model_filename in line:
                parts = line.split(',')
                best_metric = float(parts[1].split(': ')[1])  # Extract best metric
                epoch = int(parts[2].split(': ')[1])  # Extract epoch number
                return model_filename, best_metric, epoch
    return model_filename, best_metric, epoch


# Usage
model_filename_to_find = bestmodel_filename
# model_filename_to_find = 'model_1_22_18_3.pth'

load_model(directory, model_filename_to_find)

bestmodel_filename, best_metric, best_epoch = find_model_info(log_filepath, model_filename_to_find)
print(f"Model: {bestmodel_filename}, Best Metric: {best_metric}, Epoch: {best_epoch}")

In [None]:
def visualize_results_whole(test_data, model, n, title):
    model.eval()
    with torch.no_grad():
        test_outputs = model(test_data["image"].to(device))

    plt.figure("check", (12, 6))

    plt.subplot(1, 3, 1)
    plt.title(f"Input")
    input_slice = np.rot90(test_data["image"][0, 0, :, n, :])
    plt.imshow(input_slice, cmap="gist_yarg")

    plt.subplot(1, 3, 2)
    plt.title(f"Ground_truth")
    target_slice = np.rot90(test_data["target"][0, 0, :, n, :])
    plt.imshow(target_slice, cmap="gist_yarg")

    plt.subplot(1, 3, 3)
    plt.title(title)
    output_slice = np.rot90(test_outputs.detach().cpu()[0, 0, :, n, :])
    plt.imshow(output_slice, cmap="gist_yarg")
    
    plt.show()

# Usage
with torch.no_grad():
    for i, test_data in enumerate(test_loader):
        n = 57
        visualize_results_whole(test_data, model, n, f"{bestmodel_filename}\nepoch: {best_epoch}, best_metric: {best_metric}")
        if i == 2:
            break


-------------------------
# Exporting DL-PET Images


In [10]:
import nibabel as nib
import os
import numpy as np
import torch

def save_nifti(data, filename, affine=np.eye(4)):
    nifti_img = nib.Nifti1Image(data, affine)
    nib.save(nifti_img, filename)

def save_output(test_data, model, output_dir, file_name):
    model.eval()
    with torch.no_grad():
        test_outputs = model(test_data["image"].to(device))

    # Loop over each item in the batch
    for i in range(len(test_data["image"])):
        output_data = test_outputs[i, 0, :, :, :].detach().cpu().numpy()  # Assuming single-channel output
        output_file_path = os.path.join(output_dir, f"DL_{file_name[i]}")
        save_nifti(output_data, output_file_path)



with torch.no_grad():
    for i, test_data in enumerate(test_loader):
        # Extract filenames from test_files
        file_names = [os.path.basename(file_info['image']) for file_info in test_files[i*len(test_data["image"]):(i+1)*len(test_data["image"])]]

        # Save the output using the modified file names
        save_output(test_data, model, output_dir, file_names)
        
        # Optional: break the loop just for saving a few number of patients
        if i == 2:
            break
