In [1]:
import functools
import json
import logging

import config
import lightning as L
import numpy as np
import polars as pl
import torch
from preprocess import fe, load_data, preprocess
from pytorch_tabular.models.common.layers import ODST
from scipy.stats import rankdata
from sklearn.preprocessing import minmax_scale
from torch import nn
from torchmetrics import MeanMetric

from src.customs.fold import add_kfold
from src.customs.metrics import Metric, metric
from src.trainer.tabular.simple import single_inference_fn, single_train_fn

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)

In [2]:
train_test_df = load_data(config=config, valid_ratio=config.VALID_RATIO)
target_df = pl.read_csv("./data/extr_output/101/1/101.csv").with_columns(
    pl.col(config.SURVIVAL_TIME_COL).log().alias("t_log_efs_time"),
)
target_df = target_df.with_columns(
    pl.Series("t_efs_time_scaled", minmax_scale(target_df[config.SURVIVAL_TIME_COL], feature_range=(0, 1)))
)


target_cols = [x for x in target_df.columns if x.startswith("t_")]
train_test_df = train_test_df.join(
    target_df.select(
        [
            config.ID_COL,
            *target_cols,
        ],
    ),
    on=config.ID_COL,
    how="left",
)
config.META_COLS = set(config.META_COLS) | set(target_cols)

features_df = fe(config=config, train_test_df=train_test_df)
features_df = preprocess(config=config, features_df=features_df)
feature_names = sorted([x for x in features_df.columns if x.startswith(config.FEATURE_PREFIX)])
cat_features = [x for x in feature_names if x.startswith(f"{config.FEATURE_PREFIX}c_")]
cont_features = [x for x in feature_names if x.startswith(f"{config.FEATURE_PREFIX}n_")]


In [3]:
class CatEmbedding(nn.Module):
    """
    Embedding module for the categorical dataframe.
    """

    def __init__(self, projection_dim: int, categorical_cardinality: list[int], embedding_dim: int):
        """
        projection_dim: The dimension of the final output after projecting the concatenated embeddings into a lower-dimensional space.
        categorical_cardinality: A list where each element represents the number of unique categories (cardinality) in each categorical feature.
        embedding_dim: The size of the embedding space for each categorical feature.
        self.embeddings: list of embedding layers for each categorical feature.
        self.projection: sequential neural network that goes from the embedding to the output projection dimension with GELU activation.
        """
        super().__init__()
        self.embeddings = nn.ModuleList(
            [nn.Embedding(cardinality, embedding_dim) for cardinality in categorical_cardinality]
        )
        self.projection = nn.Sequential(
            nn.Linear(embedding_dim * len(categorical_cardinality), projection_dim),
            nn.GELU(),
            nn.Linear(projection_dim, projection_dim),
        )

    def forward(self, x_cat):
        """
        Apply the projection on concatened embeddings that contains all categorical features.
        """
        x_cat = [embedding(x_cat[:, i]) for i, embedding in enumerate(self.embeddings)]
        x_cat = torch.cat(x_cat, dim=1)
        return self.projection(x_cat)


class Net(nn.Module):
    """
    Train a model on both categorical embeddings and numerical data.
    """

    def __init__(
        self,
        continuous_dim: int,
        categorical_cardinality: list[int],
        embedding_dim: int,
        projection_dim: int,
        hidden_dim: int,
        dropout: float = 0.2,
    ):
        super().__init__()
        self.embeddings = CatEmbedding(projection_dim, categorical_cardinality, embedding_dim)
        self.mlp = nn.Sequential(
            ODST(projection_dim + continuous_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.Dropout(dropout),
        )
        self.main_out = nn.Linear(hidden_dim, 1)
        self.aux_out = nn.Linear(hidden_dim, 1)
        self.dropout = nn.Dropout(dropout)

        # initialize weights
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight)
                nn.init.zeros_(m.bias)

    def forward(self, x_cat, x_cont, **kwargs):
        """
        Create embedding layers for categorical data, concatenate with continous variables.
        Add dropout and goes through MLP and return raw output and 1-dimensional output as well.
        """
        x = self.embeddings(x_cat)
        x = torch.cat([x, x_cont], dim=1)
        x = self.dropout(x)
        x = self.mlp(x)
        return self.main_out(x), self.aux_out(x)


@functools.lru_cache
def combinations(N):
    """
    calculates all possible 2-combinations (pairs) of a tensor of indices from 0 to N-1,
    and caches the result using functools.lru_cache for optimization
    """
    ind = torch.arange(N)
    comb = torch.combinations(ind, r=2)
    return comb.cuda()


In [4]:
class LitModule(L.LightningModule):
    """
    Main Model creation and losses definition to fully train the model.
    """

    def __init__(
        self,
        net: nn.Module,
        optimizer: torch.optim.Optimizer,  # partial
        scheduler: torch.optim.lr_scheduler,  # partial
        aux_weight: float = 0.1,
        margin: float = 0.5,
    ):
        super().__init__()
        self.save_hyperparameters(logger=False)  # to use self.hparams

        # Creates an instance of the NN model defined above
        self.model = net
        self.results = []

        # for averaging loss across batches
        self.train_main_loss = MeanMetric()
        self.train_loss = MeanMetric()

        self.val_main_loss = MeanMetric()
        self.val_loss = MeanMetric()

    def forward(self, batch):
        x, x_aux = self.model(x_cat=batch["x_cat"], x_cont=batch["x_cont"])
        return x.squeeze(1), x_aux.squeeze(1)

    def model_step(self, batch):
        y_pred, aux_pred = self.forward(batch)
        main_loss = self.calc_main_loss(time=batch["main_target"], event=batch["event"], y_pred=y_pred)
        aux_loss = nn.functional.mse_loss(aux_pred, batch["aux_target"], reduction="mean")
        loss = main_loss + (aux_loss * self.hparams.aux_weight)
        # loss = aux_loss
        return {
            "loss": loss,
            "main_loss": main_loss,
            "aux_loss": aux_loss,
            "y_pred": y_pred,
            "aux_pred": aux_pred,
        }

    def training_step(self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor:
        """Perform a single training step on a batch of data from the training set.

        :param batch: A batch of data (a tuple) containing the input tensor of images and target
            labels.
        :param batch_idx: The index of the current batch.
        :return: A tensor of losses between model predictions and targets.
        """
        result = self.model_step(batch)

        # update and log metrics
        self.train_main_loss(result["main_loss"])
        self.train_loss(result["loss"])
        self.log("train/main_loss", self.train_main_loss, on_step=False, on_epoch=True, prog_bar=False)
        self.log("train/loss", self.train_loss, on_step=False, on_epoch=True, prog_bar=False)

        return result["loss"]

    def calc_main_loss(self, time, event, y_pred):
        N = time.shape[0]
        pairs = combinations(N)

        # 比較可能な候補ペアのみを取得
        pairs = pairs[(event[pairs[:, 0]] == 1) | (event[pairs[:, 1]] == 1)]

        left_index, right_index = pairs[:, 0], pairs[:, 1]
        time_left, time_right = time[left_index], time[right_index]
        event_left, event_right = event[left_index], event[right_index]
        y_pred_left, y_pred_right = y_pred[left_index], y_pred[right_index]

        # calculate the loss
        y = 2 * (time_left - time_right).int() - 1
        diff = y_pred_right - y_pred_left
        loss = nn.functional.relu(-y * (diff) + self.hparams.margin)

        # loss 計算対象 mask
        mask = self._get_mask(
            time_left=time_left,
            time_right=time_right,
            event_left=event_left,
            event_right=event_right,
        )
        loss = (loss.double() * (mask.double())).sum() / mask.sum()  # mean across batch
        return loss

    def _get_mask(self, time_left, time_right, event_left, event_right):
        # Case 1: left が right より生存時間が長いが、right が censored されている場合
        left_outlived = time_left >= time_right
        left_1_right_0 = (event_left == 1) & (event_right == 0)

        # Case 2: right が left より生存時間が長いが、left が censored されている場合
        right_outlived = time_right >= time_left
        right_1_left_0 = (event_right == 1) & (event_left == 0)

        # Combine the masks
        mask = (left_outlived & left_1_right_0) | (right_outlived & right_1_left_0)
        return ~mask  # Invert the mask to get the valid pairs

    def validation_step(self, batch, batch_idx):
        result = self.model_step(batch)

        # update and log metrics
        self.val_main_loss(result["main_loss"])
        self.val_loss(result["loss"])
        self.log("val/main_loss", self.val_main_loss, on_step=False, on_epoch=True, prog_bar=False)
        self.log("val/loss", self.val_loss, on_step=False, on_epoch=True, prog_bar=False)

        # store time, event, y_pred, aux_pred for c-index calculation
        self.results.append(
            {
                **{k: v for k, v in result.items() if k in ["y_pred", "aux_pred"]},
                **{k: v for k, v in batch.items() if k in ["time", "event", "race_group"]},
            }
        )

    def calc_metric(
        self,
        time: np.ndarray,
        event: np.ndarray,
        y_pred: np.ndarray,
        aux_pred: np.ndarray,
        blend_pred: np.ndarray,
        race_group: np.ndarray,
    ):
        score_y = metric(y_time=time, y_event=event, y_pred=y_pred, race_group=race_group)
        score_aux = metric(y_time=time, y_event=event, y_pred=aux_pred, race_group=race_group)
        score_blend = metric(y_time=time, y_event=event, y_pred=blend_pred, race_group=race_group)
        return score_y, score_aux, score_blend

    def on_validation_epoch_end(self):
        """
        At the end of the validation epoch, it computes and logs the concordance index
        """
        time = np.concatenate([r["time"].cpu().numpy() for r in self.results])
        event = np.concatenate([r["event"].cpu().numpy() for r in self.results])
        y_pred = np.concatenate([r["y_pred"].cpu().numpy() for r in self.results])
        aux_pred = np.concatenate([r["aux_pred"].cpu().numpy() for r in self.results])
        race_group = np.concatenate([r["race_group"].cpu().numpy() for r in self.results])

        # ensemble: rankdata -> normalize -> mean
        blend_pred = np.sum([rankdata(y_pred), rankdata(aux_pred)], axis=0)

        score_y, score_aux, score_blend = self.calc_metric(
            time=time,
            event=event,
            y_pred=y_pred,
            aux_pred=aux_pred,
            blend_pred=blend_pred,
            race_group=race_group,
        )
        self.log("val/score", score_y, on_epoch=True, prog_bar=True, logger=True)
        self.log("val/aux_score", score_aux, on_epoch=True, prog_bar=True, logger=True)
        self.log("val/blend_score", score_blend, on_epoch=True, prog_bar=True, logger=True)

        self.results = []

    def configure_optimizers(self):
        optimizer = self.hparams.optimizer(params=self.trainer.model.parameters())
        if self.hparams.scheduler is not None:
            scheduler = self.hparams.scheduler(optimizer=optimizer)
            return {
                "optimizer": optimizer,
                "lr_scheduler": {
                    "scheduler": scheduler,
                    "monitor": "val/loss",
                    "interval": "epoch",
                    "frequency": 1,
                },
            }
        return {"optimizer": optimizer}


class TrainDataset(torch.utils.data.Dataset):
    def __init__(
        self,
        df: pl.DataFrame,
        cat_feature_cols: list[str],
        cont_feature_cols: list[str],
        main_target_col: str,
        aux_target_col: str,
        time_col: str,
        event_col: str,
        race_group_col: str,
    ):
        self.cat_features = df.select(cat_feature_cols).to_numpy()
        self.cont_features = df.select(cont_feature_cols).to_numpy()
        self.time = df[time_col].to_numpy()
        self.event = df[event_col].to_numpy()
        self.main_target = df[main_target_col].to_numpy()
        self.aux_target = df[aux_target_col].to_numpy()
        self.race_group = df[race_group_col].to_numpy()

    def __len__(self):
        return len(self.time)

    def __getitem__(self, idx):
        return {
            "x_cat": torch.tensor(self.cat_features[idx], dtype=torch.long),
            "x_cont": torch.tensor(self.cont_features[idx], dtype=torch.float),
            "time": torch.tensor(self.time[idx], dtype=torch.float),
            "event": torch.tensor(self.event[idx], dtype=torch.long),
            "main_target": torch.tensor(self.main_target[idx], dtype=torch.float),
            "aux_target": torch.tensor(self.aux_target[idx], dtype=torch.float),
            "race_group": torch.tensor(self.race_group[idx]),
        }


class TestDataset(torch.utils.data.Dataset):
    def __init__(
        self,
        df: pl.DataFrame,
        cat_feature_cols: list[str],
        cont_feature_cols: list[str],
        race_group_col: str,
    ):
        self.cat_features = df.select(cat_feature_cols).to_numpy()
        self.cont_features = df.select(cont_feature_cols).to_numpy()
        self.race_group = df[race_group_col].to_numpy()

    def __len__(self):
        return len(self.cat_features)

    def __getitem__(self, idx):
        return {
            "x_cat": torch.tensor(self.cat_features[idx], dtype=torch.long),
            "x_cont": torch.tensor(self.cont_features[idx], dtype=torch.float),
            "race_group": torch.tensor(self.race_group[idx]),
        }


def inference_fn(dataloader, model, device):
    model.eval()
    predictions = []

    with torch.no_grad():
        for batch in dataloader:
            batch = {k: v.to(device) for k, v in batch.items()}
            outputs = model(**batch)[0]  # multioutput
            predictions.append(outputs.cpu())

    predictions = torch.cat(predictions, dim=0).numpy().reshape(-1)
    return predictions

In [None]:
BATCH_SIZE = 2048
MAX_EPOCHS = 32

va_result_df, va_scores = pl.DataFrame(), {}

categorical_cardinality = [features_df[f].n_unique() + 1 for f in cat_features]

for seed in config.SEEDS:
    L.seed_everything(seed)
    name = f"pair_{seed}"
    i_features_df = add_kfold(
        features_df,
        n_splits=config.N_SPLITS,
        random_state=seed,
        fold_col=config.FOLD_COL,
    )

    _va_result_df, _scores = pl.DataFrame(), {}
    for i_fold in i_features_df[config.FOLD_COL].unique().sort():
        fold_name = f"fold_{i_fold:02}"
        print(f"🚀 Start training: {name} - {fold_name}")
        tr_dataloader = torch.utils.data.DataLoader(
            TrainDataset(
                df=i_features_df.filter(pl.col(config.FOLD_COL) != i_fold),
                cat_feature_cols=cat_features,
                cont_feature_cols=cont_features,
                main_target_col="t_efs_time_scaled",
                aux_target_col="t_kmf",
                time_col=config.SURVIVAL_TIME_COL,
                event_col=config.EVENT_COL,
                race_group_col="f_c_race_group",
            ),
            batch_size=BATCH_SIZE,
            shuffle=True,
            drop_last=True,
        )
        va_dataloader = torch.utils.data.DataLoader(
            TrainDataset(
                df=i_features_df.filter(pl.col(config.FOLD_COL) == i_fold),
                cat_feature_cols=cat_features,
                cont_feature_cols=cont_features,
                main_target_col="t_efs_time_scaled",
                aux_target_col="t_kmf",
                time_col=config.SURVIVAL_TIME_COL,
                event_col=config.EVENT_COL,
                race_group_col="f_c_race_group",
            ),
            batch_size=BATCH_SIZE,
            shuffle=False,
            drop_last=False,
        )

        # Create the model
        net = Net(
            continuous_dim=len(cont_features),
            categorical_cardinality=categorical_cardinality,  # full cardinality
            embedding_dim=16,
            projection_dim=24,
            hidden_dim=64,
            dropout=0,
        )
        optimizer = functools.partial(torch.optim.AdamW, lr=0.001, weight_decay=0)
        scheduler = functools.partial(torch.optim.lr_scheduler.ReduceLROnPlateau, mode="min", factor=0.1, patience=10)

        lit_module = LitModule(
            net=net,
            optimizer=optimizer,
            scheduler=scheduler,
            aux_weight=1,
            margin=0.2,
        )

        output_dir = config.OUTPUT_DIR / name / fold_name
        trainer = L.Trainer(
            # accelerator="cpu",
            max_epochs=MAX_EPOCHS,
            callbacks=[
                L.pytorch.callbacks.ModelCheckpoint(
                    monitor="val/score",
                    dirpath=output_dir.as_posix(),
                    save_top_k=1,
                    mode="max",
                    enable_version_counter=False,
                    auto_insert_metric_name=False,
                    filename="model",
                    save_weights_only=True,
                ),
                L.pytorch.callbacks.LearningRateMonitor(logging_interval="epoch"),
                L.pytorch.callbacks.TQDMProgressBar(),
                # L.pytorch.callbacks.EarlyStopping(monitor="val/score", patience=10, mode="max"),
            ],
            logger=L.pytorch.loggers.CSVLogger(output_dir.as_posix(), name="logs", version="latest"),
        )

        # train
        trainer.fit(lit_module, tr_dataloader, va_dataloader)

        # save only best weight for inference
        best_model = LitModule.load_from_checkpoint(output_dir / "model.ckpt").model
        torch.save(best_model.state_dict(), output_dir / "model.ckpt")

        # validation
        val_preds = inference_fn(dataloader=va_dataloader, model=best_model, device="cuda")
        i_va_result_df = (
            i_features_df.filter(pl.col(config.FOLD_COL) == i_fold)
            .select(config.META_COLS)
            .with_columns(pl.Series("pred", val_preds))
        )
        i_score = Metric()(input_df=i_va_result_df)
        print(f"✅ {name} - {fold_name} - score: {i_score}")
        _scores[fold_name] = i_score
        _va_result_df = pl.concat([_va_result_df, i_va_result_df], how="diagonal_relaxed")

    # save scores
    with open(config.OUTPUT_DIR / name / "va_scores.json", "w") as f:
        json.dump(_scores, f, indent=4)

    va_scores[name] = _scores
    va_result_df = pl.concat([va_result_df, _va_result_df], how="diagonal_relaxed")


# ------------------------------
# final score
# ------------------------------
va_result_agg_df = (
    va_result_df.group_by(config.ID_COL)
    .agg(pl.col("pred").mean())
    .sort("ID")
    .join(train_test_df.select(config.META_COLS), on=config.ID_COL, how="left")
)
final_score = Metric()(input_df=va_result_agg_df)
logger.info(f"✅ final score: {final_score}")
va_scores["final"] = final_score

# save
va_result_agg_df.write_csv(f"{config.OUTPUT_DIR}/va_result.csv")
with open(f"{config.OUTPUT_DIR}/va_scores.json", "w") as f:
    json.dump(va_scores, f, indent=4)