In [None]:
import os
import random
import math
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from PIL import Image
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import cv2
import io
import base64
import pandas as pd
!pip install segmentation-models-pytorch
import segmentation_models_pytorch as smp


Image.MAX_IMAGE_PIXELS = None


DATA_ROOT = Path("/kaggle/input/dis5k-data")
TEST_ROOT = Path("/kaggle/input/kagle-test")
MODEL_DIR = Path("/kaggle/working")
VAL_ROOT = Path("/kaggle/input/validation/validation")

# ==================== SEED / DEVICE ====================
SEED = 42
random.seed(SEED)
np.random.seed(SEED)


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# ==================== PATHS / CONFIG ====================

MODEL_DIR.mkdir(parents=True, exist_ok=True)


IMG_SIZE = (1024, 1024)
BATCH_SIZE = 4
NUM_EPOCHS = 20


LR = 1e-4
ETA_MIN = 1e-7
WEIGHT_DECAY = 1e-4
RUN_TRAINING = True



# ==================== DATASET ====================
class SegmentationDataset(Dataset):
    def __init__(
        self,
        data_root,
        size=(1024, 1024),
        mask_subfolder="gt/gt",
        image_subfolder="im/im",
        image_format=".jpg",
        mask_format=".png",
        num_mask_channels=1,
    ):
        self.size = size
        self.data_root = Path(data_root)
        if not self.data_root.exists():
            raise ValueError("Instance images root doesn't exists.")


        self.mask_subfolder = mask_subfolder
        self.image_subfolder = image_subfolder
        self.image_format = image_format
        self.mask_format = mask_format
        self.num_mask_channels = num_mask_channels


        self.data = [
            i.rsplit(".", 1)[0]
            for i in os.listdir(str(self.data_root / self.mask_subfolder))
            if i.endswith(self.mask_format)
        ]
        self._length = len(self.data)


        self.image_transforms = transforms.Compose(
            [
                transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
                transforms.ToTensor(),
                transforms.Normalize([0.5], [0.5]),
            ]
        )


        self.mask_transforms = transforms.Compose(
            [
                transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
                transforms.ToTensor(),
            ]
        )


    def __len__(self):
        return self._length


    def __getitem__(self, index):
        obj = self.data[index]


        mask = Image.open(self.data_root / self.mask_subfolder / f"{obj}{self.mask_format}")
        if self.num_mask_channels == 3:
            mask = mask.convert("RGB")


        img = Image.open(self.data_root / self.image_subfolder / f"{obj}{self.image_format}").convert("RGB")


        mask = self.mask_transforms(mask)
        img = self.image_transforms(img)


        return {"mask": mask, "img": img}



def get_dataloaders(train_root: Path, size=(1024, 1024), batch_size=4):
    train_dataset = SegmentationDataset(
        data_root=train_root,
        size=size,
        mask_subfolder="gt/gt",
        image_subfolder="im/im",
        image_format=".jpg",
        mask_format=".png",
        num_mask_channels=1
    )
    val_dataset = SegmentationDataset(
        data_root=VAL_ROOT,
        size=size,
        mask_subfolder="gt",
        image_subfolder="im",
        image_format=".jpg",
        mask_format=".png",
        num_mask_channels=1
    )

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, pin_memory=True,drop_last=True)
    return train_loader, val_loader


class DiceLoss(nn.Module):
    """dice loss для бинарной сегментации"""
    def __init__(self, smooth=1.0):
        super().__init__()
        self.smooth = smooth

    def forward(self, logits, targets):
        probs = torch.sigmoid(logits)
        probs_flat = probs.view(-1)
        targets_flat = targets.view(-1)

        intersection = (probs_flat * targets_flat).sum()
        union = probs_flat.sum() + targets_flat.sum()

        dice = 1 - (2 * intersection + self.smooth) / (union + self.smooth)
        return dice



class CombinedLoss(nn.Module):
    """комбинированный loss dice + BCE"""
    def __init__(self, dice_weight=0.5, bce_weight=0.5, smooth=1.0):
        super().__init__()
        self.dice_weight = dice_weight
        self.bce_weight = bce_weight

        self.dice_loss = DiceLoss(smooth=smooth)
        self.bce_loss = nn.BCEWithLogitsLoss()

    def forward(self, logits, targets):
        dice = self.dice_loss(logits, targets)
        bce = self.bce_loss(logits, targets)

        combined = self.dice_weight * dice + self.bce_weight * bce
        return combined


# ==================== MODEL ====================

model = smp.UnetPlusPlus(
    encoder_name="efficientnet-b3",        
    encoder_weights="imagenet",     
    in_channels=3,                 
    classes=1,                     
    activation=None,
    dropout=0.5
)

if torch.cuda.device_count() > 1:
    print(f"Using {torch.cuda.device_count()} GPUs")
    model = nn.DataParallel(model)

model = model.to(device)
# ==================== METRICS ====================
def mse_metric(logits, target):
    probs = torch.sigmoid(logits)
    return F.mse_loss(probs, target)



def train_epoch(model, loader, optimizer, criterion):
    model.train()
    total_loss = 0.0
    total_mse = 0.0


    for batch in tqdm(loader, desc="train", leave=False):
        imgs = batch["img"].to(device)
        masks = batch["mask"].to(device)


        optimizer.zero_grad()
        logits = model(imgs)
        loss = criterion(logits, masks)
        loss.backward()
        optimizer.step()


        total_loss += loss.item() * imgs.size(0)
        total_mse += mse_metric(logits, masks).item() * imgs.size(0)


    n = len(loader.dataset)
    return total_loss / n, total_mse / n



def eval_epoch(model, loader, criterion):
    model.eval()
    total_loss = 0.0
    total_mse = 0.0


    with torch.no_grad():
        for batch in tqdm(loader, desc="val", leave=False):
            imgs = batch["img"].to(device)
            masks = batch["mask"].to(device)


            logits = model(imgs)
            loss = criterion(logits, masks)


            total_loss += loss.item() * imgs.size(0)
            total_mse += mse_metric(logits, masks).item() * imgs.size(0)


    n = len(loader.dataset)
    return total_loss / n, total_mse / n



# ==================== PLOTS  ====================
def plot_metrics(train_losses, val_losses, train_mses, val_mses, learning_rates, save_dir):
    fig, axes = plt.subplots(2, 2, figsize=(14, 10))


    #loss
    axes[0, 0].plot(train_losses, label="Train Loss", linewidth=2, marker="o")
    axes[0, 0].plot(val_losses, label="Val Loss", linewidth=2, marker="s")
    axes[0, 0].set_xlabel("Epoch")
    axes[0, 0].set_ylabel("Loss")
    axes[0, 0].set_title("Training Loss")
    axes[0, 0].legend()
    axes[0, 0].grid(True, alpha=0.3)


    #MSE
    axes[0, 1].plot(train_mses, label="Train MSE", linewidth=2, marker="o")
    axes[0, 1].plot(val_mses, label="Val MSE", linewidth=2, marker="s")
    axes[0, 1].set_xlabel("Epoch")
    axes[0, 1].set_ylabel("MSE")
    axes[0, 1].set_title("MSE Metric")
    axes[0, 1].legend()
    axes[0, 1].grid(True, alpha=0.3)


    # LR
    axes[1, 0].plot(learning_rates, linewidth=2, color="green", marker="o")
    axes[1, 0].set_xlabel("Epoch")
    axes[1, 0].set_ylabel("Learning Rate")
    axes[1, 0].set_title("Schedule (Cosine)")
    axes[1, 0].grid(True, alpha=0.3)
    axes[1, 0].set_yscale("log")


    #val loss 
    axes[1, 1].plot(val_losses, linewidth=2.5, color="red", marker="o")
    axes[1, 1].set_xlabel("Epoch")
    axes[1, 1].set_ylabel("Val Loss")
    axes[1, 1].set_title("validation loss ")
    axes[1, 1].grid(True, alpha=0.3)


    plt.tight_layout()
    plt.savefig(save_dir / "training_plots.png", dpi=150, bbox_inches="tight")
    plt.show()



# ==================== TRAIN ====================


train_loader, val_loader = get_dataloaders(DATA_ROOT, size=IMG_SIZE, batch_size=BATCH_SIZE)



criterion = CombinedLoss(dice_weight=0.5, bce_weight=0.5, smooth=1.0)

optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)


scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, 
    mode='min',
    factor=0.5,
    patience=2,
)


best_val_mse = float("inf")
best_path = MODEL_DIR / "unet_dis_best.pth"


train_losses, val_losses, train_mses, val_mses, learning_rates = [], [], [], [], []


if RUN_TRAINING:
    print(f"Training {NUM_EPOCHS} epochs (cosine LR, no warmup)")
    print(f"Base LR: {LR}, Min LR: {ETA_MIN}")
    print(f"L2 Regularization (weight_decay): {WEIGHT_DECAY}\n")


    for epoch in range(1, NUM_EPOCHS + 1):
        train_loss, train_mse = train_epoch(model, train_loader, optimizer, criterion)
        val_loss, val_mse = eval_epoch(model, val_loader, criterion)


        current_lr = optimizer.param_groups[0]["lr"]


        train_losses.append(train_loss)
        val_losses.append(val_loss)
        train_mses.append(train_mse)
        val_mses.append(val_mse)
        learning_rates.append(current_lr)


        print(
            f"Epoch {epoch:2d}/{NUM_EPOCHS}: "
            f"train_loss={train_loss:.6f} val_loss={val_loss:.6f} val_mse={val_mse:.6f} LR={current_lr:.8f}"
        )


        if val_mse < best_val_mse:
            best_val_mse = val_mse
            torch.save({"model_state": model.state_dict()}, best_path)
            print("  Saved best model")


        plot_metrics(train_losses, val_losses, train_mses, val_mses, learning_rates, MODEL_DIR)


        old_lr = current_lr
        scheduler.step(val_mse)  
        new_lr = optimizer.param_groups[0]["lr"]
 
        if new_lr != old_lr:
            print(f" LR reduced: {old_lr:.8f} → {new_lr:.8f}")


    print(f"\n Training completed! Best val_mse: {best_val_mse:.6f}\n")


else:
    if best_path.exists():
        model.load_state_dict(torch.load(best_path)["model_state"])


# ==================== POSTPROCESSING ====================
def postprocess_alpha_mask(mask_prob):
    """постобработка альфамаски морфология + сглаживание"""
    mask = (mask_prob * 255).astype(np.uint8)
    kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
    mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel, iterations=1)
    mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel, iterations=1)
    mask = cv2.GaussianBlur(mask, (3, 3), 1.0)
    return mask


# ==================== INFERENCE ====================
model.eval()

test_dataset = TestImageDataset(TEST_ROOT, size=(1024, 1024))
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

rows = []
with torch.no_grad():
    for batch in tqdm(test_loader, desc="test", leave=False):
        imgs = batch["img"].to(device)
        names = batch["path"]
        logits = model(imgs)
        probs = torch.sigmoid(logits)
        
        mask_prob = probs[0, 0].cpu().numpy()
        mask = postprocess_alpha_mask(mask_prob) \
        
        pil_mask = Image.fromarray(mask, mode="L")
        buf = io.BytesIO()
        pil_mask.save(buf, format="PNG")
        buf.seek(0)
        image_utf = base64.b64encode(buf.getvalue()).decode("utf-8")
        rows.append({"filename": names[0].split(".")[0], "image_utf": image_utf})


submission = pd.DataFrame(rows)
submission_path = MODEL_DIR / "submission.csv"
submission.to_csv(submission_path, index=False)
print(f"Saved submission to {submission_path}")

In [None]:
import io
import base64
import pandas as pd

class TestImageDataset(Dataset):
    def __init__(self, test_root, size=(1024, 1024)):
        self.size = size
        self.test_root = Path(test_root)

        if not self.test_root.exists():
            raise ValueError("Test images root doesn't exist")

        # Ищем все изображения в папке
        self.image_files = sorted([
            f for f in os.listdir(str(self.test_root))
            if f.lower().endswith(('.png', '.jpg', '.jpeg'))
        ])

        self.image_transforms = transforms.Compose([
            transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5]),
        ])

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, index):
        filename = self.image_files[index]
        img_path = self.test_root / filename
        img = Image.open(img_path).convert("RGB")
        img = self.image_transforms(img)

        return {"img": img, "path": filename}

In [None]:
model.eval()


test_dataset = TestImageDataset(TEST_ROOT, size=(1024, 1024))
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)


rows = []
with torch.no_grad():
    for batch in tqdm(test_loader, desc="test", leave=False):
        imgs = batch["img"].to(device)
        names = batch["path"]
        logits = model(imgs)
        probs = torch.sigmoid(logits)
        mask = (probs[0, 0].cpu().numpy() * 255.0).astype(np.uint8)
        print(mask.shape)
        pil_mask = Image.fromarray(mask, mode="L")
        buf = io.BytesIO()
        pil_mask.save(buf, format="PNG")
        image_utf = base64.b64encode(buf.getvalue()).decode("utf-8")
        rows.append({"filename": names[0].split(".")[0], "image_utf": image_utf})


submission = pd.DataFrame(rows)
submission_path = MODEL_DIR / "submission.csv"
submission.to_csv(submission_path, index=False)
print(f"Saved submission to {submission_path}")