In [61]:
import sys

sys.path.insert(
    0,
    "/kaggle/input/csiro-timm-latest/pytorch-image-models-1.0.22"
)

import timm
print("version:", timm.__version__)
print("file:", timm.__file__)

version: 1.0.22
file: /kaggle/input/csiro-timm-latest/pytorch-image-models-1.0.22/timm/__init__.py


In [62]:
import os
from pathlib import Path
import pandas as pd
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
tqdm.pandas()

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

import timm
from pytorch_lightning import LightningModule
import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2

from sklearn.model_selection import KFold, GroupKFold, StratifiedGroupKFold

from types import SimpleNamespace

In [63]:
timm.list_models("*dino*")[:10]

['vit_7b_patch16_dinov3',
 'vit_base_patch14_dinov2',
 'vit_base_patch14_reg4_dinov2',
 'vit_base_patch16_dinov3',
 'vit_base_patch16_dinov3_qkvb',
 'vit_giant_patch14_dinov2',
 'vit_giant_patch14_reg4_dinov2',
 'vit_huge_plus_patch16_dinov3',
 'vit_huge_plus_patch16_dinov3_qkvb',
 'vit_large_patch14_dinov2']

In [64]:
DATA_ROOT = '/kaggle/input/csiro-biomass/'

# train
train_df = pd.read_csv(f'{DATA_ROOT}/train.csv')
train_df[['sample_id_prefix', 'sample_id_suffix']] = train_df.sample_id.str.split('__', expand=True)

# agg_train_df の作成
cols = ['sample_id_prefix', 'image_path', 'Sampling_Date', 'State', 'Species', 'Pre_GSHH_NDVI', 'Height_Ave_cm']
agg_train_df = train_df.groupby(cols).apply(lambda df: df.set_index('target_name').target)
agg_train_df.reset_index(inplace=True)
agg_train_df.columns.name = None

agg_train_df['image'] = agg_train_df.image_path.progress_apply(
    lambda path: Image.open(DATA_ROOT + path).convert('RGB')
)


  agg_train_df = train_df.groupby(cols).apply(lambda df: df.set_index('target_name').target)


  0%|          | 0/357 [00:00<?, ?it/s]

In [65]:
# 画像サイズ確認
agg_train_df['image_size'] = agg_train_df.image.apply(lambda x: x.size)
agg_train_df['image_size'].value_counts()

# ターゲット合計確認
np.isclose(agg_train_df[['Dry_Green_g', 'Dry_Clover_g']].sum(axis=1),
           agg_train_df['GDM_g'], atol=1e-4).mean()

np.isclose(agg_train_df[['GDM_g', 'Dry_Dead_g']].sum(axis=1),
           agg_train_df['Dry_Total_g'], atol=1e-4).mean()


0.9971988795518207

In [66]:
# test.csv
test_df = pd.read_csv(DATA_ROOT + 'test.csv')
test_df[['sample_id_prefix', 'sample_id_suffix']] = test_df.sample_id.str.split('__', expand=True)

# 推論用 agg_test_df
agg_test_df = test_df.drop_duplicates(subset='sample_id_prefix').copy()

agg_test_df['image'] = agg_test_df.image_path.progress_apply(
    lambda path: Image.open(DATA_ROOT + path).convert('RGB')
)


  0%|          | 0/1 [00:00<?, ?it/s]

In [67]:
class InferenceDataset(Dataset):
    def __init__(self, df, transforms):
        self.df = df.reset_index(drop=True)
        self.transforms = transforms

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        image = self.df.iloc[idx]["image"]
        width, height = image.size
        mid_point = width // 2

        # 左右に分割
        left_image = image.crop((0, 0, mid_point, height))
        right_image = image.crop((mid_point, 0, width, height))

        if self.transforms:
            left_image = self.transforms(image=np.array(left_image))["image"]
            right_image = self.transforms(image=np.array(right_image))["image"]

        return left_image, right_image


In [68]:
!ls /kaggle/input/model-dinov3-large/model.safetensors

/kaggle/input/model-dinov3-large/model.safetensors


In [69]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import timm
import os 

class LocalMambaBlock(nn.Module):
    """
    Lightweight Mamba-style block (Gated CNN) from the reference notebook.
    Efficiently mixes tokens with linear complexity.
    """
    def __init__(self, dim, kernel_size=5, dropout=0.0):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        # Depthwise conv mixes spatial information locally
        self.dwconv = nn.Conv1d(dim, dim, kernel_size=kernel_size, padding=kernel_size // 2, groups=dim)
        self.gate = nn.Linear(dim, dim)
        self.proj = nn.Linear(dim, dim)
        self.drop = nn.Dropout(dropout)

    def forward(self, x):
        # x: (Batch, Tokens, Dim)
        shortcut = x
        x = self.norm(x)
        # Gating mechanism
        g = torch.sigmoid(self.gate(x))
        x = x * g
        # Spatial mixing via 1D Conv (requires transpose)
        x = x.transpose(1, 2)  # -> (B, D, N)
        x = self.dwconv(x)
        x = x.transpose(1, 2)  # -> (B, N, D)
        # Projection
        x = self.proj(x)
        x = self.drop(x)
        return shortcut + x

class TimmEncoder(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg

        pretrained = True if cfg.model.resume_path is None else False

        # 1. Load Backbone with global_pool='' to keep patch tokens
        #    (B, 197, 1024) instead of (B, 1024)
        self.encoder = timm.create_model(
            cfg.model.backbone,
            in_chans=cfg.task.slice_depth,
            pretrained=False,
            # drop_path_rate=cfg.model.drop_path_rate,
            features_only=False,
            num_classes=0,
            global_pool="",  # 自前でpoolingするのでここは空
        )

        # 2. Enable Gradient Checkpointing (Crucial for ViT-Large memory!)
        if hasattr(self.encoder, 'set_grad_checkpointing'):
            self.encoder.set_grad_checkpointing(True)
            print("✓ Gradient Checkpointing enabled (saves ~50% VRAM)")
        
        nf = self.encoder.num_features
        
        # 3. Mamba Fusion Neck
        #    Mixes the concatenated tokens [Left, Right]
        self.fusion = nn.Sequential(
            LocalMambaBlock(nf, kernel_size=5, dropout=0.1),
            LocalMambaBlock(nf, kernel_size=5, dropout=0.1)
        )
        
        # 4. Pooling & Heads
        self.pool = nn.AdaptiveAvgPool1d(1)
        
        # Heads (using the same logic as before, but on fused features)
        self.head_green_raw  = nn.Sequential(
            nn.Linear(nf, nf//2), nn.GELU(), nn.Dropout(0.2), 
            nn.Linear(nf//2, 1), nn.Softplus()
        )
        self.head_clover_raw = nn.Sequential(
            nn.Linear(nf, nf//2), nn.GELU(), nn.Dropout(0.2), 
            nn.Linear(nf//2, 1), nn.Softplus()
        )
        self.head_dead_raw   = nn.Sequential(
            nn.Linear(nf, nf//2), nn.GELU(), nn.Dropout(0.2), 
            nn.Linear(nf//2, 1), nn.Softplus()
        )
        
        
        if pretrained:
            self.load_pretrained()
    
        if cfg.model.freeze_backbone:
            for p in self.encoder.parameters():
                p.requires_grad = False

    def load_pretrained(self):
        try:
            path = self.cfg.model.backbone_path
            if path and os.path.exists(path):
                print(f"Loading backbone weights from local file: {path}")
                
                if path.endswith(".safetensors"):
                    from safetensors.torch import load_file
                    sd = load_file(path) # safetensors専用のロード
                else:
                    sd = torch.load(path, map_location='cpu')
                
                # wrapperの除去
                if 'model' in sd: sd = sd['model']
                elif 'state_dict' in sd: sd = sd['state_dict']
                
                # load_state_dictを実行
                self.encoder.load_state_dict(sd, strict=False)
                print('Successfully loaded local weights.')
            else:
                print(f"Warning: backbone_path not found at {path}")
        except Exception as e:
            print(f'Warning: pretrained load failed: {e}')

    def forward(self, left_img: torch.Tensor, right_img: torch.Tensor):
        # 1. Extract Tokens (B, N, D)
        #    Note: ViT usually returns [CLS, Patch1, Patch2...]
        #    We remove CLS token for spatial mixing, or keep it. Let's keep it.
        x_l = self.encoder(left_img)
        x_r = self.encoder(right_img)
        # x_l = self.encoder.forward_features(left_img)
        # x_r = self.encoder.forward_features(right_img)

        # 2. Concatenate Left and Right tokens along sequence dimension
        #    (B, N, D) + (B, N, D) -> (B, 2N, D)
        x_cat = torch.cat([x_l, x_r], dim=1)
        
        # 3. Apply Mamba Fusion
        #    This allows tokens from Left image to interact with tokens from Right image
        x_fused = self.fusion(x_cat)
        
        # 4. Global Pooling
        #    (B, 2N, D) -> (B, D, 2N) -> (B, D, 1) -> (B, D)
        x_pool = self.pool(x_fused.transpose(1, 2)).flatten(1)
        
        # 5. Prediction Heads
        green  = self.head_green_raw(x_pool)
        clover = self.head_clover_raw(x_pool)
        dead   = self.head_dead_raw(x_pool)
        
        # Summation logic
        gdm    = green + clover
        total  = gdm + dead
        
        return total, gdm, green, clover, dead

    def set_grad_checkpointing(self, enable: bool = True):
        self.encoder.set_grad_checkpointing(enable)


In [70]:
def get_model_from_cfg(cfg):
    if cfg.model.arch == "timm_encoder":
        model = TimmEncoder(cfg)
    else:
        raise ValueError(f"Unknown model architecture: {cfg.model.arch}")
    return model

In [71]:

def get_loss(cfg):
    return MyLoss(cfg)

class MyLoss(nn.Module):
    def __init__(self, cfg):
        super(MyLoss, self).__init__()
        self.cfg = cfg

        # 基本は SmoothL1（元コードと同じ）
        self.criterion = nn.SmoothL1Loss(beta=5.0, reduction="mean")

        # 将来の拡張用（今は使わないが cfg で制御できる）
        self.use_weights = getattr(cfg.loss, "use_weights", False)
        self.weights = getattr(cfg.loss, "weights", None)

    def forward(self, y_pred, y_true):
        """
        Args:
            y_pred (tuple of Tensor): (total, gdm, green, clover, dead)
                各 Tensor の形状 = (batch,)
            y_true (Tensor): (batch, 5)
                各列 = [Green, Dead, Clover, GDM, Total]
        Returns:
            dict:
                {
                    "loss": total_loss,
                    "loss_green": ...,
                    "loss_dead": ...,
                    "loss_clover": ...,
                    "loss_gdm": ...,
                    "loss_total": ...,
                }
        """
        return_dict = {}
        total, gdm, green, clover, dead = y_pred

        # 個別損失を計算
        l_green  = self.criterion(green.squeeze(),  y_true[:,0])
        l_dead   = self.criterion(dead.squeeze(),   y_true[:,1])
        l_clover = self.criterion(clover.squeeze(), y_true[:,2])
        l_gdm    = self.criterion(gdm.squeeze(),    y_true[:,3])
        l_total  = self.criterion(total.squeeze(),  y_true[:,4])

        # 辞書に格納
        return_dict["loss_green"]  = l_green
        return_dict["loss_dead"]   = l_dead
        return_dict["loss_clover"] = l_clover
        return_dict["loss_gdm"]    = l_gdm
        return_dict["loss_total"]  = l_total

        # 損失をまとめる
        losses = torch.stack([l_green, l_dead, l_clover, l_gdm, l_total])

        if self.use_weights and self.weights is not None:
            w = torch.as_tensor(self.weights, device=losses.device, dtype=losses.dtype)
            w = w / w.sum()
            total_loss = (losses * w).sum()
        else:
            total_loss = losses.mean()

        return_dict["loss"] = total_loss
        return return_dict


def main():
    pass


if __name__ == '__main__':
    main()


In [72]:
from pathlib import Path
import numpy as np
from pytorch_lightning.core.module import LightningModule
from timm.utils import ModelEmaV2
from timm.optim import create_optimizer_v2
from timm.scheduler import create_scheduler_v2
import torch

from timm.utils import ModelEmaV3

class MyModel(LightningModule):
    def __init__(self, cfg, mode="train"):
        super().__init__()
        self.preds = None
        self.gts = None

        self.cfg = cfg
        self.mode = mode
        
        self.model = get_model_from_cfg(cfg)

        # epoch 集計用
        self.val_outputs = []
        self.val_targets = []

        if mode != "test" and cfg.model.ema:
            self.model_ema = ModelEmaV3(
                self.model,
                decay=cfg.model.ema_decay,
                update_after_step=cfg.model.ema_update_after_step,
            )

        self.loss = get_loss(cfg)


    def forward(self, left_img, right_img):
        return self.model(left_img, right_img)

    def training_step(self, batch, batch_idx):
        left_img, right_img, targets = batch  # (B, 5)
        targets = targets.float()

        # outputs = (total, gdm, green, clover, dead)
        outputs = self(left_img, right_img)
        loss_dict = self.loss(outputs, targets)

        self.log_dict(
            loss_dict,
            on_step=False,
            on_epoch=True,
            prog_bar=True,
            sync_dist=True,
        )
        return loss_dict["loss"]

    def on_train_batch_end(self, out, batch, batch_idx):
        if self.cfg.model.ema:
            self.model_ema.update(self.model)

    def validation_step(self, batch, batch_idx):
        left_img, right_img, targets = batch
        targets = targets.float()

        outputs = self(left_img, right_img)
        loss_dict = self.loss(outputs, targets)

        self.log("val_loss", loss_dict["loss"], prog_bar=True, sync_dist=True)
        
        self.val_outputs.append(tuple(o.detach() for o in outputs))
        self.val_targets.append(targets.detach())

        return loss_dict

    def on_validation_epoch_end(self):
        outputs = torch.cat(
            [torch.stack(t, dim=1) for t in self.val_outputs],
            dim=0
        ).cpu().numpy()
        outputs = outputs.squeeze(-1)  # (N, 5)

        targets = torch.cat(self.val_targets).cpu().numpy()

        weighted_r2, r2_scores = calc_metric(self.cfg, outputs, targets)

        # メトリクスをログ
        self.log("val_weighted_r2", weighted_r2, prog_bar=True)

        # 複数ターゲットなら個別ログも可
        for i, r2 in enumerate(r2_scores):
            self.log(f"val_r2_target_{i}", r2)

        # 次epochに向けてクリア
        self.val_outputs.clear()
        self.val_targets.clear()

    def configure_optimizers(self):
        optimizer = create_optimizer_v2(model_or_params=self.model, **self.cfg.opt)

        scheduler, _ = create_scheduler_v2(
            optimizer=optimizer,
            num_epochs=self.cfg.trainer.max_epochs,
            **self.cfg.scheduler
        )

        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "interval": "epoch",
                "monitor": "val_weighted_r2",
            },
        }

    def lr_scheduler_step(self, scheduler, metric):
        scheduler.step(epoch=self.current_epoch)
    #     # scheduler.step_update(num_updates=self.global_step)


In [73]:
!ls /kaggle/input/model-dinov3-large

model.safetensors


In [74]:
from types import SimpleNamespace

cfg = SimpleNamespace()

# --- task ---
cfg.task = SimpleNamespace(
    img_size=224,
    img_depth=16,
    fixed_depth=16,
    slice_depth=3,
    pretrain=False,
    dirname="train_npzs"
)

# --- model ---
cfg.model = SimpleNamespace(
    freeze_end_epoch=0,
    arch="timm_encoder",
    in_channels=16,
    out_channels=1,
    depth=4,
    base_filters=64,
    dropout=0.1,
    use_batchnorm=True,
    activation="relu",
    swa=False,
    freeze_backbone=False,
    backbone="vit_large_patch16_dinov3_qkvb",
    backbone_path="/kaggle/input/model-dinov3-large/model.safetensors",
    ema=False,
    resume_path="loaded",
    drop_path_rate=0.0,
    img_size=128,
    img_depth=16,
    kernel_size=5,
    class_num=5
)

# --- data ---
cfg.data = SimpleNamespace(
    fold_num=5,
    fold_id=0,
    num_workers=8,
    batch_size=32,
    train_all=False,
    input_dir=None,
    output_dir=None,
    val_output_dir=None
)

# --- trainer ---
cfg.trainer = SimpleNamespace(
    max_epochs=30,
    devices="auto",
    strategy="auto",
    check_val_every_n_epoch=5,
    sync_batchnorm=False,
    accelerator="gpu",
    precision=32,
    gradient_clip_val=None,
    accumulate_grad_batches=1,
    deterministic=True
)

# --- test ---
cfg.test = SimpleNamespace(
    mode="test",
    output_dir="preds_results"
)

# --- opt ---
cfg.opt = SimpleNamespace(
    opt="AdamW",
    lr=1e-3,
    weight_decay=0.01
)

# --- scheduler ---
cfg.scheduler = SimpleNamespace(
    sched="cosine",
    min_lr=0.0,
    warmup_epochs=0
)

# --- loss ---
cfg.loss = SimpleNamespace(
    mixup=0.0,
    cutmix=0.0
)

# --- wandb ---
cfg.wandb = SimpleNamespace(
    project="csiro2025",
    name="exp_0",
    fast_dev_run=False
)


In [75]:
!ls /kaggle/input/csiro-simple-exp10

exp_10_epoch079_val_loss9.1485.ckpt


In [76]:
def get_val_transforms(cfg):
    return A.Compose(
        [
            A.Resize(height=cfg.task.img_size, width=cfg.task.img_size, p=1),
            # A.RandomScale(scale_limit=(1.0, 1.0), p=1),
            # A.PadIfNeeded(min_height=cfg.task.img_size, min_width=cfg.task.img_size, p=1.0,
            #              border_mode=cv2.BORDER_CONSTANT, value=0),
            # A.Crop(y_max=self.cfg.data.val_img_h, x_max=self.cfg.data.val_img_w, p=1.0),
            A.Normalize(p=1.0, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], max_pixel_value=255.0),
            ToTensorV2(p=1.0),
        ],
        p=1.0,
    )

In [77]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

ckpt_path = "/kaggle/input/csiro-simple-exp10/exp_10_epoch079_val_loss9.1485.ckpt"

model = MyModel.load_from_checkpoint(ckpt_path, cfg=cfg, mode="test")
model.to(device)
model.eval()

test_dataset = InferenceDataset(agg_test_df, get_val_transforms(cfg))
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=2, pin_memory=True)

def predict(model, dataloader, device):
    model.to(device)
    model.eval()
    preds = []
    with torch.no_grad():
        for left_img, right_img in tqdm(dataloader): # tqdmを追加して進捗を表示
            left_img = left_img.to(device)
            right_img = right_img.to(device)
            
            # outputs = (total, gdm, green, clover, dead)
            outputs = model(left_img, right_img)
            
            # submission.csv の作成ロジックに合わせて [green, clover, dead] を抽出
            # 各ヘッドの出力は (batch, 1) なので concat して (batch, 3) にする
            res = torch.cat([outputs[2], outputs[3], outputs[4]], dim=1)
            preds.append(res.cpu())
            
    return torch.cat(preds).numpy()

preds = predict(model, test_loader, device)



/usr/local/lib/python3.11/dist-packages/pytorch_lightning/utilities/migration/utils.py:56: The loaded checkpoint was produced with Lightning v2.6.0, which is newer than your current Lightning version: v2.5.5


✓ Gradient Checkpointing enabled (saves ~50% VRAM)


  0%|          | 0/1 [00:00<?, ?it/s]

In [78]:
agg_test_df[['Dry_Green_g', 'Dry_Clover_g', 'Dry_Dead_g']] = preds
agg_test_df['GDM_g'] = agg_test_df.Dry_Green_g + agg_test_df.Dry_Clover_g
agg_test_df['Dry_Total_g'] = agg_test_df.GDM_g + agg_test_df.Dry_Dead_g

cols = ['Dry_Clover_g', 'Dry_Dead_g', 'Dry_Green_g', 'Dry_Total_g', 'GDM_g']
sub_df = agg_test_df.set_index('sample_id_prefix')[cols].stack().reset_index()
sub_df.columns = ['sample_id_prefix', 'target_name', 'target']
sub_df['sample_id'] = sub_df.sample_id_prefix + '__' + sub_df.target_name

sub_df[['sample_id', 'target']].to_csv('submission.csv', index=False)


In [79]:
sub_df.head()

Unnamed: 0,sample_id_prefix,target_name,target,sample_id
0,ID1001187975,Dry_Clover_g,2.835814,ID1001187975__Dry_Clover_g
1,ID1001187975,Dry_Dead_g,14.239312,ID1001187975__Dry_Dead_g
2,ID1001187975,Dry_Green_g,13.224589,ID1001187975__Dry_Green_g
3,ID1001187975,Dry_Total_g,30.299717,ID1001187975__Dry_Total_g
4,ID1001187975,GDM_g,16.060404,ID1001187975__GDM_g


In [80]:
!head submission.csv

sample_id,target
ID1001187975__Dry_Clover_g,2.8358138
ID1001187975__Dry_Dead_g,14.239312
ID1001187975__Dry_Green_g,13.224589
ID1001187975__Dry_Total_g,30.299717
ID1001187975__GDM_g,16.060404
