In [19]:
!pip install segmentation-models-pytorch
!pip install timm



In [20]:
# Part 1/6 - imports e utilitários
import os
import cv2
import numpy as np
from glob import glob
from tqdm import tqdm

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader

import segmentation_models_pytorch as smp

# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def decode_image(path, to_gray=True):
    """
    Carrega imagem com cv2 e retorna tensor torch float32 shape [1, H, W] (scale 0..1).
    - Se imagem já é grayscale, retorna canal único.
    - Se color, converte para gray (pulmão tipicamente gray).
    """
    img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
    if img is None:
        raise FileNotFoundError(f"Image not found: {path}")

    # Se alpha presente, descarta
    if img.ndim == 3 and img.shape[2] == 4:
        img = cv2.cvtColor(img, cv2.COLOR_BGRA2BGR)

    if to_gray:
        if img.ndim == 3:
            img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    else:
        # manter color se necessário (não usado aqui)
        if img.ndim == 2:
            img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)

    # garantir uint8
    if img.dtype != np.uint8:
        img = cv2.normalize(img, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8)

    # converter para tensor [1,H,W]
    img = img.astype(np.float32) / 255.0
    img = np.expand_dims(img, axis=0)  # 1,H,W
    return torch.from_numpy(img)


In [21]:
# Para geração de weak masks
def generate_weak_mask_cxr(img_path, out_size=None):
    orig = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
    if orig is None:
        raise FileNotFoundError(img_path)
    h_orig, w_orig = orig.shape

    # 1. Trabalhar em 512
    I = cv2.resize(orig, (512,512), interpolation=cv2.INTER_AREA)

    # 2. CLAHE + blur leve
    clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
    I = clahe.apply(I)
    I = cv2.GaussianBlur(I, (7,7), 0)

    # 3. Top-hat: reduzir agressividade
    kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE,(21,21))
    toph = cv2.morphologyEx(I, cv2.MORPH_TOPHAT, kernel)

    # Mistura com original para não perder anatomia
    I2 = cv2.addWeighted(I, 0.7, toph, 0.3, 0)

    # 4. Threshold adaptativo mais realista
    bw = cv2.adaptiveThreshold(
        I2, 255,
        cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
        cv2.THRESH_BINARY_INV,
        31, 5
    )

    # 5. Remover bordas
    H, W = bw.shape
    bw[:int(0.03*H), :] = 0
    bw[:, :int(0.03*W)] = 0
    bw[:, -int(0.03*W):] = 0

    # 6. Fechar buracos + suavizar
    kernel2 = cv2.getStructuringElement(cv2.MORPH_ELLIPSE,(15,15))
    bw = cv2.morphologyEx(bw, cv2.MORPH_CLOSE, kernel2)
    bw = cv2.morphologyEx(bw, cv2.MORPH_OPEN, kernel2)

    # 7. Componentes conectados
    num, labels, stats, centroids = cv2.connectedComponentsWithStats(bw, 8)

    # manter todos componentes maiores que 1000 px
    mask_out = np.zeros_like(bw)
    for i in range(1, num):
        area = stats[i, cv2.CC_STAT_AREA]
        if area > 1000:
            mask_out[labels == i] = 255

    # 8. resize final
    if out_size is None:
        mask_final = cv2.resize(mask_out, (w_orig, h_orig), interpolation=cv2.INTER_NEAREST)
    else:
        mask_final = cv2.resize(mask_out, (out_size[1], out_size[0]), interpolation=cv2.INTER_NEAREST)

    return mask_final

# função QC
def qc_mask(mask,
            min_area_ratio=0.003,   # 0.3% — antes era 2% (!)
            max_area_ratio=0.45,    # reduzido para evitar segmentações gigantes
            min_comp_area=300,      # antes 500 — relaxado
            expected_comps=1):      # aceitar 1 ou mais componentes
    if mask is None:
        return False

    mask_bin = (mask > 127).astype(np.uint8)
    H, W = mask_bin.shape
    total = H * W

    if mask_bin.sum() == 0:
        return False

    # área global
    area_ratio = mask_bin.sum() / total
    if not (min_area_ratio <= area_ratio <= max_area_ratio):
        return False

    # componentes
    num, labels, stats, _ = cv2.connectedComponentsWithStats(mask_bin, 8)
    if num <= 1:
        return False

    # verificar se ao menos 1 componente é relevante
    areas = stats[1:, cv2.CC_STAT_AREA]
    if np.max(areas) < min_comp_area:
        return False

    return True


In [22]:
# Datasets
class SelfSegDatasetQC(Dataset):
    """
    Dataset que gera weak masks e aplica QC. Somente arquivos que passam no QC são mantidos.
    Retorna: img_tensor [1,H,W], mask_tensor [1,H,W], image_path (str)
    """
    def __init__(self, dataset_path, image_size=(256,256), exts=("*.png","*.jpg","*.jpeg","*.bmp")):
        files = []
        for e in exts:
            files += glob(os.path.join(dataset_path, "**", e), recursive=True)
        files = sorted(files)

        self.image_size = image_size
        self.valid_files = []

        print("Gerando weak masks e aplicando QC (isso pode demorar)...")
        for f in tqdm(files):
            try:
                mask = generate_weak_mask_cxr(f, out_size=image_size)
                if qc_mask(mask):
                    self.valid_files.append((f, mask))
            except Exception as e:
                # ignora arquivos problemáticos
                continue

        print(f"Arquivos totais: {len(files)}  ->  Válidos após QC: {len(self.valid_files)}")

    def __len__(self):
        return len(self.valid_files)

    def __getitem__(self, idx):
        img_path, mask_np = self.valid_files[idx]
        # carrega imagem e coloca no tamanho correto
        img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
        img = cv2.resize(img, (self.image_size[1], self.image_size[0]), interpolation=cv2.INTER_AREA)
        img = img.astype(np.float32)/255.0
        img = np.expand_dims(img, axis=0)  # 1,H,W
        img_t = torch.from_numpy(img).float()

        # mask_np já está no tamanho correto porque geramos com out_size em init
        mask = (mask_np > 127).astype(np.float32)
        mask = np.expand_dims(mask, axis=0)
        mask_t = torch.from_numpy(mask).float()

        return img_t, mask_t, img_path


class SelfLearningDataset(Dataset):
    """
    Dataset que combina imagens originais com pseudo-masks (geradas).
    Só carrega pares onde a pseudo-mask existe E passa no QC.
    """
    def __init__(self, img_dir, mask_dir, image_size=(256,256), exts=("*.png","*.jpg","*.jpeg","*.bmp")):
        files = []
        for e in exts:
            files += glob(os.path.join(img_dir, "**", e), recursive=True)
        files = sorted(files)

        self.image_size = image_size
        self.pairs = []

        print("Construindo dataset de self-learning (aplicando QC nas pseudo-masks)...")
        for f in tqdm(files):
            fname = os.path.basename(f)
            mask_path = os.path.join(mask_dir, fname)
            if not os.path.exists(mask_path):
                continue
            mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
            if mask is None:
                continue
            # redimensiona máscara para image_size para QC consistente
            mask_rs = cv2.resize(mask, (image_size[1], image_size[0]), interpolation=cv2.INTER_NEAREST)
            if qc_mask(mask_rs):
                self.pairs.append((f, mask_path))
        print(f"Imagens com pseudo-masks válidas: {len(self.pairs)}")

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        img_path, mask_path = self.pairs[idx]
        img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
        img = cv2.resize(img, (self.image_size[1], self.image_size[0]), interpolation=cv2.INTER_AREA)
        img = img.astype(np.float32)/255.0
        img = np.expand_dims(img, axis=0)
        img_t = torch.from_numpy(img).float()

        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
        mask = cv2.resize(mask, (self.image_size[1], self.image_size[0]), interpolation=cv2.INTER_NEAREST)
        mask = (mask > 127).astype(np.float32)
        mask = np.expand_dims(mask, axis=0)
        mask_t = torch.from_numpy(mask).float()

        return img_t, mask_t, img_path

In [23]:
# Loss combinada 
class DiceBCELoss(nn.Module):
    def __init__(self, bce_weight=1.0, dice_weight=1.0, smooth=1e-5):
        super().__init__()
        self.bce = nn.BCEWithLogitsLoss()
        self.bce_w = bce_weight
        self.dice_w = dice_weight
        self.smooth = smooth

    def forward(self, logits, target):
        # logits: [B,1,H,W], target: [B,1,H,W] float 0..1
        bce = self.bce(logits, target)
        pred = torch.sigmoid(logits)
        inter = (pred * target).sum(dim=(2,3))
        denom = pred.sum(dim=(2,3)) + target.sum(dim=(2,3))
        dice = (2*inter + self.smooth) / (denom + self.smooth)
        dice_loss = 1 - dice.mean()
        return self.bce_w * bce + self.dice_w * dice_loss

# Criação do modelo
def create_model(encoder_name="resnet34", encoder_weights="imagenet", in_channels=1, classes=1):
    model = smp.Unet(
        encoder_name=encoder_name,
        encoder_weights=encoder_weights,
        in_channels=in_channels,
        classes=classes,
        activation=None
    )
    return model

In [24]:
# Training loops e geração de pseudo-labels com QC

from torchvision.utils import save_image

def train_epoch(model, loader, optimizer, loss_fn):
    model.train()
    total_loss = 0.0
    n = 0
    for imgs, masks, _ in loader:
        imgs = imgs.to(device)
        masks = masks.to(device)

        optimizer.zero_grad()
        preds = model(imgs)
        loss = loss_fn(preds, masks)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        n += 1
    return total_loss / (n if n>0 else 1)

def early_learning(model, dataset, lr=1e-4, epochs=5, batch=8, save_path=None):
    loader = DataLoader(dataset, batch_size=batch, shuffle=True, num_workers=2, pin_memory=True)
    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = DiceBCELoss()

    for ep in range(epochs):
        loss = train_epoch(model, loader, optimizer, criterion)
        print(f"[Early Learning] Epoch {ep+1}/{epochs}  loss={loss:.4f}")
        if save_path:
            torch.save(model.state_dict(), f"{save_path}_early_ep{ep+1}.pth")
    return model

def generate_pseudo_labels(model, dataset_files, out_dir, threshold=0.5, qc_on_save=True):
    """
    dataset_files: lista de tuples ou caminhos. Para simplicidade, pode passar o dataset que itera em (img, mask, path)
    Se passar um Dataset (ex: SelfSegDatasetQC), a função itera sobre arquivos originais.
    """
    os.makedirs(out_dir, exist_ok=True)
    model.eval()
    model = model.to(device)

    # construir loader de avaliação que retorna caminhos
    loader = DataLoader(dataset_files, batch_size=1, shuffle=False)

    saved = 0
    skipped = 0
    with torch.no_grad():
        for imgs, _, img_paths in tqdm(loader, desc="Gerando pseudo-labels"):
            # imgs: [1,1,H,W], img_paths: tuple/list com path
            imgs = imgs.to(device)
            logits = model(imgs)
            probs = torch.sigmoid(logits)
            mask_np = (probs.cpu().numpy()[0,0] > threshold).astype(np.uint8) * 255

            fname = os.path.basename(img_paths[0])
            out_path = os.path.join(out_dir, fname)

            # opcionalmente aplicar QC antes de salvar
            if qc_on_save:
                # resize para QC usando same size do loader (imgs)
                H, W = mask_np.shape
                if not qc_mask((mask_np).astype(np.uint8)):
                    skipped += 1
                    # podemos salvar em um diretório de rejeitados se quisermos
                    continue

            cv2.imwrite(out_path, mask_np)
            saved += 1
    print(f"Pseudo masks salvas: {saved}  |  puladas (fail QC): {skipped}")


def self_learning(model, dataset, lr=1e-4, epochs=5, batch=8, save_path=None):
    loader = DataLoader(dataset, batch_size=batch, shuffle=True, num_workers=2, pin_memory=True)
    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = DiceBCELoss()

    for ep in range(epochs):
        loss = train_epoch(model, loader, optimizer, criterion)
        print(f"[Self Learning] Epoch {ep+1}/{epochs}  loss={loss:.4f}")
        if save_path:
            torch.save(model.state_dict(), f"{save_path}_self_ep{ep+1}.pth")
    return model

In [25]:
# Pipeline

INPUT_DIR = "/kaggle/input/chest-xray-pneumonia/chest_xray/test/NORMAL"   
OUT_ROOT = "/kaggle/working/weak_masks"
PSEUDO_DIR = os.path.join(OUT_ROOT, "pseudo_masks")
PSEUDO_ROUND2 = os.path.join(OUT_ROOT, "pseudo_masks_round2")
MODEL_SAVE = "/kaggle/working/unet_model"

image_size = (256,256)

# 1) Gerar weak dataset com QC 
weak_dataset = SelfSegDatasetQC(INPUT_DIR, image_size=image_size)

# 2) Early learning (treina usando as weak masks aprovadas)
model = create_model()
model = early_learning(model, weak_dataset, lr=1e-4, epochs=5, batch=8, save_path=MODEL_SAVE)

# 3) Gerar pseudo-labels (round1) em PSEUDO_DIR. OBS: a função aplica QC antes de salvar
generate_pseudo_labels(model, weak_dataset, PSEUDO_DIR, threshold=0.5, qc_on_save=True)

# 4) Construir dataset de self-learning usando apenas pseudo-masks válidas
pseudo_dataset = SelfLearningDataset(INPUT_DIR, PSEUDO_DIR, image_size=image_size)

# 5) Self-learning (treina com pseudo-masks filtradas)
model = self_learning(model, pseudo_dataset, lr=1e-4, epochs=5, batch=8, save_path=MODEL_SAVE)

# 6) Gerar pseudo-labels round2 
# Recriar um loader/dataset para evitar confusão. OBS: podemos usar pseudo_dataset (que itera imagens válidas)
generate_pseudo_labels(model, pseudo_dataset, PSEUDO_ROUND2, threshold=0.5, qc_on_save=True)

# Ideia: depois podemos treinar um modelo final com pseudo_masks_round2 + weak_dataset combinado.

Gerando weak masks e aplicando QC (isso pode demorar)...


100%|██████████| 234/234 [00:06<00:00, 35.81it/s]


Arquivos totais: 234  ->  Válidos após QC: 131
[Early Learning] Epoch 1/5  loss=1.1192
[Early Learning] Epoch 2/5  loss=0.8617
[Early Learning] Epoch 3/5  loss=0.7486
[Early Learning] Epoch 4/5  loss=0.6651
[Early Learning] Epoch 5/5  loss=0.6049


Gerando pseudo-labels: 100%|██████████| 131/131 [00:03<00:00, 39.31it/s]


Pseudo masks salvas: 90  |  puladas (fail QC): 41
Construindo dataset de self-learning (aplicando QC nas pseudo-masks)...


100%|██████████| 234/234 [00:00<00:00, 1603.78it/s]

Imagens com pseudo-masks válidas: 115





[Self Learning] Epoch 1/5  loss=0.4370
[Self Learning] Epoch 2/5  loss=0.3761
[Self Learning] Epoch 3/5  loss=0.3286
[Self Learning] Epoch 4/5  loss=0.3043
[Self Learning] Epoch 5/5  loss=0.2742


Gerando pseudo-labels: 100%|██████████| 115/115 [00:03<00:00, 35.90it/s]

Pseudo masks salvas: 113  |  puladas (fail QC): 2



