In [20]:
import logging
import os
import sys
sys.path.append("../")
import glob
import numpy as np
import torch
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
import matplotlib.pyplot as plt
import monai
from monai.data import ImageDataset, DataLoader
import monai.transforms as transforms
from monai.transforms import EnsureChannelFirst, Compose, RandRotate90, Resize, ScaleIntensity
import nibabel as nib
import pandas as pd

from utils.custom_transforms import ScaleIntensityFromHistogramPeak, SetBackgroundToZero

In [2]:
#ROOT_DIR = "/home/fehrdelt/bettik/"
ROOT_DIR = "/bettik/PROJECTS/pr-gin5_aini/fehrdelt/"

In [3]:
SUB_EXPERIMENT_NAME = "densenet3d_exp_0_0"

### Clinical data (class label) 1: nihss>15, 0: nihss<=15

In [4]:
participants_tsv_path = ROOT_DIR+"datasets/final_soop_dataset_small/participants.tsv"

participants_df = pd.read_csv(participants_tsv_path, sep="\t")

In [5]:
participants_df.head()

Unnamed: 0,participant_id,sex,age,race,acuteischaemicstroke,priorstroke,bmi,nihss,gs_rankin_6isdeath
0,sub-2,M,78.0,w,1.0,0.0,22.84,17.0,
1,sub-3,F,87.0,w,1.0,1.0,19.23,15.0,6.0
2,sub-5,M,58.0,w,1.0,0.0,37.29,1.0,
3,sub-6,,,,,,,,
4,sub-7,,,,,,,,


In [6]:
# drop rows where 'nihss' is NaN
participants_df = participants_df.dropna(subset=["nihss"])

In [7]:
participants_df.head()

Unnamed: 0,participant_id,sex,age,race,acuteischaemicstroke,priorstroke,bmi,nihss,gs_rankin_6isdeath
0,sub-2,M,78.0,w,1.0,0.0,22.84,17.0,
1,sub-3,F,87.0,w,1.0,1.0,19.23,15.0,6.0
2,sub-5,M,58.0,w,1.0,0.0,37.29,1.0,
5,sub-8,F,34.0,b,1.0,0.0,23.14,19.0,
6,sub-9,F,70.0,w,1.0,0.0,26.89,18.0,4.0


In [8]:
participants_df["high_nihss"] = (participants_df["nihss"] > 15).astype(np.int64)

participants_df.head()

Unnamed: 0,participant_id,sex,age,race,acuteischaemicstroke,priorstroke,bmi,nihss,gs_rankin_6isdeath,high_nihss
0,sub-2,M,78.0,w,1.0,0.0,22.84,17.0,,1
1,sub-3,F,87.0,w,1.0,1.0,19.23,15.0,6.0,0
2,sub-5,M,58.0,w,1.0,0.0,37.29,1.0,,0
5,sub-8,F,34.0,b,1.0,0.0,23.14,19.0,,1
6,sub-9,F,70.0,w,1.0,0.0,26.89,18.0,4.0,1


In [9]:
counts = participants_df["high_nihss"].value_counts()
count_high_nihss_1 = int(counts.get(1, 0))
count_high_nihss_0 = int(counts.get(0, 0))

print(f"high_nihss==1: {count_high_nihss_1}")
print(f"high_nihss==0: {count_high_nihss_0}")

high_nihss==1: 209
high_nihss==0: 897


### Make the images used for classification
Stacking the T2 FLAIR, the anomaly map and the atlas together

In [10]:
anomaly_maps_flair_dir = ROOT_DIR+"datasets/anomaly_maps/exp_1_0/"
flair_dir = ROOT_DIR+"datasets/final_soop_dataset_small/flair_registered/"
registered_atlases_dir = ROOT_DIR+"datasets/final_soop_dataset_small/registered_atlases/"

anomaly_maps_paths = glob.glob(anomaly_maps_flair_dir+"*.nii.gz")
flair_paths = glob.glob(flair_dir+"*.nii.gz")
registered_atlases_paths = glob.glob(registered_atlases_dir+"*.nii.gz")

In [11]:
# read the csv to get files to exclude
exclude_csv_path = ROOT_DIR+"StrokeUADiag/exclude_failed_registration.csv"
exclude_df = pd.read_csv(exclude_csv_path, header=None)



exclude_files = exclude_df[0].tolist()

print(exclude_files)

['sub-185', 'sub_1303', 'sub-199', 'sub-984', 'sub-1138', 'sub-767', 'sub-1251', 'sub-855', 'sub-1660', 'sub-512', 'sub-1698', 'sub-617', 'sub-1119', 'sub-1183', 'sub-1558', 'sub-279', 'sub-846', 'sub-1610', 'sub-1261', 'sub-1308', 'sub-1717', 'sub-7', 'sub-1041', 'sub-343', 'sub-989', 'sub-605', 'sub-234', 'sub-1203', 'sub-1491', 'sub-949', 'sub-1727']


### Make stacked images (flair image, anomaly map & registered atlas)

In [22]:
def normalize_from_histogram_peak(image, hist_norm_target_value=200.0):
    hist, bins = np.histogram(image, bins=100, range=(np.max(image)/15.0, np.max(image)))

    # Find the value corresponding to the maximum of the histogram
    most_occurred_pixel_value = bins[np.argmax(hist)]

    image_norm = image/most_occurred_pixel_value*hist_norm_target_value # scale it so the peak is always at hist_norm_target_value
    
    return image_norm

In [None]:

for image_path in tqdm(anomaly_maps_paths):

    id = os.path.basename(image_path).replace(".nii.gz", "").split('_')[-2]
    
    if id not in exclude_files:

        flair_img_path = flair_dir + id + "_FLAIR.nii.gz"
        registered_atlas_path = registered_atlases_dir + f"registered_atlas_{id}_FLAIR.nii.gz"


        ano_map_image = nib.load(image_path).get_fdata()*500.0
        flair_nii = nib.load(flair_img_path)
        flair_image = normalize_from_histogram_peak(flair_nii.get_fdata(), hist_norm_target_value=200.0)
        registered_atlas_image = nib.load(registered_atlas_path).get_fdata()*10.0

        stacked_data = np.stack([flair_image, ano_map_image, registered_atlas_image], axis=-1)

        stacked_img = nib.Nifti1Image(stacked_data, affine=flair_nii.affine)

        nib.save(stacked_img, ROOT_DIR+f"datasets/StrokeUADiag_classification_inputs/stacked_{id}.nii.gz")
    
    

  0%|                                                                                                                                            | 0/1216 [00:00<?, ?it/s]


AttributeError: 'numpy.ndarray' object has no attribute 'affine'

In [18]:
img = nib.load(ROOT_DIR+f"datasets/StrokeUADiag_classification_inputs/stacked_sub-139.nii.gz")
print(img.shape)

(3, 128, 128, 128)


### Model training

In [None]:
num_channels = 3
image_size = 128

In [None]:
labels = []

for image_path in glob.glob(ROOT_DIR+"datasets/classification_inputs/stacked_*.nii.gz"):
    id = os.path.basename(image_path).replace(".nii.gz", "").split('_')[-1]
    print(id)
    label = participants_df.loc[participants_df['participant_id'] == int(id), 'high_nihss'].values[0]
    labels.append(label)

print(labels)
    

In [None]:


images = glob.glob(ROOT_DIR+"datasets/classification_inputs/stacked_*.nii.gz")

# Define transforms
train_transforms = Compose([
    transforms.LoadImage(image_only=True),
    transforms.EnsureChannelFirst(),
    transforms.RandAffine(prob=0.5, rotate_range=(0.1, 0.1, 0.1)),
    ScaleIntensityFromHistogramPeak(), 
    transforms.RandScaleCrop(roi_scale=0.9, max_roi_scale=1.1, random_size=True),
    transforms.ResizeWithPadOrCrop((num_channels, image_size, image_size, image_size)),
    transforms.ScaleIntensityRange(a_min=0.0, a_max=700.0, b_min=0.0, b_max=1.0, clip=True),
    transforms.RandFlip(spatial_axis=1, prob=0.5),
    SetBackgroundToZero()
    ])

val_transforms = Compose([
    transforms.LoadImage(image_only=True),
    transforms.EnsureChannelFirst(),
    ScaleIntensityFromHistogramPeak(), 
    transforms.ResizeWithPadOrCrop((num_channels, image_size, image_size, image_size)),
    transforms.ScaleIntensityRange(a_min=0.0, a_max=700.0, b_min=0.0, b_max=1.0, clip=True),
    SetBackgroundToZero()
    ])

# Define image dataset, data loader
check_ds = ImageDataset(image_files=images, labels=labels, transform=train_transforms)
check_loader = DataLoader(check_ds, batch_size=2, num_workers=2, pin_memory=torch.cuda.is_available())
im, label = monai.utils.misc.first(check_loader)
print(type(im), im.shape, label)

# create a training data loader
train_ds = ImageDataset(image_files=images[:10], labels=labels[:10], transform=train_transforms)
train_loader = DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=2, pin_memory=torch.cuda.is_available())

# create a validation data loader
val_ds = ImageDataset(image_files=images[-10:], labels=labels[-10:], transform=val_transforms)
val_loader = DataLoader(val_ds, batch_size=2, num_workers=2, pin_memory=torch.cuda.is_available())



**Vérif image que le randflip ce soit bon**

In [None]:

# Create DenseNet121, CrossEntropyLoss and Adam optimizer

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = monai.networks.nets.DenseNet121(spatial_dims=3, in_channels=num_channels, out_channels=2).to(device)

loss_function = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), 1e-5)


### Training

In [None]:

# start a typical PyTorch training
val_interval = 2
best_metric = -1
epoch_loss_values = list()
metric_values = list()

os.makedirs(ROOT_DIR+f"StrokeUADiag/tensorboard/{SUB_EXPERIMENT_NAME}", exist_ok=True)
writer = SummaryWriter(ROOT_DIR+f"StrokeUADiag/tensorboard/{SUB_EXPERIMENT_NAME}")


for epoch in range(5):
    print("-" * 10)
    print(f"epoch {epoch + 1}/{5}")
    model.train()
    epoch_loss = 0
    step = 0
    for batch_data in train_loader:
        step += 1
        inputs, labels = batch_data[0].to(device), batch_data[1].to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = loss_function(outputs, labels)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        epoch_len = len(train_ds) // train_loader.batch_size
        print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}")
        writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step)
    epoch_loss /= step
    epoch_loss_values.append(epoch_loss)
    print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")

    if (epoch + 1) % val_interval == 0:
        model.eval()
        with torch.no_grad():
            num_correct = 0.0
            metric_count = 0
            for val_data in val_loader:
                val_images, val_labels = val_data[0].to(device), val_data[1].to(device)
                val_outputs = model(val_images)
                value = torch.eq(val_outputs.argmax(dim=1), val_labels)
                metric_count += len(value)
                num_correct += value.sum().item()
            metric = num_correct / metric_count
            metric_values.append(metric)
            if metric > best_metric:
                best_metric = metric
                best_metric_epoch = epoch + 1
                torch.save(model.state_dict(), "best_metric_model_classification3d_array.pth")
                print("saved new best metric model")
            print(
                "current epoch: {} current accuracy: {:.4f} best accuracy: {:.4f} at epoch {}".format(
                    epoch + 1, metric, best_metric, best_metric_epoch
                )
            )
            writer.add_scalar("val_accuracy", metric, epoch + 1)
print(f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}")
writer.close()
