In [1]:
!pip install rasterio
!pip install pytorch_lightning
!pip install pytorch_tabular

Collecting rasterio
  Downloading rasterio-1.4.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (9.1 kB)
Collecting affine (from rasterio)
  Downloading affine-2.4.0-py3-none-any.whl.metadata (4.0 kB)
Collecting cligj>=0.5 (from rasterio)
  Downloading cligj-0.7.2-py3-none-any.whl.metadata (5.0 kB)
Collecting click-plugins (from rasterio)
  Downloading click_plugins-1.1.1.2-py2.py3-none-any.whl.metadata (6.5 kB)
Downloading rasterio-1.4.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (22.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m22.3/22.3 MB[0m [31m40.1 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading cligj-0.7.2-py3-none-any.whl (7.1 kB)
Downloading affine-2.4.0-py3-none-any.whl (15 kB)
Downloading click_plugins-1.1.1.2-py2.py3-none-any.whl (11 kB)
Installing collected packages: cligj, click-plugins, affine, rasterio
Successfully installed affine-2.4.0 click-plugins-1.1.1.2 cligj-0.7.2 rasterio-1.4.3
Collecting pytorch_light

In [None]:
import os
import gc
import json
import glob
import joblib
import pickle
import rasterio
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from tqdm import tqdm
from rasterio.transform import from_bounds
from sklearn.preprocessing import StandardScaler

# Install pyproj if needed for coordinate transformation
try:
    from pyproj import Transformer
except:
    print("Installing pyproj...")
    import subprocess
    subprocess.check_call(['pip', 'install', 'pyproj'])
    from pyproj import Transformer


import torch
import torch.nn as nn
import torch.nn.functional as F

import tensorflow as tf

from pytorch_tabular import TabularModel
from pytorch_tabular.models import FTTransformerConfig
from pytorch_tabular.config import DataConfig, OptimizerConfig, TrainerConfig

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [5]:
# Check GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Using device: cpu


In [6]:
# VARIABLES
MODEL_PATH = '/content/drive/MyDrive/Bathymetry/MODEL/PINN/base_model.keras'
TFRECORD_DIR = '/content/drive/MyDrive/Bathymetry/COASTALTFRecord'
OUTPUT_DIR = '/content/drive/MyDrive/Bathymetry/RESULT'
os.makedirs(OUTPUT_DIR, exist_ok=True)

PATCH_SIZE = 256
BATCH_SIZE = 2048  # Reduced for memory efficiency
TILE_CHUNK_SIZE = 512  # Process image in 512x512 chunks
TFRECORD_CRS = 'EPSG:32651'
EXPECTED_BANDS = 8

DEPTH_MIN = 0
DEPTH_MAX = 40.0

In [7]:
# =============================
# LOAD MODEL
# =============================

print("="*70)
print("LOADING BATHYMETRY MODEL")
print("="*70)

model = tf.keras.models.load_model(MODEL_PATH)
print(f"✓ Model loaded from: {MODEL_PATH}")
print(f"✓ Expected bands: {EXPECTED_BANDS}")

# Load scaler if exists
try:
    import pickle
    scaler_path = MODEL_PATH.replace('base_model.keras', 'scaler.pkl')
    with open(scaler_path, 'rb') as f:
        scaler = pickle.load(f)
    print(f"✓ Scaler loaded")
except:
    scaler = None
    print("⚠ No scaler found, will use raw values")

model.summary()

LOADING BATHYMETRY MODEL
✓ Model loaded from: /content/drive/MyDrive/Bathymetry/MODEL/PINN/base_model.keras
✓ Expected bands: 8
⚠ No scaler found, will use raw values


In [8]:
# =============================
# TFRECORD PROCESSING (BATHYMETRY VERSION)
# =============================

def process_tfrecord_streaming_bathymetry(tfrecord_base_pattern, model, scaler, tile_bounds, expected_bands=8):
    """
    Process TFRecord for bathymetry estimation.
    Modified from classification version.
    """
    print(f"\nStreaming processing: {os.path.basename(tfrecord_base_pattern)}")

    # Find all files
    tfrecord_files = sorted(glob.glob(f"{tfrecord_base_pattern}-*.tfrecord.gz"))

    if not tfrecord_files:
        return None, None, None

    print(f"  Found {len(tfrecord_files)} file(s)")

    # Load mixer.json
    mixer_path = f"{tfrecord_base_pattern}-mixer.json"
    mixer = None
    if os.path.exists(mixer_path):
        with open(mixer_path, 'r') as f:
            mixer = json.load(f)
        print(f"  ✓ Loaded mixer.json")

    if not mixer or 'patchesPerRow' not in mixer:
        print("  ⚠️ No mixer.json found")
        return None, None, None

    # Get dimensions
    patches_per_row = mixer['patchesPerRow']
    total_patches = mixer['totalPatches']
    num_rows = int(np.ceil(total_patches / patches_per_row))

    print(f"  Grid: {num_rows} rows × {patches_per_row} cols ({total_patches} patches)")

    # Calculate output dimensions
    output_h = num_rows * PATCH_SIZE
    output_w = patches_per_row * PATCH_SIZE

    print(f"  Output size: {output_h} × {output_w} pixels")

    # Calculate bounds
    min_lon, min_lat, max_lon, max_lat = tile_bounds

    actual_pixel_size_lon = (max_lon - min_lon) / output_w
    actual_pixel_size_lat = (max_lat - min_lat) / output_h

    actual_max_lon = min_lon + (output_w * actual_pixel_size_lon)
    actual_max_lat = min_lat + (output_h * actual_pixel_size_lat)

    actual_bounds = [min_lon, min_lat, actual_max_lon, actual_max_lat]

    print(f"  Adjusted bounds: [{actual_bounds[0]:.6f}, {actual_bounds[1]:.6f}, {actual_bounds[2]:.6f}, {actual_bounds[3]:.6f}]")

    # Initialize output arrays (BATHYMETRY: depth instead of probability)
    depth_map = np.full((output_h, output_w), np.nan, dtype=np.float32)
    confidence_map = np.zeros((output_h, output_w), dtype=np.float32)

    # Band names
    band_names_ordered = ['B1', 'B2', 'B3', 'B4', 'B8', 'B8A', 'B11', 'B12']

    # Process patches
    patch_idx = 0
    total_valid_pixels = 0

    for file_idx, tfrecord_file in enumerate(tfrecord_files):
        print(f"  Processing file {file_idx+1}/{len(tfrecord_files)}")

        dataset = tf.data.TFRecordDataset(tfrecord_file, compression_type='GZIP')

        for raw_record in dataset:
            # Parse patch
            example = tf.train.Example()
            example.ParseFromString(raw_record.numpy())
            features = example.features.feature

            # Extract bands
            patch_data = {}
            for band_name in band_names_ordered:
                if band_name in features:
                    values = np.array(features[band_name].float_list.value)
                    patch_data[band_name] = values.reshape(PATCH_SIZE, PATCH_SIZE)

            if len(patch_data) != expected_bands:
                patch_idx += 1
                continue

            # Stack bands
            patch = np.stack([patch_data[bn] for bn in band_names_ordered], axis=-1)

            # Calculate position
            row_idx = patch_idx // patches_per_row
            col_idx = patch_idx % patches_per_row

            start_h = row_idx * PATCH_SIZE
            start_w = col_idx * PATCH_SIZE

            # Reshape to pixels
            pixels = patch.reshape(-1, expected_bands)
            n_pixels = len(pixels)

            # Find valid pixels
            valid_mask = ~np.any(np.isnan(pixels) | (pixels == 0), axis=1)
            valid_indices = np.where(valid_mask)[0]
            n_valid = len(valid_indices)

            if n_valid > 0:
                total_valid_pixels += n_valid

                # Normalize
                if scaler is not None:
                    import warnings
                    with warnings.catch_warnings():
                        warnings.filterwarnings('ignore')
                        valid_features = scaler.transform(pixels[valid_indices])
                else:
                    valid_features = pixels[valid_indices]

                # BATHYMETRY: Predict depths instead of probabilities
                patch_depths = np.full(n_pixels, np.nan, dtype=np.float32)

                # Process in batches
                for start_idx in range(0, n_valid, BATCH_SIZE):
                    end_idx = min(start_idx + BATCH_SIZE, n_valid)
                    batch_indices = valid_indices[start_idx:end_idx]

                    batch_features = valid_features[start_idx:end_idx]

                    # Predict depths
                    batch_depths = model.predict(batch_features, verbose=0).flatten()

                    # Clip to valid range
                    batch_depths = np.clip(batch_depths, DEPTH_MIN, DEPTH_MAX)

                    patch_depths[batch_indices] = batch_depths

                # Reshape and store
                patch_depth_map = patch_depths.reshape(PATCH_SIZE, PATCH_SIZE)
                depth_map[start_h:start_h+PATCH_SIZE, start_w:start_w+PATCH_SIZE] = patch_depth_map

                # Calculate confidence (simple: inverse normalized depth)
                patch_confidence = 1.0 - (patch_depth_map - DEPTH_MIN) / (DEPTH_MAX - DEPTH_MIN)
                patch_confidence = np.nan_to_num(patch_confidence, nan=0.0)
                confidence_map[start_h:start_h+PATCH_SIZE, start_w:start_w+PATCH_SIZE] = patch_confidence

            # Clean up
            del patch, pixels

            patch_idx += 1

            # Progress
            if patch_idx % 50 == 0:
                progress = (patch_idx / total_patches) * 100
                print(f"    Progress: {progress:.1f}% ({patch_idx}/{total_patches})")

            # Memory cleanup
            if patch_idx % 20 == 0:
                gc.collect()

    print(f"  ✓ Processed all {patch_idx} patches")

    # Statistics
    valid_depths = depth_map[~np.isnan(depth_map)]

    if len(valid_depths) > 0:
        print(f"\n  Depth Statistics:")
        print(f"    Valid pixels: {total_valid_pixels:,}")
        print(f"    Min depth:    {valid_depths.min():.2f} m")
        print(f"    Max depth:    {valid_depths.max():.2f} m")
        print(f"    Mean depth:   {valid_depths.mean():.2f} m")
        print(f"    Median depth: {np.median(valid_depths):.2f} m")

    return depth_map, confidence_map, actual_bounds

In [9]:
# =============================
# SAVE GEOTIFF (SAME AS BEFORE)
# =============================

def save_geotiff_aligned(array, output_path, bounds, crs):
    """Save array as GeoTIFF with proper georeferencing"""
    h, w = array.shape

    min_x, min_y, max_x, max_y = bounds

    pixel_width = (max_x - min_x) / w
    pixel_height = (max_y - min_y) / h

    transform = rasterio.transform.from_bounds(
        min_x, min_y, max_x, max_y, w, h
    )

    print(f"    Saving: {os.path.basename(output_path)}")
    print(f"      Size: {w} x {h}")
    print(f"      Bounds: {bounds}")

    with rasterio.open(
        output_path, 'w',
        driver='GTiff',
        height=h,
        width=w,
        count=1,
        dtype=array.dtype,
        crs=crs,
        transform=transform,
        compress='lzw',
        nodata=np.nan if array.dtype == np.float32 else -9999
    ) as dst:
        dst.write(array, 1)

    print(f"    ✓ Saved")


# =============================
# READ TILE MIXER (SAME AS BEFORE)
# =============================

def read_tile_mixer(tile_base_path):
    """Read mixer.json and extract georeferencing info"""
    mixer_path = f"{tile_base_path}-mixer.json"

    if not os.path.exists(mixer_path):
        raise FileNotFoundError(f"mixer.json not found: {mixer_path}")

    with open(mixer_path, 'r') as f:
        mixer = json.load(f)

    # Extract info
    crs = mixer['projection']['crs']
    patch_dims = mixer.get('patchDimensions', [256, 256])
    patches_per_row = mixer.get('patchesPerRow', 0)
    total_patches = mixer.get('totalPatches', 0)

    patches_per_col = total_patches // patches_per_row if patches_per_row > 0 else 0

    # Extract affine transform
    affine_matrix = mixer['projection']['affine']['doubleMatrix']

    scale_x = affine_matrix[0]
    translate_x = affine_matrix[2]
    scale_y = affine_matrix[4]
    translate_y = affine_matrix[5]

    # Calculate bounds
    patch_width_pixels = patch_dims[0]
    patch_height_pixels = patch_dims[1]

    total_width_pixels = patches_per_row * patch_width_pixels
    total_height_pixels = patches_per_col * patch_height_pixels

    min_x = translate_x
    max_y = translate_y
    max_x = min_x + (total_width_pixels * scale_x)
    min_y = max_y + (total_height_pixels * scale_y)

    bounds = [min_x, min_y, max_x, max_y]

    return {
        'crs': crs,
        'mixer': mixer,
        'patch_dims': patch_dims,
        'bounds': bounds,
        'grid_size': (patches_per_row, patches_per_col),
        'pixel_size': (scale_x, abs(scale_y))
    }

In [10]:
# =============================
# MAIN PROCESSING
# =============================

print("\n" + "="*70)
print("DISCOVERING TILES")
print("="*70)

# Find tiles
all_files = glob.glob(f"{TFRECORD_DIR}/*.tfrecord.gz")
tile_bases = set()

for file in all_files:
    basename = os.path.basename(file)
    base = basename.rsplit('-', 1)[0]
    tile_bases.add(os.path.join(TFRECORD_DIR, base))

tile_bases = sorted(tile_bases)
print(f"\nFound {len(tile_bases)} unique tiles")

# Read mixer.json for each
tile_info = []

for tile_base in tile_bases:
    basename = os.path.basename(tile_base)

    # Extract tile number
    import re
    match = re.search(r'tile[_-](\d+)', basename)
    tile_num = int(match.group(1)) if match else None

    try:
        mixer_data = read_tile_mixer(tile_base)

        tile_info.append({
            'base': tile_base,
            'number': tile_num,
            'bounds': mixer_data['bounds'],
            'crs': mixer_data['crs'],
            'mixer': mixer_data['mixer'],
            'grid_size': mixer_data['grid_size'],
            'pixel_size': mixer_data['pixel_size']
        })

        print(f"\n  Tile {tile_num}: {basename}")
        print(f"    CRS: {mixer_data['crs']}")
        print(f"    Grid: {mixer_data['grid_size'][0]} x {mixer_data['grid_size'][1]}")
        print(f"    Bounds: {mixer_data['bounds']}")

    except Exception as e:
        print(f"\n  ⚠ Error: {e}")
        continue

print(f"\n✓ Loaded {len(tile_info)} tiles with georeferencing")

if len(tile_info) == 0:
    raise ValueError("No valid tiles found!")


DISCOVERING TILES

Found 81 unique tiles

  Tile 28: S2_composite_2025_tile_028
    CRS: EPSG:4326
    Grid: 23 x 23
    Bounds: [117.46460486675275, 7.720301214779991, 117.99353290604232, 8.249229254069565]

  Tile 29: S2_composite_2025_tile_029
    CRS: EPSG:4326
    Grid: 23 x 23
    Bounds: [117.46460486675275, 8.220303501920917, 117.99353290604232, 8.749231541210492]

  Tile 30: S2_composite_2025_tile_030
    CRS: EPSG:4326
    Grid: 23 x 23
    Bounds: [117.46460486675275, 8.720305789061843, 117.99353290604232, 9.249233828351418]

  Tile 31: S2_composite_2025_tile_031
    CRS: EPSG:4326
    Grid: 23 x 23
    Bounds: [117.46460486675275, 9.220308076202768, 117.99353290604232, 9.749236115492343]

  Tile 32: S2_composite_2025_tile_032
    CRS: EPSG:4326
    Grid: 23 x 23
    Bounds: [117.46460486675275, 9.720310363343694, 117.99353290604232, 10.249238402633269]

  Tile 33: S2_composite_2025_tile_033
    CRS: EPSG:4326
    Grid: 23 x 23
    Bounds: [117.46460486675275, 10.2203126504

In [None]:
# =============================
# PROCESS ALL TILES
# =============================

print("\n" + "="*70)
print("PROCESSING TILES FOR BATHYMETRY")
print("="*70)

results = []

for idx, tile_data in enumerate(tile_info):
    tile_base = tile_data['base']
    bounds = tile_data['bounds']
    tile_num = tile_data['number']
    crs = tile_data['crs']

    print(f"\n{'='*70}")
    print(f"TILE {tile_num} ({idx+1}/{len(tile_info)})")
    print(f"Base: {os.path.basename(tile_base)}")
    print(f"CRS: {crs}")
    print(f"Bounds: {bounds}")
    print('='*70)

    tile_name = f"tile_{tile_num}"

    try:
        # Process tile
        depth_map, confidence_map, actual_bounds = process_tfrecord_streaming_bathymetry(
            tile_base, model, scaler, bounds, EXPECTED_BANDS
        )

        if depth_map is None:
            print("  ⚠️ Failed to process")
            continue

        # Use actual bounds
        if actual_bounds:
            bounds = actual_bounds

        # Statistics
        valid_depths = depth_map[~np.isnan(depth_map)]
        n_valid = len(valid_depths)

        if n_valid > 0:
            mean_depth = valid_depths.mean()
            min_depth = valid_depths.min()
            max_depth = valid_depths.max()
            median_depth = np.median(valid_depths)

            print(f"\n  Tile Statistics:")
            print(f"    Valid pixels: {n_valid:,}")
            print(f"    Mean depth: {mean_depth:.2f} m")
            print(f"    Min depth: {min_depth:.2f} m")
            print(f"    Max depth: {max_depth:.2f} m")

            results.append({
                'tile': tile_num,
                'valid_pixels': n_valid,
                'mean_depth_m': mean_depth,
                'min_depth_m': min_depth,
                'max_depth_m': max_depth,
                'median_depth_m': median_depth
            })

        # Save outputs
        print(f"\n  Saving outputs...")
        save_geotiff_aligned(
            depth_map,
            f"{OUTPUT_DIR}/{tile_name}_depth.tif",
            bounds,
            crs
        )
        save_geotiff_aligned(
            confidence_map,
            f"{OUTPUT_DIR}/{tile_name}_confidence.tif",
            bounds,
            crs
        )

        # Create visualization
        print(f"  Creating visualization...")

        fig, axes = plt.subplots(1, 2, figsize=(14, 6))

        # Depth map
        im1 = axes[0].imshow(depth_map, cmap='viridis_r', vmin=DEPTH_MIN, vmax=DEPTH_MAX)
        axes[0].set_title(f'Bathymetry (Tile {tile_num})', fontsize=14)
        axes[0].axis('off')
        cbar1 = plt.colorbar(im1, ax=axes[0], fraction=0.046)
        cbar1.set_label('Depth (m)', rotation=270, labelpad=20)

        # Confidence map
        im2 = axes[1].imshow(confidence_map, cmap='RdYlGn', vmin=0, vmax=1)
        axes[1].set_title('Confidence', fontsize=14)
        axes[1].axis('off')
        cbar2 = plt.colorbar(im2, ax=axes[1], fraction=0.046)
        cbar2.set_label('Confidence', rotation=270, labelpad=20)

        if n_valid > 0:
            plt.suptitle(f'Mean Depth: {mean_depth:.2f}m | Range: {min_depth:.2f}-{max_depth:.2f}m',
                        fontsize=12)

        plt.tight_layout()
        plt.savefig(f"{OUTPUT_DIR}/{tile_name}_result.png", dpi=150, bbox_inches='tight')
        plt.close()

        print(f"  ✓ Saved all outputs")

        # Clean up
        del depth_map, confidence_map
        gc.collect()

    except Exception as e:
        print(f"  ✗ Error: {e}")
        import traceback
        traceback.print_exc()
        gc.collect()


PROCESSING TILES FOR BATHYMETRY

TILE 28 (1/81)
Base: S2_composite_2025_tile_028
CRS: EPSG:4326
Bounds: [117.46460486675275, 7.720301214779991, 117.99353290604232, 8.249229254069565]

Streaming processing: S2_composite_2025_tile_028
  Found 11 file(s)
  ✓ Loaded mixer.json
  Grid: 23 rows × 23 cols (529 patches)
  Output size: 5888 × 5888 pixels
  Adjusted bounds: [117.464605, 7.720301, 117.993533, 8.249229]
  Processing file 1/11
  Processing file 2/11
    Progress: 9.5% (50/529)
  Processing file 3/11
    Progress: 18.9% (100/529)
  Processing file 4/11
    Progress: 28.4% (150/529)
  Processing file 5/11
    Progress: 37.8% (200/529)
  Processing file 6/11
    Progress: 47.3% (250/529)
  Processing file 7/11
    Progress: 56.7% (300/529)
  Processing file 8/11
    Progress: 66.2% (350/529)
  Processing file 9/11
    Progress: 75.6% (400/529)
  Processing file 10/11
    Progress: 85.1% (450/529)
  Processing file 11/11
    Progress: 94.5% (500/529)
  ✓ Processed all 529 patches

  S