In [1]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import random, torch.backends.cudnn as cudnn
import re
import math
from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.model_selection import train_test_split
from rtdl_revisiting_models import  FTTransformer
from torch.utils.data import WeightedRandomSampler, DataLoader, Dataset
from sklearn.metrics import r2_score, accuracy_score, f1_score, precision_score, recall_score
import matplotlib.pyplot as plt
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.decomposition import TruncatedSVD
from imblearn.over_sampling import SMOTENC
from tabm import TabM, EnsembleView, make_tabm_backbone, LinearEnsemble


In [2]:
df = pd.read_csv("convert_json_data.csv")
# df = df.iloc[:50_000].copy()
LABEL_ONLY_COLS = ["category_group"]
CATEGORICAL_FEATURES = [        
    'currency',
    'country_displayable_name',
    'location_state',
]
NUMERIC_FEATURES = [
    'goal_usd_log',
    'name_len',
    'blurb_len',
    'has_video', 
    'days_diff_launched_at_deadline_log',
    'goal_per_day_log',
    'goal_rank_in_cat',
    'goal_vs_cat_median',
    'goal_vs_country_median',
    'goal_round_100',
    'goal_round_1000',
    'cat_freq',
    'country_freq',
    'cat_x_country_freq',
    'gpd_rank_in_cat',
    'gpd_vs_cat_median',
    'gpd_dist_cat_median',
    'cat_country_share',
    'prep_days',
    'has_photo',
    'launch_dow',
    'deadline_dow',
    'too_short_or_long',
]
CYCLIC_NUMERIC = [ 
    'deadline_mon_sin',
    'deadline_mon_cos',
    'deadline_dom_sin',
    'deadline_dom_cos',  
    'launched_at_mon_sin',
    'launched_at_mon_cos',
    'launched_at_dom_sin',
    'launched_at_dom_cos',  
]

FEATURES = CATEGORICAL_FEATURES + NUMERIC_FEATURES

In [3]:
label_encoders = {}
for col in CATEGORICAL_FEATURES:
    df[col] = df[col].fillna("missing").astype(str)
    le = LabelEncoder()
    df[col] = le.fit_transform(df[col])
    label_encoders[col] = le
for col in LABEL_ONLY_COLS:
    df[col] = df[col].fillna("missing").astype(str)
    le = LabelEncoder()
    df[col] = le.fit_transform(df[col]) 
    label_encoders[col] = le

In [4]:
TARGET_FEATURE = {
    "success_cls": df["state"].to_numpy(dtype=np.int64),
    # "success_rate_cls": df["success_rate_cls"].to_numpy(dtype=np.int64),
    "risk_level": df["risk_level"].to_numpy(dtype=np.int64),
    "days_to_state_change": df["duration_class"].to_numpy(dtype=np.int64),
    "recommend_category": df["category_group"].to_numpy(dtype=np.int64),
    "goal_eval": df["goal_eval"].to_numpy(dtype=np.int64),
    # "shortfall_severity_cls": df["shortfall_severity_cls"].to_numpy(dtype=np.int64),
    "stretch_potential_cls": df["stretch_potential_cls"].to_numpy(dtype=np.int64),
}

key = df["category_group"].astype(str) + "_" + df["state"].astype(str)
idx_train, idx_val = train_test_split(np.arange(len(df)), test_size=0.2, random_state=42, stratify=key)

X_train_df = df.iloc[idx_train]
X_val_df = df.iloc[idx_val]



In [5]:
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

In [6]:
cat_cardinalities = [int(X_train_df[col].max()) + 1 for col in CATEGORICAL_FEATURES]
print("✅ FINAL cat_cardinalities =", cat_cardinalities)

✅ FINAL cat_cardinalities = [15, 25, 710]


In [7]:
class MultiTaskDataset(torch.utils.data.Dataset):
    def __init__(self, x_cat, x_cont, y_dict):
        self.x_cat = x_cat
        self.x_cont = x_cont
        self.y_dict = y_dict
        self.keys = list(y_dict.keys())

    def __len__(self):
        return len(self.x_cat)

    def __getitem__(self, idx):
        return (
            self.x_cat[idx],
            self.x_cont[idx],
            {k: self.y_dict[k][idx] for k in self.keys}
        )

y_train_dict = {
    k: torch.tensor(v[idx_train]) for k, v in TARGET_FEATURE.items()
}
y_val_dict = {
    k: torch.tensor(v[idx_val]) for k, v in TARGET_FEATURE.items()
}




combined_train = (
    X_train_df["name"].fillna("").str.lower() + " " +
    X_train_df["blurb"].fillna("").str.lower() + " " +
    X_train_df["category_slug"].fillna("").str.replace("/", " ").str.lower()
).tolist()

combined_val = (
    X_val_df["name"].fillna("").str.lower() + " " +
    X_val_df["blurb"].fillna("").str.lower() + " " +
    X_val_df["category_slug"].fillna("").str.replace("/", " ").str.lower()
).tolist()

tfidf = TfidfVectorizer(min_df=3, max_df=0.9, ngram_range=(1,2), stop_words='english')
X_tfidf_tr = tfidf.fit_transform(combined_train)
X_tfidf_va = tfidf.transform(combined_val)

svd_dim = 128
svd = TruncatedSVD(n_components=svd_dim, random_state=42)
X_txt_tr = svd.fit_transform(X_tfidf_tr)
X_txt_va = svd.transform(X_tfidf_va)

scaler = StandardScaler()
numeric_scaled = scaler.fit_transform(X_train_df[NUMERIC_FEATURES])
numeric_scaled_val = scaler.transform(X_val_df[NUMERIC_FEATURES])

X_cont_train = np.concatenate([numeric_scaled, X_txt_tr, X_train_df[CYCLIC_NUMERIC]], axis=1)
X_cont_val = np.concatenate([numeric_scaled_val, X_txt_va, X_val_df[CYCLIC_NUMERIC]], axis=1)

x_cont_tensor = torch.tensor(X_cont_train, dtype=torch.float32)
x_cont_val = torch.tensor(X_cont_val, dtype=torch.float32)

x_cat_tensor = torch.tensor(X_train_df[CATEGORICAL_FEATURES].values, dtype=torch.long)
x_cat_val = torch.tensor(X_val_df[CATEGORICAL_FEATURES].values, dtype=torch.long)

train_ds = MultiTaskDataset(x_cat_tensor, x_cont_tensor, y_train_dict)
val_ds = MultiTaskDataset(x_cat_val, x_cont_val, y_val_dict)




# ===== DataLoaders =====
BATCH_SIZE = 256  # ของคุณเดิม
idx = np.arange(len(train_ds))
# ปกติ (ไม่ oversample) ใช้เฉพาะช่วง warmup และเอาไว้เทียบ
train_loader_plain = DataLoader(train_ds,
                                batch_size=BATCH_SIZE,
                                shuffle=True,
                                drop_last=False)

# Validation ไม่แตะ (ห้าม oversample)
val_loader = DataLoader(val_ds,
                        batch_size=BATCH_SIZE,
                        shuffle=False,
                        drop_last=False)



In [8]:
for k in y_train_dict:
    y_train_dict[k] = y_train_dict[k].long()
for k in y_val_dict:
    y_val_dict[k] = y_val_dict[k].long()

num_classes_map = {
    k: int(y_train_dict[k].max().item() + 1)
    for k in y_train_dict.keys()
}
print(num_classes_map)

{'success_cls': 2, 'risk_level': 3, 'days_to_state_change': 4, 'recommend_category': 10, 'goal_eval': 3, 'stretch_potential_cls': 3}


In [9]:
import torch
import torch.nn as nn

class MultiHeadWrapper(nn.Module):
    def __init__(self, base_model: nn.Module, head_dims: dict, in_dim: int, d_hidden: int = 64):
        super().__init__()
        self.backbone = base_model
        self.in_dim = in_dim

        def make_head(out_dim, hidden=d_hidden, dropout=0.2):
            return nn.Sequential(
                nn.Linear(in_dim, hidden),
                nn.ReLU(),
                nn.Dropout(dropout),
                nn.Linear(hidden, out_dim),
            )

        # ---- หัวหลัก (ประกาศแบบแยก ไม่ใช้ลูป) ----
        self.head_success_cls = make_head(head_dims["success_cls"]) if "success_cls" in head_dims else None
        self.head_risk_level = make_head(head_dims["risk_level"]) if "risk_level" in head_dims else None

        self.head_days_to_state_change = (
            nn.Sequential(
                nn.Linear(in_dim, 256),
                nn.ReLU(),
                nn.Dropout(0.3),
                nn.Linear(256, head_dims["days_to_state_change"]),
            ) if "days_to_state_change" in head_dims else None
        )

        self.head_recommend_category = (
            nn.Sequential(
                nn.Linear(in_dim, 256),
                nn.ReLU(),
                nn.Dropout(0.3),
                nn.Linear(256, head_dims["recommend_category"]),
            ) if "recommend_category" in head_dims else None
        )

        self.head_goal_eval = make_head(head_dims["goal_eval"]) if "goal_eval" in head_dims else None

        # ---- 3 หัวใหม่ (มีคีย์ถึงจะสร้าง) ----
        self.head_stretch_potential_cls = (
            make_head(head_dims["stretch_potential_cls"]) if "stretch_potential_cls" in head_dims else None
        )
        self.head_shortfall_severity_cls = (
            make_head(head_dims["shortfall_severity_cls"]) if "shortfall_severity_cls" in head_dims else None
        )
        self.head_near_miss_cls = (
            make_head(head_dims["near_miss_cls"]) if "near_miss_cls" in head_dims else None
        )

    def forward(self, x_cont, x_cat):
        x = self.backbone(x_cont, x_cat)  # [B, in_dim]
        out = {}

        if self.head_success_cls is not None:
            out["success_cls"] = self.head_success_cls(x)
        if self.head_risk_level is not None:
            out["risk_level"] = self.head_risk_level(x)
        if self.head_days_to_state_change is not None:
            out["days_to_state_change"] = self.head_days_to_state_change(x)
        if self.head_recommend_category is not None:
            out["recommend_category"] = self.head_recommend_category(x)
        if self.head_goal_eval is not None:
            out["goal_eval"] = self.head_goal_eval(x)

        if self.head_stretch_potential_cls is not None:
            out["stretch_potential_cls"] = self.head_stretch_potential_cls(x)
        if self.head_shortfall_severity_cls is not None:
            out["shortfall_severity_cls"] = self.head_shortfall_severity_cls(x)
        if self.head_near_miss_cls is not None:
            out["near_miss_cls"] = self.head_near_miss_cls(x)

        return out


In [10]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from tabm import EnsembleView, make_tabm_backbone, LinearEnsemble
class _MLP(nn.Module):
    def __init__(self, d_in, out_dim, hidden=128, depth=2, dropout=0.1):
        super().__init__()
        layers, d = [], d_in
        for _ in range(depth-1):
            layers += [nn.Linear(d, hidden), nn.ReLU(), nn.Dropout(dropout)]
            d = hidden
        layers += [nn.Linear(d, out_dim)]
        self.net = nn.Sequential(*layers)
    def forward(self, x):  # x: (B, d_in)
        return self.net(x)

class MLPEnsemble(nn.Module):
    def __init__(self, d_in, out_dim, k, hidden=128, depth=2, dropout=0.1):
        super().__init__()
        self.mlps = nn.ModuleList([_MLP(d_in, out_dim, hidden, depth, dropout) for _ in range(k)])
    def forward(self, z):  # z: (B, k, d_in)
        outs = [self.mlps[i](z[:, i, :]) for i in range(len(self.mlps))]
        return torch.stack(outs, dim=1)  # (B, k, C)
    
class MLPHeadShared(nn.Module):
    def __init__(self, d_in, out_dim, hidden=128, depth=2, dropout=0.1):
        super().__init__()
        layers = []
        d = d_in
        for _ in range(depth-1):
            layers += [nn.Linear(d, hidden), nn.ReLU(), nn.Dropout(dropout)]
            d = hidden
        layers += [nn.Linear(d, out_dim)]
        self.net = nn.Sequential(*layers)

    def forward(self, z):            # z: (B, k, d_in)
        B, k, D = z.shape
        y = self.net(z.reshape(B*k, D))
        return y.view(B, k, -1)  
# ---- Multi-head สำหรับ TabM (ใช้ LinearEnsemble แบบตำแหน่ง) ----
class TabMHeads(nn.Module):
    def __init__(self, d_in: int, k: int, head_dims: dict[str, int]):
        super().__init__()
        self.k = k
        self.heads = nn.ModuleDict({
            name: LinearEnsemble(d_in, out_dim, k=k)  # <- ไม่มี d_in/d_out เป็นคีย์เวิร์ด
            for name, out_dim in head_dims.items()
        })
        # self.heads = nn.ModuleDict({
        #     name: MLPHeadShared(d_in, out_dim, hidden=128, depth=2, dropout=0.1)
        #     for name, out_dim in head_dims.items()
        # })
        # self.heads = nn.ModuleDict({
        #     name: MLPEnsemble(d_in, out_dim, k=k, hidden=128, depth=2, dropout=0.1)
        #     for name, out_dim in head_dims.items()
        # })

    def forward(self, z):  # z: (B, k, d_in)
        return {name: head(z) for name, head in self.heads.items()}  # (B, k, C) ต่อหัว

# ---- Backbone + multi-head ของ TabM ----
class TabMBackboneMultiHead(nn.Module):
    def __init__(
        self,
        n_num_features: int,
        cat_cardinalities: list[int] | None,
        head_dims: dict[str, int],
        k: int = 8,
        d_block: int = 256,
        n_blocks: int = 4,
        dropout: float = 0.0,
        start_scaling_init: str = "normal",
        start_scaling_init_chunks=None,  # ใส่ None ถ้าอินพุตเป็นก้อนเดียว
    ):
        super().__init__()
        self.k = k
        self.n_num = n_num_features
        self.cats = list(cat_cardinalities or [])
        d_in = n_num_features + sum(self.cats)  # one-hot cat ภายในโมเดล

        self.ensemble_view = EnsembleView(k=k)
        self.backbone = make_tabm_backbone(
            d_in=d_in,
            d_block=d_block,
            n_blocks=n_blocks,
            dropout=dropout,
            k=k,
            start_scaling_init=start_scaling_init,
            start_scaling_init_chunks=start_scaling_init_chunks,
        )
        self.heads = TabMHeads(d_in=d_block, k=k, head_dims=head_dims)

    def _one_hot_cat(self, x_cat: torch.Tensor | None):
        if x_cat is None or len(self.cats) == 0:
            return None
        oh = [F.one_hot(x_cat[:, i].long(), num_classes=c) for i, c in enumerate(self.cats)]
        return torch.cat(oh, dim=-1).float()

    def forward(self, x_num: torch.Tensor, x_cat: torch.Tensor | None = None):
        # x_num: (B, n_num), x_cat: (B, n_cat) เป็นดัชนี 0..card-1
        if x_cat is not None and len(self.cats):
            x = torch.cat([x_num, self._one_hot_cat(x_cat)], dim=-1)
        else:
            x = x_num
        x = self.ensemble_view(x)     # (B, k, D)
        z = self.backbone(x)          # (B, k, d_block)
        return self.heads(z)          # dict: name -> (B, k, C)


In [11]:
class AvgFusionBackbone(nn.Module):
    def __init__(self, tabm, ftt, d_tabm, d_ftt, d_out, learn_alpha=False, alpha_init=0.5):
        super().__init__()
        self.tabm = tabm
        self.ftt  = ftt
        self.d_out  = d_out 
        self.proj_tabm = nn.Linear(d_tabm, d_out, bias=False)
        self.proj_ftt  = nn.Linear(d_ftt,  d_out, bias=False)
        self.norm1 = nn.LayerNorm(d_out)
        self.norm2 = nn.LayerNorm(d_out)

        # ถ้าอยากเรียนรู้ alpha ให้เป็นพารามิเตอร์ 0..1 (ผ่าน sigmoid)
        self.alpha = nn.Parameter(torch.tensor(alpha_init)) if learn_alpha else None

        self._last_z1 = None
        self._last_z2 = None

    @staticmethod
    def _first_tensor(obj):
        import torch
        if torch.is_tensor(obj):
            return obj
        if isinstance(obj, (list, tuple)):
            for v in obj:
                if torch.is_tensor(v):
                    return v
        if isinstance(obj, dict):
            # ลำดับคีย์ที่มักใช้เก็บ embedding/feature
            for k in ("emb", "features", "feature", "hidden", "h", "x", "repr", "tokens", "out", "logits"):
                v = obj.get(k, None)
                if torch.is_tensor(v):
                    return v
            # เผื่อค่าด้านใน dict เป็นเทนเซอร์ตรง ๆ
            for v in obj.values():
                if torch.is_tensor(v):
                    return v
        return None

    @staticmethod
    def _call_any_signature(model, x_cont, x_cat):
        # พยายามหลาย signature
        try:
            return model(x_cont, x_cat)
        except TypeError:
            try:
                return model(x_cont)
            except TypeError:
                try:
                    return model(x_num=x_cont, x_cat=x_cat)
                except TypeError:
                    return model(x_cont)

    @staticmethod
    def _get_emb(model, x_cont, x_cat, name="model"):
        # -------- เคส TabM (มี ensemble_view + backbone) --------
        if hasattr(model, "ensemble_view") and hasattr(model, "backbone"):
            dev = next(model.backbone.parameters()).device
            x_cont = x_cont.to(dev)
            if x_cat is not None:
                x_cat = x_cat.to(dev)
            if x_cat is not None and getattr(model, "cats", None):
                x = torch.cat([x_cont, model._one_hot_cat(x_cat)], dim=-1)
            else:
                x = x_cont
            x = model.ensemble_view(x)    # (B, k, D)
            z = model.backbone(x)         # (B, k, d_block)
            return z.mean(dim=1)          # (B, d_block)

        # -------- โมเดลทั่วไป (เช่น FTT) --------
        dev = next(model.parameters()).device
        x_cont = x_cont.to(dev)
        if x_cat is not None:
            x_cat = x_cat.to(dev)

        # 1) ลอง forward ปกติหลายแบบ
        out = AvgFusionBackbone._call_any_signature(model, x_cont, x_cat)
        t = AvgFusionBackbone._first_tensor(out)

        # 2) ถ้ายังไม่ได้ ให้ลองเมธอดที่พบบ่อย
        if t is None:
            for attr in ("forward_features", "encode", "extract", "backbone", "encoder"):
                if hasattr(model, attr):
                    fn = getattr(model, attr)
                    try:
                        out2 = fn(x_cont, x_cat)
                    except TypeError:
                        try:
                            out2 = fn(x_cont)
                        except TypeError:
                            continue
                    t = AvgFusionBackbone._first_tensor(out2)
                    if t is not None:
                        break

        # 3) ถ้ายังไม่ได้อีก ให้แจ้ง error ชัด ๆ
        if t is None:
            raise TypeError(
                f"{name}.forward/encode did not return a Tensor. "
                f"Please expose a feature method (e.g., forward_features) or return logits/features."
            )

        # 4) ถ้าเป็น (B, k, D) ให้เฉลี่ย k
        if t.dim() == 3:
            t = t.mean(dim=1)
        return t
        
    def forward(self, x_cont, x_cat):
        z1 = self._get_emb(self.tabm, x_cont, x_cat, name="tabm")  # (B, d_tabm)
        z2 = self._get_emb(self.ftt,  x_cont, x_cat, name="ftt")   # (B, d_ftt)

        z1 = F.normalize(self.norm1(self.proj_tabm(z1)), dim=-1)
        z2 = F.normalize(self.norm2(self.proj_ftt(z2)), dim=-1)

        if self.alpha is None:
            z = 0.5 * (z1 + z2)
        else:
            a = torch.clamp(self.alpha, 0, 1)
            z = a * z1 + (1 - a) * z2

        self._last_z1, self._last_z2 = z1, z2
        return z  # (B, d_out)



In [12]:

d_model_out_dim = 512
n_cont_features = X_cont_train.shape[1] 
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
head_dims = {
    "success_cls": 2,
    # "success_rate_cls": 6,
    "risk_level": 3,
    "days_to_state_change": 4,
    "recommend_category": len(label_encoders["category_group"].classes_),
    "goal_eval": 3,
    # "shortfall_severity_cls": 4,
    "stretch_potential_cls": 3,
}
tabm_enc = TabMBackboneMultiHead(
    n_num_features=n_cont_features,
    cat_cardinalities=cat_cardinalities,    # ถ้าใช้ one-hot ไปแล้ว ให้ใส่ None และไม่ส่ง x_cat ตอน forward
    head_dims=head_dims,
    k=14,
    d_block=d_model_out_dim,
    n_blocks=5,
    dropout=0.3,
    start_scaling_init="normal",            # ค่านิยมใน ref
    start_scaling_init_chunks=None,         # หรือ [n_cont_features, sum(emb_dim)] เพื่อสเกลแยกกลุ่ม
).to(device)
ftt_enc = FTTransformer(
    n_cont_features=n_cont_features,  # ใช้ตามจริงจาก data
    cat_cardinalities=cat_cardinalities,
    d_out=128,
    n_blocks=5,
    d_block=d_model_out_dim,
    attention_n_heads=2,
    attention_dropout=0.25,
    ffn_d_hidden_multiplier=4 / 3,
    ffn_dropout=0.25,
    residual_dropout=0.1,
).to(device)
fuse_backbone = AvgFusionBackbone(tabm_enc, ftt_enc,
                                  d_tabm=512,
                                  d_ftt=128,
                                  d_out=256)

model = MultiHeadWrapper(
    base_model=fuse_backbone,   # <<<<<< ใช้ตัวรวมเป็น backbone
    head_dims=head_dims,
    in_dim=fuse_backbone.d_out,
    d_hidden=64
).to(device)




In [13]:
print("Model is running on:", next(model.parameters()).device)
print("CUDA available:", torch.cuda.is_available())
print(torch.__version__)
print(torch.version.cuda)

Model is running on: cuda:0
CUDA available: True
2.7.1+cu118
11.8


In [14]:
def feature_dropout(x_cont, drop_prob=0.1):
    if drop_prob <= 0: return x_cont
    mask = (torch.rand_like(x_cont) < drop_prob).float()
    return x_cont * (1 - mask)   


In [15]:
ORDINAL_HEADS = {}
def corn_targets(y, num_classes):
    B = y.size(0); K = num_classes
    t = torch.arange(K-1, device=y.device).unsqueeze(0).expand(B, -1)
    return (y.unsqueeze(1) > t).float()

def corn_loss(logits, y, pos_weight=None):
    tgt = corn_targets(y, logits.size(1) + 1)
    return F.binary_cross_entropy_with_logits(
        logits, tgt, pos_weight=pos_weight, reduction="mean"
    )

def corn_predict(logits):
    p = torch.sigmoid(logits)
    return (p > 0.5).sum(dim=1)

In [16]:
SORD_HEADS = {}   
def _sord_targets(y, C, tau=0.75, device=None):
    y = y.view(-1, 1).long()
    cls = torch.arange(C, device=device).view(1, -1)
    dist = (cls - y).abs().float()
    p = torch.softmax(-dist / tau, dim=1)
    return p

_kl = nn.KLDivLoss(reduction='batchmean')

def sord_loss(logits, y, tau=0.75):
    C = logits.size(1)
    p = _sord_targets(y, C, tau=tau, device=logits.device)   # soft target
    return _kl(F.log_softmax(logits, dim=1), p)



In [17]:
class AutomaticWeightedLoss(nn.Module):
    """automatically weighted multi-task loss

    Params：
        num: int，the number of loss
        x: multi-task loss
    Examples：
        loss1=1
        loss2=2
        awl = AutomaticWeightedLoss(2)
        loss_sum = awl(loss1, loss2)
    """
    def __init__(self, num=2):
        super(AutomaticWeightedLoss, self).__init__()
        params = torch.ones(num, requires_grad=True)
        self.params = torch.nn.Parameter(params)

    def forward(self, *x):
        loss_sum = 0
        for i, loss in enumerate(x):
            loss_sum += 0.5 / (self.params[i] ** 2) * loss + torch.log(1 + self.params[i] ** 2)
        return loss_sum

if __name__ == '__main__':
    awl = AutomaticWeightedLoss(2)
    print(awl.parameters())
    
PRIOR_HEADS  = {"stretch_potential_cls", "risk_level"}
HEADS_ORDER  = ["success_cls", "risk_level", "days_to_state_change",
                "recommend_category", "goal_eval", "stretch_potential_cls"]
OTHERS_ORDER = [h for h in HEADS_ORDER if h not in PRIOR_HEADS]

# สร้าง AWL ด้วยจำนวนหัวที่ตรงกับ OTHERS_ORDER
awl = AutomaticWeightedLoss(len(OTHERS_ORDER)).to(device)

<generator object Module.parameters at 0x000001C4A52DB610>


In [18]:
class FocalLoss(nn.Module):
    def __init__(self, gamma=2.0, weight: torch.Tensor | None = None, reduction='mean'):
        super().__init__()
        self.gamma = float(gamma)
        self.reduction = reduction
        # เก็บ weight เป็น buffer (ย้าย device ตามโมดูล, ไม่อัปเดตกราด)
        if weight is not None:
            self.register_buffer('weight', weight.float())
        else:
            self.weight = None

    def forward(self, logits: torch.Tensor, target: torch.Tensor):
        logp = F.log_softmax(logits, dim=-1)
        p = logp.exp()
        w = self.weight
        if w is not None and w.device != logits.device:
            w = w.to(logits.device)
        loss = F.nll_loss(((1 - p) ** self.gamma) * logp,
                          target, weight=w, reduction='none')
        return loss.mean() if self.reduction == 'mean' else loss.sum()

# --- คำนวณ class weight จาก label ทั้งชุด (ไม่ต้องอ้าง logits) ---
def class_weights_from_labels(y: torch.Tensor, num_classes: int) -> torch.Tensor:
    y_cpu = y.detach().view(-1).cpu().long()
    cnt = torch.bincount(y_cpu, minlength=num_classes).float()
    inv = cnt.sum() / cnt.clamp_min(1)     # inverse frequency
    w = inv / inv.mean()                    # normalize ให้มีค่าเฉลี่ย ~1
    return w

# ตัวอย่างการสร้าง weight และผูกเข้ากับ loss ฟังก์ชัน
C_risk    = num_classes_map['risk_level']
C_stretch = num_classes_map['stretch_potential_cls']

w_risk    = class_weights_from_labels(y_train_dict['risk_level'],            C_risk)
w_stretch = class_weights_from_labels(y_train_dict['stretch_potential_cls'], C_stretch)


In [19]:
def _repeat_k_targets(target: torch.Tensor, k: int) -> torch.Tensor:
    # target: (B,) -> (B*k,)
    return target.view(-1, 1).repeat(1, k).reshape(-1)

def _tabm_ce_loss(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
    # pred: (B, k, C)
    B, k, C = pred.shape
    return F.cross_entropy(pred.reshape(B*k, C), _repeat_k_targets(target, k).long())

def _tabm_sord_loss(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
    # pred: (B, k, C) , ฟังก์ชัน sord_loss ของคุณคาด (B, C)
    B, k, C = pred.shape
    return sord_loss(pred.reshape(B*k, C), _repeat_k_targets(target, k))

def _tabm_corn_loss(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
    # pred: (B, k, C-1) สำหรับ CORN
    B, k, D = pred.shape
    return corn_loss(pred.reshape(B*k, D), _repeat_k_targets(target, k))

def _ensemble_probs(pred_logits: torch.Tensor) -> torch.Tensor:
    # ถ้าเป็น (B,k,C) -> เฉลี่ย probs ข้าม k, ถ้า (B,C) -> softmax ตรงๆ
    if pred_logits.dim() == 3:
        return pred_logits.softmax(-1).mean(1)  # (B, C)
    return pred_logits.softmax(-1)              # (B, C)

def _ensemble_argmax(pred_logits: torch.Tensor) -> torch.Tensor:
    return _ensemble_probs(pred_logits).argmax(-1)

def _ensemble_corn_predict(pred_logits: torch.Tensor) -> torch.Tensor:
    # ทางง่ายและได้ผลดี: เอา logits เฉลี่ยข้าม k แล้วใช้ corn_predict เดิม
    if pred_logits.dim() == 3:
        return corn_predict(pred_logits.mean(1))
    return corn_predict(pred_logits)


In [20]:

# ==== CONFIG ====
num_epochs     = 50
warmup_epochs  = 5        
patience       = 10
focal_gamma = 1.5
FOCUS_CFG = {
    "stretch_potential_cls": {"max": 12.0, "ramp": 3},  # ดันแรงกว่านิดหน่อย
    "risk_level":            {"max": 12.0, "ramp": 3},  # ค่อยๆ ไล่ (ถ้ายังนิ่งค่อยเพิ่ม max เป็น 8)
}

def get_focus_scale(head: str, epoch: int) -> float:
    # ช่วง warmup ไม่บูสต์
    if epoch < warmup_epochs:
        return 1.0
    cfg = FOCUS_CFG.get(head)
    if cfg is None:  # หัวอื่นๆ ไม่บูสต์
        return 1.0
    t = min(1.0, (epoch - warmup_epochs) / max(1, cfg["ramp"]))
    return 1.0 + (cfg["max"] - 1.0) * t
# ==== Loss dicts ====
loss_fn_warmup = {
    "success_cls":            nn.CrossEntropyLoss(),
    # "success_rate_cls":       nn.CrossEntropyLoss(label_smoothing=0.1),  
    "risk_level":             nn.CrossEntropyLoss(),  
    "days_to_state_change":   nn.CrossEntropyLoss(),
    "recommend_category":     nn.CrossEntropyLoss(),
    "goal_eval":              nn.CrossEntropyLoss(),
    # "shortfall_severity_cls": nn.CrossEntropyLoss(),
    "stretch_potential_cls":  nn.CrossEntropyLoss(),
}

# ช่วง post: คมขึ้นเล็กน้อย และมี class weight สำหรับสองหัวสำคัญ (สมมติ cw มีอยู่แล้ว)
loss_fn_post = {
    "success_cls":            nn.CrossEntropyLoss(label_smoothing=0.1),
    # "success_rate_cls":       nn.CrossEntropyLoss(label_smoothing=0.1),
    "risk_level":             FocalLoss(gamma=2.0, weight=w_risk),
    "days_to_state_change":   nn.CrossEntropyLoss(label_smoothing=0.1),
    "recommend_category":     nn.CrossEntropyLoss(label_smoothing=0.1),
    "goal_eval":              nn.CrossEntropyLoss(label_smoothing=0.1),
    # "shortfall_severity_cls": FocalLoss(gamma=gamma_per_class, weight=w_sf, reduction='mean'),
    "stretch_potential_cls":  FocalLoss(gamma=2.0, weight=w_stretch),
}

head_names = sorted(list(set(list(loss_fn_warmup.keys()) + list(loss_fn_post.keys()))))

# ==== History ====
history = {
    "lr_step": [],
    "lr_epoch": [],
    "train_loss": [],
    "val_loss": [],
    "train_acc": {k: [] for k in head_names},
    "val_acc":   {k: [] for k in head_names},
    "train_f1":  {k: [] for k in head_names},
    "val_f1":    {k: [] for k in head_names},
    "train_precision": {k: [] for k in head_names},
    "val_precision":   {k: [] for k in head_names},
    "train_recall":    {k: [] for k in head_names},
    "val_recall":      {k: [] for k in head_names},
    "train_precision_by_class": {k: [] for k in head_names},
    "val_precision_by_class":   {k: [] for k in head_names},
    "train_recall_by_class":    {k: [] for k in head_names},
    "val_recall_by_class":      {k: [] for k in head_names},
    "train_loss_by_head":      {k: [] for k in head_names},
    "val_loss_by_head":        {k: [] for k in head_names},
    "raw_train_loss_by_head":  {k: [] for k in head_names},
    "raw_val_loss_by_head":    {k: [] for k in head_names},
}

# optimizer = torch.optim.AdamW(
#     [
#         {"params": model.parameters(), "weight_decay": 7.5e-4, "lr": 1e-3},
#         {"params": awl.parameters(),   "weight_decay": 0.0,    "lr": 1e-3},
#     ]
# )
base_lr = 1e-3
wd      = 7.5e-4

stretch_params = [p for n,p in model.named_parameters()
                  if 'stretch_potential_cls' in n and p.requires_grad]
risk_params    = [p for n,p in model.named_parameters()
                  if 'risk_level' in n and p.requires_grad]
other_params   = [p for n,p in model.named_parameters()
                  if ('stretch_potential_cls' not in n)
                  and ('risk_level' not in n)
                  and p.requires_grad]

optimizer = torch.optim.AdamW([
    {"params": other_params,   "lr": base_lr, "weight_decay": wd, "name": "backbone"},
    {"params": stretch_params, "lr": base_lr, "weight_decay": wd, "name": "stretch"},
    {"params": risk_params,    "lr": base_lr, "weight_decay": wd, "name": "risk"},
    {"params": awl.parameters(),"lr": base_lr, "weight_decay": 0.0, "name": "awl"},
])

# ติดค่า base_lr/last_scale เริ่มต้นให้ทุก group
for pg in optimizer.param_groups:
    pg["base_lr"]    = pg["lr"]
    pg["last_scale"] = 1.0


scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.5, patience=3
)


best_val_loss = float("inf")
patience_counter = 0

# with torch.no_grad():
#     cnt = torch.bincount(y_train_dict['shortfall_severity_cls'].cpu(), minlength=4).float()
#     pi_sf = (cnt / cnt.sum()).clamp_min(1e-12)
# log_prior_sf = pi_sf.log().to(device)

# def tau_schedule(epoch, num_epochs, tau_max=0.5, warm_frac=0.3):
#     t = min(1.0, epoch/(warm_frac*num_epochs))
#     return tau_max * t
def ramp_scale(epoch: int, max_scale: float, ramp_epochs: int):
    # ค่อย ๆ ไต่จาก 1.0 -> max_scale ภายใน ramp_epochs
    t = min(1.0, epoch / max(1.0, float(ramp_epochs)))
    return 1.0 + (max_scale - 1.0) * t

# ค่าเริ่มต้น (จูนได้)
STRETCH_MAX  = 12.0   # 6–10
STRETCH_RAMP = 6     # 3–8
RISK_MAX     = 10.0   # 3–6
RISK_RAMP    = 6

# =============== TRAINING LOOP ===============
for epoch in range(num_epochs):

    loss_fn_dict = loss_fn_warmup if epoch < warmup_epochs else loss_fn_post
    
        # --- boost lr เฉพาะพารามิเตอร์ของหัว (ต้น epoch) ---
    stretch_scale = ramp_scale(epoch, STRETCH_MAX, STRETCH_RAMP)
    risk_scale    = ramp_scale(epoch, RISK_MAX,    RISK_RAMP)

    for pg in optimizer.param_groups:
        name = pg.get("name", "")
        if name == "stretch":
            pg["lr"] = pg["base_lr"] * stretch_scale
            pg["last_scale"] = stretch_scale
        elif name == "risk":
            pg["lr"] = pg["base_lr"] * risk_scale
            pg["last_scale"] = risk_scale
        else:
            pg["lr"] = pg["base_lr"]
            pg["last_scale"] = 1.0


    # ========= TRAIN =========
    model.train()
    total_losses_by_head_train = {k: 0.0 for k in loss_fn_dict}
    raw_losses_by_head_train   = {k: 0.0 for k in loss_fn_dict}
    train_acc = {k: [] for k in loss_fn_dict}
    train_f1  = {k: [] for k in loss_fn_dict}
    train_cache_y = {k: [] for k in loss_fn_dict}
    train_cache_p = {k: [] for k in loss_fn_dict}
    n_batches_train = 0

    # ใช้ oversample เฉพาะช่วง warmup
    # train_loader = train_loader_os if (epoch < warmup_epochs) else train_loader_plain
    train_loader = train_loader_plain
    for x_cat, x_cont, y_dict in train_loader:
        n_batches_train += 1
        x_cat, x_cont = x_cat.to(device), x_cont.to(device)
        y_dict = {k: v.to(device) for k, v in y_dict.items()}

        x_cont = feature_dropout(x_cont, drop_prob=0.05)

        preds = model(x_cont, x_cat)

        raw_losses = {}
        for key, pred in preds.items():
            target = y_dict[key]

            if key in ORDINAL_HEADS:  # CORN
                # ---- LOSS ----
                if pred.dim() == 3:  # (B, k, C-1)
                    loss_value = _tabm_corn_loss(pred, target)
                else:                 # (B, C-1)
                    loss_value = corn_loss(pred, target)
                # ---- PRED ----
                pred_class = _ensemble_corn_predict(pred)

            elif key in SORD_HEADS:   # SORD
                # ---- LOSS ----
                if pred.dim() == 3:   # (B, k, C)
                    loss_value = _tabm_sord_loss(pred, target)
                else:                 # (B, C)
                    loss_value = sord_loss(pred, target)
                # ---- PRED ----
                pred_class = _ensemble_argmax(pred)

            else:                     # ปกติ (CrossEntropy)
                # ---- LOSS ----
                if pred.dim() == 3:   # (B, k, C)
                    loss_value = _tabm_ce_loss(pred, target)
                else:                 # (B, C)
                    loss_value = loss_fn_dict[key](pred, target.long())
                # ---- PRED ----
                pred_class = _ensemble_argmax(pred)

            raw_losses[key] = loss_value

            # --- logging (เดิมของคุณ) ---
            raw_losses_by_head_train[key]   += loss_value.item()
            total_losses_by_head_train[key] += loss_value.item()
            tgt = target.detach().cpu()
            prd = pred_class.detach().cpu()
            train_acc[key].append(accuracy_score(tgt, prd))
            train_f1[key].append(f1_score(tgt, prd, average="macro", zero_division=0))
            train_cache_y[key].append(tgt.numpy().ravel())
            train_cache_p[key].append(prd.numpy().ravel())

        # active_heads = [h for h in HEADS_ORDER if h in raw_losses]
        # if epoch < warmup_epochs:
        #     loss = torch.stack([raw_losses[h] for h in active_heads]).sum()
        # else:
        #     loss = awl(*[raw_losses[h] for h in active_heads])
        # scaled_losses = [
        #     raw_losses[h] * get_focus_scale(h, epoch)
        #     for h in active_heads
        # ]

        # if epoch < warmup_epochs:
        #     loss = torch.stack(scaled_losses).sum()
        # else:
        #     loss = awl(*scaled_losses)
        
        awl_losses   = [raw_losses[h] * get_focus_scale(h, epoch) for h in OTHERS_ORDER if h in raw_losses]
        prior_losses = [raw_losses[h] * get_focus_scale(h, epoch) for h in PRIOR_HEADS  if h in raw_losses]

        if epoch < warmup_epochs:
            loss = torch.stack(awl_losses + prior_losses).sum()
        else:
            loss = awl(*awl_losses) + (torch.stack(prior_losses).sum() if len(prior_losses) else 0.0)


        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

    # --- สรุป train (เดิมของคุณ) ---
    for k in total_losses_by_head_train:
        total_losses_by_head_train[k] /= max(1, n_batches_train)
        raw_losses_by_head_train[k]   /= max(1, n_batches_train)
    total_loss_train = sum(total_losses_by_head_train.values())

    # macro precision/recall per head (เดิมของคุณ)
    for k in loss_fn_dict:
        if len(train_cache_y[k]):
            y_all = np.concatenate(train_cache_y[k])
            p_all = np.concatenate(train_cache_p[k])
            tr_prec = precision_score(y_all, p_all, average="macro", zero_division=0)
            tr_rec  = recall_score(y_all,   p_all, average="macro", zero_division=0)
            history["train_precision"][k].append(float(tr_prec))
            history["train_recall"][k].append(float(tr_rec))
            labels = np.arange(num_classes_map[k]) if k in num_classes_map else None
            tr_prec_cls = precision_score(y_all, p_all, average=None, labels=labels, zero_division=0)
            tr_rec_cls  = recall_score(y_all,   p_all, average=None, labels=labels, zero_division=0)
            history["train_precision_by_class"][k].append(tr_prec_cls.tolist())
            history["train_recall_by_class"][k].append(tr_rec_cls.tolist())
        else:
            history["train_precision"][k].append(np.nan)
            history["train_recall"][k].append(np.nan)
            history["train_precision_by_class"][k].append(None)
            history["train_recall_by_class"][k].append(None)

    # ========= VALIDATION =========
    model.eval()
    total_losses_by_head_val = {k: 0.0 for k in loss_fn_dict}
    raw_losses_by_head_val   = {k: 0.0 for k in loss_fn_dict}
    val_acc = {k: [] for k in loss_fn_dict}
    val_f1  = {k: [] for k in loss_fn_dict}
    val_cache_y = {k: [] for k in loss_fn_dict}
    val_cache_p = {k: [] for k in loss_fn_dict}
    n_batches_val = 0

    # (NEW) เก็บ logits/targets ของ shortfall เพื่อจะเอาไป calibrate ได้ภายหลัง
    sf_logits_all, sf_targets_all = [], []

    with torch.no_grad():
        for x_cat, x_cont, y_dict in val_loader:
            n_batches_val += 1
            x_cat, x_cont = x_cat.to(device), x_cont.to(device)
            y_dict = {k: v.to(device) for k, v in y_dict.items()}

            preds = model(x_cont, x_cat)
            for key, pred in preds.items():
                target = y_dict[key]

                if key in ORDINAL_HEADS:  # CORN
                    # ---- LOSS ----
                    if pred.dim() == 3:  # (B, k, C-1)
                        val_loss_value = _tabm_corn_loss(pred, target)
                    else:                 # (B, C-1)
                        val_loss_value = corn_loss(pred, target)
                    # ---- PRED ----
                    pred_class = _ensemble_corn_predict(pred)

                elif key in SORD_HEADS:   # SORD
                    # ---- LOSS ----
                    if pred.dim() == 3:   # (B, k, C)
                        val_loss_value = _tabm_sord_loss(pred, target)
                    else:                 # (B, C)
                        val_loss_value = sord_loss(pred, target)
                    # ---- PRED ----
                    pred_class = _ensemble_argmax(pred)

                else:                     # ปกติ (CrossEntropy)
                    # ---- LOSS ----
                    if pred.dim() == 3:   # (B, k, C)
                        val_loss_value = _tabm_ce_loss(pred, target)
                    else:                 # (B, C)
                        val_loss_value = loss_fn_dict[key](pred, target.long())
                    # ---- PRED ----
                    pred_class = _ensemble_argmax(pred)

                raw_losses_by_head_val[key]   += val_loss_value.item()
                total_losses_by_head_val[key] += val_loss_value.item()

                tgt = target.cpu()
                prd = pred_class.cpu()
                val_acc[key].append(accuracy_score(tgt, prd))
                val_f1[key].append(f1_score(tgt, prd, average="macro", zero_division=0))
                val_cache_y[key].append(tgt.numpy().ravel())
                val_cache_p[key].append(prd.numpy().ravel())

    for k in total_losses_by_head_val:
        total_losses_by_head_val[k] /= max(1, n_batches_val)
        raw_losses_by_head_val[k]   /= max(1, n_batches_val)
    total_loss_val = sum(total_losses_by_head_val.values())

    for k in loss_fn_dict:
        if len(val_cache_y[k]):
            y_all = np.concatenate(val_cache_y[k])
            p_all = np.concatenate(val_cache_p[k])
            va_prec = precision_score(y_all, p_all, average="macro", zero_division=0)
            va_rec  = recall_score(y_all,   p_all, average="macro", zero_division=0)
            history["val_precision"][k].append(float(va_prec))
            history["val_recall"][k].append(float(va_rec))
            labels = np.arange(num_classes_map[k]) if k in num_classes_map else None
            va_prec_cls = precision_score(y_all, p_all, average=None, labels=labels, zero_division=0)
            va_rec_cls  = recall_score(y_all,   p_all, average=None, labels=labels, zero_division=0)
            history["val_precision_by_class"][k].append(va_prec_cls.tolist())
            history["val_recall_by_class"][k].append(va_rec_cls.tolist())
        else:
            history["val_precision"][k].append(np.nan)
            history["val_recall"][k].append(np.nan)
            history["val_precision_by_class"][k].append(None)
            history["val_recall_by_class"][k].append(None)

    # --- Scheduler (ของเดิม) ---
    focus_heads = ["stretch_potential_cls", "risk_level"]
    if all(h in total_losses_by_head_val for h in focus_heads):
        focus_val = sum(total_losses_by_head_val[h] for h in focus_heads)
    else:
        focus_val = total_loss_val
    scheduler.step(focus_val)
        
    # ให้ ReduceLROnPlateau คุม "ฐาน" แล้วอัปเดต base_lr สำหรับรอบหน้า
    for pg in optimizer.param_groups:
        # lr ตอนนี้ = (base_lr หลัง scheduler) * scale ของ epoch นี้
        # แปลงกลับไปเป็น base_lr เพื่อใช้คูณ scale ใหม่ใน epoch หน้า
        pg["base_lr"] = pg["lr"] / pg.get("last_scale", 1.0)


    # --- log lr/ history/ print (ของเดิม) ---
    last_epoch_train_steps = n_batches_train
    if last_epoch_train_steps > 0 and len(history["lr_step"]) >= last_epoch_train_steps:
        history["lr_epoch"].append(float(np.mean(history["lr_step"][-last_epoch_train_steps:])))
    else:
        history["lr_epoch"].append(float(optimizer.param_groups[0]['lr']))

    history["train_loss"].append(total_loss_train)
    history["val_loss"].append(total_loss_val)
    for k in head_names:
        history["train_loss_by_head"][k].append(total_losses_by_head_train.get(k, np.nan))
        history["val_loss_by_head"][k].append(total_losses_by_head_val.get(k, np.nan))
        history["raw_train_loss_by_head"][k].append(raw_losses_by_head_train.get(k, np.nan))
        history["raw_val_loss_by_head"][k].append(raw_losses_by_head_val.get(k, np.nan))
        history["train_acc"][k].append(float(np.mean(train_acc.get(k, []))) if len(train_acc.get(k, []))>0 else np.nan)
        history["val_acc"][k].append(float(np.mean(val_acc.get(k, []))) if len(val_acc.get(k, []))>0 else np.nan)
        history["train_f1"][k].append(float(np.mean(train_f1.get(k, []))) if len(train_f1.get(k, []))>0 else np.nan)
        history["val_f1"][k].append(float(np.mean(val_f1.get(k, []))) if len(val_f1.get(k, []))>0 else np.nan)

    # with torch.no_grad():
    #     inv_sigma2 = 1.0 / (awl.params**2 + 1e-8)
    #     awl_weights = {h: float(inv_sigma2[i].item()) for i, h in enumerate(HEADS_ORDER)}
    #     print(f"[epoch {epoch:02d}] AWL inverse-var weights: {awl_weights}")
    with torch.no_grad():
        inv_sigma2 = (1.0 / (awl.params**2 + 1e-8)).view(-1)
        K = min(len(inv_sigma2), len(OTHERS_ORDER))
        awl_weights = {h: float(inv_sigma2[i].item()) for i, h in enumerate(OTHERS_ORDER[:K])}
        for h in PRIOR_HEADS:
            awl_weights[h] = "(not in AWL; LR-ramped)"
        print(f"[epoch {epoch:02d}] AWL inverse-var weights: {awl_weights}")


    print(f"Epoch {epoch:02d} | Train Loss: {total_loss_train:.4f} | Val Loss: {total_loss_val:.4f}")
    for k in loss_fn_dict:
        ta = np.mean(train_acc[k]) if len(train_acc[k]) else float('nan')
        va = np.mean(val_acc[k])   if len(val_acc[k])   else float('nan')
        tf = np.mean(train_f1[k])  if len(train_f1[k])  else float('nan')
        vf = np.mean(val_f1[k])    if len(val_f1[k])    else float('nan')
        tp = history["train_precision"][k][-1]
        vp = history["val_precision"][k][-1]
        tr = history["train_recall"][k][-1]
        vr = history["val_recall"][k][-1]
        print(
            f"  - {k:20s} loss: {total_losses_by_head_train[k]:.4f} | "
            f"val_loss: {total_losses_by_head_val[k]:.4f} "
            f"raw_loss: {raw_losses_by_head_train[k]:.4f} | raw_val: {raw_losses_by_head_val[k]:.4f} "
            f"| train_acc: {ta:.4f} | val_acc: {va:.4f} "
            f"| train_F1: {tf:.4f} | val_F1: {vf:.4f} "
            f"| train_P: {tp:.4f} | val_P: {vp:.4f} | train_R: {tr:.4f} | val_R: {vr:.4f}"
        )

    # print("shortfall val precision by class:", history["val_precision_by_class"]["stretch_potential_cls"][-1])
    # print("shortfall val recall by class:",    history["val_recall_by_class"]["stretch_potential_cls"][-1])


[epoch 00] AWL inverse-var weights: {'success_cls': 1.0, 'days_to_state_change': 1.0, 'recommend_category': 1.0, 'goal_eval': 1.0, 'risk_level': '(not in AWL; LR-ramped)', 'stretch_potential_cls': '(not in AWL; LR-ramped)'}
Epoch 00 | Train Loss: 6.8440 | Val Loss: 6.0810
  - success_cls          loss: 0.6365 | val_loss: 0.6029 raw_loss: 0.6365 | raw_val: 0.6029 | train_acc: 0.6488 | val_acc: 0.6858 | train_F1: 0.5159 | val_F1: 0.6376 | train_P: 0.6451 | val_P: 0.6745 | train_R: 0.5718 | val_R: 0.6370
  - risk_level           loss: 1.0706 | val_loss: 1.0444 raw_loss: 1.0706 | raw_val: 1.0444 | train_acc: 0.4140 | val_acc: 0.4546 | train_F1: 0.3583 | val_F1: 0.4545 | train_P: 0.4095 | val_P: 0.4631 | train_R: 0.4140 | val_R: 0.4544
  - days_to_state_change loss: 1.0785 | val_loss: 0.7711 raw_loss: 1.0785 | raw_val: 0.7711 | train_acc: 0.4920 | val_acc: 0.6197 | train_F1: 0.4360 | val_F1: 0.5753 | train_P: 0.4695 | val_P: 0.6145 | train_R: 0.4921 | val_R: 0.6206
  - recommend_category   

KeyboardInterrupt: 

In [None]:

plt.plot(history["train_loss"], label="Train Loss")
plt.plot(history["val_loss"], label="Val Loss")
plt.legend()
plt.title("Total Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.show()

for head in history["train_acc"]:
    plt.figure()
    plt.plot(history["train_acc"][head], label=f"{head} Train Acc")
    plt.plot(history["val_acc"][head], label=f"{head} Val Acc")
    plt.legend()
    plt.title(f"{head} Accuracy")
    plt.xlabel("Epoch")
    plt.ylabel("Accuracy")
    plt.show()

for head in history["train_f1"]:
    plt.figure()
    plt.plot(history["train_f1"][head], label=f"{head} Train F1")
    plt.plot(history["val_f1"][head], label=f"{head} Val F1")
    plt.legend()
    plt.title(f"{head} Macro F1")
    plt.xlabel("Epoch")
    plt.ylabel("F1-score")
    plt.show()
