In [2]:
import logging
import os
import sys
from glob import glob
import numpy as np

import torch
import torch.nn.functional as F
from dataloader import DRIVEDataset, Rescale, ToTensor, Normalize, WeightMap
from torch.utils.data import DataLoader
from torchvision import transforms
from torch.utils.tensorboard import SummaryWriter

import monai
from monai.visualize import plot_2d_or_3d_image

import matplotlib.pyplot as plt
      
from metrics import Metrics

2024-10-06 23:19:24.608425: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-10-06 23:19:24.616839: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-10-06 23:19:24.626745: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-10-06 23:19:24.629638: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-10-06 23:19:24.637343: I tensorflow/core/platform/cpu_feature_guar

In [3]:
monai.config.print_config()
logging.basicConfig(stream=sys.stdout, level=logging.INFO)

IMAGE_ROOT = "./data_drive/traning/train"
LABEL_ROOT = "./data_drive/traning/train_masks"

# Obter listas de imagens e máscaras
image_files = glob(os.path.join(IMAGE_ROOT, "*_training.png"))
label_files = glob(os.path.join(LABEL_ROOT, "*_manual1.png"))

def get_id_from_filename(filename):
    return int(os.path.basename(filename).split("_")[0])


# Criar um dicionário para mapear IDs para arquivos
image_dict = {get_id_from_filename(f): f for f in image_files}
label_dict = {get_id_from_filename(f): f for f in label_files}

# Criar pares de imagem-máscara com base nos IDs
image_label_pairs = [(image_dict[id], label_dict[id]) for id in image_dict if id in label_dict]

# Dividir em conjuntos de treinamento e validação
train_size = int(0.8 * len(image_label_pairs))
train_files = image_label_pairs[:train_size]
val_files = image_label_pairs[train_size:]

train_transform = transforms.Compose([
    Rescale((512, 512)),
    Normalize(),
    WeightMap(),
    ToTensor(),
])

valid_transform = transforms.Compose([
    Rescale((512, 512)),
    Normalize(),
    WeightMap(),
    ToTensor(),
])


# Desempacotar as tuplas em listas separadas
train_images, train_labels = zip(*train_files)
val_images, val_labels = zip(*val_files)

# Criar os datasets
train_ds = DRIVEDataset(train_images, train_labels, transform=train_transform)
valid_ds = DRIVEDataset(val_images, val_labels, transform=valid_transform)

MONAI version: 1.3.2
Numpy version: 1.26.4
Pytorch version: 2.4.1+cu121
MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False
MONAI rev id: 59a7211070538586369afd4a01eca0a7fe2e742e
MONAI __file__: /home/<username>/anaconda3/lib/python3.12/site-packages/monai/__init__.py

Optional dependencies:
Pytorch Ignite version: NOT INSTALLED or UNKNOWN VERSION.
ITK version: NOT INSTALLED or UNKNOWN VERSION.
Nibabel version: NOT INSTALLED or UNKNOWN VERSION.
scikit-image version: 0.24.0
scipy version: 1.14.1
Pillow version: 10.4.0
Tensorboard version: 2.17.1
gdown version: NOT INSTALLED or UNKNOWN VERSION.
TorchVision version: 0.19.1+cu121
tqdm version: 4.66.5
lmdb version: 1.4.1
psutil version: 5.9.0
pandas version: 2.2.2
einops version: 0.8.0
transformers version: NOT INSTALLED or UNKNOWN VERSION.
mlflow version: NOT INSTALLED or UNKNOWN VERSION.
pynrrd version: NOT INSTALLED or UNKNOWN VERSION.
clearml version: NOT INSTALLED or UNKNOWN VERSION.

For details about installing 

In [4]:
print("Imagens de treinamento:")
print(train_images)

print("\nMáscaras de treinamento:")
print(train_labels)

print("\nImagens de validação:")
print(val_images)

print("\nMáscaras de validação:")
print(val_labels)
# Imprime os índices
print("Índices de Treinamento:", train_images)
print("Índices de Validação:", train_labels)

Imagens de treinamento:
('./data_drive/traning/train/31_training.png', './data_drive/traning/train/33_training.png', './data_drive/traning/train/28_training.png', './data_drive/traning/train/37_training.png', './data_drive/traning/train/21_training.png', './data_drive/traning/train/39_training.png', './data_drive/traning/train/30_training.png', './data_drive/traning/train/24_training.png', './data_drive/traning/train/27_training.png', './data_drive/traning/train/36_training.png', './data_drive/traning/train/26_training.png', './data_drive/traning/train/25_training.png', './data_drive/traning/train/38_training.png', './data_drive/traning/train/34_training.png', './data_drive/traning/train/32_training.png', './data_drive/traning/train/35_training.png')

Máscaras de treinamento:
('./data_drive/traning/train_masks/31_manual1.png', './data_drive/traning/train_masks/33_manual1.png', './data_drive/traning/train_masks/28_manual1.png', './data_drive/traning/train_masks/37_manual1.png', './data_

In [5]:
print("Tamanho do Dataset de Treinamento:", len(train_ds))  

Tamanho do Dataset de Treinamento: 16


Test data set

In [7]:
sample = DRIVEDataset(train_images[:1], train_labels[:1], transform=train_transform)[0]  # Use train_images e train_labels
print("Shape of transformed image:", sample['img'].shape)
print("Shape of transformed label:", sample['seg'].shape)

Shape of transformed image: torch.Size([3, 512, 512])
Shape of transformed label: torch.Size([512, 512])


Shape of transformed image: torch.Size([3, 512, 512])
Shape of transformed label: torch.Size([512, 512])
Shape of transformed weight map: torch.Size([512, 512])

In [9]:
train_loader = DataLoader(
    train_ds, batch_size=4, shuffle=True, num_workers=0, pin_memory=True,
)

val_loader = DataLoader(
    valid_ds, batch_size=2, shuffle=True, num_workers=0, pin_memory=True,
)

In [10]:
# Test the data loaders
for batch_idx, batch_data in enumerate(train_loader):
    print(f"Train Batch {batch_idx+1}:")
    print("  Image Shape:", batch_data['img'].shape)
    print("  Label Shape:", batch_data['seg'].shape)
    print("  Weight Map Shape:", batch_data['map'].shape)

for batch_idx, batch_data in enumerate(val_loader):
    print(f"Validation Batch {batch_idx+1}:")
    print("  Image Shape:", batch_data['img'].shape)
    print("  Label Shape:", batch_data['seg'].shape)
    print("  Weight Map Shape:", batch_data['map'].shape)
    break  # Just display one batch from the validation loader

Train Batch 1:
  Image Shape: torch.Size([4, 3, 512, 512])
  Label Shape: torch.Size([4, 512, 512])
  Weight Map Shape: torch.Size([4, 512, 512])
Train Batch 2:
  Image Shape: torch.Size([4, 3, 512, 512])
  Label Shape: torch.Size([4, 512, 512])
  Weight Map Shape: torch.Size([4, 512, 512])
Train Batch 3:
  Image Shape: torch.Size([4, 3, 512, 512])
  Label Shape: torch.Size([4, 512, 512])
  Weight Map Shape: torch.Size([4, 512, 512])
Train Batch 4:
  Image Shape: torch.Size([4, 3, 512, 512])
  Label Shape: torch.Size([4, 512, 512])
  Weight Map Shape: torch.Size([4, 512, 512])
Validation Batch 1:
  Image Shape: torch.Size([2, 3, 512, 512])
  Label Shape: torch.Size([2, 512, 512])
  Weight Map Shape: torch.Size([2, 512, 512])


Train Batch 1:
  Image Shape: torch.Size([4, 3, 512, 512])
  Label Shape: torch.Size([4, 512, 512])
  Weight Map Shape: torch.Size([4, 512, 512])
Train Batch 2:
  Image Shape: torch.Size([4, 3, 512, 512])
  Label Shape: torch.Size([4, 512, 512])
  Weight Map Shape: torch.Size([4, 512, 512])
Train Batch 3:
  Image Shape: torch.Size([2, 3, 512, 512])
  Label Shape: torch.Size([2, 512, 512])
  Weight Map Shape: torch.Size([2, 512, 512])
Validation Batch 1:
  Image Shape: torch.Size([2, 3, 512, 512])
  Label Shape: torch.Size([2, 512, 512])
  Weight Map Shape: torch.Size([2, 512, 512])

In [12]:
# create UNet and Adam optimizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = monai.networks.nets.UNet(
    spatial_dims=2,
    in_channels=3,
    out_channels=1,
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
    num_res_units=2,
).to(device)
loss_function = F.binary_cross_entropy_with_logits
optimizer = torch.optim.Adam(model.parameters(), lr=5e-5, weight_decay=5e-5)

# start a typical PyTorch training
epochs_total = 1000
val_interval = 1
best_loss = np.inf
best_metric_epoch = -1
epoch_loss_values = list()
metric_values = list()
writer = SummaryWriter()
      
metrics = Metrics()

In [None]:
for epoch in range(epochs_total):
    print("-" * 10)
    print(f"epoch {epoch + 1}/{epochs_total}")
    model.train()
    epoch_loss = 0
    step = 0
    for batch_data in train_loader:
        step += 1
        inputs, labels, weights = (
            batch_data["img"].to(device),
            batch_data["seg"].to(device),
            batch_data["map"].to(device),
        )
        optimizer.zero_grad()
        outputs = model(inputs).squeeze()
        loss = loss_function(outputs, labels)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        epoch_len = len(train_ds) // train_loader.batch_size
        print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}")
        writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step)
    epoch_loss /= step
    epoch_loss_values.append(epoch_loss)
    print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")

    if (epoch + 1) % val_interval == 0:
        model.eval()
        with torch.no_grad():
            epoch_val_loss = []
            epoch_val_dice = []  # Define epoch_val_dice here
            epoch_val_precision = []
            epoch_val_recall = []  # Define epoch_val_recall aqui
            epoch_val_f1 =[]
            epoch_val_iou= [] 
            val_images = None
            val_labels = None
            val_outputs = None
            for val_data in val_loader:
                val_images, val_labels, val_weights = (
                    val_data["img"].to(device),
                    val_data["seg"].to(device),
                    val_data["map"].to(device),
                )
                  
                # ... dentro do loop de validação ...

                val_outputs = model(val_images).squeeze()
                loss = loss_function(val_outputs, val_labels)
                epoch_val_loss.append(loss.item())
                
                # Calcule a pontuação Dice, precisão e revocação para cada imagem no lote
                dice_scores = []
                precision_scores = []
                recall_scores = []
                f1_scores = []
                iou_scores = []  # Lista para armazenar as pontuações IoU
                for i in range(val_outputs.shape[0]):
                    dice_score = metrics.calculate_dice_score(val_outputs[i], val_labels[i])
                    dice_scores.append(dice_score.cpu().item())

                    precision = metrics.calculate_precision(val_outputs[i], val_labels[i])
                    precision_scores.append(precision.cpu().item())

                    recall = metrics.calculate_recall(val_outputs[i], val_labels[i])
                    recall_scores.append(recall.cpu().item())

                    f1_score = metrics.calculate_f1_score(precision, recall)
                    f1_scores.append(f1_score.cpu().item())

                    iou_score = metrics.calculate_iou(val_outputs[i], val_labels[i])  # Calcule o IoU
                    iou_scores.append(iou_score.cpu().item())  # Adicione à lista de pontuações IoU

            # Calcule a média das pontuações Dice, precisão, revocação e F1-Score do lote após o loop
            epoch_val_dice.append(np.mean(dice_scores))
            epoch_val_precision.append(np.mean(precision_scores))
            epoch_val_recall.append(np.mean(recall_scores))
            epoch_val_f1.append(np.mean(f1_scores))
            epoch_val_iou.append(np.mean(iou_scores))  # Calcule a média das pontuações IoU
                
           # Calcule a média das pontuações da época
            epoch_val_loss = np.mean(epoch_val_loss)  # Calcule a média da perda de validação da época
            epoch_val_dice = np.mean(epoch_val_dice)  # Calcule a média da pontuação Dice da época
            epoch_val_precision = np.mean(epoch_val_precision)  # Calcule a média da precisão da época
            epoch_val_recall = np.mean(epoch_val_recall)  # Calcule a média da revocação da época
            epoch_val_f1 = np.mean(epoch_val_f1)  # Calcule a média do F1-Score da época
            epoch_val_iou = np.mean(epoch_val_iou)  # Calcule a média do IoU da época

            # Imprima a pontuação Dice e precisão médias da época
            print(f"Epoch {epoch + 1} - Dice Score: {epoch_val_dice:.4f}, Precision: {epoch_val_precision:.4f}, Recall: {epoch_val_recall:.4f}, F1-Score: {epoch_val_f1:.4f}, IoU: {epoch_val_iou:.4f}")

            # Agora compare a média de epoch_val_loss com best_loss
            if epoch_val_loss < best_loss:
                best_loss = epoch_val_loss
                best_metric_epoch = epoch + 1
                torch.save(
                    model.state_dict(), "best_metric_model_segmentation2d_dict.pth"
                )
                print("saved new best metric model")
            print(
                "current epoch: {} current val loss: {:.4f} best val loss: {:.4f} at epoch {}".format(
                    epoch + 1, epoch_val_loss, best_loss, best_metric_epoch
                )
            )
            
            writer.add_scalar("val_loss", epoch_val_loss, epoch + 1)
            # plot the last model output as GIF image in TensorBoard with the corresponding image and label
            plot_2d_or_3d_image(val_images, epoch + 1, writer, index=0, tag="image")
            plot_2d_or_3d_image(val_labels, epoch + 1, writer, index=0, tag="label")
            plot_2d_or_3d_image(
                val_outputs, epoch + 1, writer, index=0, tag="output"
            )

            # Plota a imagem, máscara e resultado usando matplotlib
            # plt.figure(figsize=(15, 5))
            # plt.subplot(1, 3, 1)
            # plt.imshow(val_images[0].cpu().numpy().transpose(1, 2, 0))
            # plt.title("Imagem Original")
            # plt.subplot(1, 3, 2)
            # plt.imshow(val_labels[0].cpu().numpy(), cmap="gray")
            # plt.title("Máscara")
            # plt.subplot(1, 3, 3)
            # plt.imshow(torch.sigmoid(val_outputs[0]).cpu().numpy() > 0.5, cmap="gray")
            # plt.title("Resultado da Segmentação")
            # plt.show()



print(f"train completed, best_loss: {best_loss:.4f} at epoch: {best_metric_epoch}")
writer.close()

----------
epoch 1/1000
1/4, train_loss: 0.7368
2/4, train_loss: 0.7376
3/4, train_loss: 0.7330
4/4, train_loss: 0.7285
epoch 1 average loss: 0.7340
Epoch 1 - Dice Score: 0.0030, Precision: 0.0016, Recall: 0.1396, F1-Score: 0.0031, IoU: 0.0014
saved new best metric model
current epoch: 1 current val loss: 0.7249 best val loss: 0.7249 at epoch 1
----------
epoch 2/1000
1/4, train_loss: 0.7272
2/4, train_loss: 0.7225
3/4, train_loss: 0.7210
4/4, train_loss: 0.7181
epoch 2 average loss: 0.7222
Epoch 2 - Dice Score: -0.0359, Precision: -0.0116, Recall: 0.1090, F1-Score: -0.0392, IoU: -0.0157
saved new best metric model
current epoch: 2 current val loss: 0.7147 best val loss: 0.7147 at epoch 2
----------
epoch 3/1000
1/4, train_loss: 0.7136
2/4, train_loss: 0.7147
3/4, train_loss: 0.7112
4/4, train_loss: 0.7116
epoch 3 average loss: 0.7128
Epoch 3 - Dice Score: 0.0216, Precision: 0.0145, Recall: 0.0768, F1-Score: 0.0223, IoU: 0.0107
saved new best metric model
current epoch: 3 current val l

In [None]:
# Visualizar imagem, rótulo e saída para o primeiro exemplo no lote
f, axarr = plt.subplots(1, 3)

# Imagem original
img = val_images[0].detach().cpu().permute(1, 2, 0)
axarr[0].imshow(img)

# Rótulo (Ground Truth)
label = val_labels[0].detach().cpu()
axarr[1].imshow(label, cmap='gray')

# Saída do Modelo (Previsão)
output = torch.sigmoid(val_outputs[0]).detach().cpu() 
axarr[2].imshow(output, cmap='gray')
axarr[2].set_title('Saída do Modelo')

plt.show()