# CSIRO Image2Biomass v1 Kaggle Inference



In [None]:
# 설정 및 경로

import glob
import os
import sys

import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm

sys.path.append(os.path.abspath("."))

from src.config import Config, PathConfig, TrainConfig, OptunaConfig
from src.data import RegressionDataset, load_long_dataframe, to_wide
from src.metrics import expand_targets
from src.model import build_model


def ensure_timm_installed() -> None:
    try:
        import timm  # noqa: F401
    except ImportError as exc:
        raise ImportError(
            "timm이 설치되어 있어야 합니다. Kaggle 환경에 timm 패키지가 포함된 런타임/데이터셋을 추가해주세요."
        ) from exc


DATA_ROOT = "/kaggle/input/csiro-biomass"
WEIGHTS_ROOT = "/kaggle/input/csiroi2b-weights"  # Kaggle Dataset 경로에 맞게 수정하세요.
RUN_NAME = os.environ.get("RUN_NAME", "v1_inference")
OUTPUT_ROOT = "/kaggle/working/outputs"
RUN_DIR = os.path.join(OUTPUT_ROOT, RUN_NAME)
SUBMISSION_PATH = os.path.join(RUN_DIR, "submission", "submission.csv")
WORKING_SUBMISSION = "/kaggle/working/submission.csv"

os.makedirs(os.path.join(RUN_DIR, "submission"), exist_ok=True)

cfg = Config(
    paths=PathConfig(
        data_root=DATA_ROOT,
        train_csv="train.csv",
        test_csv="test.csv",
        output_root=OUTPUT_ROOT,
        run_name=RUN_NAME,
    ),
    train=TrainConfig(),
    optuna=OptunaConfig(use_optuna=False),
    device="cuda",
)

device = torch.device(cfg.device if torch.cuda.is_available() else "cpu")
print("Using device:", device)



In [None]:
# 데이터 확인 및 로딩

ensure_timm_installed()

train_csv = cfg.paths.resolve_train_csv()
test_csv = cfg.paths.resolve_test_csv()

print("Train CSV exists:", os.path.exists(train_csv), "-", train_csv)
print("Test CSV exists:", os.path.exists(test_csv), "-", test_csv)

train_long = load_long_dataframe(train_csv)
test_long = load_long_dataframe(test_csv)

print(f"Train rows: {len(train_long)}, columns: {train_long.columns.tolist()}")
print(f"Test rows: {len(test_long)}, columns: {test_long.columns.tolist()}")

train_images = glob.glob(os.path.join(DATA_ROOT, "train", "*.jpg"))
test_images = glob.glob(os.path.join(DATA_ROOT, "test", "*.jpg"))
print(f"Detected train images: {len(train_images)}")
print(f"Detected test images: {len(test_images)}")

test_wide = to_wide(test_long, include_targets=False)
print("test_wide shape:", test_wide.shape)
test_wide.head()



In [None]:
# Dataset 및 Transform (추론 전용 DataLoader)


def get_inference_loader(test_df: pd.DataFrame) -> DataLoader:
    ds = RegressionDataset(
        test_df,
        cfg.paths.resolve_image_root(),
        cfg.train.image_size,
        augment=False,
        use_targets=False,
    )
    return DataLoader(
        ds,
        batch_size=cfg.train.batch_size,
        shuffle=False,
        num_workers=cfg.train.num_workers,
        pin_memory=True,
    )


inference_loader = get_inference_loader(test_wide)



In [None]:
# 모델 정의 (pretrained=False, 학습된 state_dict 로드 예정)


def load_model(checkpoint_path: str) -> torch.nn.Module:
    model = build_model(cfg.train.backbone, pretrained=False)
    state = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(state)
    model.to(device)
    model.eval()
    return model



In [None]:
# 추론 및 제출 파일 생성


def predict_wide(loader: DataLoader) -> np.ndarray:
    ckpts = sorted(glob.glob(os.path.join(WEIGHTS_ROOT, "fold*_best.pth")))
    if not ckpts:
        raise FileNotFoundError(f"No checkpoint files found in {WEIGHTS_ROOT}.")

    preds_stack = []
    for ckpt_path in ckpts:
        model = load_model(ckpt_path)
        fold_preds = []
        with torch.no_grad():
            for images, _, _ in tqdm(loader, desc=f"Predict {os.path.basename(ckpt_path)}"):
                images = images.to(device)
                outputs = model(images)
                fold_preds.append(outputs.cpu().numpy())
        preds_stack.append(np.concatenate(fold_preds))

    preds = np.mean(preds_stack, axis=0)
    return preds


def build_submission(test_long_df: pd.DataFrame, test_wide_df: pd.DataFrame, preds: np.ndarray, run_dir: str) -> str:
    full_preds = expand_targets(preds)
    pred_df = pd.DataFrame(full_preds, columns=["Dry_Green_g", "Dry_Dead_g", "Dry_Clover_g", "GDM_g", "Dry_Total_g"])
    pred_df["sample_id_prefix"] = test_wide_df["sample_id_prefix"].values

    pred_long = pred_df.melt(id_vars="sample_id_prefix", var_name="target_name", value_name="target")
    pred_long["sample_id"] = pred_long["sample_id_prefix"].astype(str) + "__" + pred_long["target_name"].astype(str)

    merged = test_long_df.merge(
        pred_long[["sample_id_prefix", "target_name", "target"]],
        on=["sample_id_prefix", "target_name"],
        how="left",
    )

    submission = merged[["sample_id", "target"]].copy()
    os.makedirs(os.path.join(run_dir, "submission"), exist_ok=True)
    submission.to_csv(SUBMISSION_PATH, index=False)
    return SUBMISSION_PATH


def run_inference_and_save(test_long_df: pd.DataFrame, test_wide_df: pd.DataFrame) -> str:
    loader = get_inference_loader(test_wide_df)
    preds = predict_wide(loader)
    submission_path = build_submission(test_long_df, test_wide_df, preds, RUN_DIR)

    os.makedirs(os.path.dirname(WORKING_SUBMISSION), exist_ok=True)
    pd.read_csv(submission_path).to_csv(WORKING_SUBMISSION, index=False)
    print("Saved submission to:", submission_path)
    print("Copied submission to:", WORKING_SUBMISSION)
    return submission_path



In [None]:
# 제출 파일 검증 및 출력

submission_path = run_inference_and_save(test_long, test_wide)

submission = pd.read_csv(submission_path)
print("Submission shape:", submission.shape)
print("Submission columns:", submission.columns.tolist())
print("NaN present:", submission["target"].isna().any())

sample_submission_path = os.path.join(DATA_ROOT, "sample_submission.csv")
if os.path.exists(sample_submission_path):
    sample_sub = pd.read_csv(sample_submission_path)
    sample_ids = set(sample_sub["sample_id"])
    submission_ids = set(submission["sample_id"])
    missing_ids = sample_ids - submission_ids
    extra_ids = submission_ids - sample_ids
    print("Missing IDs compared to sample_submission:", len(missing_ids))
    print("Extra IDs compared to sample_submission:", len(extra_ids))
else:
    print("sample_submission.csv not found at", sample_submission_path)

print("Final submission path:", submission_path)
print("Working copy path:", WORKING_SUBMISSION)

