In [1]:
import os
import cv2
import torch as T
from torch.utils.data import Dataset, DataLoader
import numpy as np
from glob import glob
import re
import torch.nn as nn
import timm
import torch.nn.functional as F
from sklearn.cluster import KMeans
from tqdm import tqdm

In [2]:
BASE_PATH = r"D:\AAU Internship\Code\CWF-788\IMAGE512x384"
DEVICE = T.device("cuda" if T.cuda.is_available() else "cpu")
print(f"🚀 Using device: {DEVICE}")

🚀 Using device: cuda


In [3]:
class CropWeedDataset(Dataset):
    def __init__(self, folder_path):
        self.image_paths = glob(os.path.join(folder_path, "*.jpg"))
        self.image_paths += glob(os.path.join(folder_path, "*.png"))

        def natural_key(string):
            return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string)]

        self.image_paths.sort(key=natural_key)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        img = cv2.imread(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = img.astype(np.float32) / 255.0
        img_tensor = T.from_numpy(img).permute(2, 0, 1).contiguous()

        return img_tensor.to(DEVICE), img_path

def get_loader(folder_path, batch_size=4):
    dataset = CropWeedDataset(folder_path)
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
    return loader

train_path = os.path.join(BASE_PATH, "train_new")
val_path   = os.path.join(BASE_PATH, "validation_new")
test_path  = os.path.join(BASE_PATH, "test_new")
train_loader = get_loader(train_path)
val_loader   = get_loader(val_path)
test_loader  = get_loader(test_path)

In [4]:
def compute_fused_contours(batch_imgs):
    T.cuda.empty_cache()
    fused_maps = []
    batch_imgs_np = batch_imgs.detach().cpu().numpy()

    for img in batch_imgs_np:
        img_rgb = np.transpose(img, (1, 2, 0)) * 255.0
        img_rgb = img_rgb.astype(np.uint8)
        gray = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2GRAY)

        # Sobel
        sobelx = cv2.Sobel(gray, cv2.CV_64F, 1, 0, ksize=3)
        sobely = cv2.Sobel(gray, cv2.CV_64F, 0, 1, ksize=3)
        sobel = cv2.magnitude(sobelx, sobely)
        sobel = np.clip(sobel / sobel.max(), 0, 1)

        # Canny
        canny = cv2.Canny(gray, 100, 200) / 255.0

        # Fuse Sobel + Canny
        fused = np.maximum(sobel, canny)

        fused_tensor = T.tensor(fused, dtype=T.float32).unsqueeze(0)
        fused_maps.append(fused_tensor)

    fused_batch = T.stack(fused_maps).to(batch_imgs.device)
    T.cuda.empty_cache()
    return fused_batch

In [5]:
class HRNetW32FeatureExtractor(nn.Module):
    def __init__(self, device="cuda"):
        super().__init__()
        self.device = device

        # Load pretrained HRNet-W32 from timm
        self.backbone = timm.create_model(
            "hrnet_w32", 
            pretrained=True, 
            features_only=True
        ).to(self.device)

        # HRNet returns multi-scale outputs (C1, C2, C3, C4)
        # We'll use the highest-resolution one: index 0
        self.selected_feature_idx = 0

        # Freeze the backbone
        self.backbone.eval()

    def forward(self, x):
        with T.no_grad():
            features = self.backbone(x)  # List of [B, C, H', W']
            feature_map = features[self.selected_feature_idx]  # [B, 32, 96, 128]

            # Upsample to match original input size (384×512)
            feature_map_upsampled = F.interpolate(
                feature_map,
                size=(384, 512),
                mode='bicubic',  # Better edge preservation
                align_corners=False
            )

            return feature_map_upsampled  # [B, 32, 384, 512]


In [6]:
class ContourFeatureFusion(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv_fusion = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, cnn_features, edge_maps):
        # cnn_features: [B, C1, H, W]
        # edge_maps:    [B, 1, H, W]
        x = T.cat([cnn_features, edge_maps], dim=1)  # [B, C1+1, H, W]
        fused = self.conv_fusion(x)
        return fused


In [None]:
def segment_kmeans(fused_features, n_clusters=2):
    masks = []
    fused_np = fused_features.detach().cpu().numpy()

    for feat in tqdm(fused_np, desc="🧠 Running K-Means on batch"):
        C, H, W = feat.shape
        flat_feat = feat.reshape(C, -1).T

        kmeans = KMeans(n_clusters=n_clusters, n_init='auto', random_state=42)
        labels = kmeans.fit_predict(flat_feat)

        mask = labels.reshape(H, W)
        masks.append(mask.astype(np.uint8))

    return masks
