In [None]:
import torch
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image
import glob
import os
import numpy as np

class MITBuildingsDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None):

        # --- Collect all image & mask paths ---
        self.image_paths = sorted(
            glob.glob(os.path.join(image_dir, "*.png")) +
            glob.glob(os.path.join(image_dir, "*.jpg")) +
            glob.glob(os.path.join(image_dir, "*.jpeg"))+
            glob.glob(os.path.join(image_dir, "*.tiff"))
        )

        self.mask_paths = sorted(
            glob.glob(os.path.join(mask_dir, "*.png")) +
            glob.glob(os.path.join(mask_dir, "*.jpg")) +
            glob.glob(os.path.join(mask_dir, "*.jpeg"))+
            glob.glob(os.path.join(image_dir, "*.tiff"))
        )

        self.transform = transform

        # --- Ensure equal number of files ---
        assert len(self.image_paths) == len(self.mask_paths), \
            f"Found {len(self.image_paths)} images but {len(self.mask_paths)} masks"

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        mask_path = self.mask_paths[idx]

        # --- Load image and mask ---
        img = Image.open(img_path).convert("RGB")
        mask = Image.open(mask_path).convert("L")

        # --- Resize to 512x512 ---
        img = img.resize((512, 512))
        mask = mask.resize((512, 512))

        # --- Convert to tensor ---
        img = transforms.ToTensor()(img)

        mask = np.array(mask, dtype=np.float32)
        mask = mask / 255.0  # normalize 0‚Äì1
        mask = np.expand_dims(mask, axis=0)
        mask = torch.tensor(mask, dtype=torch.float32)
        mask = (mask > 0.5).float()  # binarize

        # --- Apply optional transforms ---
        if self.transform:
            transformed = self.transform(image=img, mask=mask)
            img = transformed["image"]
            mask = transformed["mask"]

        return img, mask


In [None]:
import os
import random
import shutil
from tqdm import tqdm
import numpy as np
from PIL import Image
import cv2
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt

# =========================================================
# 1Ô∏è‚É£ U-NET MODEL
# =========================================================
class UNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=1):
        super(UNet, self).__init__()

        def conv_block(in_c, out_c):
            return nn.Sequential(
                nn.Conv2d(in_c, out_c, 3, padding=1),
                nn.BatchNorm2d(out_c),
                nn.ReLU(inplace=True),
                nn.Conv2d(out_c, out_c, 3, padding=1),
                nn.BatchNorm2d(out_c),
                nn.ReLU(inplace=True)
            )

        self.enc1 = conv_block(in_channels, 64)
        self.enc2 = conv_block(64, 128)
        self.enc3 = conv_block(128, 256)
        self.enc4 = conv_block(256, 512)
        self.pool = nn.MaxPool2d(2)

        self.bottleneck = conv_block(512, 1024)

        self.upconv4 = nn.ConvTranspose2d(1024, 512, 2, stride=2)
        self.dec4 = conv_block(1024, 512)
        self.upconv3 = nn.ConvTranspose2d(512, 256, 2, stride=2)
        self.dec3 = conv_block(512, 256)
        self.upconv2 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.dec2 = conv_block(256, 128)
        self.upconv1 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.dec1 = conv_block(128, 64)

        self.final = nn.Conv2d(64, out_channels, 1)

    def forward(self, x):
        c1 = self.enc1(x)
        p1 = self.pool(c1)

        c2 = self.enc2(p1)
        p2 = self.pool(c2)

        c3 = self.enc3(p2)
        p3 = self.pool(c3)

        c4 = self.enc4(p3)
        p4 = self.pool(c4)

        bn = self.bottleneck(p4)

        u4 = self.upconv4(bn)
        u4 = torch.cat([u4, c4], dim=1)
        d4 = self.dec4(u4)

        u3 = self.upconv3(d4)
        u3 = torch.cat([u3, c3], dim=1)
        d3 = self.dec3(u3)

        u2 = self.upconv2(d3)
        u2 = torch.cat([u2, c2], dim=1)
        d2 = self.dec2(u2)

        u1 = self.upconv1(d2)
        u1 = torch.cat([u1, c1], dim=1)
        d1 = self.dec1(u1)

        out = self.final(d1)  # no sigmoid here
        return out

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
import os
import matplotlib.pyplot as plt

def calculate_accuracy(preds, masks):
    """Compute pixel-wise binary segmentation accuracy."""
    preds_bin = (preds > 0.5).float()
    correct = (preds_bin == masks).float().sum()
    total = masks.numel()
    return (correct / total).item()

def train_model(model, train_loader, val_loader, device, epochs=20, lr=1e-4, save_dir="models"):
    """
    Train U-Net model for binary segmentation and plot loss & accuracy curves.
    """
    os.makedirs(save_dir, exist_ok=True)

    criterion = nn.BCEWithLogitsLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)

    model.to(device)
    best_val_loss = float('inf')

    # Store metrics
    history = {
        'train_loss': [],
        'val_loss': [],
        'train_acc': [],
        'val_acc': []
    }

    for epoch in range(1, epochs + 1):
        print(f"\nüöÄ Epoch {epoch}/{epochs}")
        
        # ----- Training -----
        model.train()
        train_loss, train_acc = 0.0, 0.0

        for imgs, masks in tqdm(train_loader, desc="Training", leave=False):
            imgs, masks = imgs.to(device), masks.to(device)

            optimizer.zero_grad()
            preds = model(imgs)
            loss = criterion(preds, masks)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            train_acc += calculate_accuracy(preds, masks)

        avg_train_loss = train_loss / len(train_loader)
        avg_train_acc = train_acc / len(train_loader)

        # ----- Validation -----
        model.eval()
        val_loss, val_acc = 0.0, 0.0

        with torch.no_grad():
            for imgs, masks in tqdm(val_loader, desc="Validating", leave=False):
                imgs, masks = imgs.to(device), masks.to(device)
                preds = model(imgs)
                loss = criterion(preds, masks)

                val_loss += loss.item()
                val_acc += calculate_accuracy(preds, masks)

        avg_val_loss = val_loss / len(val_loader)
        avg_val_acc = val_acc / len(val_loader)

        # Save metrics
        history['train_loss'].append(avg_train_loss)
        history['val_loss'].append(avg_val_loss)
        history['train_acc'].append(avg_train_acc)
        history['val_acc'].append(avg_val_acc)

        print(f"üìâ Train Loss: {avg_train_loss:.4f} | üßæ Val Loss: {avg_val_loss:.4f}")
        print(f"‚úÖ Train Acc: {avg_train_acc*100:.2f}% | üîç Val Acc: {avg_val_acc*100:.2f}%")

        # ----- Save best model -----
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            save_path = os.path.join(save_dir, "unet_best.pth")
            torch.save(model.state_dict(), save_path)
            print(f"üíæ Best model saved at: {save_path}")

    print("\nüéØ Training complete!")

    # ----- Plot Loss & Accuracy -----
    epochs_range = range(1, epochs + 1)

    plt.figure(figsize=(12,5))
    # Loss
    plt.subplot(1,2,1)
    plt.plot(epochs_range, history['train_loss'], label='Train Loss')
    plt.plot(epochs_range, history['val_loss'], label='Val Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.title('Training & Validation Loss')
    plt.legend()
    # Accuracy
    plt.subplot(1,2,2)
    plt.plot(epochs_range, [a*100 for a in history['train_acc']], label='Train Acc')
    plt.plot(epochs_range, [a*100 for a in history['val_acc']], label='Val Acc')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy (%)')
    plt.title('Training & Validation Accuracy')
    plt.legend()
    plt.show()

    return history


In [None]:
import os
import cv2
import torch
import numpy as np
from tqdm import tqdm
from PIL import Image
from torchvision import transforms

def predict_single_image(model_path, image_path, device=None, output_path="output_mask.png"):
    device = device or ("cuda" if torch.cuda.is_available() else "cpu")

    # Load model
    model = UNet(in_channels=3, out_channels=1)
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.to(device)
    model.eval()

    # Preprocess image
    transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor()
    ])
    img = Image.open(image_path).convert("RGB")
    tensor = transform(img).unsqueeze(0).to(device)

    # Predict mask
    with torch.no_grad():
        pred = model(tensor)
        pred = torch.sigmoid(pred)  # convert logits ‚Üí probabilities
        pred_mask = (pred.squeeze().cpu().numpy() > 0.5).astype(np.uint8) * 255

    # Display input and mask
    import matplotlib.pyplot as plt
    plt.figure(figsize=(10,5))
    plt.subplot(1,2,1)
    plt.imshow(img)
    plt.title("Input Image")
    plt.axis("off")

    plt.subplot(1,2,2)
    plt.imshow(pred_mask, cmap="gray")
    plt.title("Predicted Mask")
    plt.axis("off")
    plt.show()

    # Save predicted mask
    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    cv2.imwrite(output_path, pred_mask)
    print(f"‚úÖ Predicted mask saved at: {output_path}")


In [None]:
if __name__ == "__main__":
    import os
    import torch
    from torch.utils.data import DataLoader

    # --- Device setup ---
    device = "cuda" if torch.cuda.is_available() else "cpu"
    # device="cpu"
    print(f"Using device: {device}")

    # --- Paths ---
    train_img_dir = "Data/train"
    train_mask_dir = "Data/train_mask"
    val_img_dir = "Data/validate"
    val_mask_dir = "Data/validate_mask"
    # model_save_path = "models/unet_model.pth"
    prediction_input_dir = "images"               # all combined tiles
    prediction_output_dir = "output/predicted_masks"

    # --- Ensure model/output dirs exist ---
    os.makedirs("models", exist_ok=True)
    os.makedirs(prediction_output_dir, exist_ok=True)

    # --- Load datasets ---
    train_dataset = MITBuildingsDataset(train_img_dir, train_mask_dir)
    val_dataset = MITBuildingsDataset(val_img_dir, val_mask_dir)

    train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False)

    # --- Initialize U-Net model ---
    model = UNet(in_channels=3, out_channels=1)

    # --- Train the model ---
    history=train_model(model, train_loader, val_loader, device, epochs=1, lr=1e-4)
    
