In [None]:
# 0.65137

import os
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from PIL import Image
from tqdm import tqdm
from sklearn.model_selection import train_test_split

# ==========================================
# 1. CONFIGURATION & FLAGS
# ==========================================
# --- USER MODES (Set ONE to True) ---
START_TRAINING    = False    # Start fresh, split data, train from scratch
CONTINUE_TRAINING = True     # Load best_model.pth, use SAME split, resume training
INFERENCE_ONLY    = False    # Load best_model.pth, predict on test_images

# Paths
BASE_DIR = './data/kaggle_dataset'
TRAIN_IMG_DIR = os.path.join(BASE_DIR, 'train_images')
TEST_IMG_DIR = os.path.join(BASE_DIR, 'test_images')
TRAIN_CSV = os.path.join(BASE_DIR, 'train_ground_truth.csv')
SAMPLE_SUB_CSV = os.path.join(BASE_DIR, 'sample_submission.csv')
MODEL_SAVE_PATH = 'best_model.pth'

# Hyperparameters
BATCH_SIZE = 32
LEARNING_RATE = 0.0001
EPOCHS = 5 
VALIDATION_SPLIT = 0.2
NUM_WORKERS = 0  

# --- DEVICE SETUP FOR M3 PRO ---
if torch.backends.mps.is_available():
    DEVICE = torch.device("mps")
    print("Apple Silicon (MPS) detected. Using GPU acceleration.")
elif torch.cuda.is_available():
    DEVICE = torch.device("cuda")
    print("CUDA GPU detected.")
else:
    DEVICE = torch.device("cpu")
    print("GPU not found. Using CPU (slower).")

# ==========================================
# 2. DATASET CLASS
# ==========================================
class GeoGuessrDataset(Dataset):
    def __init__(self, dataframe, img_dir, transform=None, is_test=False):
        self.data = dataframe.reset_index(drop=True)
        self.img_dir = img_dir
        self.transform = transform
        self.is_test = is_test

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        sample_id = row['sample_id']
        
        # Load 4 images
        directions = ['north', 'east', 'south', 'west']
        images = []
        for d in directions:
            fname = f"img_{sample_id:06d}_{d}.jpg"
            img_path = os.path.join(self.img_dir, fname)
            try:
                img = Image.open(img_path).convert('RGB')
            except FileNotFoundError:
                img = Image.new('RGB', (256, 256))
            images.append(img)

        # Stitch 2x2
        w, h = images[0].size
        stitched_img = Image.new('RGB', (w * 2, h * 2))
        stitched_img.paste(images[0], (0, 0))      # North
        stitched_img.paste(images[1], (w, 0))      # East
        stitched_img.paste(images[3], (0, h))      # West
        stitched_img.paste(images[2], (w, h))      # South

        if self.transform:
            stitched_img = self.transform(stitched_img)

        if self.is_test:
            return stitched_img, sample_id
        else:
            state_idx = int(row['state_idx'])
            # Normalize GPS
            lat_norm = row['latitude'] / 90.0
            lon_norm = row['longitude'] / 180.0
            gps_target = torch.tensor([lat_norm, lon_norm], dtype=torch.float32)
            return stitched_img, state_idx, gps_target

# Transforms
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# ==========================================
# 3. MODEL DEFINITION
# ==========================================
class MultiTaskGeoNet(nn.Module):
    def __init__(self):
        super(MultiTaskGeoNet, self).__init__()
        # ResNet18 is light and fast
        self.backbone = models.resnet18(pretrained=True)
        in_features = self.backbone.fc.in_features
        self.backbone.fc = nn.Identity()
        
        self.cls_head = nn.Sequential(
            nn.Linear(in_features, 512),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, 50)
        )
        
        self.gps_head = nn.Sequential(
            nn.Linear(in_features, 512),
            nn.ReLU(),
            nn.Linear(512, 2)
        )

    def forward(self, x):
        features = self.backbone(x)
        state_logits = self.cls_head(features)
        gps_pred = self.gps_head(features)
        return state_logits, gps_pred

# ==========================================
# 4. CHECKPOINT HELPERS
# ==========================================
def save_checkpoint(model, optimizer, scheduler, val_loss, epoch, filename):
    print(f"--> SAVING BEST MODEL (Val Loss: {val_loss:.4f}) to {filename}")
    checkpoint = {
        'epoch': epoch,
        'model_state': model.state_dict(),
        'optimizer_state': optimizer.state_dict(),
        'scheduler_state': scheduler.state_dict(),
        'best_val_loss': val_loss
    }
    torch.save(checkpoint, filename)

def load_checkpoint(model, optimizer, scheduler, filename):
    if os.path.isfile(filename):
        print(f"--> Loading checkpoint '{filename}'...")
        checkpoint = torch.load(filename, map_location=DEVICE)
        model.load_state_dict(checkpoint['model_state'])
        
        if optimizer:
            optimizer.load_state_dict(checkpoint['optimizer_state'])
        if scheduler and 'scheduler_state' in checkpoint:
            scheduler.load_state_dict(checkpoint['scheduler_state'])
            
        start_epoch = checkpoint['epoch'] + 1
        best_val_loss = checkpoint['best_val_loss']
        print(f"--> Loaded. Resuming from Epoch {start_epoch} with Best Val Loss {best_val_loss:.4f}")
        return start_epoch, best_val_loss
    else:
        print(f"--> No checkpoint found at '{filename}'. Starting fresh.")
        return 0, float('inf')

# ==========================================
# 5. DATA SETUP
# ==========================================
if not INFERENCE_ONLY:
    full_df = pd.read_csv(TRAIN_CSV)
    train_df, val_df = train_test_split(full_df, test_size=VALIDATION_SPLIT, random_state=42)
    print(f"Data Split: {len(train_df)} Training, {len(val_df)} Validation")

    train_dataset = GeoGuessrDataset(train_df, TRAIN_IMG_DIR, transform=transform)
    val_dataset = GeoGuessrDataset(val_df, TRAIN_IMG_DIR, transform=transform)

    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, 
                              num_workers=NUM_WORKERS)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, 
                            num_workers=NUM_WORKERS)

# ==========================================
# 6. INIT & MODE SELECTION
# ==========================================
model = MultiTaskGeoNet().to(DEVICE)
criterion_cls = nn.CrossEntropyLoss()
criterion_gps = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=2)

current_best_val_loss = float('inf')
start_epoch = 0

if START_TRAINING:
    print("--- MODE: STARTING FRESH TRAINING ---")
    if os.path.exists(MODEL_SAVE_PATH):
        os.remove(MODEL_SAVE_PATH)

elif CONTINUE_TRAINING:
    print("--- MODE: CONTINUING TRAINING ---")
    start_epoch, current_best_val_loss = load_checkpoint(model, optimizer, scheduler, MODEL_SAVE_PATH)

elif INFERENCE_ONLY:
    print("--- MODE: INFERENCE ONLY ---")
    load_checkpoint(model, None, None, MODEL_SAVE_PATH)

# ==========================================
# 7. TRAINING LOOP
# ==========================================
if not INFERENCE_ONLY:
    print(f"Starting training loop from epoch {start_epoch+1}...")

    for epoch in range(start_epoch, start_epoch + EPOCHS):
        # --- TRAIN ---
        model.train()
        train_loss = 0.0
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1} [Train]")
        
        for images, states, gps_targets in progress_bar:
            images, states, gps_targets = images.to(DEVICE), states.to(DEVICE), gps_targets.to(DEVICE)
            
            optimizer.zero_grad()
            state_logits, gps_pred = model(images)
            
            loss_cls = criterion_cls(state_logits, states)
            loss_gps = criterion_gps(gps_pred, gps_targets)
            loss = loss_cls + (10.0 * loss_gps)
            
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            progress_bar.set_postfix({'trn_loss': loss.item()})
        
        avg_train_loss = train_loss / len(train_loader)

        # --- VALIDATE ---
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for images, states, gps_targets in tqdm(val_loader, desc=f"Epoch {epoch+1} [Val]"):
                images, states, gps_targets = images.to(DEVICE), states.to(DEVICE), gps_targets.to(DEVICE)
                
                state_logits, gps_pred = model(images)
                loss_cls = criterion_cls(state_logits, states)
                loss_gps = criterion_gps(gps_pred, gps_targets)
                loss = loss_cls + (10.0 * loss_gps)
                val_loss += loss.item()

        avg_val_loss = val_loss / len(val_loader)
        
        current_lr = optimizer.param_groups[0]['lr']
        print(f"\nEpoch {epoch+1}: Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f} | LR: {current_lr:.6f}")

        # Update Scheduler
        scheduler.step(avg_val_loss)

        # Save Best Model
        if avg_val_loss < current_best_val_loss:
            current_best_val_loss = avg_val_loss
            save_checkpoint(model, optimizer, scheduler, current_best_val_loss, epoch, MODEL_SAVE_PATH)
        else:
            print(f"Val Loss did not improve (Best: {current_best_val_loss:.4f}).")

    print("\nTraining complete.")

# ==========================================
# 8. SUBMISSION GENERATOR
# ==========================================
if INFERENCE_ONLY:
    print("\nGenerating submission using the BEST saved model...")
    sub_df = pd.read_csv(SAMPLE_SUB_CSV)
    test_dataset = GeoGuessrDataset(sub_df, TEST_IMG_DIR, transform=transform, is_test=True)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, 
                             num_workers=NUM_WORKERS)

    model.eval()
    results = []

    with torch.no_grad():
        for images, sample_ids in tqdm(test_loader, desc="Testing"):
            images = images.to(DEVICE)
            state_logits, gps_pred = model(images)
            
            top5_indices = torch.argsort(state_logits, dim=1, descending=True)[:, :5]
            pred_lat = gps_pred[:, 0] * 90.0
            pred_lon = gps_pred[:, 1] * 180.0
            
            for i in range(len(sample_ids)):
                sid = sample_ids[i].item()
                states = top5_indices[i].cpu().numpy()
                lat = float(pred_lat[i].cpu().item())
                lon = float(pred_lon[i].cpu().item())
                
                lat = max(min(lat, 90), -90)
                lon = max(min(lon, 180), -180)
                
                entry = {
                    'sample_id': sid,
                    'predicted_state_idx_1': states[0],
                    'predicted_state_idx_2': states[1],
                    'predicted_state_idx_3': states[2],
                    'predicted_state_idx_4': states[3],
                    'predicted_state_idx_5': states[4],
                    'predicted_latitude': lat,
                    'predicted_longitude': lon
                }
                results.append(entry)

    submission_df = pd.DataFrame(results)
    template_df = pd.read_csv(SAMPLE_SUB_CSV)
    final_df = template_df.copy()

    for idx, row in submission_df.iterrows():
        mask = final_df['sample_id'] == row['sample_id']
        final_df.loc[mask, 'predicted_state_idx_1'] = row['predicted_state_idx_1']
        final_df.loc[mask, 'predicted_state_idx_2'] = row['predicted_state_idx_2']
        final_df.loc[mask, 'predicted_state_idx_3'] = row['predicted_state_idx_3']
        final_df.loc[mask, 'predicted_state_idx_4'] = row['predicted_state_idx_4']
        final_df.loc[mask, 'predicted_state_idx_5'] = row['predicted_state_idx_5']
        final_df.loc[mask, 'predicted_latitude'] = row['predicted_latitude']
        final_df.loc[mask, 'predicted_longitude'] = row['predicted_longitude']

    final_df.to_csv('submission.csv', index=False)
    print(f"Submission saved to 'submission.csv' with {len(final_df)} rows.")

Apple Silicon (MPS) detected. Using GPU acceleration.
Data Split: 52784 Training, 13196 Validation




--- MODE: CONTINUING TRAINING ---
--> Loading checkpoint 'best_model.pth'...
--> Loaded. Resuming from Epoch 5 with Best Val Loss 1.9849
Starting training loop from epoch 6...


Epoch 6 [Train]: 100%|██████████| 1650/1650 [08:07<00:00,  3.38it/s, trn_loss=2.35]
Epoch 6 [Val]: 100%|██████████| 413/413 [01:25<00:00,  4.83it/s]



Epoch 6: Train Loss: 1.7624 | Val Loss: 1.9507 | LR: 0.001000
--> SAVING BEST MODEL (Val Loss: 1.9507) to best_model.pth


Epoch 7 [Train]: 100%|██████████| 1650/1650 [08:59<00:00,  3.06it/s, trn_loss=1.21] 
Epoch 7 [Val]: 100%|██████████| 413/413 [01:14<00:00,  5.51it/s]



Epoch 7: Train Loss: 1.5793 | Val Loss: 1.9149 | LR: 0.001000
--> SAVING BEST MODEL (Val Loss: 1.9149) to best_model.pth


Epoch 8 [Train]: 100%|██████████| 1650/1650 [07:45<00:00,  3.55it/s, trn_loss=1.83] 
Epoch 8 [Val]: 100%|██████████| 413/413 [01:24<00:00,  4.86it/s]



Epoch 8: Train Loss: 1.4067 | Val Loss: 1.9009 | LR: 0.001000
--> SAVING BEST MODEL (Val Loss: 1.9009) to best_model.pth


Epoch 9 [Train]: 100%|██████████| 1650/1650 [07:51<00:00,  3.50it/s, trn_loss=0.792]
Epoch 9 [Val]: 100%|██████████| 413/413 [01:25<00:00,  4.85it/s]



Epoch 9: Train Loss: 1.2440 | Val Loss: 2.0210 | LR: 0.001000
Val Loss did not improve (Best: 1.9009).


Epoch 10 [Train]: 100%|██████████| 1650/1650 [07:55<00:00,  3.47it/s, trn_loss=1.71] 
Epoch 10 [Val]: 100%|██████████| 413/413 [01:25<00:00,  4.86it/s]


Epoch 10: Train Loss: 1.0691 | Val Loss: 2.1858 | LR: 0.001000
Val Loss did not improve (Best: 1.9009).

Training complete.





In [14]:
import os
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from PIL import Image
from tqdm import tqdm
from sklearn.model_selection import train_test_split

# ==========================================
# 1. CONFIGURATION & FLAGS
# ==========================================
# --- USER MODES ---
START_TRAINING    = False   
CONTINUE_TRAINING = True    
INFERENCE_ONLY    = False   

# Paths
BASE_DIR = './data/kaggle_dataset'
TRAIN_IMG_DIR = os.path.join(BASE_DIR, 'train_images')
TEST_IMG_DIR = os.path.join(BASE_DIR, 'test_images')
TRAIN_CSV = os.path.join(BASE_DIR, 'train_ground_truth.csv')
SAMPLE_SUB_CSV = os.path.join(BASE_DIR, 'sample_submission.csv')
MODEL_SAVE_PATH = 'best_model_2.pth'

# Hyperparameters & Grid Config
BATCH_SIZE = 32
LEARNING_RATE = 0.00005 # Lower LR for deeper heads
EPOCHS = 10
VALIDATION_SPLIT = 0.2
GRID_SIZE = 25  # 25x25 grid = 625 distinct regions (Classes)

# M3 Pro Settings
NUM_WORKERS = 0 
if torch.backends.mps.is_available():
    DEVICE = torch.device("mps")
    print("Apple Silicon (MPS) detected. Using GPU acceleration.")
elif torch.cuda.is_available():
    DEVICE = torch.device("cuda")
else:
    DEVICE = torch.device("cpu")

# ==========================================
# 2. GRID SYSTEM HELPER
# ==========================================
class GridSystem:
    """
    Divides the map into a grid of tiles (cells) based on the paper's approach.
    Converts Lat/Lon <-> Class Index.
    """
    def __init__(self, df=None, grid_size=25):
        self.grid_size = grid_size
        if df is not None:
            # Determine bounds from training data
            self.min_lat = df['latitude'].min()
            self.max_lat = df['latitude'].max()
            self.min_lon = df['longitude'].min()
            self.max_lon = df['longitude'].max()
            # Add small buffer to avoid edge cases
            self.lat_step = (self.max_lat - self.min_lat + 1e-5) / grid_size
            self.lon_step = (self.max_lon - self.min_lon + 1e-5) / grid_size
            
    def save_state(self):
        return {
            'bounds': (self.min_lat, self.max_lat, self.min_lon, self.max_lon),
            'steps': (self.lat_step, self.lon_step),
            'grid_size': self.grid_size
        }
    
    def load_state(self, state):
        self.min_lat, self.max_lat, self.min_lon, self.max_lon = state['bounds']
        self.lat_step, self.lon_step = state['steps']
        self.grid_size = state['grid_size']

    def coords_to_grid(self, lat, lon):
        # Calculate row and col
        row = int((lat - self.min_lat) / self.lat_step)
        col = int((lon - self.min_lon) / self.lon_step)
        # Clip to ensure validity
        row = max(0, min(row, self.grid_size - 1))
        col = max(0, min(col, self.grid_size - 1))
        # Convert 2D -> 1D class index
        return row * self.grid_size + col

    def grid_to_coords(self, grid_idx):
        # Convert 1D -> 2D
        row = grid_idx // self.grid_size
        col = grid_idx % self.grid_size
        # Return center of the tile
        center_lat = self.min_lat + (row + 0.5) * self.lat_step
        center_lon = self.min_lon + (col + 0.5) * self.lon_step
        return center_lat, center_lon

# ==========================================
# 3. DATASET CLASS
# ==========================================
class GeoGuessrDataset(Dataset):
    def __init__(self, dataframe, img_dir, grid_system, transform=None, is_test=False):
        self.data = dataframe.reset_index(drop=True)
        self.img_dir = img_dir
        self.transform = transform
        self.is_test = is_test
        self.grid_system = grid_system

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        sample_id = row['sample_id']
        
        # Load 4 images and stitch
        directions = ['north', 'east', 'south', 'west']
        images = []
        for d in directions:
            fname = f"img_{sample_id:06d}_{d}.jpg"
            img_path = os.path.join(self.img_dir, fname)
            try:
                img = Image.open(img_path).convert('RGB')
            except FileNotFoundError:
                img = Image.new('RGB', (256, 256))
            images.append(img)

        # 2x2 Stitch
        w, h = images[0].size
        stitched_img = Image.new('RGB', (w * 2, h * 2))
        stitched_img.paste(images[0], (0, 0))
        stitched_img.paste(images[1], (w, 0))
        stitched_img.paste(images[3], (0, h))
        stitched_img.paste(images[2], (w, h))

        if self.transform:
            stitched_img = self.transform(stitched_img)

        if self.is_test:
            return stitched_img, sample_id
        else:
            state_idx = int(row['state_idx'])
            # Convert Lat/Lon to Grid Class ID
            lat, lon = row['latitude'], row['longitude']
            grid_label = self.grid_system.coords_to_grid(lat, lon)
            
            return stitched_img, state_idx, grid_label

# Transforms
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# ==========================================
# 4. MODEL DEFINITION (Improved Heads)
# ==========================================
class DeepGeoNet(nn.Module):
    def __init__(self, num_grid_classes):
        super(DeepGeoNet, self).__init__()
        self.backbone = models.resnet18(pretrained=True)
        in_features = self.backbone.fc.in_features
        self.backbone.fc = nn.Identity()
        
        # Improved Head Structure from Paper: 
        # Dense(1024) -> Drop -> Dense(512) -> Drop -> Dense(100) -> Drop -> Output
        
        # Head 1: State Classification
        self.cls_head = nn.Sequential(
            nn.Linear(in_features, 1024),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, 100),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(100, 50) # 50 States
        )
        
        # Head 2: Grid Classification (Paper's preferred method)
        self.grid_head = nn.Sequential(
            nn.Linear(in_features, 1024),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, 100),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(100, num_grid_classes) # e.g. 625 tiles
        )

    def forward(self, x):
        features = self.backbone(x)
        state_logits = self.cls_head(features)
        grid_logits = self.grid_head(features)
        return state_logits, grid_logits

# ==========================================
# 5. CHECKPOINT HELPERS
# ==========================================
def save_checkpoint(model, optimizer, scheduler, grid_system, val_loss, epoch, filename):
    print(f"--> SAVING BEST MODEL (Val Loss: {val_loss:.4f}) to {filename}")
    checkpoint = {
        'epoch': epoch,
        'model_state': model.state_dict(),
        'optimizer_state': optimizer.state_dict(),
        'scheduler_state': scheduler.state_dict(),
        'grid_system_state': grid_system.save_state(), # Save grid definitions
        'best_val_loss': val_loss
    }
    torch.save(checkpoint, filename)

def load_checkpoint(model, optimizer, scheduler, grid_system, filename):
    if os.path.isfile(filename):
        print(f"--> Loading checkpoint '{filename}'...")
        
        # Added weights_only=False to allow loading the grid system and numpy values
        checkpoint = torch.load(filename, map_location=DEVICE, weights_only=False)
        
        # Load Grid System first to ensure model matches
        if 'grid_system_state' in checkpoint:
            grid_system.load_state(checkpoint['grid_system_state'])
            
        model.load_state_dict(checkpoint['model_state'])
        
        if optimizer:
            optimizer.load_state_dict(checkpoint['optimizer_state'])
        if scheduler and 'scheduler_state' in checkpoint:
            scheduler.load_state_dict(checkpoint['scheduler_state'])
            
        start_epoch = checkpoint['epoch'] + 1
        best_val_loss = checkpoint['best_val_loss']
        return start_epoch, best_val_loss
    else:
        print(f"--> No checkpoint found at '{filename}'. Starting fresh.")
        return 0, float('inf')

# ==========================================
# 6. SETUP & DATA PREP
# ==========================================
grid_system = GridSystem(grid_size=GRID_SIZE)

if not INFERENCE_ONLY:
    full_df = pd.read_csv(TRAIN_CSV)
    
    # Initialize Grid System with ALL training data
    grid_system = GridSystem(full_df, grid_size=GRID_SIZE)
    
    train_df, val_df = train_test_split(full_df, test_size=VALIDATION_SPLIT, random_state=42)
    print(f"Data Split: {len(train_df)} Training, {len(val_df)} Validation")

    train_dataset = GeoGuessrDataset(train_df, TRAIN_IMG_DIR, grid_system, transform=transform)
    val_dataset = GeoGuessrDataset(val_df, TRAIN_IMG_DIR, grid_system, transform=transform)

    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)

# ==========================================
# 7. INITIALIZATION
# ==========================================
num_grid_classes = GRID_SIZE * GRID_SIZE
model = DeepGeoNet(num_grid_classes=num_grid_classes).to(DEVICE)

# Both tasks are now Classification (CrossEntropy)
criterion_cls = nn.CrossEntropyLoss()
criterion_grid = nn.CrossEntropyLoss()

optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2)

current_best_val_loss = float('inf')
start_epoch = 0

if START_TRAINING:
    print("--- MODE: STARTING FRESH TRAINING ---")
    if os.path.exists(MODEL_SAVE_PATH):
        os.remove(MODEL_SAVE_PATH)

elif CONTINUE_TRAINING:
    print("--- MODE: CONTINUING TRAINING ---")
    start_epoch, current_best_val_loss = load_checkpoint(model, optimizer, scheduler, grid_system, MODEL_SAVE_PATH)

elif INFERENCE_ONLY:
    print("--- MODE: INFERENCE ONLY ---")
    # We must load the grid system state to know how to decode the predictions
    load_checkpoint(model, None, None, grid_system, MODEL_SAVE_PATH)

# ==========================================
# 8. TRAINING LOOP
# ==========================================
if not INFERENCE_ONLY:
    print(f"Starting training loop from epoch {start_epoch+1}...")

    for epoch in range(start_epoch, start_epoch + EPOCHS):
        # --- TRAIN ---
        model.train()
        train_loss = 0.0
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1} [Train]")
        
        for images, states, grid_targets in progress_bar:
            images, states, grid_targets = images.to(DEVICE), states.to(DEVICE), grid_targets.to(DEVICE)
            
            optimizer.zero_grad()
            state_logits, grid_logits = model(images)
            
            # Loss: State Classification + Grid Classification
            loss_cls = criterion_cls(state_logits, states)
            loss_grid = criterion_grid(grid_logits, grid_targets)
            
            # Balance the two classification tasks
            loss = loss_cls + loss_grid 
            
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            progress_bar.set_postfix({'loss': loss.item()})
        
        avg_train_loss = train_loss / len(train_loader)

        # --- VALIDATE ---
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for images, states, grid_targets in tqdm(val_loader, desc=f"Epoch {epoch+1} [Val]"):
                images, states, grid_targets = images.to(DEVICE), states.to(DEVICE), grid_targets.to(DEVICE)
                
                state_logits, grid_logits = model(images)
                
                loss_cls = criterion_cls(state_logits, states)
                loss_grid = criterion_grid(grid_logits, grid_targets)
                loss = loss_cls + loss_grid
                
                val_loss += loss.item()

        avg_val_loss = val_loss / len(val_loader)
        
        print(f"\nEpoch {epoch+1}: Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}")
        
        scheduler.step(avg_val_loss)

        if avg_val_loss < current_best_val_loss:
            current_best_val_loss = avg_val_loss
            save_checkpoint(model, optimizer, scheduler, grid_system, current_best_val_loss, epoch, MODEL_SAVE_PATH)
        else:
            print(f"Val Loss did not improve (Best: {current_best_val_loss:.4f}).")

    print("\nTraining complete.")

# ==========================================
# 9. SUBMISSION GENERATOR
# ==========================================
if INFERENCE_ONLY:
    print("\nGenerating submission using the BEST saved model...")
    sub_df = pd.read_csv(SAMPLE_SUB_CSV)
    test_dataset = GeoGuessrDataset(sub_df, TEST_IMG_DIR, grid_system, transform=transform, is_test=True)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)

    model.eval()
    results = []

    with torch.no_grad():
        for images, sample_ids in tqdm(test_loader, desc="Testing"):
            images = images.to(DEVICE)
            state_logits, grid_logits = model(images)
            
            # 1. State Prediction (Top 5)
            top5_indices = torch.argsort(state_logits, dim=1, descending=True)[:, :5]
            
            # 2. Grid Prediction (Top 1) -> Convert to Lat/Lon
            grid_preds = torch.argmax(grid_logits, dim=1)
            
            for i in range(len(sample_ids)):
                sid = sample_ids[i].item()
                states = top5_indices[i].cpu().numpy()
                grid_idx = grid_preds[i].item()
                
                # Decoder: Grid ID -> Center Lat/Lon
                pred_lat, pred_lon = grid_system.grid_to_coords(grid_idx)
                
                entry = {
                    'sample_id': sid,
                    'predicted_state_idx_1': states[0],
                    'predicted_state_idx_2': states[1],
                    'predicted_state_idx_3': states[2],
                    'predicted_state_idx_4': states[3],
                    'predicted_state_idx_5': states[4],
                    'predicted_latitude': pred_lat,
                    'predicted_longitude': pred_lon
                }
                results.append(entry)

    # Save to CSV (matches original template)
    submission_df = pd.DataFrame(results)
    template_df = pd.read_csv(SAMPLE_SUB_CSV)
    final_df = template_df.copy()

    for idx, row in submission_df.iterrows():
        mask = final_df['sample_id'] == row['sample_id']
        for k in row.keys():
            if k != 'sample_id':
                final_df.loc[mask, k] = row[k]

    final_df.to_csv('submission_2.csv', index=False)
    print(f"Submission saved to 'submission_2.csv' with {len(final_df)} rows.")

Apple Silicon (MPS) detected. Using GPU acceleration.
Data Split: 52784 Training, 13196 Validation




--- MODE: CONTINUING TRAINING ---
--> Loading checkpoint 'best_model_2.pth'...
Starting training loop from epoch 9...


Epoch 9 [Train]: 100%|██████████| 1650/1650 [08:32<00:00,  3.22it/s, loss=4.64]
Epoch 9 [Val]: 100%|██████████| 413/413 [01:04<00:00,  6.40it/s]



Epoch 9: Train Loss: 3.6865 | Val Loss: 4.9491
Val Loss did not improve (Best: 4.6096).


Epoch 10 [Train]: 100%|██████████| 1650/1650 [08:04<00:00,  3.40it/s, loss=4.47]
Epoch 10 [Val]: 100%|██████████| 413/413 [01:02<00:00,  6.65it/s]



Epoch 10: Train Loss: 3.4444 | Val Loss: 5.0725
Val Loss did not improve (Best: 4.6096).


Epoch 11 [Train]: 100%|██████████| 1650/1650 [08:19<00:00,  3.30it/s, loss=3.21]
Epoch 11 [Val]: 100%|██████████| 413/413 [01:23<00:00,  4.94it/s]



Epoch 11: Train Loss: 3.2074 | Val Loss: 5.0177
Val Loss did not improve (Best: 4.6096).


Epoch 12 [Train]: 100%|██████████| 1650/1650 [08:14<00:00,  3.33it/s, loss=4.38]
Epoch 12 [Val]: 100%|██████████| 413/413 [01:22<00:00,  5.00it/s]



Epoch 12: Train Loss: 2.6384 | Val Loss: 5.0589
Val Loss did not improve (Best: 4.6096).


Epoch 13 [Train]: 100%|██████████| 1650/1650 [08:23<00:00,  3.28it/s, loss=2.53]
Epoch 13 [Val]: 100%|██████████| 413/413 [01:19<00:00,  5.17it/s]



Epoch 13: Train Loss: 2.3351 | Val Loss: 5.5350
Val Loss did not improve (Best: 4.6096).


Epoch 14 [Train]: 100%|██████████| 1650/1650 [08:13<00:00,  3.34it/s, loss=2.73]
Epoch 14 [Val]: 100%|██████████| 413/413 [01:17<00:00,  5.31it/s]



Epoch 14: Train Loss: 2.1336 | Val Loss: 5.8574
Val Loss did not improve (Best: 4.6096).


Epoch 15 [Train]: 100%|██████████| 1650/1650 [08:10<00:00,  3.37it/s, loss=1.56]
Epoch 15 [Val]: 100%|██████████| 413/413 [01:11<00:00,  5.74it/s]



Epoch 15: Train Loss: 1.8389 | Val Loss: 6.1133
Val Loss did not improve (Best: 4.6096).


Epoch 16 [Train]: 100%|██████████| 1650/1650 [08:24<00:00,  3.27it/s, loss=2.7]  
Epoch 16 [Val]: 100%|██████████| 413/413 [01:10<00:00,  5.89it/s]



Epoch 16: Train Loss: 1.6452 | Val Loss: 6.7881
Val Loss did not improve (Best: 4.6096).


Epoch 17 [Train]: 100%|██████████| 1650/1650 [08:19<00:00,  3.30it/s, loss=1.81] 
Epoch 17 [Val]: 100%|██████████| 413/413 [01:24<00:00,  4.90it/s]



Epoch 17: Train Loss: 1.5271 | Val Loss: 7.0991
Val Loss did not improve (Best: 4.6096).


Epoch 18 [Train]: 100%|██████████| 1650/1650 [08:11<00:00,  3.35it/s, loss=1.78] 
Epoch 18 [Val]: 100%|██████████| 413/413 [01:24<00:00,  4.90it/s]


Epoch 18: Train Loss: 1.3927 | Val Loss: 7.5259
Val Loss did not improve (Best: 4.6096).

Training complete.





In [19]:
# GeoGuessr CS-GY 6643 pipeline: train / resume / infer
# Modes:
#   "train"   -> train from scratch
#   "resume"  -> resume training
#   "infer"   -> run inference only and create submission.csv

MODE = "resume"   # change to "resume" or "infer"

DATA_ROOT = "./data/kaggle_dataset"
TRAIN_CSV = f"{DATA_ROOT}/train_ground_truth.csv"
SAMPLE_SUB_CSV = f"{DATA_ROOT}/sample_submission.csv"
CHECKPOINT_DIR = "./checkpoints"
BEST_CKPT_PATH = f"{CHECKPOINT_DIR}/best_model.pth"
SUBMISSION_OUT = "./submission_gpt.csv"

NUM_EPOCHS = 20
BATCH_SIZE = 64
LR = 3e-4
VAL_SPLIT = 0.1
WEIGHT_REG = 0.5
NUM_WORKERS = 0
RANDOM_SEED = 42

import os
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

import time
import random
import numpy as np
import pandas as pd
from PIL import Image
from tqdm.auto import tqdm

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader, random_split
import torchvision.transforms as T
import torchvision.models as models

# Device: prefer MPS
if torch.backends.mps.is_available():
    DEVICE = torch.device("mps")
elif torch.cuda.is_available():
    DEVICE = torch.device("cuda")
else:
    DEVICE = torch.device("cpu")

print("Using device:", DEVICE)


def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
set_seed(RANDOM_SEED)

# ============================================================
# Dataset
# ============================================================

class GeoDataset(Dataset):
    def __init__(self, df, img_dir, transform=None, state2internal=None, with_labels=True):
        self.df = df.reset_index(drop=True)
        self.img_dir = img_dir
        self.transform = transform
        self.with_labels = with_labels
        self.state2internal = state2internal

    def load_img(self, fname):
        p = os.path.join(self.img_dir, fname)
        img = Image.open(p).convert("RGB")
        if self.transform:
            img = self.transform(img)
        return img

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        imgs = [
            self.load_img(row["image_north"]),
            self.load_img(row["image_east"]),
            self.load_img(row["image_south"]),
            self.load_img(row["image_west"]),
        ]
        views = torch.stack(imgs, dim=0)

        if self.with_labels:
            st = self.state2internal[int(row["state_idx"])]
            coords = torch.tensor([row["latitude"], row["longitude"]], dtype=torch.float32)
            return views, st, coords

        return (
            views,
            int(row["sample_id"]),
            row["image_north"],
            row["image_east"],
            row["image_south"],
            row["image_west"],
        )

# ============================================================
# Model
# ============================================================

class GeoModel(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        base = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
        fd = base.fc.in_features
        base.fc = nn.Identity()
        self.encoder = base
        self.class_head = nn.Linear(fd, num_classes)
        self.reg_head = nn.Linear(fd, 2)

    def forward(self, views):
        B, V, C, H, W = views.shape
        x = views.view(B * V, C, H, W)
        f = self.encoder(x)                 # [B*V, F]
        f = f.view(B, V, -1).mean(dim=1)    # [B, F]
        return self.class_head(f), self.reg_head(f)

# ============================================================
# Haversine
# ============================================================

def haversine(lat1, lon1, lat2, lon2):
    R = 6371
    lat1 = torch.deg2rad(lat1)
    lat2 = torch.deg2rad(lat2)
    lon1 = torch.deg2rad(lon1)
    lon2 = torch.deg2rad(lon2)
    dlat = lat2 - lat1
    dlon = lon2 - lon1
    a = torch.sin(dlat/2)**2 + torch.cos(lat1)*torch.cos(lat2)*torch.sin(dlon/2)**2
    return 2 * R * torch.atan2(torch.sqrt(a), torch.sqrt(1 - a))

# ============================================================
# Train helpers
# ============================================================

def train_one_epoch(model, loader, optim, ce, regr):
    model.train()
    total, correct = 0, 0
    loss_sum, cls_sum, reg_sum = 0, 0, 0

    for views, labels, coords in tqdm(loader, desc="Train", leave=False):
        views = views.to(DEVICE)
        labels = labels.to(DEVICE)
        coords = coords.to(DEVICE)

        optim.zero_grad()
        logits, pred = model(views)

        lc = ce(logits, labels)
        lr = regr(pred, coords)
        loss = lc + WEIGHT_REG * lr
        loss.backward()
        optim.step()

        loss_sum += loss.item() * views.size(0)
        cls_sum += lc.item() * views.size(0)
        reg_sum += lr.item() * views.size(0)

        correct += (logits.argmax(1) == labels).sum().item()
        total += views.size(0)

    return (
        loss_sum / total,
        cls_sum / total,
        reg_sum / total,
        correct / total,
    )

def eval_epoch(model, loader, ce, regr):
    model.eval()
    total, correct = 0, 0
    loss_sum, cls_sum, reg_sum = 0, 0, 0
    dists = []

    with torch.no_grad():
        for views, labels, coords in tqdm(loader, desc="Val", leave=False):
            views = views.to(DEVICE)
            labels = labels.to(DEVICE)
            coords = coords.to(DEVICE)

            logits, pred = model(views)

            lc = ce(logits, labels)
            lr = regr(pred, coords)
            loss = lc + WEIGHT_REG * lr

            loss_sum += loss.item() * views.size(0)
            cls_sum += lc.item() * views.size(0)
            reg_sum += lr.item() * views.size(0)

            correct += (logits.argmax(1) == labels).sum().item()
            total += views.size(0)

            dists.append(haversine(coords[:,0], coords[:,1], pred[:,0], pred[:,1]).cpu())

    if dists:
        d = torch.cat(dists).mean().item()
    else:
        d = float("nan")

    return loss_sum/total, cls_sum/total, reg_sum/total, correct/total, d

# ============================================================
# Data setup
# ============================================================

train_df = pd.read_csv(TRAIN_CSV)

states = sorted(train_df.state_idx.unique())
state2internal = {s:i for i,s in enumerate(states)}
internal2state = {i:s for s,i in state2internal.items()}
num_classes = len(states)

transform = T.Compose([
    T.Resize((256,256)),
    T.ToTensor(),
    T.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225]),
])

full_ds = GeoDataset(train_df, f"{DATA_ROOT}/train_images", transform, state2internal)
val_sz = int(len(full_ds)*VAL_SPLIT)
train_sz = len(full_ds) - val_sz

train_ds, val_ds = random_split(full_ds, [train_sz, val_sz])

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)

# ============================================================
# Model + optim
# ============================================================

model = GeoModel(num_classes).to(DEVICE)
optim = torch.optim.AdamW(model.parameters(), lr=LR)
sched = torch.optim.lr_scheduler.CosineAnnealingLR(optim, T_max=NUM_EPOCHS)
ce = nn.CrossEntropyLoss()
regr = nn.SmoothL1Loss()

start_epoch = 0
best_val = float("inf")

# Resume
if MODE in ["resume","infer"] and os.path.exists(BEST_CKPT_PATH):
    ck = torch.load(BEST_CKPT_PATH, map_location=DEVICE, weights_only=False)
    model.load_state_dict(ck["model"])
    if MODE == "resume":
        optim.load_state_dict(ck["optim"])
        sched.load_state_dict(ck["sched"])
        start_epoch = ck["epoch"]+1
        best_val = ck["best"]
        print("Resumed at epoch", start_epoch)
    else:
        print("Loaded model for inference")

# ============================================================
# Training
# ============================================================

if MODE in ["train","resume"]:
    print("Training begins")
    for ep in range(start_epoch, NUM_EPOCHS):
        tl, tcls, treg, tacc = train_one_epoch(model, train_loader, optim, ce, regr)
        vl, vcls, vreg, vacc, vdist = eval_epoch(model, val_loader, ce, regr)
        sched.step()

        print(f"Epoch {ep+1}/{NUM_EPOCHS} | "
              f"train_loss={tl:.4f} acc={tacc:.4f} | "
              f"val_loss={vl:.4f} acc={vacc:.4f} dist={vdist:.1f} km")

        if vl < best_val:
            best_val = vl
            torch.save({
                "epoch": ep,
                "model": model.state_dict(),
                "optim": optim.state_dict(),
                "sched": sched.state_dict(),
                "best": best_val,
                "states": states
            }, BEST_CKPT_PATH)
            print("  Saved new best model\n")

# ============================================================
# Inference
# ============================================================
if MODE in ["infer","train","resume"]:
    print("Running inference...")

    ck = torch.load(BEST_CKPT_PATH, map_location=DEVICE, weights_only=False)
    model.load_state_dict(ck["model"])
    states = ck["states"]
    internal2state = {i:s for i,s in enumerate(states)}

    sub = pd.read_csv(SAMPLE_SUB_CSV)
    test_df = sub[["sample_id","image_north","image_east","image_south","image_west"]].copy()

    test_ds = GeoDataset(
        test_df,
        f"{DATA_ROOT}/test_images",
        transform,
        state2internal,
        with_labels=False
    )
    test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)

    preds = {}
    model.eval()

    for batch in tqdm(test_loader, desc="Infer"):
        views, sids, *_ = batch
        views = views.to(DEVICE)

        with torch.no_grad():
            logits, pred = model(views)
            probs = torch.softmax(logits, dim=1)
            top5 = probs.topk(5, dim=1).indices.cpu()

        for i, sid in enumerate(sids):
            sid = int(sid)
            ints = top5[i].tolist()
            mapped = [internal2state[j] for j in ints]

            lat = float(pred[i,0].item())
            lon = float(pred[i,1].item())

            lat = max(-90,min(90,lat))
            lon = max(-180,min(180,lon))

            preds[sid] = {
                "state1": mapped[0],
                "state2": mapped[1],
                "state3": mapped[2],
                "state4": mapped[3],
                "state5": mapped[4],
                "lat": lat,
                "lon": lon
            }

    out = sub.copy()
    out["predicted_state_idx_1"] = out["sample_id"].apply(lambda x: preds[x]["state1"])
    out["predicted_state_idx_2"] = out["sample_id"].apply(lambda x: preds[x]["state2"])
    out["predicted_state_idx_3"] = out["sample_id"].apply(lambda x: preds[x]["state3"])
    out["predicted_state_idx_4"] = out["sample_id"].apply(lambda x: preds[x]["state4"])
    out["predicted_state_idx_5"] = out["sample_id"].apply(lambda x: preds[x]["state5"])
    out["predicted_latitude"] = out["sample_id"].apply(lambda x: preds[x]["lat"])
    out["predicted_longitude"] = out["sample_id"].apply(lambda x: preds[x]["lon"])

    out.to_csv(SUBMISSION_OUT, index=False)
    print("Saved:", SUBMISSION_OUT)

Using device: mps
Resumed at epoch 9
Training begins


Train:   0%|          | 0/928 [00:00<?, ?it/s]

Val:   0%|          | 0/104 [00:00<?, ?it/s]

Epoch 10/20 | train_loss=0.9672 acc=0.9305 | val_loss=1.5360 acc=0.8128 dist=332.0 km
  Saved new best model



Train:   0%|          | 0/928 [00:00<?, ?it/s]

KeyboardInterrupt: 