<a href="https://colab.research.google.com/github/zrmondsc/gedi_waveform_processor/blob/main/notebooks/cnn_regression_3x3_sentinel_to_latent_gedi.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## 1. Authentication

In [1]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

In [2]:
from google.colab import auth
auth.authenticate_user()

## 2. Load and parse tfrecord.gz dataset

In [3]:
# Load dataset from Google Cloud Storage
tfrecord_path = 'gs://ee-gedi-data/tfrecords/gedi_latent_patches_3x3.tfrecord.gz'
raw_dataset = tf.data.TFRecordDataset(tfrecord_path, compression_type='GZIP')

In [8]:
# Sentinel bands and patch size
BANDS = ['VV', 'VH', 'B2', 'B3', 'B4', 'B5', 'B6',
         'B7', 'B8', 'B8A', 'B11', 'B12']
LATENT_KEYS = [f"latent_{i}" for i in range(8)]
PATCH_SIZE = 3

# Feature schema
feature_description = {
    f"{band}_patch": tf.io.FixedLenFeature([PATCH_SIZE * PATCH_SIZE], tf.float32)
    for band in BANDS
}
feature_description.update({
    key: tf.io.FixedLenFeature([], tf.float32) for key in LATENT_KEYS
})

# Parser function
def parse_example(example_proto):
    example = tf.io.parse_single_example(example_proto, feature_description)

    # Reconstruct patch: [3, 3, 12]
    patch = tf.stack([
        tf.reshape(example[f"{band}_patch"], [PATCH_SIZE, PATCH_SIZE])
        for band in BANDS
    ], axis=-1)

    # Combine latents into [8] vector
    latents = tf.stack([example[key] for key in LATENT_KEYS])

    return patch, latents

In [9]:
# Parse dataset
parsed_dataset = raw_dataset.map(parse_example)
parsed_dataset

<_MapDataset element_spec=(TensorSpec(shape=(3, 3, 12), dtype=tf.float32, name=None), TensorSpec(shape=(8,), dtype=tf.float32, name=None))>

## 3. Split data into training and testing datasets

In [12]:
import numpy as np
from tqdm import tqdm

X, y = [], []

# Convert tensors to numpy arrays
for patch, latents in tqdm(parsed_dataset):
    X.append(patch.numpy())
    y.append(latents.numpy())

X = np.array(X)
y = np.array(y)

10546it [00:06, 1540.05it/s]


In [16]:
from sklearn.model_selection import train_test_split

# Split numpy arrays with sk-learn
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.3, random_state=42
)

print(f"X_train shape: {X_train.shape}")
print(f"X_test shape: {X_test.shape}")
print(f"y_train shape: {y_train.shape}")
print(f"y_test shape: {y_test.shape}")

X_train shape: (7382, 3, 3, 12)
X_test shape: (3164, 3, 3, 12)
y_train shape: (7382, 8)
y_test shape: (3164, 8)


In [17]:
# Convert back to a tensor and batch
train_ds = tf.data.Dataset.from_tensor_slices((X_train, y_train)).batch(64).prefetch(tf.data.AUTOTUNE)
val_ds = tf.data.Dataset.from_tensor_slices((X_test, y_test)).batch(64).prefetch(tf.data.AUTOTUNE)

## 4. Build and compile CNN

In [31]:
from tensorflow.keras import layers, models

def build_model():
    model = models.Sequential([
        layers.Input(shape=(3, 3, 12)),         # your Sentinel patch
        layers.Conv2D(64, (2, 2), activation='relu'),
        layers.Flatten(),
        layers.Dense(64, activation='relu'),
        layers.Dense(8)  # 8D output = GEDI latent vector
    ])
    return model

model = build_model()
model.summary()

In [32]:
model.compile(
    optimizer='adam',
    loss='mse',
    metrics=['mae']
)

In [33]:
history = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=150
)

Epoch 1/150
[1m116/116[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 4ms/step - loss: 49344.4531 - mae: 114.8361 - val_loss: 711.0732 - val_mae: 20.0155
Epoch 2/150
[1m116/116[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3ms/step - loss: 640.6285 - mae: 18.8920 - val_loss: 402.0988 - val_mae: 15.0487
Epoch 3/150
[1m116/116[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3ms/step - loss: 376.3176 - mae: 14.5021 - val_loss: 290.9688 - val_mae: 12.8686
Epoch 4/150
[1m116/116[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 4ms/step - loss: 271.5208 - mae: 12.3785 - val_loss: 216.9597 - val_mae: 11.1250
Epoch 5/150
[1m116/116[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3ms/step - loss: 210.5602 - mae: 10.9553 - val_loss: 172.9533 - val_mae: 9.9710
Epoch 6/150
[1m116/116[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 4ms/step - loss: 170.6009 - mae: 9.8851 - val_loss: 144.5483 - val_mae: 9.1270
Epoch 7/150
[1m116/116[0m [32m━━━━━━━━━━━━━

In [None]:
import matplotlib.pyplot as plt

plt.plot(history.history['loss'], label='Train Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('MSE')
plt.title('Training vs Validation Loss')
plt.grid(True)
plt.legend()
plt.show()