# CSIRO Biomass (Kaggle) — Baseline + Ablations

## Baseline

### Data & evaluation

* **Tiny data (~357 unique train images)** → overfit risk is the main constraint.
* **Train table is long-format:** ~1785 rows = 357 images × 5 targets → pivot to **wide (357 rows)** for training.
* **Validation:** grouped CV by the true sampling unit (no row-level splits).

  * candidate splits to try (ablation):

    * **Group by `Sampling_Date`** (often improves CV↔LB realism)
    * stratify by **State** (target distributions differ by state)
* **Submission/inference:** images only (treat metadata as train-only).

### Model

* **Backbone:** DINOv3 (**frozen**)
* **Neck:** none (baseline)

### Head

* **2-layer MLP** → **5 outputs** (`Green, Clover, Dead, GDM, Total`)
* Use LayerNorm (or keep backbone norm), dropout ~0.1–0.3.

### Targets & loss

* Train on **log1p** targets.
* Baseline loss: **weighted MSE** in log-space.

  * weights roughly: `Green=0.1, Clover=0.1, Dead=0.1, GDM=0.2, Total=0.5`
* Metric: global weighted R² (in original space) computed across all targets.

### Augmentation (label-safe)

Goal: improve robustness without breaking the “grams ↔ pixels” relationship.

* Safe geometric:

  * flips
  * 90° rotations
  * small translate/shear (no scale)
  * **jigsaw (patch shuffle):** split into a small grid (e.g., 3×3) and randomly permute patches (use low probability; preserves scale but breaks global layout)
* Photometric:

  * brightness/contrast/saturation/hue jitter
  * mild blur
  * mild autocontrast
    Avoid:
* random resized crop / heavy zoom
* cutout / random erasing
* heavy rotations that require re-scaling

### Training hygiene

* Use AMP (BF16).
* Consider cosine LR + warmup.
* Consider EMA/SWA once baseline is stable.
* Clip gradients (e.g., 1.0 norm) if training is noisy.

## Ablations

### A) Must-verify (rules + correctness)

* Verify that each image has all 5 targets after pivot (wide).
* Verify metric implementation matches the competition definition.
* Verify no leakage: same `Sampling_Date` never appears in both train and val.

### B) Baseline implementation decisions

* Compare CV split strategies:

  * GroupKFold by `Sampling_Date`
  * StratifiedGroupKFold by `State` with groups=`Sampling_Date`
* Compare head sizes / dropout.

### C) CV evaluation protocol

* Report mean ± std across folds.
* Track per-target metrics for diagnostics.
* Keep the split fixed across experiments.

### D) Inference stability

* Test-time augmentation (TTA): rotations/flips.
* Average predictions in log-space vs original space.

$1$2- Try **jigsaw (patch shuffle)**: 3×3 or 4×4 grid, low probability, validate under CV.

$3 Tiling (multi-crop / multi-instance)

* **Tiled backbone features:** split each image into an `n×n` grid (start with **2×2**) and run the frozen DINO backbone on each tile.
* **Pooling (keep it simple):**

  * mean pool tile embeddings → single feature → head
  * (optional) mean+max concat (still simple)
  * (optional) per-tile head then mean of predictions
* **If images are stitched left/right:** tile each half separately → pool L and R → concat `[L, R]` → head.
* **Ablate:** `n=1` (no tiling) vs `n=2` vs `n=3` (compute-heavy) and pooling choice.

### G) Improvements (after baseline is stable)

* Unfreeze last N blocks (careful with overfit).
* SWA/EMA.
* Better heads (depth/width).
* Regularization sweeps.

## Insights

* Host: public/private split is **not fully random**; test includes some **non-overlapping** time/location periods.
* Group by `Sampling_Date` to reduce leakage from date-correlated collection conditions.
* With ~300 images, CV is noisy; report mean±std and avoid “seed shopping”.


**Repo note:** core code is also packaged under `src/csiro_biomass/` (see `scripts/train_cv.py`).


# Start

In [1]:
import os, json, pathlib, yaml
yaml_path = "/notebooks/env.yaml"

with open(yaml_path, 'r', encoding='utf-8') as f:
    env = yaml.safe_load(f)

for k, v in env.items():
    os.environ[k] = v

# Train

In [2]:
WB = "https://dinov3.llamameta.net/dinov3_vitb16/dinov3_vitb16_pretrain_lvd1689m-73cec8be.pth?Policy=eyJTdGF0ZW1lbnQiOlt7InVuaXF1ZV9oYXNoIjoidW84aXJvdGQyeThwcGpuNXFveGthZTE4IiwiUmVzb3VyY2UiOiJodHRwczpcL1wvZGlub3YzLmxsYW1hbWV0YS5uZXRcLyoiLCJDb25kaXRpb24iOnsiRGF0ZUxlc3NUaGFuIjp7IkFXUzpFcG9jaFRpbWUiOjE3NjU5NzI4MTd9fX1dfQ__&Signature=H5H5kLVc6V83i-s2euNHx6t9KlVeG27QKX6qtkXNiLwEzuCshJD4RfwUbQv8oBJOZXPezAVJZPRkYRdsb4jh-LQ72DZtEuNkjNKHf7Pn57wzee0bjEYjWdJmOqK4waaSe9TQqELM%7EPgzdAT4LCSHYcFQ%7EleRnHGGGJiHBmTd6e1xZYhvUCfkvVD1TG-zM7R0-P%7EMLetHMvWl%7EUapCMYthsWqZctsYAQKUQxsLrly8Y4EaM8hm5nowpArPZC4myNO1iiXld5Hc3t9CVLEdYT7LIct0x6cf3-B-6WOgxGb7LdLPCcZPPfoGgX3KGtTAgNQYOpGFs-hgILFHRKVOJ7T3A__&Key-Pair-Id=K15QRJLYKIFSLZ&Download-Request-ID=1893388161261111"
WL = "https://dinov3.llamameta.net/dinov3_vitl16/dinov3_vitl16_pretrain_lvd1689m-8aa4cbdd.pth?Policy=eyJTdGF0ZW1lbnQiOlt7InVuaXF1ZV9oYXNoIjoidW84aXJvdGQyeThwcGpuNXFveGthZTE4IiwiUmVzb3VyY2UiOiJodHRwczpcL1wvZGlub3YzLmxsYW1hbWV0YS5uZXRcLyoiLCJDb25kaXRpb24iOnsiRGF0ZUxlc3NUaGFuIjp7IkFXUzpFcG9jaFRpbWUiOjE3NjU5NzI4MTd9fX1dfQ__&Signature=H5H5kLVc6V83i-s2euNHx6t9KlVeG27QKX6qtkXNiLwEzuCshJD4RfwUbQv8oBJOZXPezAVJZPRkYRdsb4jh-LQ72DZtEuNkjNKHf7Pn57wzee0bjEYjWdJmOqK4waaSe9TQqELM%7EPgzdAT4LCSHYcFQ%7EleRnHGGGJiHBmTd6e1xZYhvUCfkvVD1TG-zM7R0-P%7EMLetHMvWl%7EUapCMYthsWqZctsYAQKUQxsLrly8Y4EaM8hm5nowpArPZC4myNO1iiXld5Hc3t9CVLEdYT7LIct0x6cf3-B-6WOgxGb7LdLPCcZPPfoGgX3KGtTAgNQYOpGFs-hgILFHRKVOJ7T3A__&Key-Pair-Id=K15QRJLYKIFSLZ&Download-Request-ID=1893388161261111"
WL_plus = "https://dinov3.llamameta.net/dinov3_vith16plus/dinov3_vith16plus_pretrain_lvd1689m-7c1da9a5.pth?Policy=eyJTdGF0ZW1lbnQiOlt7InVuaXF1ZV9oYXNoIjoidW84aXJvdGQyeThwcGpuNXFveGthZTE4IiwiUmVzb3VyY2UiOiJodHRwczpcL1wvZGlub3YzLmxsYW1hbWV0YS5uZXRcLyoiLCJDb25kaXRpb24iOnsiRGF0ZUxlc3NUaGFuIjp7IkFXUzpFcG9jaFRpbWUiOjE3NjU5NzI4MTd9fX1dfQ__&Signature=H5H5kLVc6V83i-s2euNHx6t9KlVeG27QKX6qtkXNiLwEzuCshJD4RfwUbQv8oBJOZXPezAVJZPRkYRdsb4jh-LQ72DZtEuNkjNKHf7Pn57wzee0bjEYjWdJmOqK4waaSe9TQqELM%7EPgzdAT4LCSHYcFQ%7EleRnHGGGJiHBmTd6e1xZYhvUCfkvVD1TG-zM7R0-P%7EMLetHMvWl%7EUapCMYthsWqZctsYAQKUQxsLrly8Y4EaM8hm5nowpArPZC4myNO1iiXld5Hc3t9CVLEdYT7LIct0x6cf3-B-6WOgxGb7LdLPCcZPPfoGgX3KGtTAgNQYOpGFs-hgILFHRKVOJ7T3A__&Key-Pair-Id=K15QRJLYKIFSLZ&Download-Request-ID=1893388161261111"

In [3]:
WB = "https://dinov3.llamameta.net/dinov3_vitb16/dinov3_vitb16_pretrain_lvd1689m-73cec8be.pth?Policy=eyJTdGF0ZW1lbnQiOlt7InVuaXF1ZV9oYXNoIjoidW84aXJvdGQyeThwcGpuNXFveGthZTE4IiwiUmVzb3VyY2UiOiJodHRwczpcL1wvZGlub3YzLmxsYW1hbWV0YS5uZXRcLyoiLCJDb25kaXRpb24iOnsiRGF0ZUxlc3NUaGFuIjp7IkFXUzpFcG9jaFRpbWUiOjE3NjU5NzI4MTd9fX1dfQ__&Signature=H5H5kLVc6V83i-s2euNHx6t9KlVeG27QKX6qtkXNiLwEzuCshJD4RfwUbQv8oBJOZXPezAVJZPRkYRdsb4jh-LQ72DZtEuNkjNKHf7Pn57wzee0bjEYjWdJmOqK4waaSe9TQqELM%7EPgzdAT4LCSHYcFQ%7EleRnHGGGJiHBmTd6e1xZYhvUCfkvVD1TG-zM7R0-P%7EMLetHMvWl%7EUapCMYthsWqZctsYAQKUQxsLrly8Y4EaM8hm5nowpArPZC4myNO1iiXld5Hc3t9CVLEdYT7LIct0x6cf3-B-6WOgxGb7LdLPCcZPPfoGgX3KGtTAgNQYOpGFs-hgILFHRKVOJ7T3A__&Key-Pair-Id=K15QRJLYKIFSLZ&Download-Request-ID=1893388161261111"
WL = "https://dinov3.llamameta.net/dinov3_vitl16/dinov3_vitl16_pretrain_lvd1689m-8aa4cbdd.pth?Policy=eyJTdGF0ZW1lbnQiOlt7InVuaXF1ZV9oYXNoIjoidW84aXJvdGQyeThwcGpuNXFveGthZTE4IiwiUmVzb3VyY2UiOiJodHRwczpcL1wvZGlub3YzLmxsYW1hbWV0YS5uZXRcLyoiLCJDb25kaXRpb24iOnsiRGF0ZUxlc3NUaGFuIjp7IkFXUzpFcG9jaFRpbWUiOjE3NjU5NzI4MTd9fX1dfQ__&Signature=H5H5kLVc6V83i-s2euNHx6t9KlVeG27QKX6qtkXNiLwEzuCshJD4RfwUbQv8oBJOZXPezAVJZPRkYRdsb4jh-LQ72DZtEuNkjNKHf7Pn57wzee0bjEYjWdJmOqK4waaSe9TQqELM%7EPgzdAT4LCSHYcFQ%7EleRnHGGGJiHBmTd6e1xZYhvUCfkvVD1TG-zM7R0-P%7EMLetHMvWl%7EUapCMYthsWqZctsYAQKUQxsLrly8Y4EaM8hm5nowpArPZC4myNO1iiXld5Hc3t9CVLEdYT7LIct0x6cf3-B-6WOgxGb7LdLPCcZPPfoGgX3KGtTAgNQYOpGFs-hgILFHRKVOJ7T3A__&Key-Pair-Id=K15QRJLYKIFSLZ&Download-Request-ID=1893388161261111"
WL_plus = "https://dinov3.llamameta.net/dinov3_vith16plus/dinov3_vith16plus_pretrain_lvd1689m-7c1da9a5.pth?Policy=eyJTdGF0ZW1lbnQiOlt7InVuaXF1ZV9oYXNoIjoidW84aXJvdGQyeThwcGpuNXFveGthZTE4IiwiUmVzb3VyY2UiOiJodHRwczpcL1wvZGlub3YzLmxsYW1hbWV0YS5uZXRcLyoiLCJDb25kaXRpb24iOnsiRGF0ZUxlc3NUaGFuIjp7IkFXUzpFcG9jaFRpbWUiOjE3NjU5NzI4MTd9fX1dfQ__&Signature=H5H5kLVc6V83i-s2euNHx6t9KlVeG27QKX6qtkXNiLwEzuCshJD4RfwUbQv8oBJOZXPezAVJZPRkYRdsb4jh-LQ72DZtEuNkjNKHf7Pn57wzee0bjEYjWdJmOqK4waaSe9TQqELM%7EPgzdAT4LCSHYcFQ%7EleRnHGGGJiHBmTd6e1xZYhvUCfkvVD1TG-zM7R0-P%7EMLetHMvWl%7EUapCMYthsWqZctsYAQKUQxsLrly8Y4EaM8hm5nowpArPZC4myNO1iiXld5Hc3t9CVLEdYT7LIct0x6cf3-B-6WOgxGb7LdLPCcZPPfoGgX3KGtTAgNQYOpGFs-hgILFHRKVOJ7T3A__&Key-Pair-Id=K15QRJLYKIFSLZ&Download-Request-ID=1893388161261111"

import os
import sys
import copy
import math
import random
import requests
import numpy as np
import pandas as pd
from PIL import Image
import uuid
from itertools import chain

"""os.environ["TORCHINDUCTOR_FX_GRAPH_CACHE"]="1"
os.environ["TORCHINDUCTOR_AUTOGRAD_CACHE"]="1"
os.environ["TORCHINDUCTOR_CACHE_DIR"] = "/notebooks/dinov3/compile_cache"""

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, Subset
from torch.optim.swa_utils import AveragedModel, SWALR
import torchvision.transforms as T
from sklearn.model_selection import StratifiedGroupKFold
import torch.nn.functional as F
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import comet_ml


sys.path.insert(0, "/notebooks/dinov3")  # your fork
from dinov3.layers.block import SelfAttentionBlock

"""torch.backends.cuda.enable_flash_sdp(True)
torch.backends.cuda.enable_mem_efficient_sdp(False)
torch.backends.cuda.enable_math_sdp(False)

dynamo_config = torch._dynamo.config
dynamo_config.compiled_autograd = True
dynamo_config.capture_scalar_outputs = False
dynamo_config.cache_size_limit = 512"""

#torch.set_float32_matmul_precision("highest")
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True


# -------------------------
# data: long -> wide
# -------------------------
def load_train_wide(csv_path: str) -> pd.DataFrame:
    df = pd.read_csv(csv_path)
    idx_cols = ["image_path", "Sampling_Date", "State", "Species", "Pre_GSHH_NDVI", "Height_Ave_cm"]
    wide = (
        df.pivot_table(index=idx_cols, columns="target_name", values="target", aggfunc="first")
          .reset_index()
    )
    for t in TARGETS:
        if t not in wide.columns:
            wide[t] = np.nan
    wide = wide.dropna(subset=TARGETS).reset_index(drop=True)
    wide["abs_path"] = wide["image_path"].apply(lambda p: os.path.join(ROOT, p))
    return wide

model_size = "b"
W = WB
plus = ""
COMPILE_MODEL = False
REPO_DIR = "/notebooks/dinov3"
DINO_WEIGHTS = f"/notebooks/dinov3/weights/dinov3_vit{model_size}16_pretrain{plus}.pth"
MODEL = torch.hub.load(REPO_DIR, f'dinov3_vit{model_size}16{plus.replace("_", "")}', source='local', weights=DINO_WEIGHTS, verbose=True)
#MODEL_PLUS = torch.hub.load(REPO_DIR, f'dinov3_vit{model_size}16plus', source='local', weights=WL_plus, verbose=True)
NUM_WORKERS = os.cpu_count() - 2
ROOT = "/notebooks/kaggle/csiro"
CSV_PATH = os.path.join(ROOT, "train.csv")
TARGETS = ["Dry_Green_g", "Dry_Clover_g", "Dry_Dead_g", "GDM_g", "Dry_Total_g"]
WIDE_DF=load_train_wide(CSV_PATH)
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD  = (0.229, 0.224, 0.225)
IMG_SIZE = 512
SEED = 420
DTYPE = torch.bfloat16  # set to torch.bfloat16 on GPUs that support it
RUN_SWEEPS = True  # set True to run CV sweeps
FEAT_DIM = MODEL.norm.normalized_shape[0]
NUM_HEADS = 10

In [4]:
def download_file(url: str, out_path: str, chunk_size: int = 1024 * 1024):
    os.makedirs(os.path.dirname(out_path) or ".", exist_ok=True)
    with requests.get(url, stream=True, timeout=60) as r:
        r.raise_for_status()
        with open(out_path, "wb") as f:
            for chunk in r.iter_content(chunk_size=chunk_size):
                if chunk:  # filter keep-alive chunks
                    f.write(chunk)
    return out_path

# example
url = "https://dinov3.llamameta.net/dinov3_vitl16/dinov3_vitl16_pretrain_lvd1689m-8aa4cbdd.pth?Policy=eyJTdGF0ZW1lbnQiOlt7InVuaXF1ZV9oYXNoIjoidW84aXJvdGQyeThwcGpuNXFveGthZTE4IiwiUmVzb3VyY2UiOiJodHRwczpcL1wvZGlub3YzLmxsYW1hbWV0YS5uZXRcLyoiLCJDb25kaXRpb24iOnsiRGF0ZUxlc3NUaGFuIjp7IkFXUzpFcG9jaFRpbWUiOjE3NjU5NzI4MTd9fX1dfQ__&Signature=H5H5kLVc6V83i-s2euNHx6t9KlVeG27QKX6qtkXNiLwEzuCshJD4RfwUbQv8oBJOZXPezAVJZPRkYRdsb4jh-LQ72DZtEuNkjNKHf7Pn57wzee0bjEYjWdJmOqK4waaSe9TQqELM%7EPgzdAT4LCSHYcFQ%7EleRnHGGGJiHBmTd6e1xZYhvUCfkvVD1TG-zM7R0-P%7EMLetHMvWl%7EUapCMYthsWqZctsYAQKUQxsLrly8Y4EaM8hm5nowpArPZC4myNO1iiXld5Hc3t9CVLEdYT7LIct0x6cf3-B-6WOgxGb7LdLPCcZPPfoGgX3KGtTAgNQYOpGFs-hgILFHRKVOJ7T3A__&Key-Pair-Id=K15QRJLYKIFSLZ&Download-Request-ID=1893388161261111"
out_path = "/notebooks/dinov3/weights/dinov3_vitl16_pretrain.pth"
#download_file(url, out_path)
#print("Saved to:", out_path)


# Utils

In [5]:
def _denorm_img(x: torch.Tensor, mean=IMAGENET_MEAN, std=IMAGENET_STD) -> torch.Tensor:
    """
    x: [3,H,W] float tensor normalized with mean/std.
    returns: [3,H,W] in [0,1]
    """
    mean = torch.tensor(mean, dtype=x.dtype, device=x.device).view(3, 1, 1)
    std  = torch.tensor(std,  dtype=x.dtype, device=x.device).view(3, 1, 1)
    x = x * std + mean
    return x.clamp(0, 1)

@torch.no_grad()
def show_nxn_grid(dataset=None, dataloader=None, n=4, indices=None, seed=0,
                  mean=IMAGENET_MEAN, std=IMAGENET_STD,
                  show_targets=True, targets_are_log1p=True, figsize_per_cell=3.0):
    assert (dataset is not None) ^ (dataloader is not None), "Pass exactly one of dataset or dataloader."
    k = n * n
    xs, ys = [], []
    if dataset is not None:
        if indices is None:
            rng = random.Random(seed)
            indices = [rng.randrange(len(dataset)) for _ in range(k)]
        else:
            assert len(indices) >= k, f"Need at least {k} indices."

        for i in indices[:k]:
            x, y = dataset[i]
            xs.append(x)
            ys.append(y)

        x_batch = torch.stack(xs, dim=0)  # [k,3,H,W]
        y_batch = torch.stack(ys, dim=0) if show_targets else None

    else:
        for xb, yb in dataloader:
            for j in range(xb.shape[0]):
                xs.append(xb[j])
                ys.append(yb[j])
                if len(xs) >= k:
                    break
            if len(xs) >= k:
                break

        x_batch = torch.stack(xs, dim=0)
        y_batch = torch.stack(ys, dim=0) if show_targets else None

    # plot
    fig, axes = plt.subplots(n, n, figsize=(n * figsize_per_cell, n * figsize_per_cell))
    axes = np.asarray(axes)

    for idx in range(k):
        ax = axes[idx // n, idx % n]
        x = _denorm_img(x_batch[idx], mean=mean, std=std)
        img = x.permute(1, 2, 0).cpu().numpy()  # [H,W,3] in [0,1]
        ax.imshow(img)
        ax.axis("off")

        if show_targets and (y_batch is not None):
            y = y_batch[idx].detach().cpu()
            if targets_are_log1p:
                y = torch.expm1(y).clamp_min(0.0)
            # short title
            ax.set_title(" ".join([f"{v:.2f}" for v in y.tolist()]), fontsize=8)

    plt.tight_layout()
    plt.show()


class TileEncoder(nn.Module):
    def __init__(self, backbone: nn.Module, input_res: int):
        super().__init__()
        self.backbone = backbone
        self.input_res = input_res

    def forward(self, x: torch.Tensor, grid):
        B, C, H, W = x.shape
        r, c = grid
        hs = torch.linspace(0, H, steps=r + 1, device=x.device).round().long()
        ws = torch.linspace(0, W, steps=c + 1, device=x.device).round().long()
        tiles = []
        for i in range(r):
            for j in range(c):
                rs, re = hs[i].item(), hs[i + 1].item()
                cs, ce = ws[j].item(), ws[j + 1].item()
                xt = x[:, :, rs:re, cs:ce]
                if xt.shape[-2:] != (self.input_res, self.input_res):
                    xt = F.interpolate(xt, size=(self.input_res, self.input_res), mode="bilinear", align_corners=False)
                tiles.append(xt)
        tiles = torch.stack(tiles, dim=1)
        flat = tiles.view(-1, C, self.input_res, self.input_res)
        feats = self.backbone(flat)
        feats = feats.view(B, -1, feats.shape[-1])
        return feats

# Train Utils

In [6]:
# -------------------------
# transforms
# -------------------------
class PadToSquare:
    def __init__(self, fill=0):
        self.fill = fill

    def __call__(self, img: Image.Image) -> Image.Image:
        w, h = img.size
        if w == h:
            return img
        s = max(w, h)
        new = Image.new(img.mode, (s, s), color=self.fill)
        new.paste(img, ((s - w) // 2, (s - h) // 2))
        return new

def get_tfms():
    return T.Compose([
        T.RandomHorizontalFlip(p=0.5),
        T.RandomVerticalFlip(p=0.5),
        T.RandomChoice([
            T.Lambda(lambda x: x),
            T.RandomRotation((90, 90)),
            T.RandomRotation((180, 180)),
            T.RandomRotation((270, 270)),
        ]),
        T.ColorJitter(brightness=0.20, contrast=0.20, saturation=0.20, hue=0.04),

    ])

def post_tfms(): 
    normalize = T.Normalize(mean=(0.485, 0.456, 0.406),
                            std=(0.229, 0.224, 0.225))

    return T.Compose([T.ToTensor(),normalize])
# -------------------------
# dataset
# -------------------------

class BiomassBaseCached(Dataset):
    """Caches resized/padded PIL images + stores y_log once."""
    def __init__(self, wide_df, img_size=IMG_SIZE):
        self.df = wide_df.reset_index(drop=True)
        y = self.df[TARGETS].values.astype(np.float32)
        self.y_log = np.log1p(y)

        # cache at fixed size (PIL)
        pre = T.Compose([
            PadToSquare(fill=0),
            T.Resize((img_size, img_size), antialias=True),
        ])
        self.imgs = []
        for p in self.df["abs_path"].tolist():
            im = Image.open(p).convert("RGB")
            im = pre(im)
            self.imgs.append(im.copy())
            im.close()

    def __len__(self): 
        return len(self.df)

    def __getitem__(self, i):
        return self.imgs[i], torch.from_numpy(self.y_log[i])  # PIL, y_log


class TransformView(Dataset):
    """Applies train/val transforms on top of the same cached base dataset."""
    def __init__(self, base: BiomassBaseCached, tfms):
        self.base = base
        self.tfms = tfms

    def __len__(self): return len(self.base)

    def __getitem__(self, i):
        img, y = self.base[i]          # img is cached PIL
        x = self.tfms(img)             # apply aug+tensor+norm OR tensor+norm
        return x, y

    
# -------------------------
# Loss
# -------------------------  
class WeightedMSELoss(nn.Module):
    def __init__(self, weights=(0.1, 0.1, 0.1, 0.2, 0.5), normalize=True):
        super().__init__()
        w = torch.as_tensor(weights, dtype=torch.float32)
        self.register_buffer("w", w)
        self.normalize = normalize

    def forward(self, pred_log: torch.Tensor, target_log: torch.Tensor) -> torch.Tensor:
        w = self.w.view(1, -1)
        err2 = (pred_log - target_log).pow(2)
        loss = (err2 * w).sum(dim=-1)
        if self.normalize:
            loss = loss / (self.w.sum() + 1e-12)
        return loss.mean()


# -------------------------
# model: frozen DINOv3 + head
# -------------------------
class DINOv3Regressor(nn.Module):
    def __init__(self, backbone: nn.Module, hidden=1024, depth=2, drop=0.1, out_dim=5, feat_dim = None, norm=None, num_neck=1):
        super().__init__()
        self.backbone = backbone
        feat_dim = feat_dim or FEAT_DIM
        for p in self.backbone.parameters():
            p.requires_grad_(False)
        self.backbone.eval()

        assert not num_neck or feat_dim == 768, "Only VIT B is supported for now for neck"
        neck = []
        for _ in range(num_neck):
            neck.append(SelfAttentionBlock(feat_dim, num_heads=12))
        self.neck = nn.Sequential(*neck) if neck else []

        if depth < 2:
            raise ValueError(f"depth must be >= 2 (got {depth})")
        
        layers = []
        in_dim = feat_dim
        norm_layer = norm or nn.LayerNorm
        for _ in range(depth - 1):
            layers += [nn.Linear(in_dim, hidden), norm_layer(hidden), nn.GELU(), nn.Dropout(drop)]
            in_dim = hidden
        layers += [nn.Linear(in_dim, out_dim)]
        self.head = nn.Sequential(*layers)
        self.norm = norm_layer(feat_dim)
        self.init()

    def forward(self, x):
        with torch.no_grad():
            x, rope = self.backbone(x)
            x = x["x_prenorm"]
            
        for neck in self.neck:
            x = neck(x, rope)
        x = self.norm(x[: , 0, :])
        return self.head(x)  
    
    def set_train(self, train = True):
        self.neck.train(train)
        self.head.train(train)
        self.norm.train(train)
        
    @torch.no_grad()
    def init(self):
        modules = [*self.head.modules(), *self.neck.modules(), *self.norm.modules()]
        for m in modules:
            if isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, mean=0.0, std=0.02)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)

            elif isinstance(m, nn.LayerNorm):
                if m.elementwise_affine:
                    nn.init.ones_(m.weight)
                    nn.init.zeros_(m.bias)

# CV

In [7]:
def eval_global_wr2(model, dl_va, w_vec, device="cuda"):
    model.eval()
    w5 = w_vec.to(device).view(1, -1)
    ss_res  = torch.zeros((), device=device)
    sum_w   = torch.zeros((), device=device)
    sum_wy  = torch.zeros((), device=device)
    sum_wy2 = torch.zeros((), device=device)

    with torch.inference_mode(), torch.amp.autocast(device_type="cuda", dtype=DTYPE, enabled=device.startswith("cuda")):
        for x, y_log in dl_va:
            x = x.to(device, non_blocking=True)
            y_log = y_log.to(device, non_blocking=True)   # log1p targets
            p_log = model(x).float()                      # log1p preds

            y = torch.expm1(y_log)
            p = torch.expm1(p_log).clamp_min(0.0)

            w = w5.expand_as(y)                           # [B, 5]
            diff = (y - p)

            ss_res  += (w * diff * diff).sum()
            sum_w   += w.sum()
            sum_wy  += (w * y).sum()
            sum_wy2 += (w * y * y).sum()

    mu = sum_wy / (sum_w + 1e-12)
    ss_tot = sum_wy2 - sum_w * mu * mu
    score = (1.0 - ss_res / (ss_tot + 1e-12)).item()
    return score

def cos_sin_lr(ep: int, epochs: int, lr_start: float, lr_final: float) -> float:
    if epochs <= 1:
        return lr_final
    t = (ep - 1) / (epochs - 1)  # 0 -> 1
    return lr_final + 0.5 * (lr_start - lr_final) * (1.0 + math.cos(math.pi * t))

def set_optimizer_lr(opt, lr: float):
    for pg in opt.param_groups:
        pg["lr"] = lr

In [8]:
def train_one_fold(
    ds_tr_view,
    ds_va_view,
    backbone,
    tr_idx,
    va_idx,
    wd=1e-4,
    fold_idx=0,
    epochs=5,
    lr_start=3e-4,
    lr_final=5e-5,
    batch_size=128,
    clip_val=3,
    device="cuda",
    save_path=None,
    verbose=False,
    plot_imgs=False,
    early_stopping=6,
    head_hidden=1024,
    head_depth=2,
    head_drop=0.1,
    num_neck=0,
    comet_exp=None,
    skip_log_first_n=5,
    curr_fold=0,

    # --- SWA phase ---
    swa_epochs=15,
    swa_lr=None,
    swa_anneal_epochs=10,
    swa_load_best=True,
    swa_eval_freq=2,
):

    def _trainable_blocks(m: nn.Module):
        parts = []
        if hasattr(m, "neck") and m.neck is not None:
            parts.append(m.neck)
        if hasattr(m, "head") and m.head is not None:
            parts.append(m.head)
        if hasattr(m, "norm") and m.norm is not None:
            parts.append(m.norm)
        return parts

    def _trainable_params_list(m: nn.Module):
        blocks = _trainable_blocks(m)
        params = list(chain.from_iterable(b.parameters() for b in blocks))
        params = [p for p in params if p.requires_grad]
        return params

    def _save_parts(m):
        state = {}
        for name in ("neck", "head", "norm"):
            part = getattr(m, name, None)
            if part is not None:
                state[name] = {k: v.detach().cpu() for k, v in part.state_dict().items()}
        return state

    def _load_parts(m, state):
        for name in ("neck", "head", "norm"):
            part = getattr(m, name, None)
            if part is not None and name in state:
                part.load_state_dict(state[name], strict=True)

    # ---- data ----
    tr_subset = Subset(ds_tr_view, tr_idx)
    va_subset = Subset(ds_va_view, va_idx)

    dl_kwargs = dict(
        batch_size=batch_size,
        pin_memory=True,
        num_workers=NUM_WORKERS,
        persistent_workers=(NUM_WORKERS > 0),
    )
    dl_tr = DataLoader(tr_subset, shuffle=True, **dl_kwargs)
    dl_va = DataLoader(va_subset, shuffle=False, **dl_kwargs)

    if plot_imgs:
        show_nxn_grid(dataloader=dl_tr, n=4)
        return

    # ---- model ----
    model = DINOv3Regressor(
        backbone, hidden=head_hidden, drop=head_drop, depth=head_depth, num_neck=num_neck
    ).to(device)
    model.init()

    criterion = WeightedMSELoss().to(device)

    trainable_params = _trainable_params_list(model)
    opt = torch.optim.AdamW(trainable_params, lr=lr_start, weight_decay=wd)

    scaler = None
    if device.startswith("cuda") and DTYPE == torch.float16:
        scaler = torch.amp.GradScaler()

    if swa_lr is None:
        swa_lr = lr_final

    # ---- bookkeeping ----
    best_score = -1e9
    best_state = None
    best_opt_state = None
    patience = 0

    # =========================
    # Phase A: normal training
    # =========================
    p_bar = tqdm(range(1, epochs + 1))
    for ep in p_bar:
        lr = cos_sin_lr(ep, epochs, lr_start, lr_final)
        set_optimizer_lr(opt, lr)

        model.set_train(True)
        running = 0.0
        n_seen = 0

        for x, y_log in dl_tr:
            x = x.to(device, non_blocking=True)
            y_log = y_log.to(device, non_blocking=True)

            opt.zero_grad(set_to_none=True)
            with torch.amp.autocast(device_type="cuda", dtype=DTYPE, enabled=device.startswith("cuda")):
                p_log = model(x)
                loss = criterion(p_log, y_log)

            if scaler is not None:
                scaler.scale(loss).backward()

                # unscale before clipping
                if clip_val and clip_val > 0:
                    scaler.unscale_(opt)
                    torch.nn.utils.clip_grad_norm_(trainable_params, max_norm=float(clip_val))

                scaler.step(opt)
                scaler.update()
            else:
                loss.backward()

                if clip_val and clip_val > 0:
                    torch.nn.utils.clip_grad_norm_(trainable_params, max_norm=float(clip_val))

                opt.step()

            bs = x.size(0)
            running += loss.detach() * bs
            n_seen += bs

        train_loss = (running / max(n_seen, 1)).item()
        model.set_train(False)
        score = eval_global_wr2(model, dl_va, criterion.w, device=device)

        if comet_exp is not None and ep > skip_log_first_n:
            comet_exp.log_metrics(
                {f"train_loss_{curr_fold}": train_loss, f"val_wR2_{curr_fold}": score},
                step=ep,
            )

        if score > best_score:
            best_score = score
            patience = 0
            best_state = _save_parts(model)
            best_opt_state = copy.deepcopy(opt.state_dict())
        else:
            patience += 1

        s1 = f"Best score: {best_score:.4f} | Patience: {patience:02d}/{early_stopping:02d} | lr: {lr:6.4f}"
        s2 = f"[fold {fold_idx}] | train_loss={train_loss:.4f} | val_wR2={score:.4f} | {s1}"
        if verbose:
            print(s2)
        p_bar.set_postfix_str(s2)

        if patience >= early_stopping:
            p_bar.set_postfix_str(s2 + " | Early stopping -> SWA phase")
            break

    p_bar.close()

    if (swa_epochs <= 0) or (best_state is None):
        if save_path and best_state is not None:
            torch.save(best_state, save_path)
        return best_score

    # =========================
    # Phase B: SWA extra epochs
    # =========================
    if swa_load_best:
        _load_parts(model, best_state)
        if best_opt_state is not None:
            opt.load_state_dict(best_opt_state)

    swa_model = AveragedModel(model).to(device)
    swa_sched = SWALR(opt, swa_lr=swa_lr, anneal_epochs=swa_anneal_epochs, anneal_strategy="cos")

    p_bar = tqdm(range(1, swa_epochs + 1))
    swa_score = None

    for k in p_bar:
        model.set_train(True)
        running = 0.0
        swa_n_seen = 0

        for x, y_log in dl_tr:
            x = x.to(device, non_blocking=True)
            y_log = y_log.to(device, non_blocking=True)

            opt.zero_grad(set_to_none=True)
            with torch.amp.autocast(device_type="cuda", dtype=DTYPE, enabled=device.startswith("cuda")):
                p_log = model(x)
                loss = criterion(p_log, y_log)

            if scaler is not None:
                scaler.scale(loss).backward()

                if clip_val and clip_val > 0:
                    scaler.unscale_(opt)
                    torch.nn.utils.clip_grad_norm_(trainable_params, max_norm=float(clip_val))

                scaler.step(opt)
                scaler.update()
            else:
                loss.backward()

                if clip_val and clip_val > 0:
                    torch.nn.utils.clip_grad_norm_(trainable_params, max_norm=float(clip_val))

                opt.step()

            bs = x.size(0)
            running += loss.detach() * bs
            swa_n_seen += bs

        swa_loss = (running / max(swa_n_seen, 1)).item()

        swa_sched.step()
        swa_model.update_parameters(model)

        if comet_exp is not None:
            comet_exp.log_metrics({f"swa_train_loss_{curr_fold}": swa_loss}, step=k)

        s2 = f"[fold {fold_idx}] | swa_loss={swa_loss:.4f}"
        if verbose:
            print(s2)
        p_bar.set_postfix_str(s2)

        if swa_eval_freq and (k % swa_eval_freq == 0):
            swa_score = eval_global_wr2(swa_model, dl_va, criterion.w, device=device)
            if comet_exp is not None:
                comet_exp.log_metrics({f"swa_wR2_{curr_fold}": swa_score}, step=k)

    p_bar.close()

    if swa_score is None or (swa_eval_freq and (k % swa_eval_freq) != 0):
        swa_score = eval_global_wr2(swa_model, dl_va, criterion.w, device=device)

    if save_path:
        swa_blocks_state = _save_parts(swa_model.module)
        torch.save(swa_blocks_state, save_path)

    return swa_score



In [9]:
def run_groupkfold_cv(
    dataset,
    wide_df,
    n_splits=5,
    group_col="Sampling_Date",
    tfms_fn = get_tfms,
    comet_exp_name = None,
    sweep_config = "",
    **train_kwargs,
):
    sgkf = StratifiedGroupKFold(n_splits=n_splits, shuffle=True, random_state=SEED)
    X = wide_df  
    y = wide_df["State"].values             
    groups = wide_df[group_col].values
    ds_tr_view = TransformView(dataset, T.Compose([tfms_fn(), post_tfms()]))
    ds_va_view = TransformView(dataset, post_tfms())
    
    """
    bb_copy = copy.deepcopy(backbone)
    model = DINOv3Regressor(bb_copy, hidden=head_hidden, drop=head_drop, depth=head_depth, norm=head_norm).to(device)
    
    if COMPILE_MODEL:
        model.compile(fullgraph=True, mode="default", backend="inductor", dynamic=True)
    """
    if comet_exp_name is not None:
        comet_exp = comet_ml.start(
            api_key=os.getenv("COMET_API_KEY"),
            project_name=comet_exp_name,
            experiment_key=None
        )
    fold_scores = []
    try:
        comet_exp.set_name(comet_exp_name + "_" + sweep_config + "_" + str(uuid.uuid4())[:3])
        for fold_idx, (tr_idx, va_idx) in enumerate(sgkf.split(X, y, groups)):
                score = train_one_fold(
                    ds_tr_view=ds_tr_view,
                    ds_va_view=ds_va_view,
                    tr_idx=tr_idx,
                    va_idx=va_idx,
                    fold_idx=fold_idx,
                    device="cuda",
                    comet_exp = comet_exp,
                    curr_fold = fold_idx,
                    **train_kwargs,
                )
                if "plot_imgs" in train_kwargs and train_kwargs["plot_imgs"]:
                    return None, None, None
                fold_scores.append(score)
    except Exception as e:
        print(f"Fold {fold_idx} failed with exception: {e}")
    finally:
        if comet_exp is not None:
            comet_exp.end()

    fold_scores = np.array(fold_scores, dtype=np.float32)
    mean = float(fold_scores.mean())
    std = float(fold_scores.std(ddof=0))

    print("\nCV summary")
    for i, s in enumerate(fold_scores.tolist()):
        print(f"  fold {i}: {s:.4f}")
    print(f"  mean ± std: {mean:.4f} ± {std:.4f}")
    return fold_scores, mean, std


In [10]:
dataset_biomass = BiomassBaseCached(WIDE_DF, img_size=IMG_SIZE)
assert RUN_SWEEPS

# Sweep

In [None]:
def get_tfms_0():
    return T.Compose([
        T.RandomHorizontalFlip(p=0.5),
        T.RandomVerticalFlip(p=0.5),
        T.RandomChoice([
            T.Lambda(lambda x: x),
            T.RandomRotation((90, 90)),
            T.RandomRotation((180, 180)),
            T.RandomRotation((270, 270)),
        ]),
        T.ColorJitter(brightness=0.25, contrast=0.25, saturation=0.25, hue=0.035),

    ])

In [None]:
train_kwargs = dict(
    dataset=dataset_biomass,
    wide_df=WIDE_DF,
    backbone=MODEL,
    epochs=80,
    batch_size=64,
    wd=3e-3,
    head_hidden=2048,
    head_drop=0.1,
    head_depth=5,
    plot_imgs=False,
    early_stopping=15,
    comet_exp_name="cv5",
)

sweeps = [
    dict(num_neck=1, head_depth=4, tfms_fn=get_tfms_0),
    dict(num_neck=1, head_depth=5, tfms_fn=get_tfms_0),
    dict(num_neck=2, head_depth=4, tfms_fn=get_tfms_0),
    dict(num_neck=2, head_depth=5, tfms_fn=get_tfms_0),

]

sweep_id = str(uuid.uuid4())[:4]
for sweep in sweeps: 
    new_train_kwargs = train_kwargs.copy() 
    for k, v in sweep.items(): 
        new_train_kwargs[k] = v
    new_train_kwargs["comet_exp_name"] += f"-{sweep_id}"
    new_train_kwargs["sweep_config"] = str(sweep)
    scores, mean, std = run_groupkfold_cv(**new_train_kwargs)


[1;38;5;39mCOMET INFO:[0m Experiment is live on comet.com https://www.comet.com/v1kstrand/cv5-9e96/47d36f8e6a41406783f68c8d13cc5ce0

[1;38;5;196mCOMET ERROR:[0m Error logging git-related information


  0%|          | 0/80 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/80 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/80 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/80 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/80 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

[1;38;5;39mCOMET INFO:[0m ---------------------------------------------------------------------------------------
[1;38;5;39mCOMET INFO:[0m Comet.ml Experiment Summary
[1;38;5;39mCOMET INFO:[0m ---------------------------------------------------------------------------------------
[1;38;5;39mCOMET INFO:[0m   Data:
[1;38;5;39mCOMET INFO:[0m     display_summary_level : 1
[1;38;5;39mCOMET INFO:[0m     name                  : cv5-9e96_{'num_neck': 1, 'head_depth': 4, 'tfms_fn': <function get_tfms_0 at 0x7f68a6c80a60>}_43b
[1;38;5;39mCOMET INFO:[0m     url                   : https://www.comet.com/v1kstrand/cv5-9e96/47d36f8e6a41406783f68c8d13cc5ce0
[1;38;5;39mCOMET INFO:[0m   Metrics [count] (min, max):
[1;38;5;39mCOMET INFO:[0m     swa_train_loss_0 [15] : (0.08894143998622894, 0.18911533057689667)
[1;38;5;39mCOMET INFO:[0m     swa_train_loss_1 [15] : (0.0808490738272667, 0.19342468678951263)
[1;38;5;39mCOMET INFO:[0m     swa_train_loss_2 [15] : (0.07483459264039993, 0


CV summary
  fold 0: 0.6072
  fold 1: 0.7279
  fold 2: 0.8853
  fold 3: 0.7125
  fold 4: 0.6563
  mean ± std: 0.7179 ± 0.0940


[1;38;5;39mCOMET INFO:[0m Experiment is live on comet.com https://www.comet.com/v1kstrand/cv5-9e96/6844c46295c24ad082f681b2a5d28679

[1;38;5;196mCOMET ERROR:[0m Error logging git-related information


  0%|          | 0/80 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/80 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/80 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/80 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/80 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

[1;38;5;39mCOMET INFO:[0m ---------------------------------------------------------------------------------------
[1;38;5;39mCOMET INFO:[0m Comet.ml Experiment Summary
[1;38;5;39mCOMET INFO:[0m ---------------------------------------------------------------------------------------
[1;38;5;39mCOMET INFO:[0m   Data:
[1;38;5;39mCOMET INFO:[0m     display_summary_level : 1
[1;38;5;39mCOMET INFO:[0m     name                  : cv5-9e96_{'num_neck': 1, 'head_depth': 5, 'tfms_fn': <function get_tfms_0 at 0x7f68a6c80a60>}_3bf
[1;38;5;39mCOMET INFO:[0m     url                   : https://www.comet.com/v1kstrand/cv5-9e96/6844c46295c24ad082f681b2a5d28679
[1;38;5;39mCOMET INFO:[0m   Metrics [count] (min, max):
[1;38;5;39mCOMET INFO:[0m     swa_train_loss_0 [15] : (0.0854143500328064, 0.21025492250919342)
[1;38;5;39mCOMET INFO:[0m     swa_train_loss_1 [15] : (0.07407134026288986, 0.181295245885849)
[1;38;5;39mCOMET INFO:[0m     swa_train_loss_2 [15] : (0.020146487280726433, 0.


CV summary
  fold 0: 0.5875
  fold 1: 0.7355
  fold 2: 0.8521
  fold 3: 0.7177
  fold 4: 0.7036
  mean ± std: 0.7193 ± 0.0843


[1;38;5;39mCOMET INFO:[0m Experiment is live on comet.com https://www.comet.com/v1kstrand/cv5-9e96/a7cea6df28f143ccbf613eadd4e37737

[1;38;5;196mCOMET ERROR:[0m Error logging git-related information


  0%|          | 0/80 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/80 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/80 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/80 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/80 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

[1;38;5;39mCOMET INFO:[0m ---------------------------------------------------------------------------------------
[1;38;5;39mCOMET INFO:[0m Comet.ml Experiment Summary
[1;38;5;39mCOMET INFO:[0m ---------------------------------------------------------------------------------------
[1;38;5;39mCOMET INFO:[0m   Data:
[1;38;5;39mCOMET INFO:[0m     display_summary_level : 1
[1;38;5;39mCOMET INFO:[0m     name                  : cv5-9e96_{'num_neck': 2, 'head_depth': 4, 'tfms_fn': <function get_tfms_0 at 0x7f68a6c80a60>}_6c5
[1;38;5;39mCOMET INFO:[0m     url                   : https://www.comet.com/v1kstrand/cv5-9e96/a7cea6df28f143ccbf613eadd4e37737
[1;38;5;39mCOMET INFO:[0m   Metrics [count] (min, max):
[1;38;5;39mCOMET INFO:[0m     swa_train_loss_0 [15] : (0.0671190395951271, 0.1777181625366211)
[1;38;5;39mCOMET INFO:[0m     swa_train_loss_1 [15] : (0.0802803635597229, 0.21668942272663116)
[1;38;5;39mCOMET INFO:[0m     swa_train_loss_2 [15] : (0.05535123869776726, 0.1


CV summary
  fold 0: 0.5938
  fold 1: 0.7468
  fold 2: 0.8391
  fold 3: 0.6660
  fold 4: 0.6950
  mean ± std: 0.7081 ± 0.0821


[1;38;5;39mCOMET INFO:[0m Experiment is live on comet.com https://www.comet.com/v1kstrand/cv5-9e96/ce77217fa1494bbf8d4ff1987c460f1a

[1;38;5;196mCOMET ERROR:[0m Error logging git-related information


  0%|          | 0/80 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/80 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/80 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/80 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

# End