# Сегментация радионуклидных изображений легких

In [1]:
!pip install accelerate

import numpy as np
import albumentations
import random
import torch
import os

from pycocotools.coco import COCO
from albumentations.pytorch.transforms import ToTensorV2
from accelerate import Accelerator
from torch.utils.data import DataLoader, random_split
from torch.nn import functional as F
from matplotlib import pyplot as plt
from os.path import join as pjoin

from dataset import LungsDataset, create_masks_for_all_images
from Model.Unet import UNet, count_model_params
from Model.MeanIoU import MeanIoU
from Model.DiceLoss import DiceLoss
from Model.train import train, CheckpointSaver, load_checkpoint



In [2]:
def seed_everything(seed: int = 314159, torch_deterministic: bool = False) -> None:
    os.environ["PYTHONHASHSEED"] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.use_deterministic_algorithms(torch_deterministic)


seed_everything(42, torch_deterministic=False)

In [3]:
annFile = 'instances_default.json'
coco = COCO(annFile)

create_masks_for_all_images(coco, "Data/Masks")

loading annotations into memory...
Done (t=0.00s)
creating index...
index created!


In [4]:
IMAGE_SIZE = 256
transforms = albumentations.Compose(
    [
        albumentations.Resize(height=IMAGE_SIZE, width=IMAGE_SIZE, p=1.0),
        albumentations.HorizontalFlip(p=0.5),
        ToTensorV2(),
    ]
)

In [5]:
full_dataset = LungsDataset(root_dir="Data", transforms = transforms)
train_size = int(0.7 * len(full_dataset))
val_size = int(0.2 * len(full_dataset))
test_size = len(full_dataset) - train_size - val_size

train_dataset, val_dataset, test_dataset = random_split(full_dataset, [train_size, val_size, test_size])

print('Количество изображений в полном датасете:',len(full_dataset))
print('Количество изображений в тренировочном датасете:',len(train_dataset))
print('Количество изображений в валидационном датасете:',len(val_dataset))
print('Количество изображений в тестовом датасете:',len(test_dataset))

Количество изображений в полном датасете: 27
Количество изображений в тренировочном датасете: 18
Количество изображений в валидационном датасете: 5
Количество изображений в тестовом датасете: 4


In [6]:
LEARNING_RATE = 1e-4
BATCH_SIZE = 4
NUM_WORKERS = 2
EPOCH_NUM = 20
CHECKPOINTS_DIR = "checkpoints"
TENSORBOARD_DIR = "tensorboard"
RM_CHECKPOINTS_DIR = False

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

In [11]:
train_dataset[0]

(1348, 2022, 3) (1348, 2022)


(tensor([[[ 93.,  93.,  93.,  ...,  93.,  93.,  93.],
          [255., 255., 255.,  ..., 255., 255., 255.],
          [255., 255., 255.,  ..., 255., 255., 255.],
          ...,
          [255., 255., 255.,  ..., 255., 255., 255.],
          [255., 255., 255.,  ..., 255., 255., 255.],
          [ 93.,  93.,  93.,  ...,  93.,  93.,  93.]],
 
         [[ 93.,  93.,  93.,  ...,  93.,  93.,  93.],
          [255., 255., 255.,  ..., 255., 195., 255.],
          [255., 255., 196.,  ..., 174., 255., 255.],
          ...,
          [255., 255., 255.,  ..., 255., 253., 255.],
          [255., 255., 255.,  ..., 255., 255., 255.],
          [ 93.,  93.,  93.,  ...,  93.,  93.,  93.]],
 
         [[ 93.,  93.,  93.,  ...,  93.,  93.,  93.],
          [255., 255., 255.,  ..., 255., 195., 255.],
          [255., 255., 196.,  ..., 174., 255., 255.],
          ...,
          [255., 255., 255.,  ..., 255., 253., 255.],
          [255., 255., 255.,  ..., 255., 255., 255.],
          [ 93.,  93.,  93.,  .

In [7]:
train_dataloader = DataLoader(
    train_dataset, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, shuffle=True
) # (27, 3, 256, 256, 256), (27, 1, 256, 256)
val_dataloader = DataLoader(
    val_dataset, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, shuffle=True
)
test_dataloader = DataLoader(
    test_dataset, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, shuffle=True
)

model = UNet(in_channels=3, out_channels=1)

loss_fn = DiceLoss()
metric_fn = MeanIoU()

optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
lr_scheduler = torch.optim.lr_scheduler.StepLR(
    optimizer=optimizer, step_size=5, gamma=0.8
) # уменьшение скорости обучения
count_model_params(model=model)

17262977

In [8]:
accelerator = Accelerator(cpu=False, mixed_precision="fp16")
os.makedirs(CHECKPOINTS_DIR, exist_ok=True)
checkpointer = CheckpointSaver(
    accelerator=accelerator,
    model=model,
    metric_name="DICE",
    save_dir=CHECKPOINTS_DIR,
    rm_save_dir=RM_CHECKPOINTS_DIR,
    max_history=5,
    should_minimize=False
)



In [9]:
os.makedirs(TENSORBOARD_DIR, exist_ok=True)
tensorboard_logger = torch.utils.tensorboard.SummaryWriter(log_dir=TENSORBOARD_DIR)

In [10]:
model, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(
    model, optimizer, train_dataloader, val_dataloader, lr_scheduler
)

RuntimeError: User specified an unsupported autocast device_type 'mps'

In [None]:
train(
    model=model,
    optimizer=optimizer,
    train_dataloader=train_dataloader,
    val_dataloader=val_dataloader,
    loss_function=loss_fn,
    metric_function=metric_fn,
    lr_scheduler=lr_scheduler,
    accelerator=accelerator,
    epoch_num=EPOCH_NUM,
    checkpointer=checkpointer,
    tb_logger=tensorboard_logger,
    save_on_val=True,
    show_every_x_batch=5
)

In [None]:
model = UNet(in_channels=3, out_channels=1)
model = load_checkpoint(
    model=model, load_path=pjoin(CHECKPOINTS_DIR, "model_checkpoint_best.pt")
)
model = model.to(DEVICE)
model.eval()

In [ ]:
thresold = 0.8
for sample_idx in range(test_size):
    image, target = test_dataset[sample_idx]
    preds = F.sigmoid(model(image.unsqueeze(0).to(DEVICE))).squeeze(0)
    binary_preds = (preds > thresold)
    
    a = metric_fn(preds, target)

    fig, ax = plt.subplots(1, 3, figsize=(9, 18))
    ax[0].imshow(image.numpy().transpose(1, 2, 0).astype(np.uint8))
    ax[0].set_title("Image") 
    ax[1].imshow(target.numpy().transpose(1, 2, 0).astype(np.uint8))
    ax[1].set_title("Mask")
    ax[2].imshow(binary_preds.cpu().numpy()[0])
    ax[2].set_title("Target")
    plt.show()
  