In [1]:
import os
import glob
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
import rasterio
import numpy as np
from sklearn.metrics import f1_score, cohen_kappa_score
import timm

########################################
# Configuration
########################################
landsat_dir = '/landsat_model'  # contains sceneName_B1.tif ... sceneName_B6.tif
soil_dir = '/soilgrid_model'        # contains sceneName_B1.tif for soil
batch_size = 8
learning_rate = 1e-4
num_epochs = 10
threshold = None  # Will set this after we have training data stats

########################################
# Dataset
########################################
class SoilCarbonDataset(Dataset):
    def __init__(self, landsat_dir, soil_dir, scene_list):
        self.landsat_dir = landsat_dir
        self.soil_dir = soil_dir
        self.scenes = scene_list

    def __len__(self):
        return len(self.scenes)

    def __getitem__(self, idx):
        scene = self.scenes[idx]

        # Read the 6 Landsat bands
        # Landsat bands
        landsat_bands = []
        landsat_files = sorted(glob.glob(os.path.join(self.landsat_dir, f"{scene}*B[1-6].tif")))
        for band_file in landsat_files:
            with rasterio.open(band_file) as src:
                band_data = src.read(1)  # shape: (80,80)
                landsat_bands.append(band_data)
        landsat_img = np.stack(landsat_bands, axis=0)  # shape: (6,80,80)

        # Soil data (single band)
        soil_file = glob.glob(os.path.join(self.soil_dir, f"{scene}*B1.tif"))[0]
        with rasterio.open(soil_file) as src:
            soil_data = src.read(1)  # shape: (80,80) or whatever size
        soil_val = np.mean(soil_data)

        landsat_img = torch.from_numpy(landsat_img).float()  # (6,80,80)
        soil_val = torch.tensor(soil_val).float()

        return landsat_img, soil_val

########################################
# Prepare Data
########################################
# Assume each scene is identified by a unique base name (e.g. scene_001)
# We'll find all scenes by listing Landsat directory and extracting the base names.
landsat_files = glob.glob(os.path.join(landsat_dir, "*_B1.tif"))
scenes = [os.path.basename(f).replace("_B1.tif","") for f in landsat_files]

dataset = SoilCarbonDataset(landsat_dir, soil_dir, scenes)

# Split into train/val
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

# Compute threshold for binary classification metrics (F1, Kappa)
# We'll use median soil value from the training set
all_train_soil_vals = []
for i in range(len(train_dataset)):
    # get train dataset indices
    landsat_img, soil_val = train_dataset[i]
    all_train_soil_vals.append(soil_val.item())
threshold = np.median(all_train_soil_vals)
print(f"Using threshold={threshold} for F1 and Kappa computation.")

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

########################################
# Model
########################################
# Create ViT model from timm and adapt first layer for 6 channels and last layer for regression
model = timm.create_model('vit_small_patch16_224', pretrained=False, num_classes=1)

# Adapt input conv layer to 6 channels (instead of 3)
original_conv = model.patch_embed.proj
model.patch_embed.proj = nn.Conv2d(6, original_conv.out_channels, 
                                   kernel_size=original_conv.kernel_size,
                                   stride=original_conv.stride,
                                   padding=original_conv.padding)

# The model expects 224x224 input. We have 80x80.
# For simplicity, resize input in the forward pass (quick hack).
# A better solution: modify model.patch_embed.patch_size or use another model more suited to 80x80.
resize = nn.Upsample(size=(224,224), mode='bilinear', align_corners=False)

model = model.cuda()

########################################
# Loss and Optimizer
########################################
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

########################################
# Utility functions for metrics
########################################
def compute_metrics(y_true, y_pred, threshold):
    # y_true, y_pred are arrays of floats
    # Convert to binary
    y_true_bin = (y_true >= threshold).astype(int)
    y_pred_bin = (y_pred >= threshold).astype(int)

    mse = np.mean((y_true - y_pred)**2)
    f1 = f1_score(y_true_bin, y_pred_bin)
    kappa = cohen_kappa_score(y_true_bin, y_pred_bin)

    return mse, f1, kappa

########################################
# Training Loop
########################################
for epoch in range(num_epochs):
    model.train()
    for landsat, soil_val in train_loader:
        landsat = landsat.cuda()  # (B,6,80,80)
        soil_val = soil_val.cuda()

        optimizer.zero_grad()
        landsat_resized = resize(landsat)  # (B,6,224,224)
        preds = model(landsat_resized)  # (B,1)
        loss = criterion(preds.squeeze(), soil_val)
        loss.backward()
        optimizer.step()

    # Validation
    model.eval()
    all_val_true = []
    all_val_pred = []
    with torch.no_grad():
        for landsat, soil_val in val_loader:
            landsat = landsat.cuda()
            soil_val = soil_val.cuda()
            landsat_resized = resize(landsat)
            preds = model(landsat_resized)
            
            all_val_true.append(soil_val.cpu().numpy())
            all_val_pred.append(preds.squeeze().cpu().numpy())

    all_val_true = np.concatenate(all_val_true)
    all_val_pred = np.concatenate(all_val_pred)
    mse_val, f1_val, kappa_val = compute_metrics(all_val_true, all_val_pred, threshold)

    print(f"Epoch {epoch+1}/{num_epochs}: Val MSE={mse_val:.4f}, F1={f1_val:.4f}, Kappa={kappa_val:.4f}")


ModuleNotFoundError: No module named 'torch'