## Import Libs

In [1]:
import os
import warnings # 避免一些可以忽略的报错
warnings.filterwarnings('ignore')
import sys
import random
import copy
import math
from tqdm.auto import tqdm
from PIL import Image
import time
import gc
from collections import defaultdict

import pandas as pd
import numpy as np
import albumentations as A
from albumentations.pytorch import ToTensorV2
from sklearn.model_selection import StratifiedGroupKFold

import timm
import torch
from torch import nn
import torch.nn.functional as F
from torch import optim
from torch.utils.data import Dataset, DataLoader
from torch.optim import lr_scheduler # 学习率调度器
from torch.optim.lr_scheduler import _LRScheduler, CosineAnnealingLR

from colorama import Fore, Back, Style
b_ = Fore.BLUE
sr_ = Style.RESET_ALL

In [None]:
train = pd.read_csv("CSIRO/csiro_data_split.csv")
train

Unnamed: 0,image_path,Sampling_Date,State,Species,Pre_GSHH_NDVI,Height_Ave_cm,Dry_Clover_g,Dry_Dead_g,Dry_Green_g,Dry_Total_g,...,emb1152,bin_total,clover_frac,bin_clover,state_key,key_L1,key_L2,key_L3,final_stratify,fold
0,/kaggle/input/csiro-biomass/train/ID1011485656...,2015/9/4,Tas,Ryegrass_Clover,0.62,4.6667,0.0000,31.9984,16.2751,48.2735,...,-0.027203,M,0.000000,Lo,Tas,Tas_M_Lo,Tas_M,Tas,Tas_M_Lo,4
1,/kaggle/input/csiro-biomass/train/ID1012260530...,2015/4/1,NSW,Lucerne,0.55,16.0000,0.0000,0.0000,7.6000,7.6000,...,-0.001485,L,0.000000,Lo,NSW,NSW_L_Lo,NSW_L,NSW,NSW_L_Lo,0
2,/kaggle/input/csiro-biomass/train/ID1025234388...,2015/9/1,WA,SubcloverDalkeith,0.38,1.0000,6.0500,0.0000,0.0000,6.0500,...,-0.030815,L,1.000000,Hi,WA,WA_L_Hi,WA_L,WA,WA_L_Hi,3
3,/kaggle/input/csiro-biomass/train/ID1028611175...,2015/5/18,Tas,Ryegrass,0.66,5.0000,0.0000,30.9703,24.2376,55.2079,...,-0.020800,H,0.000000,Lo,Tas,Tas_H_Lo,Tas_H,Tas,Tas_H_Lo,2
4,/kaggle/input/csiro-biomass/train/ID1035947949...,2015/9/11,Tas,Ryegrass,0.54,3.5000,0.4343,23.2239,10.5261,34.1844,...,-0.031969,M,0.039624,Lo,Tas,Tas_M_Lo,Tas_M,Tas,Tas_M_Lo,2
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
352,/kaggle/input/csiro-biomass/train/ID975115267.jpg,2015/7/8,WA,Clover,0.73,3.0000,40.0300,0.0000,0.8000,40.8300,...,-0.033440,M,0.980407,Hi,WA,WA_M_Hi,WA_M,WA,WA_M_Hi,2
353,/kaggle/input/csiro-biomass/train/ID978026131.jpg,2015/9/4,Tas,Clover,0.83,3.1667,24.6445,4.1948,12.0601,40.8994,...,-0.022025,M,0.671428,Hi,Tas,Tas_M_Hi,Tas_M,Tas,Tas_M_Hi,2
354,/kaggle/input/csiro-biomass/train/ID980538882.jpg,2015/2/24,NSW,Phalaris,0.69,29.0000,0.0000,1.1457,91.6543,92.8000,...,-0.010250,H,0.000000,Lo,NSW,NSW_H_Lo,NSW_H,NSW,NSW_H_Lo,1
355,/kaggle/input/csiro-biomass/train/ID980878870.jpg,2015/7/8,WA,Clover,0.74,2.0000,32.3575,0.0000,2.0325,34.3900,...,-0.028169,M,0.940898,Hi,WA,WA_M_Hi,WA_M,WA,WA_M_Hi,0


## CONFIG

In [None]:
class CONFIG:
    is_debug = False
    seed = 308
    n_folds = 5
    n_workers = os.cpu_count() // 2

    # train_csv = "/kaggle/input/csiro-biomass/train.csv" # 官方 csv
    train_csv = "CSIRO/CSIRO_my_5fold_train_csv.csv" # 分过 5 折的 csv
    train_img_path = "CSIRO/train" # (1000, 2000)
    pretrain_ckpt_path = "CSIRO/ckpt"

    train_batch_size = 4
    valid_batch_size = 16
    now_cv = -np.inf

    epochs = 10
    start_lr_backbone = 1e-5
    start_lr_head = 1e-3
    min_lr_backbone = 1e-8
    min_lr_head = 1e-6
    scheduler = 'CosineAnnealingWithWarmupLR'
    n_accumulate = 1.0
    ckpt_save_path = None
    T_max = (357 // (n_folds * train_batch_size) + 1) * (n_folds - 1) * epochs

    model_name = "convnextv2_tiny.fcmae_ft_in22k_in1k"
    if "dino" in model_name:
        img_size = [518, 1036]
    else:
        img_size = [512, 1024]
    # img_size = [384, 768]
    """
    tf_efficientnet_b0.ns_jft_in1k
    edgenext_base.in21k_ft_in1k
    vit_base_patch14_dinov2.lvd142m
    convnextv2_tiny.fcmae_ft_in22k_in1k

    convnext_large_mlp.laion2b_ft_augreg_inat21
    resnet50.a1_in1k_ft_inat21
    efficientnet_b5.in1k_ft_inat21
    seresnext50_32x4d.racm_in1k_ft_inat21
    vit_base_patch14_dinov2.lvd142m
    vit_small_patch14_dinov2.lvd142m
    convnextv2_base.fcmae_ft_in22k_in1k
    """
    is_pretrained = True
    head_out = 5
    DataParallel = False
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

## Set Random Seed

In [None]:
def set_seed(seed=308):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    
set_seed(CONFIG.seed)

## Data Progress

In [None]:
# ## KFold

# # 读取并生成了 'image_group_id'
# df = pd.read_csv(CONFIG.train_csv)
# df['image_group_id'] = df['sample_id'].apply(lambda x: x.split('__')[0])

# # --- 步骤 1: 统计每个物种有多少个唯一的 Group (图片数) ---
# species_counts = df.groupby('Species')['image_group_id'].nunique()
# print("各物种图片数量统计：")
# print(species_counts)

# # --- 步骤 2: 定义稀有阈值 ---
# # 如果你想做 5 折，理论上至少需要 5 张图才能保证每折一张。
# # 为了保险，我们可以把少于 10 张图的都归为 'Other'
# threshold = 10 
# rare_species = species_counts[species_counts < threshold].index.tolist()

# print(f"\n将被归类为 'Other' 的稀有物种: {rare_species}")

# # --- 步骤 3: 创建用于分层的临时列 ---
# # 如果是稀有物种，标记为 'Other'，否则保持原名
# df['stratify_col'] = df['Species'].apply(lambda x: 'Other' if x in rare_species else x)

# # --- 步骤 4: 重新进行 StratifiedGroupKFold ---
# sgkf = StratifiedGroupKFold(n_splits=CONFIG.n_folds)
# df['fold'] = -1

# for fold_id, (train_idx, val_idx) in enumerate(sgkf.split(X=df, y=df['stratify_col'], groups=df['image_group_id'])):
#     df.loc[val_idx, 'fold'] = fold_id

# # --- 步骤 5: 再次检查分布 ---
# print("\n优化后的分布检查 (注意 Other 类的分布):")
# print(df.groupby(['fold', 'stratify_col']).size().unstack())

In [None]:
train_all = pd.read_csv(CONFIG.train_csv)
train_all

In [None]:
id_and_fold = {}
for i in range(len(train_all)):
    row = train_all.iloc[i, :]
    _id = row.sample_id.split("_")[0]
    _fold = row.fold.item()
    if _id not in id_and_fold.keys():
        id_and_fold[_id] = _fold

train = pd.DataFrame(list(id_and_fold.items()), columns=['sample_id', 'fold'])
train

## Dataset and DataLoader

In [None]:
# 1. 获取配置对象
cfg = timm.get_pretrained_cfg(CONFIG.model_name)

# 2. 【核心修复】先转成字典 (.to_dict()) 再传入
# 这样 resolve_data_config 就能正常使用 .get() 方法了
cfg_dict = cfg.to_dict()
data_config = timm.data.resolve_data_config(pretrained_cfg=cfg_dict)

# 3. 提取结果
_mean = data_config['mean']
_std = data_config['std']

print(f"模型: {CONFIG.model_name}")
print(f"自动获取 Mean: {_mean}")
print(f"自动获取 Std:  {_std}")
# ------------------------------------------------------


def transform(img):
    composition = A.Compose([
        A.Resize(CONFIG.img_size[0], CONFIG.img_size[1]),
        A.Normalize(
            mean=_mean,
            std=_std
        ),
        ToTensorV2(),
    ])
    return composition(image=img)["image"]

def transform_train(img):
    if random.random() < 0.5:
        # =================================================
        # Path A: 50% 概率 - 原始 Transform (无增强)
        # =================================================
        composition = A.Compose([
            A.Resize(CONFIG.img_size[0], CONFIG.img_size[0]), # 这里resize到 (512, 512)
            A.Normalize(mean=_mean, std=_std),
            ToTensorV2(),
        ])
    else:
        composition = A.Compose([
            A.Resize(CONFIG.img_size[0], CONFIG.img_size[0]),
            # 几何增强 (适合俯拍的草地/植物)
            A.HorizontalFlip(p=0.5),      # 水平翻转
            A.VerticalFlip(p=0.5),        # 垂直翻转
            A.RandomRotate90(p=0.5),      # 90度旋转 (草地没有方向性，这个很好用)
            
            # 像素/颜色增强 (模拟不同天气、光照、模糊)
            A.OneOf([
                A.RandomBrightnessContrast(p=1.0), # 亮度对比度
                A.HueSaturationValue(p=1.0),       # 色调饱和度
                A.GaussNoise(p=1.0),               # 高斯噪点
            ], p=0.5), # 这组增强有 50% 概率触发

            A.Normalize(mean=_mean, std=_std),
            ToTensorV2(),
        ])
    
    return composition(image=img)["image"]

def transform_valid(img):
    composition = A.Compose([
        A.Resize(CONFIG.img_size[0], CONFIG.img_size[1]), # 这里resize到 (512, 1024)
        A.Normalize(
            mean=_mean,
            std=_std
        ),
        ToTensorV2(),
    ])
    return composition(image=img)["image"]

In [None]:
class CSIRODataset(Dataset):
    def __init__(self, df, original_train=train_all, transform=transform):
        super().__init__()
        self.df = df
        self.original_train = original_train
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx, :]
        img_id = row.sample_id

        img_path = os.path.join(CONFIG.train_img_path, img_id + ".jpg")

        img = Image.open(img_path)
        img = np.array(img)
        if self.transform != None:
            img = self.transform(img)
        
        target_id = ["__Dry_Clover_g", "__Dry_Dead_g", "__Dry_Green_g", "__Dry_Total_g", "__GDM_g"]
        label = []
        for _id in target_id:
            tmp_row = self.original_train[self.original_train["sample_id"] == f"{img_id}{_id}"]["target"].item()
            label.append(tmp_row)
        label = torch.tensor(label, dtype=torch.float32)

        return img, label

In [None]:
def prepare_loaders(df, fold=0):
    df_train = df[df["fold"] != fold]
    df_valid = df[df["fold"] == fold]
    
    train_datasets = CSIRODataset(df=df_train, transform=transform_train)
    valid_datasets = CSIRODataset(df=df_valid, transform=transform_valid)
    
    train_loader = DataLoader(train_datasets, batch_size=CONFIG.train_batch_size, num_workers=CONFIG.n_workers, shuffle=True, pin_memory=True)
    valid_loader = DataLoader(valid_datasets, batch_size=CONFIG.valid_batch_size, num_workers=CONFIG.n_workers, shuffle=False, pin_memory=True)
    
    
    return train_loader, valid_loader

In [None]:
# 以下代码可检查Dataset，DataLoader是否实现基本功能
train_loader, valid_loader = prepare_loaders(train, 0)
x_train, y_train = next(iter(train_loader))
x_valid, y_valid = next(iter(valid_loader))
print(f"X_train shape : {x_train.shape}") # (batch_size, channels, H, W)
print(f"y_train shape : {y_train.shape}")
print(f"x_valid shape : {x_valid.shape}")
print(f"y_valid shape : {y_valid.shape}")

# 删除变量，回收垃圾
del train_loader, valid_loader, x_train, y_train, x_valid, y_valid
gc.collect()

## Evaluation

In [None]:
def Calculate_Weighted_R2(y_true, y_pred):
    """
    计算 Kaggle CSIRO Image2Biomass 比赛的加权 R2 分数。
    
    参数:
    y_true: 真实值，形状为 [n_samples, 5]
    y_pred: 预测值，形状为 [n_samples, 5]
    
    列顺序假设:
    0: Dry_Clover_g (w=0.1)
    1: Dry_Dead_g   (w=0.1)
    2: Dry_Green_g  (w=0.1)
    3: Dry_Total_g  (w=0.5)
    4: GDM_g        (w=0.2)
    """
    
    # 1. 定义权重向量
    weights = np.array([0.1, 0.1, 0.1, 0.5, 0.2])
    
    # 2. 确保输入是 numpy 数组
    y_true = np.array(y_true)
    y_pred = np.array(y_pred)
    
    # 3. 计算全局加权均值 (Global Weighted Mean)
    # 这里的 sum(weights) = 1.0，所以分母实际上就是 样本数 * 1.0
    # 我们利用广播机制将权重应用到每一行
    weighted_sum = np.sum(y_true * weights) 
    total_weight = np.sum(weights) * y_true.shape[0] # weights总和 * 样本数
    y_bar_w = weighted_sum / total_weight
    
    # 4. 计算残差平方和 (SS_res)
    # 公式: sum( w_j * (y_j - y_hat_j)^2 )
    ss_res = np.sum(weights * (y_true - y_pred)**2)
    
    # 5. 计算总离差平方和 (SS_tot)
    # 公式: sum( w_j * (y_j - y_bar_w)^2 )
    # 注意这里减去的是全局加权均值 y_bar_w
    ss_tot = np.sum(weights * (y_true - y_bar_w)**2)
    
    # 6. 计算 R2
    # 避免分母为0的极个别情况
    if ss_tot == 0:
        return 0.0
        
    r2 = 1 - (ss_res / ss_tot)
    
    return r2

# --- 测试用例 ---
# 模拟数据
dummy_true = np.array([
    [5, 16, 36, 54, 42],
    [3, 8, 10, 18, 13]
])
# 假设预测非常接近
dummy_pred = dummy_true + 0.5 

score = Calculate_Weighted_R2(dummy_true, dummy_pred)
print(f"验证集得分: {score:.5f}")

# 使用示例
# 模拟验证集数据
n_valid_samples = 16
# 随机生成验证集真实值和预测值
y_valid_true = np.random.rand(n_valid_samples, 5)
y_valid_pred = np.random.rand(n_valid_samples, 5)
# 计算分数
score = Calculate_Weighted_R2(y_valid_true, y_valid_pred).item()
score

In [None]:
# def Calculate_Weighted_R2(y_true, y_pred): # K2 感觉不太行
#     """
#     计算Kaggle比赛的全局加权R²分数
    
#     参数:
#     -----------
#     y_true : array-like of shape (n_samples, n_features)
#         真实值，特征顺序必须是：["__Dry_Clover_g", "__Dry_Dead_g", "__Dry_Green_g", "__Dry_Total_g", "__GDM_g"]
#     y_pred : array-like of shape (n_samples, n_features)
#         预测值，形状与y_true相同
    
#     返回:
#     -----------
#     float
#         加权R²分数
#     """
#     # 转换为numpy数组（处理torch.Tensor和numpy数组两种情况）
#     if isinstance(y_true, torch.Tensor):
#         y_true = y_true.cpu().numpy()
#     if isinstance(y_pred, torch.Tensor):
#         y_pred = y_pred.cpu().numpy()
    
#     y_true = np.asarray(y_true)
#     y_pred = np.asarray(y_pred)
    
#     # 验证输入形状
#     assert y_true.shape == y_pred.shape, "y_true和y_pred的形状必须相同"
#     assert y_true.ndim == 2, "输入必须是2维数组，形状为[样本数量, 特征数量]"
#     assert y_true.shape[1] == 5, f"特征数量必须是5，当前是{y_true.shape[1]}"
    
#     # 特征权重映射（根据官方评估方案）
#     # 注意：权重是根据目标类型（特征）而不是样本分配的
#     feature_weights = np.array([
#         0.1,  # __Dry_Clover_g
#         0.1,  # __Dry_Dead_g  
#         0.1,  # __Dry_Green_g
#         0.5,  # __Dry_Total_g
#         0.2   # __GDM_g
#     ])
    
#     # 展平数组（官方要求：所有(image, target)对一起计算）
#     y_true_flat = y_true.flatten()  # 形状: (n_samples * n_features,)
#     y_pred_flat = y_pred.flatten()  # 形状: (n_samples * n_features,)
    
#     # 为每个展平后的值创建对应的权重
#     # 每个样本的5个特征循环使用feature_weights
#     weights_flat = np.tile(feature_weights, y_true.shape[0])
    
#     # 计算全局加权均值 y_w
#     y_w = np.sum(weights_flat * y_true_flat) / np.sum(weights_flat)
    
#     # 计算残差平方和 SSres
#     ss_res = np.sum(weights_flat * (y_true_flat - y_pred_flat) ** 2)
    
#     # 计算总平方和 SStot
#     ss_tot = np.sum(weights_flat * (y_true_flat - y_w) ** 2)
    
#     # 防止除零错误
#     if ss_tot == 0:
#         return 1.0 if ss_res == 0 else 0.0
    
#     # 计算加权R²
#     r2 = 1 - (ss_res / ss_tot)
    
#     return r2

# # 使用示例
# # 模拟验证集数据
# n_valid_samples = 16
# # 随机生成验证集真实值和预测值
# y_valid_true = np.random.rand(n_valid_samples, 5)
# y_valid_pred = np.random.rand(n_valid_samples, 5)
# # 计算分数
# score = Calculate_Weighted_R2(y_valid_true, y_valid_pred).item()
# score

## Model

In [None]:
class CSIROModel(nn.Module):
    def __init__(self):
        super(CSIROModel, self).__init__()
        self.backbone = timm.create_model(model_name=CONFIG.model_name, 
                                          pretrained=False)
        if CONFIG.is_pretrained:
            self.backbone.load_state_dict(torch.load(f"{CONFIG.pretrain_ckpt_path}/{CONFIG.model_name}.pth"))

        if "efficientnet" in CONFIG.model_name:
            in_features = self.backbone.classifier.in_features
            self.backbone.classifier = nn.Identity()
        elif "edgenext" in CONFIG.model_name:
            in_features = self.backbone.head.fc.in_features
            self.backbone.head.fc = nn.Identity()
        elif "convnext" in CONFIG.model_name:
            in_features = self.backbone.head.fc.in_features
            self.backbone.head.fc = nn.Identity()
        elif "dino" in CONFIG.model_name:
            in_features = 1536
        else:
            raise("Error model!")
        
        self.head = nn.Sequential(
            nn.Linear(in_features, in_features // 2),
            nn.LeakyReLU(),
            nn.Linear(in_features // 2, CONFIG.head_out),
        )
        
        
    def forward(self, x):
        if "dino" in CONFIG.model_name:
            mid = CONFIG.img_size[0]
            x_left = x[:, :, :, :mid]
            x_right = x[:, :, :, mid:]
            _tmp1 = self.backbone(x_left)
            _tmp2 = self.backbone(x_right)
            _tmp = torch.cat([_tmp1, _tmp2], dim=1) # shape: [B, 1536]
        else:
            _tmp = self.backbone(x)
        output = self.head(_tmp)
        return output

In [None]:
model = CSIROModel()
model

## Train and Valid Function

In [None]:
# criterion = nn.MSELoss()

# 加权MSE损失
class WeightedMSELoss(nn.Module):
    def __init__(self, feature_weights=None):
        super().__init__()
        if feature_weights is None:
            # 权重
            self.register_buffer('feature_weights', 
                               torch.tensor([0.1, 0.1, 0.1, 0.5, 0.2]))
        else:
            self.register_buffer('feature_weights', torch.tensor(feature_weights))
    
    def forward(self, y_pred, y_true):
        # 确保权重设备和类型正确
        weights = self.feature_weights.to(device=y_pred.device, dtype=y_pred.dtype)
        
        # 1. 计算加权平方误差 (Batch_Size, 5)
        loss = weights * (y_pred - y_true) ** 2
        
        # 2. 关键修改：
        # 先在 dim=1 (特征维度) 求和 -> 得到每个样本的加权 Error Sum
        # 再在 dim=0 (Batch维度) 求平均 -> 得到 Batch 的平均 Loss
        return loss.sum(dim=1).mean()

# 实例化损失函数
criterion = WeightedMSELoss()

In [None]:
def train_one_epoch(model, optimizer, train_loader, epoch):
    model.train()
    
    y_preds = []
    y_trues = []
    
    dataset_size = 0
    running_loss = 0.0
    bar = tqdm(enumerate(train_loader), total=len(train_loader))
    for step, (images, labels) in bar:
        optimizer.zero_grad()
        
        batch_size = images.size(0)
        if CONFIG.DataParallel:
            images = images.cuda().float()
            labels = labels.cuda().float()
        else:
            images = images.to(CONFIG.device, dtype=torch.float)
            labels = labels.to(CONFIG.device, dtype=torch.float)
            
        outputs = model(images)
        loss = criterion(outputs, labels) / CONFIG.n_accumulate
        loss.backward()
        
        if (step + 1) % CONFIG.n_accumulate == 0:
            optimizer.step()

            # zero the parameter gradients
            optimizer.zero_grad()

        y_preds.append(outputs.detach().cpu().numpy())
        y_trues.append(labels.detach().cpu().numpy())

        train_cv = Calculate_Weighted_R2(np.concatenate(y_trues), np.concatenate(y_preds))

        running_loss += (loss.item() * batch_size)

        dataset_size += batch_size
        
        epoch_loss = running_loss / dataset_size
        
        bar.set_postfix(Epoch=epoch,
                        Train_Loss=epoch_loss,
                        Train_CV_R2=train_cv,
                        LR_backbone=optimizer.optimizer1.param_groups[0]['lr'],
                        LR_head=optimizer.optimizer2.param_groups[0]['lr'])
    # Ensure that a parameter update is performed after the last accumulation cycle
    if (step + 1) % CONFIG.n_accumulate != 0:
        optimizer.step()
        optimizer.zero_grad()
        
    return epoch_loss, train_cv

In [None]:
# @torch.inference_mode()
def valid_one_epoch(model, optimizer, valid_loader, epoch):
    model.eval()
    
    y_preds = []
    y_trues = []
    dataset_size = 0
    running_loss = 0.0
    bar = tqdm(enumerate(valid_loader), total=len(valid_loader))
    with torch.no_grad():
        for step, (images, labels) in bar:
            batch_size = images.size(0)
            if CONFIG.DataParallel:
                images = images.cuda().float()
                labels = labels.cuda().float()
            else:
                images = images.to(CONFIG.device, dtype=torch.float)
                labels = labels.to(CONFIG.device, dtype=torch.float)

            outputs = model(images)
            loss = criterion(outputs, labels) / CONFIG.n_accumulate

            y_preds.append(outputs.detach().cpu().numpy())
            y_trues.append(labels.detach().cpu().numpy())
            valid_cv = Calculate_Weighted_R2(np.concatenate(y_trues), np.concatenate(y_preds))
        
            running_loss += (loss.item() * batch_size)

            dataset_size += batch_size

            epoch_loss = running_loss / dataset_size

            bar.set_postfix(Epoch=epoch,
                            Valid_Loss=epoch_loss,
                            Valid_CV_R2=valid_cv,
                            LR_backbone=optimizer.optimizer1.param_groups[0]['lr'],
                            LR_head=optimizer.optimizer2.param_groups[0]['lr'])
        

        y_preds = np.concatenate(y_preds)
        y_trues = np.concatenate(y_trues)
        cv = Calculate_Weighted_R2(y_trues, y_preds) 
    
    return epoch_loss, cv

In [None]:
def get_time_fold():
    # Get the current time stamp
    current_time = time.time()
    print("Current timestamp:", current_time)
    
    # Convert a timestamp to a local time structure
    local_time = time.localtime(current_time)
    
    # Formatting local time
    CONFIG.formatted_time = time.strftime('%Y-%m-%d_%H:%M:%S', local_time)
    print("now time:", CONFIG.formatted_time)
    
    CONFIG.ckpt_save_path = f"CSIRO/output/{CONFIG.formatted_time}_{CONFIG.model_name}_output"
    if os.path.exists(CONFIG.ckpt_save_path) is False:
        os.makedirs(CONFIG.ckpt_save_path)

In [None]:
def run_training(fold, model, optimizer, train_loader, valid_loader, num_epochs=CONFIG.epochs, now_cv=CONFIG.now_cv):
    if torch.cuda.is_available():
        print("[INFO] Using GPU: {} x {}\n".format(torch.cuda.get_device_name(), torch.cuda.device_count()))
    
    start = time.time()
    best_model_wts = copy.deepcopy(model.state_dict())
    best_epoch_cv = now_cv
    best_model_path = None
    history = defaultdict(list)
    
    for epoch in range(1, num_epochs + 1):
        gc.collect()
        train_epoch_loss, train_epoch_cv = train_one_epoch(model, optimizer, train_loader, epoch)
        valid_epoch_loss, valid_epoch_cv = valid_one_epoch(model, optimizer, valid_loader, epoch)
        print(f"epoch: {epoch}, LOSS = {valid_epoch_loss}, CV = {valid_epoch_cv}")
        
        history['Train Loss'].append(train_epoch_loss)
        history['Valid Loss'].append(valid_epoch_loss)
        history['Train CV'].append(train_epoch_cv)
        history['Valid CV'].append(valid_epoch_cv)
        history['lr_backbone'].append(optimizer.optimizer1.param_groups[0]['lr'])
        history['lr_head'].append(optimizer.optimizer2.param_groups[0]['lr'])
        
        # deep copy the model
        if valid_epoch_cv >= best_epoch_cv:
            print(f"{b_}epoch: {epoch}, Validation CV Improved ({best_epoch_cv} ---> {valid_epoch_cv}))")
            best_epoch_cv = valid_epoch_cv
            best_model_wts = copy.deepcopy(model.state_dict())
            PATH = "{}/{}_CV_{:.4f}_Loss{:.4f}_epoch{:.0f}.pth".format(CONFIG.ckpt_save_path, fold, best_epoch_cv, valid_epoch_loss, epoch)
            best_model_path = PATH
            torch.save(model.state_dict(), PATH)
            print(f"Model Saved{sr_}")
            
        print()
    
    end = time.time()
    time_elapsed = end - start
    print('Training complete in {:.0f}h {:.0f}m {:.0f}s'.format(
        time_elapsed // 3600, (time_elapsed % 3600) // 60, (time_elapsed % 3600) % 60))
    print("Best CV: {:.4f}".format(best_epoch_cv))

    # load best model weights
    model.load_state_dict(best_model_wts)

    return model, history, best_model_path

## Optimizer

In [None]:
class CosineAnnealingWithWarmupLR(_LRScheduler):
    def __init__(self, optimizer, T_max, eta_min=0, warmup_epochs=10, last_epoch=-1):
        self.T_max = T_max
        self.eta_min = eta_min
        self.warmup_epochs = warmup_epochs
        self.cosine_epochs = T_max - warmup_epochs
        super(CosineAnnealingWithWarmupLR, self).__init__(optimizer, last_epoch)

    def get_lr(self):
        if self.last_epoch < self.warmup_epochs:
            # Linear warmup
            return [(base_lr * (self.last_epoch + 1) / self.warmup_epochs) for base_lr in self.base_lrs]
        else:
            # Cosine annealing
            cosine_epoch = self.last_epoch - self.warmup_epochs
            return [self.eta_min + (base_lr - self.eta_min) * (1 + math.cos(math.pi * cosine_epoch / self.cosine_epochs)) / 2 for base_lr in self.base_lrs]

In [None]:
# lr scheduler
def fetch_scheduler(optimizer, T_max, min_lr):
    if CONFIG.scheduler == 'CosineAnnealingLR':
        scheduler = lr_scheduler.CosineAnnealingLR(optimizer,T_max=T_max, 
                                                   eta_min=min_lr)
    elif CONFIG.scheduler == "CosineAnnealingWithWarmupLR":
        scheduler = CosineAnnealingWithWarmupLR(optimizer, T_max=T_max, eta_min=min_lr, warmup_epochs=T_max//CONFIG.epochs)
        
    elif CONFIG.scheduler == None:
        return None
        
    return scheduler

In [None]:
class merge_optim():
    def __init__(self, optimizer1, optimizer2, lr_scheduler1=None, lr_scheduler2=None):
        self.optimizer1 = optimizer1
        self.optimizer2 = optimizer2
        self.lr_scheduler1 = lr_scheduler1
        self.lr_scheduler2 = lr_scheduler2

    def zero_grad(self):
        self.optimizer1.zero_grad()
        self.optimizer2.zero_grad()

    def step(self):
        self.optimizer1.step()
        self.optimizer2.step()
        if self.lr_scheduler1 is not None:
            self.lr_scheduler1.step()
        if self.lr_scheduler2 is not None:
            self.lr_scheduler2.step()

In [None]:
def get_optimizer(model):
    if CONFIG.DataParallel:
        optimizer_backbone = optim.AdamW(model.module.backbone.parameters(), lr=CONFIG.start_lr_backbone)
        optimizer_head = optim.AdamW(model.module.head.parameters(), lr=CONFIG.start_lr_head)
    else:
        optimizer_backbone = optim.AdamW(model.backbone.parameters(), lr=CONFIG.start_lr_backbone)
        optimizer_head = optim.AdamW(model.head.parameters(), lr=CONFIG.start_lr_head)
    
    scheduler_backbone = fetch_scheduler(optimizer_backbone, T_max=CONFIG.T_max, min_lr=CONFIG.min_lr_backbone)
    scheduler_head = fetch_scheduler(optimizer_head, T_max=CONFIG.T_max, min_lr=CONFIG.min_lr_head)
    
    optimizer = merge_optim(optimizer_backbone, optimizer_head, scheduler_backbone, scheduler_head)
    return optimizer

## Start Training

In [None]:
oof = []
true = []
historys = []
get_time_fold()

for fold in range(0, CONFIG.n_folds):
    print(f"==================== Train on Fold {fold+1} ====================")
    del model
    torch.cuda.empty_cache()
    model = CSIROModel()
    if CONFIG.DataParallel:
        device_ids = [0, 1]
        model = torch.nn.DataParallel(model, device_ids=device_ids)
        model = model.cuda()
    else:
        model = model.to(CONFIG.device)
        
    optimizer = get_optimizer(model)
    
    train_loader, valid_loader = prepare_loaders(train, fold)
    model, history, best_model_path = run_training(fold+1, model, optimizer, 
                                                   train_loader, valid_loader, 
                                                   num_epochs=CONFIG.epochs, now_cv=CONFIG.now_cv)
    historys.append(history)
    
    bar = tqdm(enumerate(valid_loader), total=len(valid_loader))
    with torch.no_grad():
        for step, (images, labels) in bar:
            batch_size = images.size(0)
            if CONFIG.DataParallel:
                images = images.cuda().float()
                labels = labels.cuda().float()
            else:
                images = images.to(CONFIG.device, dtype=torch.float)
                labels = labels.to(CONFIG.device, dtype=torch.float)

            outputs = model(images)
            
            oof.append(outputs.detach().cpu().numpy())
            true.append(labels.detach().cpu().numpy())
        print()

oof = np.concatenate(oof)
true = np.concatenate(true)

# 第一轮 cv
# 0.5251 # MSELoss       + 没有正则化 + efficientnetb0
# 0.5227 # MSELoss       + 正则化    + efficientnetb0
# 0.5472 # WeightMSELoss + 正则化    + efficientnetb0
# 0.8073 # WeightMSELoss + 正则化    + edgenext
# 0.8073

## Local CV

In [None]:
local_cv = Calculate_Weighted_R2(true, oof)
print("Local CV : ", local_cv)

np.save("CSIRO/true.npy", true)
np.save(f"{CONFIG.ckpt_save_path}/oof.npy", oof)

# 0.5683843152035679 # MSELoss       + 正则化 + efficientnetb0
# 0.5809311309267369 # WeightMSELoss + 正则化 + efficientnetb0
# Local CV :  0.7153420406058225

## Logs

In [None]:
# fold = 0
# history = historys[fold]

In [None]:
# plt.plot( range(len(history["Train Loss"])), history["Train Loss"], label="Train Loss")
# plt.plot( range(len(history["Valid Loss"])), history["Valid Loss"], label="Valid Loss")
# plt.xlabel("epochs")
# plt.ylabel("Loss")
# plt.grid()
# plt.legend()
# plt.show()

In [None]:
# plt.plot( range(len(history["Train CV(R2)"])), history["Train CV(R2)"], label="Train CV(R2)")
# plt.plot( range(len(history["Valid CV(R2)"])), history["Valid CV(R2)"], label="Valid CV(R2)")
# plt.xlabel("epochs")
# plt.ylabel("CV(R2)")
# plt.grid()
# plt.legend()
# plt.show()

In [None]:
# plt.plot( range(len(history["lr"])), history["lr"], label="lr")
# plt.xlabel("epochs")
# plt.ylabel("lr")
# plt.grid()
# plt.legend()
# plt.show()