# Training Unet model on synthetic data and testing it on AI4Mars dataset

## Importing required libraries

In [None]:
import os
import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
import matplotlib.pyplot as plt


import albumentations as A
from albumentations.pytorch import ToTensorV2
from sklearn.model_selection import train_test_split

In [4]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", DEVICE)
print("CUDA version:", torch.version.cuda)
print("Torch version:", torch.__version__)
print("GPU:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "None")


Device: cuda
CUDA version: 12.8
Torch version: 2.9.0+cu128
GPU: NVIDIA GeForce RTX 4060 Laptop GPU


In [None]:
IMG_DIR = r"/home/uname/Unity-Mars-Simulation/My project/Assets/TerrainLayerImages/Images"
MASK_DIR = r"/home/uname/Unity-Mars-Simulation/My project/Assets/TerrainLayerImages/Labels"


## The colour maping is different in synthetic and real labels

In [6]:
# --- define color mappings ---

COLOR_MAP = {
    (255,   0,   0): 0,  # red - soil
    (0,   255,   0): 1,  # green - bedrock
    (255, 235,   4): 2,  # yellow (actual Unity export)
    (128, 128, 128): 3,  # gray - bigrock
    (0,     0, 255): 4,  # blue - sky
}


# For testing masks (0,0,0), (1,1,1), (2,2,2), (3,3,3), (255,255,255)
ALT_COLOR_MAP = {
    (0, 0, 0): 0,
    (1, 1, 1): 1,
    (2, 2, 2): 2,
    (3, 3, 3): 3,
    (255, 255, 255): 4
}


def rgb_to_class(mask, color_map):
    """
    Convert RGB mask to single-channel class index mask.
    mask: HxWx3 numpy array
    color_map: dict mapping (r,g,b) -> class_id
    """
    h, w, _ = mask.shape
    class_mask = np.zeros((h, w), dtype=np.int64)
    for rgb, idx in color_map.items():
        matches = np.all(mask == rgb, axis=-1)
        class_mask[matches] = idx
    return class_mask


## Geting the Unity Scenes Dataset

In [11]:
class MarsTerrainDataset(Dataset):
    def __init__(self, img_dir, mask_dir, transform=None, use_alt_map=False):
        self.img_dir = img_dir
        self.mask_dir = mask_dir
        self.transform = transform

        def list_images(folder):
            valid_exts = (".png", ".jpg", ".jpeg", ".bmp", ".tif", ".tiff")
            return sorted([f for f in os.listdir(folder) if f.lower().endswith(valid_exts)])

        self.images = list_images(img_dir)
        self.masks = list_images(mask_dir)



        self.color_map = ALT_COLOR_MAP if use_alt_map else COLOR_MAP
        assert len(self.images) == len(self.masks), "Mismatch between images and masks"

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.images[idx])
        mask_path = os.path.join(self.mask_dir, self.masks[idx])

        image = cv2.imread(img_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        mask = cv2.imread(mask_path)
        mask = cv2.cvtColor(mask, cv2.COLOR_BGR2RGB)
        mask = rgb_to_class(mask, self.color_map)

        if self.transform:
            augmented = self.transform(image=image, mask=mask)
            image = augmented["image"]
            mask = augmented["mask"]

        return image, mask


In [12]:
all_images = sorted(os.listdir(IMG_DIR))
train_imgs, val_imgs = train_test_split(all_images, test_size=0.2, random_state=42)
print(f"Training: {len(train_imgs)}, Validation: {len(val_imgs)}")


Training: 1600, Validation: 400


In [13]:
train_transforms = A.Compose([
    A.RandomCrop(512, 512),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.2),
    A.RandomRotate90(p=0.5),
    A.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.02, p=0.5),
    A.Normalize(mean=(0.0,0.0,0.0), std=(1.0,1.0,1.0)),
    ToTensorV2()
])

val_transforms = A.Compose([
    A.Resize(1024, 1024),
    A.Normalize(mean=(0.0,0.0,0.0), std=(1.0,1.0,1.0)),
    ToTensorV2()
])


In [14]:
train_dataset = MarsTerrainDataset(IMG_DIR, MASK_DIR, transform=train_transforms)
val_dataset   = MarsTerrainDataset(IMG_DIR, MASK_DIR, transform=val_transforms)

train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=2)
val_loader   = DataLoader(val_dataset, batch_size=2, shuffle=False, num_workers=2)


In [None]:
class DoubleConv(nn.Module):
    """(Conv2d ‚Üí BN ‚Üí ReLU) √ó 2"""
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.double_conv(x)


class UNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=5, features=[64, 128, 256, 512]):
        super().__init__()
        self.downs = nn.ModuleList()
        self.ups = nn.ModuleList()
        self.pool = nn.MaxPool2d(2, 2)

        # Down part
        for feat in features:
            self.downs.append(DoubleConv(in_channels, feat))
            in_channels = feat

        # Up part (reverse)
        for feat in reversed(features):
            self.ups.append(
                nn.ConvTranspose2d(feat * 2, feat, kernel_size=2, stride=2)
            )
            self.ups.append(DoubleConv(feat * 2, feat))

        # Bottleneck
        self.bottleneck = DoubleConv(features[-1], features[-1] * 2)

        # Final output layer
        self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)

    def forward(self, x):
        skip_connections = []
        for down in self.downs:
            x = down(x)
            skip_connections.append(x)
            x = self.pool(x)

        x = self.bottleneck(x)
        skip_connections = skip_connections[::-1]

        for idx in range(0, len(self.ups), 2):
            x = self.ups[idx](x)
            skip_connection = skip_connections[idx // 2]

            if x.shape != skip_connection.shape:
                x = F.interpolate(x, size=skip_connection.shape[2:])

            concat_skip = torch.cat((skip_connection, x), dim=1)
            x = self.ups[idx + 1](concat_skip)

        return self.final_conv(x)


In [16]:
NUM_CLASSES = 5
model = UNet(in_channels=3, out_channels=NUM_CLASSES).to(DEVICE)
print("Model initialized on:", DEVICE)


Model initialized on: cuda


## Defining Loss

In [None]:
criterion = nn.CrossEntropyLoss()              # for multi-class segmentation
optimizer = optim.Adam(model.parameters(), lr=1e-3)


In [None]:
def train_fn(loader, model, optimizer, loss_fn, device):
    model.train()
    loop = tqdm(loader, leave=False)
    total_loss = 0

    for batch_idx, (data, targets) in enumerate(loop):
        data = data.to(device)
        targets = targets.long().to(device)

        # forward
        preds = model(data)
        loss = loss_fn(preds, targets)
        total_loss += loss.item()

        # backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        loop.set_postfix(loss=loss.item())

    return total_loss / len(loader)


def validate_fn(loader, model, loss_fn, device):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for data, targets in loader:
            data = data.to(device)
            targets = targets.long().to(device)
            preds = model(data)
            loss = loss_fn(preds, targets)
            total_loss += loss.item()
    return total_loss / len(loader)


## Training model

In [None]:

from torch.utils.tensorboard import SummaryWriter
import numpy as np
import os

# --- setup ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# Model loading
model = model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

# --- logging and checkpoint dirs ---
log_dir = "./runs/unet_training"
os.makedirs(log_dir, exist_ok=True)
writer = SummaryWriter(log_dir)

best_val_loss = np.inf
patience = 5
counter = 0
num_epochs = 50

save_path = "best_unet_model.pth"

# --- training loop ---
for epoch in range(num_epochs):
    model.train()
    train_loss = 0.0

    loop = tqdm(train_loader, desc=f"Epoch [{epoch+1}/{num_epochs}]", leave=False)
    for images, masks in loop:
        images, masks = images.to(device, dtype=torch.float32), masks.to(device, dtype=torch.long)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, masks)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()

    train_loss /= len(train_loader)

    # --- validation ---
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for images, masks in val_loader:
            images, masks = images.to(device, dtype=torch.float32), masks.to(device, dtype=torch.long)
            outputs = model(images)
            loss = criterion(outputs, masks)
            val_loss += loss.item()

    val_loss /= len(val_loader)

    print(f"Epoch [{epoch+1}/{num_epochs}] | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")
    writer.add_scalars("Loss", {"Train": train_loss, "Val": val_loss}, epoch)

    # --- early stopping ---
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        counter = 0
        torch.save(model.state_dict(), save_path)
        print(f"‚úÖ Validation improved; model saved to {save_path}")
    else:
        counter += 1
        print(f"‚ö†Ô∏è No improvement for {counter} epochs.")
        if counter >= patience:
            print("‚èπ Early stopping triggered.")
            break

writer.close()

Using device: cuda


                                                               

Epoch [1/50] | Train Loss: 1.3169 | Val Loss: 2.7051
‚úÖ Validation improved; model saved to best_unet_model.pth


                                                               

Epoch [2/50] | Train Loss: 1.2574 | Val Loss: 1.4072
‚úÖ Validation improved; model saved to best_unet_model.pth


                                                               

Epoch [3/50] | Train Loss: 1.2320 | Val Loss: 2.2951
‚ö†Ô∏è No improvement for 1 epochs.


                                                               

Epoch [4/50] | Train Loss: 1.2352 | Val Loss: 2.0103
‚ö†Ô∏è No improvement for 2 epochs.


                                                               

Epoch [5/50] | Train Loss: 1.2300 | Val Loss: 1.2563
‚úÖ Validation improved; model saved to best_unet_model.pth


                                                               

Epoch [6/50] | Train Loss: 1.2126 | Val Loss: 1.5633
‚ö†Ô∏è No improvement for 1 epochs.


                                                               

Epoch [7/50] | Train Loss: 1.2337 | Val Loss: 1.4939
‚ö†Ô∏è No improvement for 2 epochs.


                                                               

Epoch [8/50] | Train Loss: 1.2168 | Val Loss: 1.3838
‚ö†Ô∏è No improvement for 3 epochs.


                                                               

Epoch [9/50] | Train Loss: 1.1980 | Val Loss: 1.5552
‚ö†Ô∏è No improvement for 4 epochs.


                                                                

Epoch [10/50] | Train Loss: 1.1879 | Val Loss: 1.2752
‚ö†Ô∏è No improvement for 5 epochs.
‚èπ Early stopping triggered.


## Now that training is done, we move on to testing on AI4Mars dataset

In [None]:
import os

# ------------------------------------------------
# CONFIG
# ------------------------------------------------
IMAGE_DIR = "/home/uname/Unity-Mars-Simulation/training-and-testing/ai4mars-dataset-merged-0.1/msl/images/edr"
LABEL_DIRS = [
    "/home/uname/Unity-Mars-Simulation/training-and-testing/ai4mars-dataset-merged-0.1/msl/labels/train",
    "/home/uname/Unity-Mars-Simulation/training-and-testing/ai4mars-dataset-merged-0.1/msl/labels/test/masked-gold-min3-100agree",
]

# ------------------------------------------------
# FAST CHECK
# ------------------------------------------------
image_files = [f for f in os.listdir(IMAGE_DIR) if f.lower().endswith(('.jpg', '.png'))]
print(f"‚úÖ Found {len(image_files)} total images in EDR folder.")

# Collect all label filenames from both folders
label_files = set()
for ld in LABEL_DIRS:
    if not os.path.exists(ld):
        print(f"‚ö†Ô∏è Warning: Label directory not found: {ld}")
        continue
    files = [f for f in os.listdir(ld) if f.lower().endswith('.png')]
    print(f"  ‚Ä¢ {len(files)} label files in {os.path.basename(ld)}")
    label_files.update(files)

print(f"‚úÖ Total unique label files: {len(label_files)}")

# Match by base name (ignore extension)
matches = []
for img in image_files:
    base = os.path.splitext(img)[0]
    for suffix in ["_merged.png", ".png"]:
        if base + suffix in label_files:
            matches.append(img)
            break

print(f"\n‚úÖ Found {len(matches)} matching image‚Äìlabel pairs.")
if matches:
    print("‚úÖ Example matches:")
    for m in matches[:10]:
        print("  ", m)
else:
    print("‚ùå No matches found ‚Äî check naming conventions (e.g. '_merged' suffix).")


‚úÖ Found 18127 total images in EDR folder.
  ‚Ä¢ 16064 label files in train
  ‚Ä¢ 322 label files in masked-gold-min3-100agree
‚úÖ Total unique label files: 16386

‚úÖ Found 16386 matching image‚Äìlabel pairs.
‚úÖ Example matches:
   NLB_546279320EDR_F0621530NCAM07753M1.JPG
   NLB_499420274EDR_F0501116NCAM00289M1.JPG
   NLB_620669790EDR_F0763002NCAM00341M1.JPG
   NLB_546462602EDR_F0622026NCAM00260M1.JPG
   NLB_509001376EDR_F0522668NCAM07753M1.JPG
   NLA_408594521EDR_F0051398NCAM05134M1.JPG
   NLB_485828380EDR_F0481530NCAM07753M1.JPG
   NLB_468426251EDR_F0441140NCAM00353M1.JPG
   NLB_559954945EDR_F0660952NCAM00270M1.JPG
   NLB_616601909EDR_F0762194NCAM00259M1.JPG


In [None]:
# ------------------------------------------------
# CONFIG
# ------------------------------------------------
IMAGE_DIR = "/home/uname/Unity-Mars-Simulation/training-and-testing/ai4mars-dataset-merged-0.1/msl/images/edr"
LABEL_DIRS = [
    "/home/uname/Unity-Mars-Simulation/training-and-testing/ai4mars-dataset-merged-0.1/msl/labels/train",
    "/home/uname/Unity-Mars-Simulation/training-and-testing/ai4mars-dataset-merged-0.1/msl/labels/test/masked-gold-min3-100agree",
]
MODEL_PATH = "best_unet_model.pth"
NUM_CLASSES = 5
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
PROGRESS_FILE = "progress_log.txt"

# ------------------------------------------------
# DATASET CLASS
# ------------------------------------------------
class MarsEvalDataset(Dataset):
    def __init__(self, image_dir, label_dirs):
        self.samples = []
        for f in os.listdir(image_dir):
            if not f.lower().endswith((".jpg", ".png")):
                continue
            base = os.path.splitext(f)[0]

            mask_path = None
            for ld in label_dirs:
                for suffix in ["_merged.png", ".png"]:
                    candidate = os.path.join(ld, base + suffix)
                    if os.path.exists(candidate):
                        mask_path = candidate
                        break
                if mask_path:
                    break

            if mask_path:
                self.samples.append((os.path.join(image_dir, f), mask_path))

        print(f"‚úÖ Found {len(self.samples)} matching image-label pairs out of {len(os.listdir(image_dir))} images")

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path, mask_path = self.samples[idx]
        img = cv2.imread(img_path)
        if img is None:
            raise ValueError(f"Could not read {img_path}")
        if img.ndim == 2:
            img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
        else:
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = cv2.resize(img, (1024, 1024))
        img_tensor = torch.tensor(img).permute(2, 0, 1).float() / 255.0

        gt = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
        gt = cv2.resize(gt, (1024, 1024), interpolation=cv2.INTER_NEAREST)
        gt[gt == 255] = 4   # ‚úÖ FIX: map 255 ‚Üí class 4
        return img_tensor, torch.tensor(gt, dtype=torch.long), os.path.basename(img_path)

# ------------------------------------------------
# METRICS
# ------------------------------------------------
def compute_iou_and_acc(pred, gt, num_classes):
    ious = []
    acc = (pred == gt).sum() / np.prod(gt.shape)
    for c in range(num_classes):
        inter = np.logical_and(pred == c, gt == c).sum()
        union = np.logical_or(pred == c, gt == c).sum()
        if union > 0:
            ious.append(inter / union)
    return np.mean(ious), acc

# ------------------------------------------------
# MODEL LOAD
# ------------------------------------------------
model = UNet(in_channels=3, out_channels=NUM_CLASSES).to(DEVICE)
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
model.eval()

# ------------------------------------------------
# DATA LOAD
# ------------------------------------------------
dataset = MarsEvalDataset(IMAGE_DIR, LABEL_DIRS)

# ‚ö° QUICK TEST MODE (20 samples only)
# dataset.samples = dataset.samples[:20]
print(f"‚ö° Running quick evaluation on {len(dataset)} samples only.")

loader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=2)

# ------------------------------------------------
# EVALUATION LOOP
# ------------------------------------------------
iou_scores = []
acc_scores = []

if os.path.exists(PROGRESS_FILE):
    os.remove(PROGRESS_FILE)

with torch.no_grad():
    for idx, (imgs, gts, names) in enumerate(tqdm(loader, desc="Evaluating")):
        imgs = imgs.to(DEVICE)
        preds = model(imgs)
        preds = torch.argmax(F.softmax(preds, dim=1), dim=1).cpu().numpy()[0]
        gt = gts.squeeze().numpy()

        iou, acc = compute_iou_and_acc(preds, gt, NUM_CLASSES)
        iou_scores.append(iou)
        acc_scores.append(acc)

        # Save progress every 10 samples ‚úÖ
        if (idx + 1) % 10 == 0:
            mean_iou = float(np.mean(iou_scores))
            mean_acc = float(np.mean(acc_scores))
            with open(PROGRESS_FILE, "a") as f:
                f.write(f"{idx+1},{mean_iou:.4f},{mean_acc:.4f}\n")

print("\n‚úÖ Evaluation complete")
print(f"Samples evaluated: {len(iou_scores)}")
if len(iou_scores) > 0:
    print(f"Mean IoU: {np.mean(iou_scores):.4f}")
    print(f"Pixel Accuracy: {np.mean(acc_scores):.4f}")
    print(f"üìÑ Progress saved in: {PROGRESS_FILE}")
else:
    print("‚ùå No matching samples evaluated ‚Äî check folder structure again.")


‚úÖ Found 16386 matching image-label pairs out of 18130 images
‚ö° Running quick evaluation on 16386 samples only.


Evaluating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 16386/16386 [1:03:38<00:00,  4.29it/s]


‚úÖ Evaluation complete
Samples evaluated: 16386
Mean IoU: 0.1002
Pixel Accuracy: 0.2333
üìÑ Progress saved in: progress_log.txt



