# CSIRO Biomass TPU Baseline

TPU-friendly notebook that trains a TensorFlow model end-to-end so we can compare TPU vs GPU performance without relying on GPU-only libraries (xgboost, lightgbm, catboost, timm, etc.).

In [None]:
import os
import math
import random
from dataclasses import dataclass
from pathlib import Path

import numpy as np
import pandas as pd
import tensorflow as tf


In [None]:
# TPU strategy --------------------------------------------------------------
try:
    resolver = tf.distribute.cluster_resolver.TPUClusterResolver.connect()
    strategy = tf.distribute.TPUStrategy(resolver)
    print(f"Running on TPU: {resolver.master()}")
except ValueError:
    strategy = tf.distribute.get_strategy()
    print("TPU not found, falling back to default strategy")

def seed_all(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    tf.random.set_seed(seed)

@dataclass
class CFG:
    data_root: Path = Path("/kaggle/input/csiro-biomass")
    image_size: int = 448
    batch_size: int = 16
    epochs: int = 5
    val_split: float = 0.15
    seed: int = 42

TARGET_NAMES = ["Dry_Clover_g", "Dry_Dead_g", "Dry_Green_g", "Dry_Total_g", "GDM_g"]
seed_all(CFG.seed)


In [None]:
# Load metadata -------------------------------------------------------------
train_csv = CFG.data_root / "train.csv"
test_csv = CFG.data_root / "test.csv"
train_df = pd.read_csv(train_csv)
test_df = pd.read_csv(test_csv)

def pivot_targets(df: pd.DataFrame) -> pd.DataFrame:
    if "target" in df.columns:
        wide = df.pivot_table(
            index=["image_path", "Sampling_Date", "State", "Species", "Pre_GSHH_NDVI", "Height_Ave_cm"],
            columns="target_name",
            values="target",
            aggfunc="mean",
        ).reset_index()
    else:
        df = df.copy()
        df["dummy"] = 0.0
        wide = df.pivot_table(index="image_path", columns="target_name", values="dummy", aggfunc="mean").reset_index()
    return wide

train_wide = pivot_targets(train_df)
test_wide = pivot_targets(test_df)
print(f"Train records: {len(train_wide)}, Test records: {len(test_wide)}")


In [None]:
# Train/validation split ----------------------------------------------------
from sklearn.model_selection import train_test_split

def make_image_path(rel_path: str) -> str:
    return str(CFG.data_root / rel_path)

train_wide["abs_path"] = train_wide["image_path"].apply(make_image_path)
test_wide["abs_path"] = test_wide["image_path"].apply(make_image_path)

train_idx, val_idx = train_test_split(
    np.arange(len(train_wide)),
    test_size=CFG.val_split,
    random_state=CFG.seed,
    shuffle=True,
)
train_meta = train_wide.iloc[train_idx].reset_index(drop=True)
val_meta = train_wide.iloc[val_idx].reset_index(drop=True)
print(f"Train split: {len(train_meta)} | Val split: {len(val_meta)}")


In [None]:
# tf.data pipelines ---------------------------------------------------------
AUTOTUNE = tf.data.AUTOTUNE

def decode_image(path: tf.Tensor) -> tf.Tensor:
    img = tf.io.read_file(path)
    img = tf.image.decode_jpeg(img, channels=3)
    img = tf.image.resize(img, (CFG.image_size, CFG.image_size))
    img = tf.cast(img, tf.float32) / 255.0
    return img

def augment(img: tf.Tensor) -> tf.Tensor:
    img = tf.image.random_flip_left_right(img)
    img = tf.image.random_flip_up_down(img)
    return img

def build_dataset(meta_df: pd.DataFrame, training: bool) -> tf.data.Dataset:
    paths = meta_df["abs_path"].values
    targets = meta_df[TARGET_NAMES].values if training else None
    path_ds = tf.data.Dataset.from_tensor_slices(paths)
    img_ds = path_ds.map(decode_image, num_parallel_calls=AUTOTUNE)
    if training:
        img_ds = img_ds.map(lambda x: augment(x), num_parallel_calls=AUTOTUNE)
    if training:
        target_ds = tf.data.Dataset.from_tensor_slices(targets.astype("float32"))
        ds = tf.data.Dataset.zip((img_ds, target_ds))
    else:
        ds = img_ds.map(lambda x: (x,))
    if training:
        ds = ds.shuffle(2048, seed=CFG.seed)
    ds = ds.batch(CFG.batch_size).prefetch(AUTOTUNE)
    return ds

train_ds = build_dataset(train_meta, training=True)
val_ds = build_dataset(val_meta, training=True)
test_ds = build_dataset(test_wide, training=False)


In [None]:
# Model definition ----------------------------------------------------------
with strategy.scope():
    base = tf.keras.applications.EfficientNetV2S(
        include_top=False,
        input_shape=(CFG.image_size, CFG.image_size, 3),
        weights='imagenet'
    )
    base.trainable = False  # fine-tune later if needed
    inputs = tf.keras.Input(shape=(CFG.image_size, CFG.image_size, 3))
    x = base(inputs, training=False)
    x = tf.keras.layers.GlobalAveragePooling2D()(x)
    x = tf.keras.layers.Dropout(0.25)(x)
    x = tf.keras.layers.Dense(512, activation='relu')(x)
    x = tf.keras.layers.Dropout(0.2)(x)
    outputs = tf.keras.layers.Dense(len(TARGET_NAMES), activation='linear')(x)
    model = tf.keras.Model(inputs, outputs)

model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3),
    loss='mae',
    metrics=[tf.keras.metrics.MeanAbsoluteError(name='mae')]
)
model.summary()


In [None]:
# Training -----------------------------------------------------------------
callbacks = [
    tf.keras.callbacks.ReduceLROnPlateau(patience=2, factor=0.5, verbose=1),
    tf.keras.callbacks.EarlyStopping(patience=3, restore_best_weights=True, verbose=1)
]

ehistory = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=CFG.epochs,
    callbacks=callbacks,
    verbose=1
)


In [None]:
# Inference & submission ----------------------------------------------------
preds = model.predict(test_ds, verbose=1)
sub_df = pd.DataFrame(preds, columns=TARGET_NAMES)
sub_df.insert(0, 'image_path', test_wide['image_path'])
sub_df = sub_df.melt(id_vars='image_path', value_vars=TARGET_NAMES, var_name='target_name', value_name='target')
sub_df = sub_df.merge(test_df[['sample_id', 'image_path', 'target_name']], on=['image_path', 'target_name'], how='left')
submission = sub_df[['sample_id', 'target']].copy()
submission['target'] = submission['target'].clip(min=0)
submission.to_csv('submission_tpu.csv', index=False)
submission.head()
