In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!cp -r /content/drive/MyDrive/TreeAi4Species_Competition/deteset_semantic/12_RGB_SemSegm_640_fL.zip -d /content/

In [None]:
!unzip /content/12_RGB_SemSegm_640_fL -d /content/dataset/

In [None]:
dataset_dir = "/content/dataset"
output_dir = "/content/drive/MyDrive/TreeAi4Species_Competition/weights/semantic_20250629-2"

In [None]:
%%capture
!pip install -U git+https://github.com/qubvel-org/segmentation_models.pytorch
!pip install lightning albumentations

In [None]:
import torch
print(torch.__version__)
print(torch.cuda.is_available())

In [None]:
import cv2
import os
import glob
import albumentations as A
from albumentations.pytorch import ToTensorV2
from pathlib import Path

import torch
from torch.optim import lr_scheduler
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import torch.nn as nn

import matplotlib.pyplot as plt
import pytorch_lightning as pl
import seaborn as sns

import segmentation_models_pytorch as smp

import numpy as np

from tqdm.notebook import tqdm

In [None]:
class SegmentationDataset(Dataset):
    """Semantic Segmentation Dataset for drone images with Albumentations"""

    def __init__(self, root_path, split="train", transforms=None) -> None:
        super().__init__()

        image_dir = Path(root_path) / split / "images"
        mask_dir = Path(root_path) / split / "labels"

        self.image_paths = sorted(list(image_dir.glob("*.png")))
        self.mask_paths = sorted(list(mask_dir.glob("*.png")))

        stems_img = {p.stem for p in self.image_paths}
        stems_msk = {p.stem for p in self.mask_paths}
        common = stems_img & stems_msk

        self.image_paths = [p for p in self.image_paths if p.stem in common]
        self.mask_paths = [p for p in self.mask_paths if p.stem in common]

        if transforms is not None:
            self.transforms = transforms
        else:
            if split == "train":
                self.transforms = A.Compose([
                    A.Resize(640, 640),
                    A.HorizontalFlip(p=0.5),
                    A.RandomRotate90(p=0.5),
                    A.Transpose(p=0.5),
                    A.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.15, rotate_limit=20, p=0.5, border_mode=cv2.BORDER_REFLECT),
                    A.CoarseDropout(num_holes_range=(1, 10), hole_height_range=(16, 64), hole_width_range=(16, 64), fill=(0, 0, 0), fill_mask=0, p=0.4),
                    A.RandomBrightnessContrast(brightness_limit=0.05, contrast_limit=0.05, p=0.5),
                    A.HueSaturationValue(hue_shift_limit=5, sat_shift_limit=8, val_shift_limit=5, p=0.5),
                    A.OneOf([A.GaussianBlur(blur_limit=(3, 7), p=0.5), A.MedianBlur(blur_limit=(3, 7), p=0.5),], p=0.3),
                    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
                    ToTensorV2()
                ])
            else:  # val / test
                self.transforms = A.Compose([
                    A.Resize(640, 640),
                    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
                    ToTensorV2()
                ])

    def convert_mask(self, mask):
        return mask

    def __len__(self):
        return len(self.image_paths)


    def __getitem__(self, index):
        image = cv2.imread(str(self.image_paths[index]))
        if image is None:
            raise FileNotFoundError(f"Image not found at {self.image_paths[index]}")
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        mask = cv2.imread(str(self.mask_paths[index]), cv2.IMREAD_UNCHANGED)
        if mask is None:
            raise FileNotFoundError(f"Mask not found at {self.mask_paths[index]}")
        mask = self.convert_mask(mask)

        augmented = self.transforms(image=image, mask=mask)
        image = augmented['image']
        mask = augmented['mask'].long()

        return image, mask

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from torch.utils.data import Dataset

def display_data_sample(dataset: Dataset, num_samples: int = 1) -> None:
    for i in range(num_samples):
        idx = np.random.randint(0, len(dataset))
        image, mask = dataset[idx]

        image_np = image.permute(1, 2, 0).cpu().numpy()
        image_np = (image_np * np.array([0.229, 0.224, 0.225]) +
                    np.array([0.485, 0.456, 0.406]))
        image_np = np.clip(image_np, 0, 1)

        mask_np = mask.cpu().numpy()

        plt.figure(figsize=(10, 4))

        plt.subplot(1, 2, 1)
        plt.imshow(image_np)
        plt.title("Image")
        plt.axis("off")

        plt.subplot(1, 2, 2)
        cmap = plt.get_cmap("nipy_spectral", 62)
        im = plt.imshow(mask_np, cmap=cmap, vmin=0, vmax=61)
        plt.title("Segmentation Mask")
        plt.axis("off")
        plt.colorbar(im, fraction=0.046, pad=0.04, ticks=range(0, 62, 5))

        plt.tight_layout()
        plt.show()

In [None]:
class EarlyStopping:
    def __init__(self, patience=10, verbose=False, delta=0, path='checkpoint.pt', trace_func=print):
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_metric_min = -np.Inf
        self.delta = delta
        self.path = path
        self.trace_func = trace_func

    def __call__(self, val_metric, model):
        score = val_metric
        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_metric, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            if self.verbose: self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience: self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_metric, model)
            self.counter = 0
    def save_checkpoint(self, val_metric, model):
        if self.verbose: self.trace_func(f'Validation metric improved ({self.val_metric_min:.6f} --> {val_metric:.6f}). Saving model ...')
        torch.save(model.state_dict(), self.path)
        self.val_metric_min = val_metric

In [None]:
BS = 6

train_ds = SegmentationDataset(root_path = dataset_dir, split='train')
train_dataloader = DataLoader(train_ds, batch_size=BS, shuffle=True, num_workers=0)
val_ds = SegmentationDataset(root_path = dataset_dir, split='val')
val_dataloader = DataLoader(val_ds, batch_size=BS, shuffle=False, num_workers=0)

In [None]:
display_data_sample(train_ds, 5)

In [None]:
torch.cuda.empty_cache()

In [None]:
os.makedirs(output_dir, exist_ok=True)
weights_path = os.path.join(output_dir, 'best_model.pth')
print(f"Model will be save in: {weights_path}")

EPOCHS = 100
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
NUM_CLASSES = 62

print("Using DeepLabV3+ with EfficientNet-B5 backbone")
model = smp.DeepLabV3Plus(
    encoder_name="efficientnet-b5",
    encoder_weights="imagenet",
    classes=NUM_CLASSES,
    activation=None,
).to(DEVICE)


optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS, eta_min=1e-6)

loss1 = smp.losses.DiceLoss(mode='multiclass', from_logits=True)
loss2 = nn.CrossEntropyLoss()
criterion = lambda p, t: 0.5 * loss1(p, t) + 0.5 * loss2(p, t)

patience = 15
early_stopping = EarlyStopping(patience=patience, verbose=True, path=weights_path)

scaler = torch.cuda.amp.GradScaler()

train_losses = []
val_losses = []

for e in range(EPOCHS):
    model.train()
    running_train_loss = 0.0

    train_loop = tqdm(train_dataloader, desc=f"Epoch {e+1}/{EPOCHS} [Training]", leave=False)
    for image, mask in train_loop:
        image = image.to(DEVICE)
        mask = mask.to(DEVICE)

        optimizer.zero_grad()

        with torch.cuda.amp.autocast():
            output = model(image)
            train_loss = criterion(output, mask.long())

        scaler.scale(train_loss).backward()
        scaler.step(optimizer)
        scaler.update()

        running_train_loss += train_loss.item()
        train_loop.set_postfix(loss=train_loss.item())

    avg_train_loss = running_train_loss / len(train_dataloader)
    train_losses.append(avg_train_loss)

    model.eval()
    running_val_loss = 0.0
    val_iou_scores = []

    val_loop = tqdm(val_dataloader, desc=f"Epoch {e+1}/{EPOCHS} [Validation]", leave=False)
    with torch.no_grad():
        for image, mask in val_loop:
            image = image.to(DEVICE)
            mask = mask.to(DEVICE)

            with torch.cuda.amp.autocast():
                output = model(image)
                val_loss = criterion(output, mask.long())

            running_val_loss += val_loss.item()
            preds = torch.argmax(output, dim=1)

            tp, fp, fn, tn = smp.metrics.get_stats(preds, mask.long(), mode='multiclass', num_classes=NUM_CLASSES)
            iou = smp.metrics.iou_score(tp, fp, fn, tn, reduction='macro')

            val_iou_scores.append(iou)
            val_loop.set_postfix(val_loss=val_loss.item(), val_iou=iou.item())

    avg_val_loss = running_val_loss / len(val_dataloader)
    val_losses.append(avg_val_loss)

    epoch_val_iou = torch.stack(val_iou_scores).mean().item()

    print(f"Epoch: {e+1}/{EPOCHS} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f} | Val IoU: {epoch_val_iou:.4f}")

    scheduler.step()

    early_stopping(epoch_val_iou, model)

    if early_stopping.early_stop:
        print("Early stopping triggered")
        break

print(f"Training finished. Loading best model weights from {weights_path}")
model.load_state_dict(torch.load(weights_path))

In [None]:
sns.lineplot(x = range(len(train_losses)), y= train_losses).set(title='Train Loss')
plt.show()
sns.lineplot(x = range(len(train_losses)), y= val_losses).set(title='Validation Loss')
plt.show()

In [None]:
NUM_CLASSES = 62
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
output_path = os.path.join(r"/content/drive/MyDrive/TreeAi4Species_Competition/weights/semantic_20250629-2/best_model.pth")


model = smp.DeepLabV3Plus(
    encoder_name="efficientnet-b5",
    encoder_weights=None,
    classes=NUM_CLASSES,
    activation=None,
)

model.load_state_dict(torch.load(output_path, map_location=DEVICE))

model.to(DEVICE)
model.eval()

In [None]:
val_iou = 0.0
val_fscore = 0.0
val_precision = 0.0
val_recall = 0.0
num_batches = len(val_dataloader)

num_classes = 62  # or model.output_shape[1]

def iou_metric(preds, targets, eps=1e-6):
    intersection = (preds * targets).sum(dim=(1,2,3))
    union = preds.sum(dim=(1,2,3)) + targets.sum(dim=(1,2,3)) - intersection
    iou = (intersection + eps) / (union + eps)
    return iou.mean()

def fscore_metric(preds, targets, beta=1, eps=1e-6):
    tp = (preds * targets).sum(dim=(1,2,3))
    fp = (preds * (1 - targets)).sum(dim=(1,2,3))
    fn = ((1 - preds) * targets).sum(dim=(1,2,3))
    fscore = ((1 + beta**2) * tp + eps) / ((1 + beta**2) * tp + beta**2 * fn + fp + eps)
    return fscore.mean()

def precision_metric(preds, targets, eps=1e-6):
    tp = (preds * targets).sum(dim=(1,2,3))
    fp = (preds * (1 - targets)).sum(dim=(1,2,3))
    precision = (tp + eps) / (tp + fp + eps)
    return precision.mean()

def recall_metric(preds, targets, eps=1e-6):
    tp = (preds * targets).sum(dim=(1,2,3))
    fn = ((1 - preds) * targets).sum(dim=(1,2,3))
    recall = (tp + eps) / (tp + fn + eps)
    return recall.mean()

with torch.no_grad():
    for image, mask in tqdm(val_dataloader, desc="Validation Progress", unit="batch"):
        image = image.to(DEVICE)
        mask = mask.to(DEVICE)  # shape: [B, H, W]

        output = model(image)  # shape: [B, C, H, W]
        probs = torch.softmax(output, dim=1)  # Softmax over class channel
        preds = torch.argmax(probs, dim=1)    # shape: [B, H, W]

        # One-hot encode
        preds_onehot = torch.nn.functional.one_hot(preds, num_classes).permute(0, 3, 1, 2).float()
        mask_onehot = torch.nn.functional.one_hot(mask, num_classes).permute(0, 3, 1, 2).float()

        val_iou += iou_metric(preds_onehot, mask_onehot).item()
        val_fscore += fscore_metric(preds_onehot, mask_onehot).item()
        val_precision += precision_metric(preds_onehot, mask_onehot).item()
        val_recall += recall_metric(preds_onehot, mask_onehot).item()

val_iou /= num_batches
val_fscore /= num_batches
val_precision /= num_batches
val_recall /= num_batches

print(f"Validation IoU:      {val_iou:.4f}")
print(f"Validation F1 Score: {val_fscore:.4f}")
print(f"Validation Precision:{val_precision:.4f}")
print(f"Validation Recall:   {val_recall:.4f}")

In [None]:
def display_val_predictions(
    model,
    dataset: Dataset,
    num_samples: int = 3,
    num_classes: int = 62,
    cmap: str = "nipy_spectral"
) -> None:
    model.eval()

    for i in range(num_samples):
        idx = np.random.randint(0, len(dataset))
        image, mask = dataset[idx]

        image_input = image.unsqueeze(0).to(DEVICE)

        with torch.no_grad():
            output = model(image_input)
            pred_mask = torch.argmax(output, dim=1).squeeze(0).cpu().numpy()

        image_np = image.permute(1, 2, 0).cpu().numpy()
        mask_np = mask.squeeze().cpu().numpy()

        plt.figure(figsize=(12, 4))

        plt.subplot(1, 3, 1)
        plt.imshow(image_np)
        plt.title("Input Image")
        plt.axis("off")

        plt.subplot(1, 3, 2)
        plt.imshow(mask_np, cmap=cmap, vmin=0, vmax=num_classes - 1)
        plt.title("Ground Truth")
        plt.axis("off")

        plt.subplot(1, 3, 3)
        plt.imshow(pred_mask, cmap=cmap, vmin=0, vmax=num_classes - 1)
        plt.title("Prediction")
        plt.axis("off")

        plt.tight_layout()
        plt.show()

In [None]:
display_val_predictions(model, val_ds, num_samples=20, num_classes=62)

In [None]:
def remove_small_objects(pred_mask, min_size_pixels):
    num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(pred_mask.astype(np.uint8), connectivity=8)
    cleaned_mask = np.zeros_like(pred_mask)
    for i in range(1, num_labels):
        if stats[i, cv2.CC_STAT_AREA] >= min_size_pixels:
            cleaned_mask[labels == i] = pred_mask[labels == i]
    return cleaned_mask


def inference_on_folder(
    model,
    test_dir: str,
    save_dir: str = 'predictions',
    device: str = 'cuda',
    plot: bool = False,
    use_tta: bool = True,
    area_thres: int = 100
) -> None:

    os.makedirs(save_dir, exist_ok=True)
    model.to(device)
    model.eval()

    image_paths = sorted(glob.glob(os.path.join(test_dir, '*.png')))

    base_transform = A.Compose([
        A.Resize(640, 640),
        A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        ToTensorV2()
    ])

    desc = "Inferencing with TTA" if use_tta else "Inferencing"
    for idx, img_path in enumerate(tqdm(image_paths, desc=desc), start=1):
        image_name = Path(img_path).stem
        image = cv2.imread(img_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        if use_tta:
            flipped_image = cv2.flip(image, 1)

            original_tensor = base_transform(image=image)['image']
            flipped_tensor = base_transform(image=flipped_image)['image']

            input_batch = torch.stack([original_tensor, flipped_tensor]).to(device)

            with torch.no_grad():
                outputs = model(input_batch)
                probs = torch.softmax(outputs, dim=1)

                original_probs = probs[0]
                flipped_probs = probs[1]
                unflipped_probs = torch.flip(flipped_probs, dims=[-1])

                avg_probs = (original_probs + unflipped_probs) / 2.0
                pred = torch.argmax(avg_probs, dim=0).cpu().numpy()
                # pred = remove_small_objects(pred, min_size_pixels=area_thres)

        else:
            aug = base_transform(image=image)
            input_tensor = aug['image'].unsqueeze(0).to(device)

            with torch.no_grad():
                output = model(input_tensor)
                pred = torch.argmax(output, dim=1).squeeze(0).cpu().numpy()
                # pred = remove_small_objects(pred, min_size_pixels=area_thres)


        filename = f"{image_name}.npy"
        save_path = os.path.join(save_dir, filename)
        np.save(save_path, pred)

        if plot:
            plt.figure(figsize=(8, 4))
            plt.subplot(1, 2, 1)
            plt.imshow(image)
            plt.title("Input Image")
            plt.axis('off')

            plot_title = "Prediction (with TTA)" if use_tta else "Prediction"
            plt.subplot(1, 2, 2)
            plt.imshow(pred, cmap='nipy_spectral')
            plt.title(plot_title)
            plt.axis('off')

            plt.tight_layout()
            plt.show()

In [None]:
inference_on_folder(model, os.path.join(dataset_dir, 'test', 'images'), save_dir= os.path.join(output_dir, 'test_predictions'), device=DEVICE, plot=False, use_tta=True, area_thres=100)

In [None]:
import time
time.sleep(60)

from google.colab import runtime
runtime.unassign()