In [4]:
import os
import shutil
import tempfile
import numpy as np

import matplotlib.pyplot as plt
from tqdm import tqdm

from monai.losses import DiceCELoss
from monai.inferers import sliding_window_inference
from monai.transforms import (
    AsDiscrete,
    Compose,
    CropForegroundd,
    CopyItemsd,
    LoadImaged,
    Orientationd,
    RandFlipd,
    RandCropByPosNegLabeld,
    RandShiftIntensityd,
    ScaleIntensityRanged,
    Spacingd,
    RandRotate90d,
    EnsureTyped,
    ConvertToMultiChannelBasedOnBratsClassesd, 
    RandRotated, 
    RandScaleIntensityd, 
    ToTensord, 
    AsDiscreted,
    Invertd,
    SaveImaged,
    EnsureChannelFirstd
)

from monai.config import print_config
from monai.metrics import DiceMetric
from monai.networks.nets import SwinUNETR, UNet, VNet, DynUNet

from monai.data import (
    ThreadDataLoader, Dataset, DataLoader,
    CacheDataset,
    load_decathlon_datalist,
    decollate_batch,
    set_track_meta,
)

from monai.utils import set_determinism
from monai.handlers.utils import from_engine
#from codecarbon import EmissionsTracker
from datetime import datetime

import torch
import glob
import shutil

import csv


In [10]:
images_path = "/home/fehrdelt/data_ssd/2024_Projet_Anais/Clinical_database/all_images/"
labels_path = "/home/fehrdelt/data_ssd/2024_Projet_Anais/Clinical_database/all_labels/"

datasets = {}

# csv file generated by the make_dataset_distributions.ipynb file
with open('/home/fehrdelt/data_ssd/2024_Projet_Anais/datasets.csv', mode ='r') as file:    
    csvFile = csv.reader(file)
    for line in csvFile:
        datasets[line[0]]=line[1:]


test_images = [images_path + s for s in datasets["test_patients"]]
test_labels = [labels_path + s for s in datasets["test_patients"]]

val_images = [images_path + s for s in datasets["step1_val_patients"]]
val_labels = [labels_path + s for s in datasets["step1_val_patients"]]

train_images = [images_path + s for s in datasets["step1_train_patients"]]
train_labels = [labels_path + s for s in datasets["step1_train_patients"]]

print(train_labels)

['/home/fehrdelt/data_ssd/2024_Projet_Anais/Clinical_database/all_labels/aini-stroke-20280-aguorshlcg-20210612-0-1_1CraneStandardMAR-CraneSansIV-000243_844.nii.gz', '/home/fehrdelt/data_ssd/2024_Projet_Anais/Clinical_database/all_labels/aini-stroke-18911-vvfvqggjzh-20210216-0-3_6TSA+Thrombophl�biteMAR-CraneSansIV-000241_754.nii.gz', '/home/fehrdelt/data_ssd/2024_Projet_Anais/Clinical_database/all_labels/aini-stroke-10061-pindeuwojy-20180416-0-1_1CraneStandardMAR-CraneSansIV-000304_818.nii.gz', '/home/fehrdelt/data_ssd/2024_Projet_Anais/Clinical_database/all_labels/aini-stroke-21117-lwyawydrul-20211025-0-3_4TroncsSupra-AortiquesVasculaireMAR-CraneSansIV-000314_016.nii.gz']


In [11]:
test_transforms = Compose(
    [
        LoadImaged(keys="image", ensure_channel_first=False),
        EnsureChannelFirstd(keys="image"),
        CropForegroundd(keys="image", source_key="image"),
        Orientationd(keys="image", axcodes="RAS"),
        Spacingd(
            keys="image",
            pixdim=(1, 1, 1),
            mode="bilinear",
        ),
    ])

test_dict_ds = [{"image": image_name, "label": label_name} for image_name,label_name in zip(test_images, test_labels)]
test_ds = Dataset(data=test_dict_ds, transform=test_transforms)

test_loader = DataLoader(test_ds, batch_size=1)# num_workers=4)

monai.transforms.croppad.dictionary CropForegroundd.__init__:allow_smaller: Current default value of argument `allow_smaller=True` has been deprecated since version 1.2. It will be changed to `allow_smaller=False` in version 1.5.


In [12]:
val_transforms = Compose(
    [
        LoadImaged(keys="image", ensure_channel_first=False),
        EnsureChannelFirstd(keys="image"),
        CropForegroundd(keys="image", source_key="image"),
        Orientationd(keys="image", axcodes="RAS"),
        Spacingd(
            keys="image",
            pixdim=(1, 1, 1),
            mode="bilinear",
        ),
    ])

val_dict_ds = [{"image": image_name, "label": label_name} for image_name,label_name in zip(val_images, val_labels)]
val_ds = Dataset(data=val_dict_ds, transform=val_transforms)

val_loader = DataLoader(val_ds, batch_size=1)# num_workers=4)

In [13]:
train_transforms = Compose(
    [
        LoadImaged(keys=("image", "label")),
        EnsureChannelFirstd(keys=("image", "label")),
        EnsureTyped(keys=("image", "label"), device=context.device),
        Orientationd(keys=("image", "label"), axcodes="RAS"),
        Spacingd(keys=("image", "label"), pixdim=self.target_spacing, mode=("bilinear", "nearest")),
        NormalizeIntensityd(keys="image", nonzero=True),
        #CropForegroundd(keys=("image", "label"), source_key="image", margin=10, k_divisible=[self.roi_size[0], self.roi_size[1], self.roi_size[2]],),
        #GaussianSmoothd(keys="image", sigma=0.4),
        ScaleIntensityd(keys="image", minv=-1.0, maxv=1.0),
        RandSpatialCropd(
            keys=["image", "label"],
            roi_size=[self.roi_size[0], self.roi_size[1], self.roi_size[2]],
            random_size=False,
        ),

        RandFlipd(
            keys=["image", "label"], spatial_axis=[0], prob=0.10,
        ),
        RandFlipd(
            keys=["image", "label"], spatial_axis=[1], prob=0.10,
        ),
        RandFlipd(
            keys=["image", "label"], spatial_axis=[2], prob=0.10,
        ),
        RandRotate(
            prob=0.8,
            range_x=0.3,
            range_y=0.3,
            range_z=0.3,
            keys=["image", "label"]
        ),
        RandShiftIntensityd(
            keys=["image"], offsets=0.10, prob=0.50,
        ),
        SelectItemsd(keys=("image", "label")),
    ])

train_dict_ds = [{"image": image_name, "label": label_name} for image_name,label_name in zip(train_images, train_labels)]
train_ds = Dataset(data=train_dict_ds, transform=train_transforms)

train_loader = DataLoader(train_ds, batch_size=1)# num_workers=4)

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:

model = monai.networks.nets.DynUNet(
        spatial_dims=3,
        in_channels=1,
        out_channels=8,
        kernel_size=[3, 3, 3, 3, 3, 3],
        strides=[1, 2, 2, 2, 2, [2, 2, 1]],
        upsample_kernel_size=[2, 2, 2, 2, [2, 2, 1]],
        norm_name="instance",
        deep_supervision=False,
        res_block=True,
    ).to(device)

loss_function = torch.nn.DiceCELoss()
# loss_function = torch.nn.BCEWithLogitsLoss()  # also works with this data

optimizer = torch.optim.Adam(model.parameters(), 1e-4)

# start a typical PyTorch training
val_interval = 1
best_metric = -1
best_metric_epoch = -1
epoch_loss_values = []
metric_values = []
writer = SummaryWriter()
max_epochs = 500

for epoch in range(max_epochs):
    print("-" * 10)
    print(f"epoch {epoch + 1}/{max_epochs}")
    model.train()
    epoch_loss = 0
    step = 0
    
    validation_loss = []

    for batch_data in train_loader:
        step += 1
        inputs, labels = batch_data[0].to(device), batch_data[1].to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = loss_function(outputs, labels)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        epoch_len = len(train_ds) // train_loader.batch_size
        print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}")
        writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step)

    epoch_loss /= step
    epoch_loss_values.append(epoch_loss)
    print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")

    if (epoch + 1) % val_interval == 0:
        model.eval()

        num_correct = 0.0
        metric_count = 0
        for val_data in val_loader:
            val_images, val_labels = val_data[0].to(device), val_data[1].to(device)
            with torch.no_grad():
                val_outputs = model(val_images)               
                val_loss = loss_function(val_outputs, val_labels)                          # ajout
                validation_loss.append(val_loss.item())                                    # par
                writer.add_scalar("val_loss", val_loss.item(), epoch_len * epoch + step)   # moi
                
                value = torch.eq(val_outputs.argmax(dim=1), val_labels.argmax(dim=1))
                metric_count += len(value)
                num_correct += value.sum().item()

        metric = num_correct / metric_count
        metric_values.append(metric)

        if metric >= best_metric:
            best_metric = metric
            best_metric_epoch = epoch + 1
            torch.save(model.state_dict(), f"6_mai_best_metric_model_classification3d_array{epoch}.pth")
            print("saved new best metric model")

        print(f"Current epoch: {epoch+1} current accuracy: {metric:.4f} ")
        print(f"Best val accuracy: {best_metric:.4f} at epoch {best_metric_epoch}")
        writer.add_scalar("val_accuracy", metric, epoch + 1)

print(f"Training completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}")
writer.close()