<a href="https://colab.research.google.com/github/pushkar-hue/MarsSimNav/blob/main/mars_terrain_v1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os
import requests
import zipfile
from tqdm.notebook import tqdm

url = "https://data.nasa.gov/docs/legacy/ai4mars-dataset-merged-0.1.zip"
zip_path = "ai4mars-dataset.zip"
extract_dir = "ai4mars-dataset"

# Download with progress bar
def download_file(url, dest_path):
    response = requests.get(url, stream=True)
    total_size = int(response.headers.get('content-length', 0))
    with open(dest_path, 'wb') as f, tqdm(
        desc=dest_path,
        total=total_size,
        unit='iB',
        unit_scale=True,
        unit_divisor=1024,
    ) as bar:
        for chunk in response.iter_content(chunk_size=1024):
            size = f.write(chunk)
            bar.update(size)

# Download if not already done
if not os.path.exists(zip_path):
    print("Downloading dataset...")
    download_file(url, zip_path)
else:
    print("ZIP file already downloaded.")

# Unzip
if not os.path.exists(extract_dir):
    print("Extracting...")
    with zipfile.ZipFile(zip_path, 'r') as zip_ref:
        zip_ref.extractall(extract_dir)
    print("Done.")
else:
    print("Dataset already extracted.")


In [None]:
import os

def walk_through(dir_path):
  for dirpath, dirnames, filenames in os.walk(dir_path):
    print(f"There are {len(dirnames)} directories and {len(filenames)} images in '{dirpath}'.")

In [None]:
#walk_through(extract_dir)

In [None]:
import os
import shutil
import random
from tqdm.notebook import tqdm

# Input paths
root_dir = "ai4mars-dataset/ai4mars-dataset-merged-0.1/msl"
images_dir = os.path.join(root_dir, "images", "edr")
labels_dir = os.path.join(root_dir, "labels", "train")

# Output paths
subset_dir = "ai4mars-subset"
subset_images = os.path.join(subset_dir, "images")
subset_labels = os.path.join(subset_dir, "labels")

# Make output dirs
os.makedirs(subset_images, exist_ok=True)
os.makedirs(subset_labels, exist_ok=True)

# Build base filename sets
image_dict = {f.rsplit(".", 1)[0]: f for f in os.listdir(images_dir) if f.lower().endswith(".jpg")}
label_dict = {f.rsplit(".", 1)[0]: f for f in os.listdir(labels_dir) if f.lower().endswith(".png")}

# Intersection of base names
common_basenames = sorted(set(image_dict.keys()) & set(label_dict.keys()))
print(f"✅ Found {len(common_basenames)} matched image-label pairs.")

# Select up to 5000
subset_size = min(5000, len(common_basenames))
subset_basenames = random.sample(common_basenames, subset_size)

# Copy matched files
print(f"Copying {subset_size} pairs to 'ai4mars-subset/'...")
for base in tqdm(subset_basenames):
    shutil.copy(os.path.join(images_dir, image_dict[base]), os.path.join(subset_images, image_dict[base]))
    shutil.copy(os.path.join(labels_dir, label_dict[base]), os.path.join(subset_labels, label_dict[base]))

print("Subset creation complete!")


In [None]:
walk_through(subset_dir)

In [None]:
import os
import random
import cv2
import matplotlib.pyplot as plt
import numpy as np

# Paths
subset_images = "ai4mars-subset/images"
subset_labels = "ai4mars-subset/labels"

# Get all matching base filenames
image_files = sorted([f for f in os.listdir(subset_images) if f.lower().endswith(".jpg")])
label_files = sorted([f for f in os.listdir(subset_labels) if f.lower().endswith(".png")])
base_names = sorted(set(f.rsplit('.', 1)[0] for f in image_files) &
                    set(f.rsplit('.', 1)[0] for f in label_files))

# Pick a few random pairs
sample_bases = random.sample(base_names, 3)

# Label class mapping (visual only)
label_colors = {
    0: [0, 0, 0],        # soil
    1: [100, 100, 100],  # bedrock
    2: [255, 255, 0],    # sand
    3: [255, 0, 0],      # big rock
    255: [255, 255, 255] # null (white)
}

def decode_label_mask(mask):
    decoded = np.zeros((*mask.shape[:2], 3), dtype=np.uint8)
    for val, color in label_colors.items():
        decoded[(mask == val).all(axis=-1)] = color
    return decoded

# Display
for base in sample_bases:
    img_path = os.path.join(subset_images, base + ".JPG")
    lbl_path = os.path.join(subset_labels, base + ".png")

    img = cv2.imread(img_path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

    label = cv2.imread(lbl_path)
    label = decode_label_mask(label)

    plt.figure(figsize=(12, 4))

    plt.subplot(1, 3, 1)
    plt.imshow(img)
    plt.title("Image")
    plt.axis("off")

    plt.subplot(1, 3, 2)
    plt.imshow(label)
    plt.title("Decoded Label")
    plt.axis("off")

    plt.subplot(1, 3, 3)
    overlay = cv2.addWeighted(img, 0.7, label, 0.3, 0)
    plt.imshow(overlay)
    plt.title("Overlay")
    plt.axis("off")

    plt.show()


In [None]:
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image
import os
import numpy as np
import torch

In [None]:
RGB_CLASS_MAP = {
    (0, 0, 0): 0,           # soil
    (1, 1, 1): 1,           # bedrock
    (2, 2, 2): 2,           # sand
    (3, 3, 3): 3,           # big rock
    (255, 255, 255): 255    # ignore
}

In [None]:
def encode_mask(mask_img):
    mask_np = np.array(mask_img)
    h, w, _ = mask_np.shape
    label = np.ones((h, w), dtype=np.uint8) * 255  # default = ignore
    for rgb, idx in RGB_CLASS_MAP.items():
        matches = np.all(mask_np == rgb, axis=-1)
        label[matches] = idx
    return label

In [None]:

class MarsSegmentationDataset(Dataset):
    def __init__(self, image_dir, label_dir, transform=None, img_size=256):
        self.image_dir = image_dir
        self.label_dir = label_dir
        self.images = sorted([f for f in os.listdir(image_dir) if f.endswith(".JPG")])
        self.transform = transform
        self.img_size = img_size

        self.image_transform = transforms.Compose([
            transforms.Resize((img_size, img_size)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
      img_name = self.images[idx]
      base_name = img_name.rsplit(".", 1)[0]
      label_name = base_name + ".png"

      # Load and resize
      img = Image.open(os.path.join(self.image_dir, img_name)).convert("RGB")
      label_img = Image.open(os.path.join(self.label_dir, label_name)).convert("RGB")
      label_img = label_img.resize((self.img_size, self.img_size), resample=Image.NEAREST)

      # Encode label
      label = encode_mask(label_img)

      # Apply transforms
      img = self.image_transform(img)
      label = torch.from_numpy(label).long()

      return img, label


In [None]:
from torch.utils.data import random_split, DataLoader
# Full dataset (your 5k subset)
full_dataset = MarsSegmentationDataset(
    image_dir="ai4mars-subset/images",
    label_dir="ai4mars-subset/labels",
    img_size=256
)

# Split sizes
train_size = int(0.8 * len(full_dataset))
test_size = len(full_dataset) - train_size

# Random split
train_dataset, test_dataset = random_split(full_dataset, [train_size, test_size])

# Loaders
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False)

train_loader, test_loader

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

In [None]:
import torchvision
import torch.nn as nn

# Load model
model = torchvision.models.segmentation.deeplabv3_resnet50(pretrained=True)

# Replace classifier with 4 output classes (soil, bedrock, sand, big rock)
model.classifier = torchvision.models.segmentation.deeplabv3.DeepLabHead(2048, 4)


In [None]:
model = model.to(device)

In [None]:
import torch.optim as optim

# Loss: CrossEntropy with ignore_index
criterion = nn.CrossEntropyLoss(ignore_index=255)

# Optimizer
optimizer = optim.Adam(model.parameters(), lr=1e-4)

scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)

In [None]:
from tqdm.notebook import tqdm

def train_one_epoch(model, loader, optimizer, criterion):
    model.train()
    running_loss = 0.0

    for images, masks in tqdm(loader, desc="Training"):
        images = images.to(device)
        masks = masks.to(device)

        optimizer.zero_grad()
        outputs = model(images)['out']
        loss = criterion(outputs, masks)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    return running_loss / len(loader)


In [None]:
def evaluate(model, loader, criterion):
    model.eval()
    total_loss = 0.0
    correct = 0
    total = 0

    with torch.inference_mode():
        for images, masks in loader:
            images = images.to(device)
            masks = masks.to(device)

            outputs = model(images)['out']
            loss = criterion(outputs, masks)
            total_loss += loss.item()

            preds = outputs.argmax(1)
            valid = masks != 255
            correct += (preds[valid] == masks[valid]).sum().item()
            total += valid.sum().item()

    accuracy = correct / total
    return total_loss / len(loader), accuracy


In [None]:
EPOCHS = 3

for epoch in range(EPOCHS):
    train_loss = train_one_epoch(model, train_loader, optimizer, criterion)
    val_loss, val_acc = evaluate(model, test_loader, criterion)

    print(f"[{epoch+1}/{EPOCHS}] Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | Val Acc: {val_acc*100:.2f}%")
    scheduler.step()


In [None]:
torch.save(model.state_dict(), "deeplabv3_mars.pth")


In [None]:
from google.colab import files
files.download("deeplabv3_mars.pth")
