In [None]:
!pip install segmentation-models-pytorch
!pip install timm

In [None]:
import cv2
import torch
import kagglehub
import numpy as np

from torch import nn, optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from torchvision.transforms import v2
import torchvision.transforms.functional as TF
from torchvision.io import decode_image
from torch.utils.data import random_split

from matplotlib import pyplot as plt
import matplotlib.patches as patches

from tqdm import tqdm

import os
from glob import glob
import json

import segmentation_models_pytorch as smp

DEVICE = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else "cpu"
DEVICE

In [None]:
INPUT_DIR = "/kaggle/input/chest-xray-pneumonia/chest_xray/test/NORMAL"
OUTPUT_DIR = "/kaggle/working/weak_masks"
os.makedirs(OUTPUT_DIR, exist_ok=True)

In [None]:
def generate_weak_mask_cxr(img_path):
    # 1. carregar
    img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
    h, w = img.shape
    img = cv2.resize(img, (512, 512))
    
    img[img < 50] = 255 
    img[img > 180] = 255
    
    # 2. equalização + suavização
    clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
    I = clahe.apply(img)
    I = cv2.GaussianBlur(I, (7,7), 0)
    
    for ang in np.arange(0, 151, 30):
        se = np.zeros((15, 15), np.uint8)
        cv2.ellipse(se, (15//2, 15//2), (15//2, 1), ang, 0, 360, 1, -1)
        I_supp = cv2.morphologyEx(I, cv2.MORPH_OPEN, se)
        I = cv2.min(I, I_supp)
        
    # 3. inverter para que opacidades fiquem claras
    I_inv = cv2.normalize(255 - I, None, 0, 255, cv2.NORM_MINMAX)

    # 4. opções de threshold (escolha uma)
    _, bw = cv2.threshold(I_inv, 127, 255, cv2.THRESH_BINARY)
    
    # 5. limpar regiões fora dos pulmões
    # remove um pequeno pedaço do topo e das laterais que não incluem o pulmão (geralmente)
    bw[:40, :], bw[:, -70:], bw[:, :70] = 0,0, 0
    
    # 6. morfologia para suavizar
    kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (9,9))
    bw = cv2.morphologyEx(bw, cv2.MORPH_OPEN, kernel)
    bw = cv2.morphologyEx(bw, cv2.MORPH_CLOSE, kernel)
    
    # 7. manter regiões internas
    num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(bw)
    mask_out = np.zeros_like(bw)
    
    for i in range(1, num_labels):
        area = stats[i, cv2.CC_STAT_AREA]
        x,y = stats[i, cv2.CC_STAT_LEFT], stats[i, cv2.CC_STAT_TOP]
        wc,hc = stats[i, cv2.CC_STAT_WIDTH], stats[i, cv2.CC_STAT_HEIGHT]

        # descarta regiões no topo e muito pequenas
        if area > 500 and y > 39:
            mask_out[labels == i] = 255
    
    # 8. voltar ao tamanho original
    mask_out = cv2.resize(mask_out, (w, h))
    return (mask_out > 0).astype(np.uint8)

In [None]:
class SelfSegDataset(Dataset):
    def __init__(self, dataset_path=INPUT_DIR, image_size=None, device=DEVICE):

        self._path_data = dataset_path

        self.data = glob(os.path.join(self._path_data, "**", "*.jpeg"), recursive=True)

        self.device = device

        if image_size:
            self.resize_image = v2.Resize(size=image_size, interpolation=v2.InterpolationMode.BILINEAR, antialias=True)
            self.resize_mask = v2.Resize(size=image_size, interpolation=v2.InterpolationMode.NEAREST)
        else:
            self.resize_image = lambda x: x
            self.resize_mask = lambda x: x

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        image_path = self.data[index]

        image = decode_image(image_path).to(self.device)
        mask = generate_weak_mask_cxr(image_path)

        image = self.resize_image(image)
        mask = self.resize_mask(mask)

        return image, mask

    def orig_image(self, index):
        image_path = self.data[index]
        return decode_image(image_path)

In [None]:
# # Gera máscaras binárias de 0 ou 255.
# for filename in tqdm(os.listdir(INPUT_DIR)):
#     # if filename.lower().endswith(".png":
#         img_path = os.path.join(INPUT_DIR, filename)
#         mask = generate_weak_mask_cxr(img_path)
#         cv2.imwrite(os.path.join(OUTPUT_DIR, filename.rsplit(".",1)[0] + ".png"), (mask*255))

100%|██████████| 234/234 [00:09<00:00, 25.31it/s]


In [None]:
ds = SelfSegDataset(image_size=(512, 512))

for i in range(5):
    img, mask = ds[i]
    
    plt.figure(figsize=(12,6))
    plt.subplot(1,2,1)
    plt.title("Imagem Original")
    plt.imshow(img.cpu()[0], cmap="gray")
    plt.axis("off")
    
    plt.subplot(1,2,2)
    plt.title("Weak Label Gerada")
    plt.imshow(mask.cpu()[0], cmap="gray")
    plt.axis("off")
    
plt.show()

### Usando UNet com ativação Sigmóide para treinar as Weak Labels

In [None]:
model = smp.Unet(
    encoder_name="resnet34",
    encoder_weights="imagenet",
    in_channels=1,
    classes=1,
    activation="sigmoid"
)

loss = smp.losses.DiceLoss(mode="binary")

dataset = SelfSegDataset(image_size=(512,512))
train_loader = DataLoader(dataset, batch_size=32, shuffle=True)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

model = model.to(DEVICE)

In [None]:
for epoch in range(20):
    model.train()
    total = 0

    for i, (imgs, masks) in enumerate(train_loader):
        imgs = imgs.to(DEVICE).float()
        masks = masks.to(DEVICE)

        preds = model(imgs)
        loss_value = loss(preds, masks)

        optimizer.zero_grad()
        loss_value.backward()
        optimizer.step()

        total += loss_value.item()
        
        for k in range(2):
            plt.figure(figsize=(9,3))

            plt.subplot(1,3,1)
            plt.title("Imagem Original")
            plt.imshow(imgs[k].cpu()[0], cmap="gray")
            plt.axis("off")

            plt.subplot(1,3,2)
            plt.title("Weak Label inicial")
            plt.imshow(masks[k].cpu()[0], cmap="gray")
            plt.axis("off")
            
            plt.subplot(1,3,3)
            plt.title("Weak Label Gerada")
            plt.imshow(preds[k].detach().cpu()[0], cmap="gray")
            plt.axis("off")

            plt.show()