In [19]:
import os
import glob
import copy

import numpy as np
import rasterio
from tqdm import tqdm

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader

# ==============================
# USER SETTINGS
# ==============================

# Folder with the GEE-exported composite tiles
TILES_DIR = r"C:\Users\hvt632\Presto_embedded_model\input_img"

# Output folder for embedding rasters
OUT_DIR = os.path.join(TILES_DIR, "embeddings_64d")

# Embedding model parameters
EMBED_DIM   = 64       # size of embedding vector (can change to 128, etc.)
HIDDEN_DIM  = 256      # hidden layer size in autoencoder

# Training parameters
N_TRAIN_PIXELS   = 200_000   # total number of pixels to sample for training
BATCH_SIZE_TRAIN = 1024
N_EPOCHS         = 50

# Inference parameters
BATCH_SIZE_INFER = 4096

# Use GPU if available
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Random seed (optional, for reproducibility)
np.random.seed(42)
torch.manual_seed(42)

# ==============================
# 1. LIST TILES
# ==============================

os.makedirs(OUT_DIR, exist_ok=True)

tile_paths = sorted(glob.glob(os.path.join(TILES_DIR, "*.tif")))
if not tile_paths:
    raise RuntimeError(f"No .tif files found in {TILES_DIR}")

print(f"Found {len(tile_paths)} tiles.")

Found 10 tiles.


In [20]:

# ==============================
# 2. COLLECT TRAINING PIXELS
# ==============================

def clean_pixel_matrix(data):
    """
    data: (N_pixels, n_bands) float32
    - remove non-finite rows (NaN/inf)
    - remove rows that are all zero
    """
    # Ensure float32
    data = data.astype(np.float32)

    # Non-finite → drop
    finite_mask = np.all(np.isfinite(data), axis=1)

    # All-zero → drop (likely outside ROI)
    nonzero_mask = ~np.all(data == 0, axis=1)

    mask = finite_mask & nonzero_mask
    return data[mask], mask


def collect_training_samples(tile_paths, n_train_pixels):
    """
    Iterates over tiles and collects up to n_train_pixels pixel vectors for training.
    """
    samples = []
    total = 0

    print(f"\nCollecting up to {n_train_pixels} training pixels...")

    for tile_path in tile_paths:
        print(f"  Reading {tile_path}")
        with rasterio.open(tile_path) as src:
            arr = src.read()  # shape: (bands, H, W)

        bands, H, W = arr.shape
        data = arr.reshape(bands, -1).T  # (N_pixels, bands)

        data_clean, mask = clean_pixel_matrix(data)
        if data_clean.shape[0] == 0:
            continue

        if total + data_clean.shape[0] <= n_train_pixels:
            samples.append(data_clean)
            total += data_clean.shape[0]
        else:
            need = n_train_pixels - total
            idx = np.random.choice(data_clean.shape[0], size=need, replace=False)
            samples.append(data_clean[idx])
            total += need
            break

        if total >= n_train_pixels:
            break

    if not samples:
        raise RuntimeError("No valid pixels collected for training.")

    X_train = np.vstack(samples)
    print(f"Collected {X_train.shape[0]} pixels for training, "
          f"feature dimension = {X_train.shape[1]}")
    return X_train


X_train = collect_training_samples(tile_paths, N_TRAIN_PIXELS)

# ==============================
# 3. PER-FEATURE NORMALIZATION
# ==============================

# Compute mean & std per feature (ignore NaNs just in case)
feat_mean = np.nanmean(X_train, axis=0)
feat_std  = np.nanstd(X_train, axis=0)

# Prevent division by zero
feat_std[feat_std == 0] = 1.0

# Normalize training data
X_train_norm = (X_train - feat_mean) / feat_std

print("\nFeature stats:")
print("  mean (first 5):", feat_mean[:5])
print("  std  (first 5):", feat_std[:5])

# ==============================
# 4. DATASET & MODEL
# ==============================

class PixelDataset(Dataset):
    def __init__(self, X):
        # X: numpy array (N, D)
        self.X = torch.from_numpy(X)

    def __len__(self):
        return self.X.shape[0]

    def __getitem__(self, idx):
        return self.X[idx]


input_dim = X_train_norm.shape[1]

class Autoencoder(nn.Module):
    def __init__(self, input_dim, embed_dim, hidden_dim):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, embed_dim)
        )
        self.decoder = nn.Sequential(
            nn.Linear(embed_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, input_dim)
        )

    def forward(self, x):
        z = self.encoder(x)
        x_rec = self.decoder(z)
        return x_rec, z

model = Autoencoder(input_dim, EMBED_DIM, HIDDEN_DIM).to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
loss_fn = nn.MSELoss()


Collecting up to 200000 training pixels...
  Reading C:\Users\hvt632\Presto_embedded_model\input_img\presto_input_SK_2019_2020_tile_0.tif
  Reading C:\Users\hvt632\Presto_embedded_model\input_img\presto_input_SK_2019_2020_tile_1.tif
  Reading C:\Users\hvt632\Presto_embedded_model\input_img\presto_input_SK_2019_2020_tile_2.tif
  Reading C:\Users\hvt632\Presto_embedded_model\input_img\presto_input_SK_2019_2020_tile_3.tif
  Reading C:\Users\hvt632\Presto_embedded_model\input_img\presto_input_SK_2019_2020_tile_4.tif
  Reading C:\Users\hvt632\Presto_embedded_model\input_img\presto_input_SK_2019_2020_tile_5.tif
  Reading C:\Users\hvt632\Presto_embedded_model\input_img\presto_input_SK_2019_2020_tile_6.tif
  Reading C:\Users\hvt632\Presto_embedded_model\input_img\presto_input_SK_2019_2020_tile_7.tif
  Reading C:\Users\hvt632\Presto_embedded_model\input_img\presto_input_SK_2019_2020_tile_8.tif
  Reading C:\Users\hvt632\Presto_embedded_model\input_img\presto_input_SK_2019_2020_tile_9.tif
Collec

In [21]:
# ==============================
# 5. TRAIN AUTOENCODER
# ==============================

print("\nTraining autoencoder...")
train_ds = PixelDataset(X_train_norm)
train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE_TRAIN, shuffle=True)

for epoch in range(N_EPOCHS):
    model.train()
    total_loss = 0.0
    for batch in train_dl:
        batch = batch.to(DEVICE)
        optimizer.zero_grad()
        recon, z = model(batch)
        loss = loss_fn(recon, batch)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * batch.size(0)

    avg_loss = total_loss / len(train_ds)
    print(f"Epoch {epoch+1}/{N_EPOCHS} - Loss: {avg_loss:.6f}")

# Save model + normalization stats
model_path = os.path.join(TILES_DIR, f"ae_model_{EMBED_DIM}d_norm.pth")
torch.save({
    "state_dict": model.state_dict(),
    "input_dim": input_dim,
    "embed_dim": EMBED_DIM,
    "hidden_dim": HIDDEN_DIM,
    "feat_mean": feat_mean,
    "feat_std": feat_std
}, model_path)
print(f"Saved model to {model_path}")

# ==============================
# 6. APPLY MODEL TO EACH TILE
# ==============================

def compute_embeddings_for_tile(tile_path, model, feat_mean, feat_std, out_dir):
    """
    Loads one composite tile, encodes all valid pixels into embeddings,
    and writes an embedding raster with EMBED_DIM bands.
    Uses the same normalization (feat_mean/std) as training.
    """
    base = os.path.basename(tile_path)
    name, ext = os.path.splitext(base)
    out_path = os.path.join(out_dir, f"{name}_embed_{EMBED_DIM}d.tif")

    if os.path.exists(out_path):
        print(f"  [SKIP] {out_path} already exists")
        return

    print(f"  Processing tile: {tile_path}")
    with rasterio.open(tile_path) as src:
        arr = src.read().astype(np.float32)  # (bands, H, W)
        profile = src.profile

    bands, H, W = arr.shape
    data = arr.reshape(bands, -1).T  # (N_pixels, bands)

    # Clean as in training
    data_clean, mask_valid = clean_pixel_matrix(data)
    n_valid = data_clean.shape[0]
    print(f"    Valid pixels: {n_valid} / {data.shape[0]}")

    if n_valid == 0:
        print("    No valid pixels, skipping.")
        return

    # Normalize using training stats
    data_norm = (data_clean - feat_mean) / feat_std

    # Compute embeddings
    model.eval()
    emb_valid = np.zeros((n_valid, EMBED_DIM), dtype=np.float32)

    with torch.no_grad():
        for i in range(0, n_valid, BATCH_SIZE_INFER):
            batch_np = data_norm[i:i+BATCH_SIZE_INFER]
            batch = torch.from_numpy(batch_np).to(DEVICE)
            z = model.encoder(batch)
            emb_valid[i:i+BATCH_SIZE_INFER] = z.cpu().numpy()

    # Fill full grid with 0 for invalid pixels (can treat as nodata later)
    emb_full = np.zeros((data.shape[0], EMBED_DIM), dtype=np.float32)
    emb_full[mask_valid] = emb_valid

    # Reshape to raster format: (bands, H, W)
    emb_raster = emb_full.T.reshape(EMBED_DIM, H, W)

    # Output profile
    out_profile = copy.deepcopy(profile)
    out_profile.update({
        "count": EMBED_DIM,
        "dtype": "float32",
        "nodata": 0.0    # all-zero vector = invalid pixel
    })

    with rasterio.open(out_path, "w", **out_profile) as dst:
        dst.write(emb_raster)

    print(f"    Saved embedding tile: {out_path}")


print("\nApplying encoder to all tiles...")
for tile_path in tqdm(tile_paths):
    compute_embeddings_for_tile(tile_path, model, feat_mean, feat_std, OUT_DIR)

print("\nDONE. All embedding tiles saved in:")
print(OUT_DIR)



Training autoencoder...
Epoch 1/50 - Loss: 0.276827
Epoch 2/50 - Loss: 0.075764
Epoch 3/50 - Loss: 0.050439
Epoch 4/50 - Loss: 0.037159
Epoch 5/50 - Loss: 0.028715
Epoch 6/50 - Loss: 0.023744
Epoch 7/50 - Loss: 0.020185
Epoch 8/50 - Loss: 0.017397
Epoch 9/50 - Loss: 0.015361
Epoch 10/50 - Loss: 0.013710
Epoch 11/50 - Loss: 0.012498
Epoch 12/50 - Loss: 0.011516
Epoch 13/50 - Loss: 0.010773
Epoch 14/50 - Loss: 0.010152
Epoch 15/50 - Loss: 0.009529
Epoch 16/50 - Loss: 0.009048
Epoch 17/50 - Loss: 0.008634
Epoch 18/50 - Loss: 0.008178
Epoch 19/50 - Loss: 0.007875
Epoch 20/50 - Loss: 0.007437
Epoch 21/50 - Loss: 0.007251
Epoch 22/50 - Loss: 0.006921
Epoch 23/50 - Loss: 0.006789
Epoch 24/50 - Loss: 0.006570
Epoch 25/50 - Loss: 0.006460
Epoch 26/50 - Loss: 0.006255
Epoch 27/50 - Loss: 0.006143
Epoch 28/50 - Loss: 0.006056
Epoch 29/50 - Loss: 0.005895
Epoch 30/50 - Loss: 0.005918
Epoch 31/50 - Loss: 0.005713
Epoch 32/50 - Loss: 0.005633
Epoch 33/50 - Loss: 0.005547
Epoch 34/50 - Loss: 0.00550

  0%|                                                                                           | 0/10 [00:00<?, ?it/s]

  Processing tile: C:\Users\hvt632\Presto_embedded_model\input_img\presto_input_SK_2019_2020_tile_0.tif
    Valid pixels: 8255 / 250000


 10%|████████▎                                                                          | 1/10 [00:01<00:09,  1.02s/it]

    Saved embedding tile: C:\Users\hvt632\Presto_embedded_model\input_img\embeddings_64d\presto_input_SK_2019_2020_tile_0_embed_64d.tif
  Processing tile: C:\Users\hvt632\Presto_embedded_model\input_img\presto_input_SK_2019_2020_tile_1.tif
    Valid pixels: 9999 / 250000


 20%|████████████████▌                                                                  | 2/10 [00:02<00:08,  1.01s/it]

    Saved embedding tile: C:\Users\hvt632\Presto_embedded_model\input_img\embeddings_64d\presto_input_SK_2019_2020_tile_1_embed_64d.tif
  Processing tile: C:\Users\hvt632\Presto_embedded_model\input_img\presto_input_SK_2019_2020_tile_2.tif
    Valid pixels: 10736 / 250000


 30%|████████████████████████▉                                                          | 3/10 [00:03<00:07,  1.01s/it]

    Saved embedding tile: C:\Users\hvt632\Presto_embedded_model\input_img\embeddings_64d\presto_input_SK_2019_2020_tile_2_embed_64d.tif
  Processing tile: C:\Users\hvt632\Presto_embedded_model\input_img\presto_input_SK_2019_2020_tile_3.tif
    Valid pixels: 11199 / 250000


 40%|█████████████████████████████████▏                                                 | 4/10 [00:04<00:06,  1.01s/it]

    Saved embedding tile: C:\Users\hvt632\Presto_embedded_model\input_img\embeddings_64d\presto_input_SK_2019_2020_tile_3_embed_64d.tif
  Processing tile: C:\Users\hvt632\Presto_embedded_model\input_img\presto_input_SK_2019_2020_tile_4.tif
    Valid pixels: 11500 / 250000


 50%|█████████████████████████████████████████▌                                         | 5/10 [00:05<00:05,  1.01s/it]

    Saved embedding tile: C:\Users\hvt632\Presto_embedded_model\input_img\embeddings_64d\presto_input_SK_2019_2020_tile_4_embed_64d.tif
  Processing tile: C:\Users\hvt632\Presto_embedded_model\input_img\presto_input_SK_2019_2020_tile_5.tif
    Valid pixels: 11500 / 250000


 60%|█████████████████████████████████████████████████▊                                 | 6/10 [00:06<00:04,  1.01s/it]

    Saved embedding tile: C:\Users\hvt632\Presto_embedded_model\input_img\embeddings_64d\presto_input_SK_2019_2020_tile_5_embed_64d.tif
  Processing tile: C:\Users\hvt632\Presto_embedded_model\input_img\presto_input_SK_2019_2020_tile_6.tif
    Valid pixels: 11199 / 250000


 70%|██████████████████████████████████████████████████████████                         | 7/10 [00:07<00:03,  1.01s/it]

    Saved embedding tile: C:\Users\hvt632\Presto_embedded_model\input_img\embeddings_64d\presto_input_SK_2019_2020_tile_6_embed_64d.tif
  Processing tile: C:\Users\hvt632\Presto_embedded_model\input_img\presto_input_SK_2019_2020_tile_7.tif
    Valid pixels: 10736 / 250000


 80%|██████████████████████████████████████████████████████████████████▍                | 8/10 [00:08<00:02,  1.01s/it]

    Saved embedding tile: C:\Users\hvt632\Presto_embedded_model\input_img\embeddings_64d\presto_input_SK_2019_2020_tile_7_embed_64d.tif
  Processing tile: C:\Users\hvt632\Presto_embedded_model\input_img\presto_input_SK_2019_2020_tile_8.tif
    Valid pixels: 9999 / 250000


 90%|██████████████████████████████████████████████████████████████████████████▋        | 9/10 [00:09<00:01,  1.00s/it]

    Saved embedding tile: C:\Users\hvt632\Presto_embedded_model\input_img\embeddings_64d\presto_input_SK_2019_2020_tile_8_embed_64d.tif
  Processing tile: C:\Users\hvt632\Presto_embedded_model\input_img\presto_input_SK_2019_2020_tile_9.tif
    Valid pixels: 9037 / 250000


100%|██████████████████████████████████████████████████████████████████████████████████| 10/10 [00:10<00:00,  1.01s/it]

    Saved embedding tile: C:\Users\hvt632\Presto_embedded_model\input_img\embeddings_64d\presto_input_SK_2019_2020_tile_9_embed_64d.tif

DONE. All embedding tiles saved in:
C:\Users\hvt632\Presto_embedded_model\input_img\embeddings_64d



