In [None]:
# %% ===================== Imports & Global Setup =====================
# pip: transformers==4.41.1
import os
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")

import sys
import re
import math
import warnings
import traceback
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.distributed as dist
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter  # optional, used only on rank 0

from tqdm import tqdm
from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.model_selection import train_test_split
from typing import Tuple, Dict, Any

import torch.backends.cuda as cuda_backends
cuda_backends.matmul.allow_tf32 = True
if hasattr(torch, "set_float32_matmul_precision"):
    torch.set_float32_matmul_precision("high")

from transformers import AutoTokenizer, AutoModel, get_cosine_schedule_with_warmup
from transformers.utils import logging as hf_logging
hf_logging.set_verbosity_info()
warnings.filterwarnings("ignore")


# %% ===================== Units & canonicalization =====================
UNIT_SCALES: Dict[str, Tuple[str, float]] = {
    'lb': ('g', 453.59237), 'lbs': ('g', 453.59237),
    'pound': ('g', 453.59237), 'pounds': ('g', 453.59237),
    'kg': ('g', 1000.0), 'kilogram': ('g', 1000.0), 'kilograms': ('g', 1000.0),
    'g': ('g', 1.0), 'gram': ('g', 1.0), 'grams': ('g', 1.0),
    'mg': ('g', 1e-3), 'milligram': ('g', 1e-3), 'milligrams': ('g', 1e-3),
    'oz': ('g', 28.349523125), 'ounce': ('g', 28.349523125), 'ounces': ('g', 28.349523125),

    'l': ('ml', 1000.0), 'liter': ('ml', 1000.0), 'litre': ('ml', 1000.0),
    'liters': ('ml', 1000.0), 'litres': ('ml', 1000.0),
    'ml': ('ml', 1.0), 'milliliter': ('ml', 1.0), 'milliliters': ('ml', 1.0),
    'millilitre': ('ml', 1.0), 'millilitres': ('ml', 1.0),
    'fl oz': ('ml', 29.5735295625), 'floz': ('ml', 29.5735295625),

    'in': ('cm', 2.54), 'inch': ('cm', 2.54), 'inches': ('cm', 2.54),
    'ft': ('cm', 30.48), 'foot': ('cm', 30.48), 'feet': ('cm', 30.48),
    'cm': ('cm', 1.0), 'mm': ('cm', 0.1),

    'count': ('count', 1.0), 'ct': ('count', 1.0),
    'pcs': ('count', 1.0), 'piece': ('count', 1.0), 'pieces': ('count', 1.0),
}

def _norm_unit(u: str) -> str:
    u = u.lower().strip()
    u = re.sub(r'[.\s]+', ' ', u)
    u = u.replace('fluid ounce', 'fl oz').replace('fluid ounces', 'fl oz')
    u = u.replace('fl. oz', 'fl oz').replace('fl-oz', 'fl oz')
    u = u.replace('ounces', 'oz').replace('ounce', 'oz')
    u = u.replace('ct.', 'ct')
    return u

def canonicalize_unit(raw_unit: str) -> Tuple[str, float]:
    if not isinstance(raw_unit, str) or not raw_unit.strip():
        return '<unk>', 1.0
    u = _norm_unit(raw_unit)
    if u in UNIT_SCALES:
        return UNIT_SCALES[u]
    u2 = u.replace(' ', '')
    if u2 in UNIT_SCALES:
        return UNIT_SCALES[u2]
    if u.endswith('s') and u[:-1] in UNIT_SCALES:
        return UNIT_SCALES[u[:-1]]
    return '<unk>', 1.0


# %% ===================== Patterns & feature extraction =====================
VALUE_RE = re.compile(r'Value:\s*([+-]?\d+(?:\.\d+)?)', re.IGNORECASE)
UNIT_RE  = re.compile(r'Unit:\s*([A-Za-z.\-\s]+)', re.IGNORECASE)

PACK_PATTERNS = [
    r'[Pp]ack\s*of\s*(\d+)',
    r'(\d+)\s*[Pp]ack\b',
    r'(\d+)\s*-\s*[Cc]ount',
    r'(\d+)\s*(?:ct|count)\b',
    r'[x×]\s*(\d+)\b',
    r'\b(\d+)\s*[x×]\s*\d+\s*(?:oz|fl\s*oz|g|ml)\b',
    r'\(.*?[Pp]ack\s*of\s*(\d+).*?\)',
    r'\bcase\s*of\s*(\d+)\b',
]
PACK_RE = re.compile('|'.join(PACK_PATTERNS))

PREMIUM_KWS = ['organic', 'premium', 'gourmet', 'artisan', 'natural', 'handcrafted', 'imported', 'luxury']
BULK_KWS    = ['pack', 'case', 'bulk', 'bundle', 'family size', 'wholesale']

CAT_RULES = [
    ('soup',     r'\bsoup\b|ramen|broth'),
    ('sauce',    r'\bsauce\b|ketchup|mustard|mayo|dressing|marinara|salsa'),
    ('cookies',  r'\bcookies?\b|biscuit'),
    ('candy',    r'candy|chocolate|gummy|toffee|mint'),
    ('snack',    r'\bchips?\b|pretzel|popcorn|cracker|snack'),
    ('spice',    r'\bspice\b|seasoning|masala|herb|salt|pepper'),
    ('beverage', r'coffee|tea|soda|drink|beverage|juice'),
    ('grains',   r'rice|pasta|noodle|flour|oats|cereal'),
    ('oil',      r'\boil\b|olive oil|canola|sunflower'),
    ('gift',     r'gift|basket|hamper'),
    ('dairy',    r'cheese|milk|butter|yogurt'),
]

PROTEIN_RE  = re.compile(r'(\d+(?:\.\d+)?)\s*(?:g|grams?)\s*(?:of\s*)?protein', re.IGNORECASE)
FIBER_RE    = re.compile(r'(\d+(?:\.\d+)?)\s*(?:g|grams?)\s*(?:of\s*)?fiber', re.IGNORECASE)
CALORIES_RE = re.compile(r'(\d+(?:\.\d+)?)\s*calories?', re.IGNORECASE)
SUGAR_RE    = re.compile(r'(\d+(?:\.\d+)?)\s*(?:g|grams?)\s*(?:of\s*)?sugar', re.IGNORECASE)
ITEM_NAME_RE = re.compile(r'Item Name:\s*(.+?)(?=\n|Bullet Point|Product Description|$)', re.IGNORECASE | re.DOTALL)

def clean_text(text: str) -> str:
    if not isinstance(text, str):
        return ""
    text = re.sub(r'Value:.*?Unit:.*?(?:\n|$)', ' ', text, flags=re.IGNORECASE)
    text = re.sub(r'\s+', ' ', text)
    return text.strip()

def extract_value_unit(text: str) -> Tuple[float, str]:
    if not isinstance(text, str):
        return 0.0, 'Unknown'
    vm = VALUE_RE.search(text)
    um = UNIT_RE.search(text)
    value = float(vm.group(1)) if vm else 0.0
    unit  = um.group(1).strip() if um else 'Unknown'
    return value, unit

def detect_pack_size(text: str) -> int:
    if not isinstance(text, str):
        return 1
    m = PACK_RE.search(text)
    if not m: return 1
    for g in m.groups():
        if g is not None:
            try:
                return max(1, int(g))
            except:
                continue
    return 1

def extract_brand(text: str) -> str:
    if not isinstance(text, str): return "<unk>"
    m = re.search(r'Item Name:\s*(.+)', text, flags=re.IGNORECASE)
    s = m.group(1) if m else text
    s = s.split('|')[0].split(' - ')[0].split(',')[0].strip()
    s = re.sub(r'\s+', ' ', s)
    if len(s) == 0:
        return "<unk>"
    first_tok = s.split(' ')[0]
    if len(first_tok) <= 2:
        return s[:40].strip() or "<unk>"
    return first_tok[:40]

def category_bucket(text: str) -> str:
    low = text.lower() if isinstance(text, str) else ""
    for cat, pattern in CAT_RULES:
        if re.search(pattern, low):
            return cat
    return 'other'

def extract_item_name(text: str) -> str:
    if not isinstance(text, str): return ""
    m = ITEM_NAME_RE.search(text)
    return m.group(1).strip() if m else ""

def extract_sub_category(item_name: str) -> str:
    if not item_name: return "Other"
    return item_name[:50] or "Other"

def extract_flavor_profile(item_name: str, text: str) -> str:
    flavors = re.search(r'(mild|original|creamy|blue cheese|sherry|basil|key lime|black raspberry|strawberry banana|cookie dough|other flavors)', text.lower())
    return flavors.group(1) if flavors else "Other"

def extract_features_row(text: str) -> Dict[str, Any]:
    if not isinstance(text, str): text = ""
    value_raw, unit_raw = extract_value_unit(text)
    canon_unit, scale = canonicalize_unit(unit_raw)
    value_canon = value_raw * scale

    pack_size = detect_pack_size(text)
    value_per_pack = value_canon / max(pack_size, 1)
    low = text.lower()

    total_ml = value_canon if canon_unit == 'ml' else 0.0
    total_g  = value_canon if canon_unit == 'g'  else 0.0
    is_bulk = 1.0 if (pack_size >= 6 or total_ml >= 1500 or total_g >= 1000) else 0.0

    item_name = extract_item_name(text)
    sub_category = extract_sub_category(item_name)
    flavor_profile = extract_flavor_profile(item_name, text)

    protein_m = PROTEIN_RE.search(text)
    fiber_m = FIBER_RE.search(text)
    calories_m = CALORIES_RE.search(text)
    sugar_m = SUGAR_RE.search(text)
    protein_grams = float(protein_m.group(1)) if protein_m else 0.0
    fiber_grams = float(fiber_m.group(1)) if fiber_m else 0.0
    calories_per_serving = float(calories_m.group(1)) if calories_m else 0.0
    sugar_grams = float(sugar_m.group(1)) if sugar_m else 0.0

    is_high_protein = 1.0 if 'high protein' in low or protein_grams > 10 else 0.0
    is_low_calorie = 1.0 if 'low calorie' in low or (calories_per_serving > 0 and calories_per_serving < 50) else 0.0

    # Log/ratio/inverse features
    log_value_canon = np.log1p(value_canon)
    log_pack_size = np.log1p(pack_size)
    log_value_per_pack = np.log1p(value_per_pack)
    log_total_g = np.log1p(total_g)
    log_total_ml = np.log1p(total_ml)
    value_to_pack_ratio = value_canon / max(pack_size, 1)

    is_weight_unit = 1.0 if canon_unit == 'g' else 0.0
    is_volume_unit = 1.0 if canon_unit == 'ml' else 0.0
    is_count_unit = 1.0 if canon_unit == 'count' else 0.0

    inv_total_g = 1.0 / total_g if total_g > 0 else 0.0
    inv_total_ml = 1.0 / total_ml if total_ml > 0 else 0.0
    inv_value_canon = 1.0 / value_canon if value_canon > 0 else 0.0
    inv_pack_size = 1.0 / pack_size

    return {
        'value': float(value_raw),
        'unit': unit_raw,
        'value_canon': float(value_canon),
        'canon_unit': canon_unit,
        'pack_size': float(pack_size),
        'value_per_pack': float(value_per_pack),
        'total_ml': float(total_ml),
        'total_g': float(total_g),

        'log_value_canon': float(log_value_canon),
        'log_pack_size': float(log_pack_size),
        'log_value_per_pack': float(log_value_per_pack),
        'log_total_g': float(log_total_g),
        'log_total_ml': float(log_total_ml),
        'value_to_pack_ratio': float(value_to_pack_ratio),
        'is_weight_unit': float(is_weight_unit),
        'is_volume_unit': float(is_volume_unit),
        'is_count_unit': float(is_count_unit),
        'inv_total_g': float(inv_total_g),
        'inv_total_ml': float(inv_total_ml),
        'inv_value_canon': float(inv_value_canon),
        'inv_pack_size': float(inv_pack_size),

        'text_length': len(text),
        'word_count': len(text.split()),
        'bullet_count': float(text.count('Bullet Point')),
        'premium_count': float(sum(kw in low for kw in PREMIUM_KWS)),
        'bulk_count': float(sum(kw in low for kw in BULK_KWS)),
        'has_number': float(1 if re.search(r'\d', text) else 0),
        'is_bulk': is_bulk,
        'brand': extract_brand(text),
        'category_bucket': category_bucket(text),
        'sub_category': sub_category,
        'flavor_profile': flavor_profile,

        'is_organic': 1.0 if 'organic' in low else 0.0,
        'is_gluten_free': 1.0 if 'gluten-free' in low or 'gluten free' in low else 0.0,
        'is_vegan': 1.0 if 'vegan' in low or 'plant-based' in low else 0.0,
        'is_keto_friendly': 1.0 if 'keto' in low or 'low carb' in low else 0.0,
        'is_non_gmo': 1.0 if 'non-gmo' in low or 'gmo free' in low else 0.0,
        'is_kosher': 1.0 if 'kosher' in low else 0.0,
        'is_dairy_free': 1.0 if 'dairy free' in low else 0.0,
        'is_nut_free': 1.0 if 'nut-free' in low else 0.0,
        'is_soy_free': 1.0 if 'soy free' in low else 0.0,
        'is_sugar_free': 1.0 if 'sugar free' in low or 'no sugar added' in low else 0.0,
        'is_low_fat': 1.0 if 'low fat' in low or '0 grams trans fat' in low else 0.0,
        'is_low_sodium': 1.0 if 'low sodium' in low or 'no salt added' in low else 0.0,
        'is_all_natural': 1.0 if 'all natural' in low or 'natural' in low else 0.0,
        'is_fair_trade': 1.0 if 'fair trade' in low else 0.0,
        'is_usda_certified': 1.0 if 'usda organic' in low else 0.0,

        'protein_grams': protein_grams,
        'fiber_grams': fiber_grams,
        'calories_per_serving': calories_per_serving,
        'sugar_grams': sugar_grams,
        'has_vitamins': 1.0 if 'vitamin' in low or 'vitamins' in low else 0.0,
        'has_minerals': 1.0 if any(k in low for k in ['potassium', 'iron', 'magnesium', 'zinc']) else 0.0,
        'has_antioxidants': 1.0 if 'antioxidants' in low else 0.0,
        'is_high_protein': is_high_protein,
        'is_low_calorie': is_low_calorie,

        'is_ready_to_eat': 1.0 if 'ready to eat' in low else 0.0,
        'is_easy_to_prepare': 1.0 if 'easy to prepare' in low or 'instant' in low else 0.0,
        'is_versatile': 1.0 if 'versatile' in low or 'multiple uses' in low else 0.0,
        'is_snack': 1.0 if 'snack' in low or 'on-the-go' in low else 0.0,
        'is_beverage': 1.0 if any(k in low for k in ['drink', 'beverage', 'tea', 'coffee', 'juice']) else 0.0,
        'is_baking_ingredient': 1.0 if any(k in low for k in ['baking', 'flour', 'powder']) else 0.0,
        'has_no_preservatives': 1.0 if 'no preservatives' in low else 0.0,
        'is_shelf_stable': 1.0 if any(k in low for k in ['shelf stable', 'canned']) else 0.0,

        'has_ingredients_list': 1.0 if 'ingredients:' in low else 0.0,
        'has_product_description': 1.0 if 'product description:' in low else 0.0,
    }

def build_numeric_table(catalog_series: pd.Series) -> pd.DataFrame:
    feats = [extract_features_row(t) for t in catalog_series]
    return pd.DataFrame(feats)


# %% ===================== Metrics & Loss =====================
def smape(y_true, y_pred) -> float:
    y_true_arr = np.asarray(y_true, dtype=float)
    y_pred_arr = np.asarray(y_pred, dtype=float)
    y_true_arr = np.abs(y_true_arr)
    y_pred_arr = np.abs(y_pred_arr)
    denom = (y_true_arr + y_pred_arr) / 2.0
    return float(np.mean(np.abs(y_pred_arr - y_true_arr) / (denom + 1e-8)) * 100.0)

def smape_np(pred, true):
    pred = np.abs(pred)
    true = np.abs(true)
    den = (pred + true) / 2.0
    return float(np.mean(np.abs(pred - true) / (den + 1e-8)) * 100.0)

def smape_loss(pred, true):
    pred_abs = torch.abs(pred)
    true_abs = torch.abs(true)
    denom = (pred_abs + true_abs) / 2.0
    return torch.mean(torch.abs(pred - true) / (denom + 1e-8)) * 100.0

def combined_loss(pred_log, true_log, pred_raw, true_raw, alpha=0.7):
    l1_loss = torch.nn.functional.l1_loss(pred_log, true_log)
    smape_val = smape_loss(pred_raw, true_raw)
    return alpha * l1_loss + (1.0 - alpha) * smape_val / 100.0


# %% ===================== Dataset =====================
class PriceDataset(Dataset):
    def __init__(self, df, tokenizer, unit_encoder=None, cat_encoder=None, flav_encoder=None, scaler=None,
                 max_length=384, is_test=False):
        self.df = df.reset_index(drop=True)
        self.tok = tokenizer
        self.max_length = max_length
        self.is_test = is_test

        feats = [extract_features_row(t) for t in self.df['catalog_content'].fillna("")]
        self.feat_df = pd.DataFrame(feats)

        self.numeric_cols = [col for col in self.feat_df.columns 
                             if pd.api.types.is_numeric_dtype(self.feat_df[col])]
        X_num = self.feat_df[self.numeric_cols].fillna(0).values.astype(np.float32)
        if scaler is None:
            self.scaler = StandardScaler()
            X_num = self.scaler.fit_transform(X_num)
        else:
            self.scaler = scaler
            X_num = self.scaler.transform(X_num)
        self.X_num = X_num.astype(np.float32)

        if unit_encoder is None:
            self.unit_le = LabelEncoder()
            units = sorted(set(self.feat_df['canon_unit'].fillna('<unk>').tolist()) | {'<unk>'})
            self.unit_le.fit(units)
        else:
            self.unit_le = unit_encoder

        if cat_encoder is None:
            self.cat_le = LabelEncoder()
            cats = sorted(set(self.feat_df['category_bucket'].fillna('other').tolist()) | {'other'})
            self.cat_le.fit(cats)
        else:
            self.cat_le = cat_encoder

        if flav_encoder is None:
            self.flav_le = LabelEncoder()
            flavs = sorted(set(self.feat_df['flavor_profile'].fillna('Other').tolist()) | {'Other'})
            self.flav_le.fit(flavs)
        else:
            self.flav_le = flav_encoder

    def __len__(self): return len(self.df)

    def __getitem__(self, i):
        row = self.df.iloc[i]
        txt = clean_text(row['catalog_content'])
        enc = self.tok(txt, max_length=self.max_length, padding='max_length',
                       truncation=True, return_tensors='pt')
        u = self.feat_df.iloc[i]['canon_unit'] or '<unk>'
        try:
            u_id = int(self.unit_le.transform([u])[0])
        except Exception:
            u_id = int(self.unit_le.transform(['<unk>'])[0])

        c = self.feat_df.iloc[i]['category_bucket'] or 'other'
        c_id = int(self.cat_le.transform([c])[0])

        f = self.feat_df.iloc[i]['flavor_profile'] or 'Other'
        f_id = int(self.flav_le.transform([f])[0])

        item = {
            'input_ids': enc['input_ids'].squeeze(0),
            'attention_mask': enc['attention_mask'].squeeze(0),
            'num': torch.tensor(self.X_num[i], dtype=torch.float32),
            'unit': torch.tensor(u_id, dtype=torch.long),
            'cat': torch.tensor(c_id, dtype=torch.long),
            'flav': torch.tensor(f_id, dtype=torch.long),
            'sample_id': row['sample_id'],
        }
        if not self.is_test:
            item['ylog'] = torch.tensor(np.log1p(float(row['price'])), dtype=torch.float32)
            item['y_raw'] = torch.tensor(float(row['price']), dtype=torch.float32)
        return item


# %% ===================== Model =====================
class Model(nn.Module):
    def __init__(self, name, num_num, num_units, num_cats, num_flavs, hidden=768, drop=0.2):
        super().__init__()
        self.enc = AutoModel.from_pretrained(name, trust_remote_code=False)
        d = self.enc.config.hidden_size
        self.unit_emb = nn.Embedding(num_units, 16)
        self.cat_emb = nn.Embedding(num_cats, 16)
        self.flav_emb = nn.Embedding(num_flavs, 16)
        self.mlp = nn.Sequential(
            nn.Linear(d + num_num + 16 + 16 + 16, hidden),
            nn.LayerNorm(hidden), 
            nn.SiLU(), 
            nn.Dropout(drop),
            nn.Linear(hidden, hidden//2),
            nn.LayerNorm(hidden//2), 
            nn.SiLU(), 
            nn.Dropout(drop),
            nn.Linear(hidden//2, hidden//4),
            nn.LayerNorm(hidden//4),
            nn.SiLU(),
            nn.Dropout(drop/2),
            nn.Linear(hidden//4, 1)
        )

    def mean_pool(self, last_hidden, mask):
        mask = mask.unsqueeze(-1).float()
        return (last_hidden * mask).sum(1) / mask.sum(1).clamp(min=1e-6)

    def forward(self, ids, mask, num, unit, cat, flav):
        out = self.enc(input_ids=ids, attention_mask=mask).last_hidden_state
        txt = self.mean_pool(out, mask)
        u = self.unit_emb(unit)
        c = self.cat_emb(cat)
        f = self.flav_emb(flav)
        x = torch.cat([txt, num, u, c, f], 1)
        return self.mlp(x).squeeze(-1)


# %% ===================== Utils: DDP helpers =====================
def set_seed(seed: int, rank: int = 0):
    s = seed + rank
    np.random.seed(s)
    torch.manual_seed(s)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(s)

def is_dist_avail_and_initialized():
    return dist.is_available() and dist.is_initialized()

def get_rank():
    return dist.get_rank() if is_dist_avail_and_initialized() else 0

def is_main_process():
    return get_rank() == 0

def all_gather_numpy(arr: np.ndarray) -> np.ndarray:
    """Gather variable-length numpy arrays across ranks to rank 0."""
    if not is_dist_avail_and_initialized():
        return arr
    gathered = [None for _ in range(dist.get_world_size())]
    dist.all_gather_object(gathered, arr)
    if is_main_process():
        return np.concatenate([g for g in gathered if g is not None]) if len(gathered) else arr
    return arr  # non-main returns its local (unused by caller)


# %% ===================== Main (DDP) =====================
def main():
    # -------- Config --------
    SEED = 42
    DATA = "dataset"
    OUT = "/models/T1"
    os.makedirs(OUT, exist_ok=True)

    MODEL_NAME = 'sentence-transformers/all-mpnet-base-v2'
    BATCH = 24                 # per-GPU
    EPOCHS = 15
    WARMUP_E = 2
    MAXL = 384
    LR_ENC = 1.5e-5
    LR_HEAD = 1.2e-3
    ALPHA = 0.7

    # -------- DDP init (or single process fallback) --------
    ddp = False
    local_rank_env = os.environ.get("LOCAL_RANK", None)
    if local_rank_env is not None:
        ddp = True
        dist.init_process_group(backend="nccl", init_method="env://")
        local_rank = int(local_rank_env)
        torch.cuda.set_device(local_rank)
        dev = torch.device(f"cuda:{local_rank}")
    else:
        local_rank = 0
        dev = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    set_seed(SEED, rank=local_rank)
    main_proc = is_main_process()

    if main_proc:
        print(f"[Boot] transformer_tabular_v2 (DDP={ddp}) starting...", flush=True)
        print(f"[Setup] Device: {dev} (rank={get_rank()})", flush=True)
        print(f"[Phase 1] Model: {MODEL_NAME}, MaxLen: {MAXL}, Batch/GPU: {BATCH}", flush=True)

    # -------- Data --------
    train_csv = os.path.join(DATA, "train.csv")
    if not os.path.exists(train_csv):
        if main_proc:
            print(f"[Error] Missing dataset file: {train_csv}", flush=True)
        if ddp:
            dist.barrier()
            dist.destroy_process_group()
        return

    if main_proc:
        print("[Load] Reading train.csv ...", flush=True)
    df = pd.read_csv(train_csv)
    if main_proc:
        print(f"[Load] Rows: {len(df)}, Columns: {list(df.columns)}", flush=True)
        print("[Split] Train/Val split...", flush=True)
    tr, va = train_test_split(df, test_size=0.2, random_state=SEED)
    if main_proc:
        print(f"[Split] Train: {len(tr)}, Val: {len(va)}", flush=True)

    # -------- Tokenizer --------
    if main_proc:
        print(f"[HF] Loading tokenizer/model: {MODEL_NAME}", flush=True)
    tok = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=False)

    # -------- Build initial datasets to infer encoders/scaler --------
    dtr_tmp = PriceDataset(tr, tok, max_length=MAXL, is_test=False)
    dva_tmp = PriceDataset(va, tok, max_length=MAXL, is_test=False,
                           unit_encoder=dtr_tmp.unit_le, cat_encoder=dtr_tmp.cat_le,
                           flav_encoder=dtr_tmp.flav_le, scaler=dtr_tmp.scaler)

    num_num = len(dtr_tmp.numeric_cols)
    if main_proc:
        print(f"[Model] Building with {num_num} numeric features", flush=True)

    # -------- Model --------
    model = Model(MODEL_NAME, num_num=num_num,
                  num_units=len(dtr_tmp.unit_le.classes_),
                  num_cats=len(dtr_tmp.cat_le.classes_),
                  num_flavs=len(dtr_tmp.flav_le.classes_),
                  hidden=768)
    model.to(dev)

    # memory: enable checkpointing on encoder and disable cache
# MPNetModel does not support gradient checkpointing — don't call it.
    if hasattr(model, "enc") and hasattr(model.enc, "config"):
        model.enc.config.use_cache = False  # still helpful for memory


    # -------- Rebuild datasets on every rank (keep same encoders/scaler) --------
    dtr = PriceDataset(tr, tok, max_length=MAXL, is_test=False,
                       unit_encoder=dtr_tmp.unit_le, cat_encoder=dtr_tmp.cat_le,
                       flav_encoder=dtr_tmp.flav_le, scaler=dtr_tmp.scaler)
    dva = PriceDataset(va, tok, max_length=MAXL, is_test=False,
                       unit_encoder=dtr_tmp.unit_le, cat_encoder=dtr_tmp.cat_le,
                       flav_encoder=dtr_tmp.flav_le, scaler=dtr_tmp.scaler)

    # -------- Samplers & Loaders --------
    if ddp:
        sampler_tr = DistributedSampler(dtr, shuffle=True, drop_last=False)
        sampler_va = DistributedSampler(dva, shuffle=False, drop_last=False)
    else:
        sampler_tr = None
        sampler_va = None

    ltr = DataLoader(dtr, batch_size=BATCH, shuffle=(sampler_tr is None),
                     sampler=sampler_tr, num_workers=2, pin_memory=(dev.type=='cuda'))
    lva = DataLoader(dva, batch_size=BATCH, shuffle=False,
                     sampler=sampler_va, num_workers=2, pin_memory=(dev.type=='cuda'))

    # -------- Optimizer & Scheduler --------
    enc_params = list(model.enc.parameters())
    head_params = [p for n, p in model.named_parameters() if not n.startswith('enc')]

    opt = torch.optim.AdamW([
        {'params': enc_params, 'lr': LR_ENC, 'weight_decay': 0.01},
        {'params': head_params, 'lr': LR_HEAD, 'weight_decay': 0.0},
    ])

    total_steps = max(1, len(ltr)) * EPOCHS
    sched = get_cosine_schedule_with_warmup(opt, num_warmup_steps=max(1, int(0.05 * total_steps)),
                                            num_training_steps=total_steps)

    scaler = torch.cuda.amp.GradScaler(enabled=(dev.type == 'cuda'))

    # -------- Warmup: freeze encoder initially --------
    for p in model.enc.parameters():
        p.requires_grad = False

    # -------- Wrap with DDP --------
    if ddp:
        model = DDP(model, device_ids=[dev.index], output_device=dev.index, find_unused_parameters=False)

    # -------- Train loop --------
    best = 1e9
    patience = 6
    bad = 0

    if main_proc:
        print("[Train] Starting epochs...", flush=True)
        print(f"[Loss] Using combined loss: {ALPHA}*L1 + {1-ALPHA}*SMAPE", flush=True)

    try:
        for ep in range(EPOCHS):
            if ddp:
                sampler_tr.set_epoch(ep)

            if main_proc:
                print(f"\n[Epoch {ep+1}/{EPOCHS}] ----------------------------", flush=True)
                tbar = tqdm(total=len(ltr), desc="train", mininterval=1.0, leave=False)
            else:
                tbar = None

            # Unfreeze after warmup epochs
            if ep == WARMUP_E:
                if main_proc: print("[Epoch] Unfreezing encoder", flush=True)
                if ddp:
                    for p in model.module.enc.parameters():
                        p.requires_grad = True
                else:
                    for p in model.enc.parameters():
                        p.requires_grad = True

            # Train
            model.train()
            tot_loss = 0.0
            for b in ltr:
                ids = b['input_ids'].to(dev, non_blocking=True)
                m = b['attention_mask'].to(dev, non_blocking=True)
                num = b['num'].to(dev, non_blocking=True)
                u = b['unit'].to(dev, non_blocking=True)
                c = b['cat'].to(dev, non_blocking=True)
                f = b['flav'].to(dev, non_blocking=True)
                y_log = b['ylog'].to(dev, non_blocking=True)
                y_raw = b['y_raw'].to(dev, non_blocking=True)

                opt.zero_grad(set_to_none=True)
                with torch.cuda.amp.autocast(enabled=(dev.type == 'cuda')):
                    pr_log = model(ids, m, num, u, c, f)
                    pr_raw = torch.expm1(pr_log).clamp(min=0.99)
                    loss = combined_loss(pr_log, y_log, pr_raw, y_raw, alpha=ALPHA)

                scaler.scale(loss).backward()
                nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                scaler.step(opt)
                scaler.update()
                sched.step()

                tot_loss += float(loss.item())
                if tbar is not None: tbar.update(1)

            if tbar is not None: tbar.close()
            # Reduce train loss across ranks
            tl = torch.tensor([tot_loss, len(ltr)], dtype=torch.float32, device=dev)
            if ddp:
                dist.all_reduce(tl, op=dist.ReduceOp.SUM)
            tr_loss = (tl[0].item() / max(1.0, tl[1].item()))

            # Validation
            model.eval()
            P_local = []
            T_local = []
            with torch.no_grad():
                if main_proc:
                    vbar = tqdm(total=len(lva), desc="val", mininterval=1.0, leave=False)
                else:
                    vbar = None

                for b in lva:
                    ids = b['input_ids'].to(dev, non_blocking=True)
                    m  = b['attention_mask'].to(dev, non_blocking=True)
                    num = b['num'].to(dev, non_blocking=True)
                    u = b['unit'].to(dev, non_blocking=True)
                    c = b['cat'].to(dev, non_blocking=True)
                    f = b['flav'].to(dev, non_blocking=True)
                    y_log = b['ylog'].to(dev, non_blocking=True)

                    with torch.cuda.amp.autocast(enabled=(dev.type == 'cuda')):
                        pr_log = model(ids, m, num, u, c, f)

                    P_local.append(pr_log.float().cpu().numpy())
                    T_local.append(y_log.float().cpu().numpy())
                    if vbar is not None: vbar.update(1)

                if vbar is not None: vbar.close()

            P_local = np.concatenate(P_local) if len(P_local) else np.array([], dtype=np.float32)
            T_local = np.concatenate(T_local) if len(T_local) else np.array([], dtype=np.float32)

            # Gather predictions/targets to rank 0
            if ddp:
                P_all = all_gather_numpy(P_local)
                T_all = all_gather_numpy(T_local)
            else:
                P_all, T_all = P_local, T_local

            if main_proc:
                sm = smape_np(np.expm1(P_all).clip(0.99, None), np.expm1(T_all))
                print(f"[Epoch {ep+1}] train loss: {tr_loss:.4f} | val SMAPE: {sm:.2f}%", flush=True)

                # Save best
                if sm + 1e-6 < best:
                    best = sm
                    state = model.module.state_dict() if ddp else model.state_dict()
                    torch.save({
                        'state_dict': state,
                        'unit_classes': dtr.unit_le.classes_.tolist(),
                        'cat_classes': dtr.cat_le.classes_.tolist(),
                        'flav_classes': dtr.flav_le.classes_.tolist(),
                        'scaler_mean': dtr.scaler.mean_.tolist(),
                        'scaler_scale': dtr.scaler.scale_.tolist(),
                        'numeric_cols': dtr.numeric_cols,
                        'tokenizer': MODEL_NAME,
                        'best_smape': best,
                        'max_length': MAXL,
                        'hidden_dim': 768,
                        'alpha': ALPHA,
                    }, os.path.join(OUT, "best_model.pt"))
                    print(f"[Save] Checkpoint saved (SMAPE {best:.2f}%) → {OUT}/best_model.pt", flush=True)
                    bad = 0
                else:
                    bad += 1
                    if bad >= patience:
                        print("[EarlyStop] Patience exceeded. Stopping.", flush=True)
                        break

        if main_proc:
            print(f"\n[Done] Best val SMAPE: {best:.2f}% | Artifacts at {OUT}", flush=True)

    except Exception as e:
        if main_proc:
            print("[Fatal] Unhandled exception:", repr(e), flush=True)
            traceback.print_exc()
        # let it raise — no sys.exit to avoid IPython traceback noise
        raise
    finally:
        if ddp:
            dist.barrier()
            dist.destroy_process_group()


# %% ===================== Entrypoint =====================
if __name__ == "__main__":
    main()
