In [1]:
# ================================================================
# ConvNeXt-L Multi-Task Biomass Regression — FULL WORKING CELL
# ================================================================
import os, numpy as np, pandas as pd, warnings
from tqdm import tqdm
from PIL import Image
import torch, torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
import torchvision.models as models
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

warnings.filterwarnings("ignore")

# -------------------------------------------------------------
# ENVIRONMENT PATHS (LOCAL OR KAGGLE)
# -------------------------------------------------------------
if os.path.exists("/kaggle/input"):
    print("Running in KAGGLE environment")
    BASE = "/kaggle/input/csiro-biomass"
    TRAIN_CSV = f"{BASE}/train.csv"
    TEST_CSV  = f"{BASE}/test.csv"
    TRAIN_IMG = f"{BASE}/train"
    TEST_IMG  = f"{BASE}/test"
    SAVE_PATH = "/kaggle/working/submission.csv"
else:
    print("Running in LOCAL environment")
    BASE = "/home/rameyjm7/workspace/datasets/CSIRO"
    TRAIN_CSV = f"{BASE}/train.csv"
    TEST_CSV  = f"{BASE}/test.csv"
    TRAIN_IMG = f"{BASE}/train"
    TEST_IMG  = f"{BASE}/test"
    SAVE_PATH = f"{BASE}/submission.csv"

device = "cuda" if torch.cuda.is_available() else "cpu"

# -------------------------------------------------------------
# LOAD DATA
# -------------------------------------------------------------
df = pd.read_csv(TRAIN_CSV)
df["sample_id_base"] = df["sample_id"].apply(lambda x: x.split("__")[0])

targets = ["Dry_Green_g","Dry_Dead_g","Dry_Clover_g","GDM_g","Dry_Total_g"]

# -------------------------------------------------------------
# MULTI-TASK DATASET
# -------------------------------------------------------------
class BiomassDataset(Dataset):
    def __init__(self, df, img_dir, transform=None):
        self.df = df
        self.img_dir = img_dir
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        sid = row["sample_id_base"]
        img = Image.open(os.path.join(self.img_dir, f"{sid}.jpg")).convert("RGB")
        if self.transform:
            img = self.transform(img)

        y = row[targets].values.astype(np.float32)
        return img, torch.tensor(y)

# -------------------------------------------------------------
# IMAGE TRANSFORMS FOR CONVNEXT-L
# -------------------------------------------------------------
transform = T.Compose([
    T.Resize((384, 384)),
    T.ToTensor(),
    T.Normalize([0.485,0.456,0.406], [0.229,0.224,0.225])
])

# -------------------------------------------------------------
# TRAIN/VAL SPLIT
# -------------------------------------------------------------
pivot = df.pivot(index="sample_id_base", columns="target_name", values="target").reset_index()
X = pivot
y = pivot[targets]

train_df, val_df = train_test_split(pivot, test_size=0.15, random_state=42)

train_ds = BiomassDataset(train_df, TRAIN_IMG, transform)
val_ds   = BiomassDataset(val_df, TRAIN_IMG, transform)

train_loader = DataLoader(train_ds, batch_size=16, shuffle=True, num_workers=4)
val_loader   = DataLoader(val_ds, batch_size=16, shuffle=False, num_workers=4)

# -------------------------------------------------------------
# CONVNEXT-L MULTI-TASK MODEL (FIXED VERSION)
# -------------------------------------------------------------
class ConvNeXtL_MTL(nn.Module):
    def __init__(self):
        super().__init__()
        self.backbone = models.convnext_large(
            weights=models.ConvNeXt_Large_Weights.IMAGENET1K_V1
        )

        # Remove classifier
        self.backbone.classifier = nn.Identity()

        # ConvNeXt output: (B, 1536, H/32, W/32)
        # FIX: Add global average pooling
        self.pool = nn.AdaptiveAvgPool2d((1,1))

        self.head = nn.Sequential(
            nn.Linear(1536, 1024),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, 5)     # 5 regression outputs
        )

    def forward(self, x):
        x = self.backbone.features(x)
        x = self.pool(x)
        x = torch.flatten(x, 1)
        return self.head(x)

model = ConvNeXtL_MTL().to(device)

# -------------------------------------------------------------
# TRAINING SETUP
# -------------------------------------------------------------
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5, weight_decay=1e-4)
criterion = nn.MSELoss()
scaler = torch.cuda.amp.GradScaler(enabled=(device=="cuda"))

# -------------------------------------------------------------
# TRAINING LOOP
# -------------------------------------------------------------
EPOCHS = 8

for epoch in range(1, EPOCHS+1):
    model.train()
    train_loss = 0

    for imgs, ys in tqdm(train_loader, desc=f"Epoch {epoch}/{EPOCHS}"):
        imgs, ys = imgs.to(device), ys.to(device)

        optimizer.zero_grad(set_to_none=True)

        with torch.cuda.amp.autocast(enabled=(device=="cuda")):
            preds = model(imgs)
            loss = criterion(preds, ys)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        train_loss += loss.item() * imgs.size(0)

    train_loss /= len(train_loader.dataset)

    # ---------------------- VALIDATION ----------------------
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for imgs, ys in val_loader:
            imgs, ys = imgs.to(device), ys.to(device)
            preds = model(imgs)
            val_loss += criterion(preds, ys).item() * imgs.size(0)

    val_loss /= len(val_loader.dataset)

    print(f"Epoch {epoch}: Train={train_loss:.4f}  Val={val_loss:.4f}")

# -------------------------------------------------------------
# FULL TEST SET PREDICTION
# -------------------------------------------------------------
test_df = pd.read_csv(TEST_CSV)
test_df["sample_id_base"] = test_df["sample_id"].apply(lambda x: x.split("__")[0])
unique_ids = test_df["sample_id_base"].unique()

test_preds = {}

for sid in tqdm(unique_ids, desc="Predicting TEST"):
    img_path = os.path.join(TEST_IMG, f"{sid}.jpg")
    img = Image.open(img_path).convert("RGB")
    x = transform(img).unsqueeze(0).to(device)

    with torch.no_grad():
        pred = model(x).cpu().numpy()[0]

    test_preds[sid] = pred

# -------------------------------------------------------------
# BUILD FINAL SUBMISSION
# -------------------------------------------------------------
rows = []
for _, row in test_df.iterrows():
    sid = row["sample_id_base"]
    tname = row["target_name"]
    idx = targets.index(tname)
    rows.append({
        "sample_id": row["sample_id"],
        "target": float(test_preds[sid][idx])
    })

submission = pd.DataFrame(rows)
submission.to_csv(SAVE_PATH, index=False)

print("\nSaved submission:", SAVE_PATH)
print(submission.head())
print("Shape:", submission.shape)


Running in LOCAL environment


Epoch 1/8: 100%|██████████| 19/19 [00:26<00:00,  1.37s/it]


Epoch 1: Train=1275.5123  Val=1256.1846


Epoch 2/8: 100%|██████████| 19/19 [00:06<00:00,  2.93it/s]


Epoch 2: Train=1192.8734  Val=995.8061


Epoch 3/8: 100%|██████████| 19/19 [00:06<00:00,  2.93it/s]


Epoch 3: Train=873.0203  Val=573.0002


Epoch 4/8: 100%|██████████| 19/19 [00:06<00:00,  2.92it/s]


Epoch 4: Train=525.4929  Val=440.3355


Epoch 5/8: 100%|██████████| 19/19 [00:06<00:00,  2.91it/s]


Epoch 5: Train=433.8575  Val=316.3148


Epoch 6/8: 100%|██████████| 19/19 [00:06<00:00,  2.90it/s]


Epoch 6: Train=359.2002  Val=256.5038


Epoch 7/8: 100%|██████████| 19/19 [00:06<00:00,  2.90it/s]


Epoch 7: Train=313.5235  Val=235.2583


Epoch 8/8: 100%|██████████| 19/19 [00:06<00:00,  2.90it/s]


Epoch 8: Train=299.6916  Val=256.0292


Predicting TEST: 100%|██████████| 1/1 [00:01<00:00,  1.63s/it]


Saved submission: /home/rameyjm7/workspace/datasets/CSIRO/submission.csv
                    sample_id     target
0  ID1001187975__Dry_Clover_g   5.726462
1    ID1001187975__Dry_Dead_g  10.555432
2   ID1001187975__Dry_Green_g  27.412813
3   ID1001187975__Dry_Total_g  42.583504
4         ID1001187975__GDM_g  31.700449
Shape: (5, 2)



