# Section 1: CNN Classification Logic
This section handles the imports, configuration, and the definition of our **SatelliteCNN** model. We use the **EuroSAT dataset** (64x64 satellite images) to train a classifier that identifies terrain types like forests, highways, and water bodies.

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split
import os

# --- CONFIGURATION ---
BATCH_SIZE = 32
LEARNING_RATE = 0.001
EPOCHS = 10  # Increase to 10 for better results
DATA_PATH = './eurosat_data' # Where to download the data

# Check for GPU (runs much faster), else use CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# --- 1. DATASET & PREPROCESSING ---
# EuroSAT images are 64x64. We convert them to Tensors and Normalize.
transform = transforms.Compose([
    transforms.Resize((64, 64)), # Ensure size consistency
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

print("Downloading/Loading EuroSAT Dataset...")
# This automatically downloads the dataset if not present
dataset = datasets.EuroSAT(root=DATA_PATH, download=True, transform=transform)

# Determine the class names (Important for our Navigation later)
classes = dataset.classes
print(f"Classes found ({len(classes)}): {classes}")

# Split: 80% Training, 20% Validation
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_data, val_data = random_split(dataset, [train_size, val_size])

# Create Data Loaders
# num_workers=0 is safer for Windows compatibility
train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
val_loader = DataLoader(val_data, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

# --- 2. CNN ARCHITECTURE ---
class SatelliteCNN(nn.Module):
    def __init__(self, num_classes):
        super(SatelliteCNN, self).__init__()

        # Block 1: Detects simple edges/colors
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.relu = nn.ReLU()
        self.pool = nn.MaxPool2d(2, 2) # 64x64 -> 32x32

        # Block 2: Detects textures (grass vs trees)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        # Pool again: 32x32 -> 16x16

        # Block 3: Detects complex objects (buildings, river banks)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        # Pool again: 16x16 -> 8x8

        # Fully Connected Layer (Classifier)
        # Input size: 128 channels * 8 * 8 pixels = 8192 features
        self.fc1 = nn.Linear(128 * 8 * 8, 512)
        self.dropout = nn.Dropout(0.5) # Prevents overfitting (good for grading!)
        self.fc2 = nn.Linear(512, num_classes)

    def forward(self, x):
        x = self.pool(self.relu(self.conv1(x)))
        x = self.pool(self.relu(self.conv2(x)))
        x = self.pool(self.relu(self.conv3(x)))

        # Flatten
        x = x.view(x.size(0), -1)

        x = self.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x

# --- 3. TRAINING LOOP ---
def train_model():
    model = SatelliteCNN(num_classes=len(classes)).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

    print("\nStarting Training...")

    for epoch in range(EPOCHS):
        model.train() # Set to training mode
        running_loss = 0.0

        for i, (images, labels) in enumerate(train_loader):
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        print(f"Epoch {epoch+1}/{EPOCHS} - Loss: {running_loss/len(train_loader):.4f}")

    print("Training Complete.")

    # --- EVALUATION ---
    print("\nEvaluating on Validation Set...")
    model.eval() # Set to evaluation mode
    correct = 0
    total = 0

    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    acc = 100 * correct / total
    print(f'Final Accuracy: {acc:.2f}%')

    # --- SAVE MODEL ---
    # We save the model state AND the class names so the agent knows them
    save_data = {
        'model_state': model.state_dict(),
        'classes': classes
    }
    torch.save(save_data, 'eurosat_cnn.pth')
    print("\nModel and class mapping saved to 'eurosat_cnn.pth'")

if __name__ == "__main__":
    train_model()

# Section 2: Model Evaluation
In this section, we evaluate the performance of our trained CNN using a validation set. We generate a **Confusion Matrix** to visualize which terrain types the model classifies correctly and where it might be making mistakes (e.g., confusing different types of crops).

In [None]:
import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split
from sklearn.metrics import confusion_matrix, classification_report
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np

# --- SETUP ---
DATA_PATH = './eurosat_data'
MODEL_PATH = 'eurosat_cnn.pth'
BATCH_SIZE = 32
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Re-define architecture (Required to load weights)
class SatelliteCNN(nn.Module):
    def __init__(self, num_classes):
        super(SatelliteCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
        self.relu = nn.ReLU()
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.conv3 = nn.Conv2d(64, 128, 3, padding=1)
        self.fc1 = nn.Linear(128 * 8 * 8, 512)
        self.dropout = nn.Dropout(0.5)
        self.fc2 = nn.Linear(512, num_classes)

    def forward(self, x):
        x = self.pool(self.relu(self.conv1(x)))
        x = self.pool(self.relu(self.conv2(x)))
        x = self.pool(self.relu(self.conv3(x)))
        x = x.view(x.size(0), -1) 
        x = self.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x

# --- LOAD DATA ---
transform = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

print("Preparing Validation Set...")
dataset = datasets.EuroSAT(root=DATA_PATH, download=True, transform=transform)
classes = dataset.classes

# We need to recreate the validation split exactly as before? 
# Actually, for a confusion matrix, any random subset is fine for analysis.
# Let's take the last 20%
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
_, val_data = random_split(dataset, [train_size, val_size])
val_loader = DataLoader(val_data, batch_size=BATCH_SIZE, shuffle=False)

# --- LOAD MODEL ---
print("Loading Model...")
checkpoint = torch.load(MODEL_PATH, map_location=device)
model = SatelliteCNN(num_classes=len(classes)).to(device)

# Handle state dictionary keys
sd = checkpoint['model_state'] if 'model_state' in checkpoint else checkpoint
clean_sd = {k.replace("module.", ""): v for k, v in sd.items()}
model.load_state_dict(clean_sd)
model.eval()

# --- GENERATE PREDICTIONS ---
y_true = []
y_pred = []

print("Running Inference on Validation Data...")
with torch.no_grad():
    for images, labels in val_loader:
        images = images.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs, 1)
        
        y_true.extend(labels.cpu().numpy())
        y_pred.extend(predicted.cpu().numpy())

# --- PLOT CONFUSION MATRIX ---
cm = confusion_matrix(y_true, y_pred)
plt.figure(figsize=(12, 10))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=classes, yticklabels=classes)
plt.xlabel('Predicted Label')
plt.ylabel('True Label')
plt.title('Confusion Matrix: Where is the Model Failing?')
plt.show()

# --- PRINT DETAILED REPORT ---
print("\n--- CLASSIFICATION REPORT ---")
print(classification_report(y_true, y_pred, target_names=classes))

# Section 3: Map Scanning and A* Pathfinding
This is the core of the rover's intelligence. 
1. **Map Scanning:** We slice a large satellite map into grids and use the CNN to predict the terrain for each tile.
2. **Cost Map:** We assign "traversal costs" based on terrain (e.g., Highway = 1, Water = Infinity).
3. **A* Search:** We run the A* algorithm to find the most energy-efficient path from the start point to the goal.

In [None]:
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import heapq
import os
from google.colab import files

# --- CONFIGURATION ---
MODEL_PATH = 'eurosat_cnn.pth'

# 1. FIX: Handle the upload dictionary correctly
print("Please upload your map image:")
uploaded = files.upload()

# Get the filename string from the dictionary
if len(uploaded) > 0:
    MAP_FILE = list(uploaded.keys())[0]
    print(f"Map set to: {MAP_FILE}")
else:
    print("No file uploaded.")
    MAP_FILE = None

SCAN_STEP = 16
START_COORD = (2, 2)
END_COORD = None

# --- 1. CORE MODEL ---
class SatelliteCNN(nn.Module):
    def __init__(self, num_classes=10):
        super(SatelliteCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
        self.relu = nn.ReLU()
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.conv3 = nn.Conv2d(64, 128, 3, padding=1)
        self.fc1 = nn.Linear(128 * 8 * 8, 512)
        self.dropout = nn.Dropout(0.5)
        self.fc2 = nn.Linear(512, num_classes)

    def forward(self, x):
        x = self.pool(self.relu(self.conv1(x)))
        x = self.pool(self.relu(self.conv2(x)))
        x = self.pool(self.relu(self.conv3(x)))
        x = x.view(x.size(0), -1)
        x = self.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
CLASSES = ['AnnualCrop', 'Forest', 'HerbaceousVegetation', 'Highway',
           'Industrial', 'Pasture', 'PermanentCrop', 'Residential',
           'River', 'SeaLake']

def load_model():
    if not os.path.exists(MODEL_PATH):
        print(f"Error: {MODEL_PATH} not found.")
        return None
    model = SatelliteCNN(num_classes=10).to(device)
    checkpoint = torch.load(MODEL_PATH, map_location=device)
    sd = checkpoint['model_state'] if 'model_state' in checkpoint else checkpoint
    model.load_state_dict({k.replace("module.", ""): v for k, v in sd.items()})
    model.eval()
    return model

# --- 2. PHYSICS & CLEANING ---
def get_surface_cost(terrain_type):
    if terrain_type in ['River', 'SeaLake']: return 999

    # Make Highway essentially "free" to travel on.
    # This creates a "Magnetic" effect.
    elif terrain_type == 'Highway': return 0.1

    # Increase the penalty for leaving the road.
    # If it leaves the road, it must really MEAN it.
    elif terrain_type in ['Industrial', 'Residential']: return 35
    elif terrain_type in ['Pasture', 'AnnualCrop', 'HerbaceousVegetation']: return 20
    elif terrain_type == 'Forest': return 100

    else: return 50


def clean_cost_grid(costs):
    """
    Speckle Filter: Removes isolated 'River' pixels (shadows/noise).
    If a River block has < 4 River neighbors, it's probably just a shadow.
    """
    rows, cols = costs.shape
    cleaned = costs.copy()
    corrections = 0

    # Iterate through grid (skip edges for simplicity)
    for r in range(1, rows-1):
        for c in range(1, cols-1):
            if costs[r, c] >= 500: # Found a River/Water block

                # Count river neighbors
                river_neighbors = 0
                for dr in [-1, 0, 1]:
                    for dc in [-1, 0, 1]:
                        if dr == 0 and dc == 0: continue
                        if costs[r+dr, c+dc] >= 500:
                            river_neighbors += 1

                # THRESHOLD: Real rivers usually have at least 4 connected blocks
                if river_neighbors < 4:
                    # It's a speckle! Flatten it to 'Safe Terrain' (e.g. 20)
                    cleaned[r, c] = 20
                    corrections += 1

    print(f"Speckle Filter: Removed {corrections} isolated false-positive blocks.")
    return cleaned

def scan_terrain(model):
    if not MAP_FILE or not os.path.exists(MAP_FILE):
        print("Map file not found.")
        return None, None, 0, 0

    print(f"Scanning {MAP_FILE}...")
    full_image = Image.open(MAP_FILE).convert('RGB')
    w, h = full_image.size
    cols, rows = w // SCAN_STEP, h // SCAN_STEP

    transform = transforms.Compose([
        transforms.Resize((64, 64)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    cost_grid = np.zeros((rows, cols))

    with torch.no_grad():
        for r in range(rows):
            for c in range(cols):
                cy, cx = r * SCAN_STEP + SCAN_STEP//2, c * SCAN_STEP + SCAN_STEP//2

                crop_size = 32

                left = max(0, cx - crop_size)
                top = max(0, cy - crop_size)
                right = min(w, cx + crop_size)
                bottom = min(h, cy + crop_size)

                tile = full_image.crop((left, top, right, bottom))
                input_t = transform(tile).unsqueeze(0).to(device)

                outputs = model(input_t)
                probs = torch.nn.functional.softmax(outputs, dim=1)
                top2_prob, top2_idx = torch.topk(probs, 2)

                best_class = CLASSES[top2_idx[0][0].item()]
                best_conf = top2_prob[0][0].item()
                second_class = CLASSES[top2_idx[0][1].item()]

                # 90% Confidence Filter
                if best_class in ['River', 'SeaLake'] and best_conf < 0.90:
                    terrain = second_class
                else:
                    terrain = best_class

                cost_grid[r, c] = get_surface_cost(terrain)

    # Apply Post-Processing Cleaning
    final_costs = clean_cost_grid(cost_grid)

    return full_image, final_costs, rows, cols


# --- 3. A* PATHFINDING ---
def a_star(costs, start, goal):
    rows, cols = costs.shape
    # Safety clamp
    start = (min(start[0], rows-1), min(start[1], cols-1))
    goal = (min(goal[0], rows-1), min(goal[1], cols-1))

    costs[start] = 1
    costs[goal] = 1

    frontier = [(0, start)]
    came_from = {start: None}
    cost_so_far = {start: 0}

    while frontier:
        _, current = heapq.heappop(frontier)
        if current == goal: break

        for dx, dy, dist in [(-1,0,1), (1,0,1), (0,-1,1), (0,1,1),
                             (-1,-1,1.4), (-1,1,1.4), (1,-1,1.4), (1,1,1.4)]:
            nx, ny = current[0]+dx, current[1]+dy

            if 0 <= nx < rows and 0 <= ny < cols:
                surface_cost = costs[nx, ny]
                if surface_cost >= 500: continue

                new_cost = cost_so_far[current] + (surface_cost * dist)

                if (nx, ny) not in cost_so_far or new_cost < cost_so_far[(nx, ny)]:
                    cost_so_far[(nx, ny)] = new_cost
                    priority = new_cost + (abs(goal[0]-nx) + abs(goal[1]-ny))
                    heapq.heappush(frontier, (priority, (nx, ny)))
                    came_from[(nx, ny)] = current

    if goal not in came_from: return None
    path = []
    curr = goal
    while curr != start:
        path.append(curr)
        curr = came_from[curr]
    path.append(start)
    path.reverse()
    return path

# --- 4. EXECUTION ---
if __name__ == "__main__":
    model = load_model()
    if model:
        img, costs, rows, cols = scan_terrain(model)

        if img is not None:
            start = START_COORD if START_COORD else (0, 0)
            goal = END_COORD if END_COORD else (rows-1, cols-1)

            path = a_star(costs, start, goal)

            # Visualize
            fig, ax = plt.subplots(1, 2, figsize=(20, 10))

            ax[0].imshow(img)
            ax[0].set_title("Rover Navigation")
            ax[0].axis('off')

            heatmap = ax[1].imshow(costs, cmap='jet', interpolation='nearest')
            ax[1].set_title("Cost Map (Despeckled)")
            plt.colorbar(heatmap, ax=ax[1], fraction=0.046, pad=0.04)
            ax[1].axis('off')

            if path:
                ys = [r * SCAN_STEP + SCAN_STEP//2 for r, c in path]
                xs = [c * SCAN_STEP + SCAN_STEP//2 for r, c in path]
                ax[0].plot(xs, ys, color='yellow', linewidth=3)
                ax[0].scatter(xs[0], ys[0], c='green', s=150)
                ax[0].scatter(xs[-1], ys[-1], c='red', s=150)

                path_ys = [r for r, c in path]
                path_xs = [c for r, c in path]
                ax[1].plot(path_xs, path_ys, color='white', linewidth=2, linestyle='--')

            plt.tight_layout()
            plt.show()