In [1]:
import warnings
warnings.filterwarnings("ignore")

import os
import sys

dir2 = os.path.abspath('')
dir1 = os.path.dirname(dir2)
if not dir1 in sys.path:
    sys.path.append(dir1)

os.chdir('..')

%load_ext autoreload
%autoreload

In [2]:
from pathlib import Path

import pandas as pd
import numpy as np

import torch

from hydra import initialize, compose
from hydra.utils import instantiate

from ptls.preprocessing import PandasDataPreprocessor
from ptls.frames import PtlsDataModule

from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.loggers import TensorBoardLogger

from sklearn.model_selection import train_test_split

from src.coles import CustomColesDataset, CustomColesValidationDataset, CustomCoLES
from src.local_validation import LocalValidationModel

In [3]:
from src.pooling import PoolingModel

# Example of usage with churn dataset


In [4]:
DATASET = "age"

with initialize(config_path="../config", version_base=None):
    cfg = compose(config_name="config_" + DATASET)
    
cfg_preprop = cfg["dataset"]
cfg_model = cfg["model"]

In [5]:
cfg["dataset"]["dir_path"], cfg["dataset"]["train_file_name"]

('data/preprocessed_new', 'age.parquet')

In [23]:
df = pd.read_parquet(Path(cfg["dataset"]["dir_path"]).joinpath("raif_with_users_categories_10k_users.parquet"))
df.head()

Unnamed: 0,timestamp,amount,mcc_code,user_id,holiday_target,weekend_target,global_target,categorycode,gender,age,married_,residenttype
43850,2019-11-18,1081005.0,187,3955,0,0,0,50,F,33,not_married,R
43851,2019-11-20,374827.0,187,3955,0,0,0,50,F,33,not_married,R
43852,2019-11-18,250000.0,187,3955,0,0,0,50,F,33,not_married,R
43853,2019-10-01,200000.0,281,3955,0,0,0,50,F,33,not_married,R
43854,2019-12-18,164100.0,231,3955,0,0,0,50,F,33,not_married,R


In [24]:
preprocessor = PandasDataPreprocessor(
    col_id="user_id",
    col_event_time="timestamp",
    event_time_transformation="dt_to_timestamp",
    cols_category=["mcc_code", "married_", "gender", "categorycode", "residenttype", "age"],
    category_transformation="frequency",
    cols_numerical=["amount"],  # keep column with fake local targets
    cols_first_item=["global_target", "married_", "gender", "categorycode", "residenttype", "age"],
    return_records=True,
)

dataset = preprocessor.fit_transform(df)

train, val = train_test_split(dataset, test_size=.2)

In [25]:
model: CustomCoLES = instantiate(cfg_model["model"])
model.seq_encoder.is_reduce_sequence = True
    
#model.seq_encoder.load_state_dict(torch.load("saved_models/churn/coles/coles_best_state_dict_100_v2.pth"))
model.load_state_dict(torch.load("saved_models/coles_raif_default.pth"))

<All keys matched successfully>

In [28]:
ax = []
for x in train:
    ax.append(x["gender"])

np.unique(ax)

array([1, 2])

In [19]:
from ptls.frames.coles.split_strategy import NoSplit
from ptls.frames.coles import ColesDataset

from ptls.data_load.datasets import MemoryMapDataset
from ptls.data_load.iterable_processing import SeqLenFilter

min_len = 20
col_time = "event_time"

data_train = MemoryMapDataset(train, [SeqLenFilter(min_len)])
data_val = MemoryMapDataset(val, [SeqLenFilter(min_len)])

train_data_no_split = ColesDataset(
    data_train,
    NoSplit(),
    col_time
)

val_data_no_split = ColesDataset(
    data_val,
    NoSplit(),
    col_time
)

datamodule_no_split = PtlsDataModule(
    train_data=train_data_no_split,
    valid_data=val_data_no_split,
    test_data=val_data_no_split,
    train_batch_size=4,
    valid_batch_size=4, 
    test_batch_size=4
)

In [20]:
x, y = next(iter(datamodule_no_split.train_dataloader()))

In [21]:
x.payload.keys()

dict_keys(['holiday_target', 'weekend_target', 'event_time', 'mcc_code', 'amount'])

In [18]:
y

tensor([0])

In [72]:
for (x, y) in datamodule_no_split.train_dataloader():
    print(x.payload["gender"][0][0])


tensor(1)
tensor(1)
tensor(2)
tensor(2)
tensor(1)
tensor(2)
tensor(2)
tensor(2)
tensor(2)
tensor(1)
tensor(1)
tensor(1)
tensor(1)
tensor(1)
tensor(1)
tensor(2)
tensor(1)
tensor(1)
tensor(1)
tensor(2)
tensor(1)
tensor(1)
tensor(2)
tensor(2)
tensor(2)
tensor(1)
tensor(1)
tensor(2)
tensor(2)
tensor(1)
tensor(1)
tensor(2)
tensor(2)
tensor(1)
tensor(2)
tensor(2)
tensor(1)
tensor(1)
tensor(1)
tensor(2)
tensor(1)
tensor(1)
tensor(1)
tensor(2)
tensor(1)
tensor(1)
tensor(2)
tensor(1)
tensor(2)
tensor(1)
tensor(1)
tensor(2)
tensor(2)
tensor(1)
tensor(2)
tensor(2)
tensor(2)
tensor(1)
tensor(2)
tensor(2)
tensor(1)
tensor(2)
tensor(1)
tensor(1)
tensor(2)
tensor(2)
tensor(2)
tensor(1)
tensor(2)
tensor(2)
tensor(1)
tensor(2)
tensor(1)
tensor(1)
tensor(1)
tensor(1)
tensor(1)
tensor(1)
tensor(2)
tensor(1)
tensor(1)
tensor(2)
tensor(1)
tensor(1)
tensor(1)
tensor(1)
tensor(2)
tensor(2)
tensor(1)
tensor(2)
tensor(1)
tensor(2)
tensor(2)
tensor(2)
tensor(2)
tensor(1)
tensor(1)
tensor(2)
tensor(1)
tensor(2)


In [64]:
x, y = next(iter(train_datamodule.train_dataloader()))

In [65]:
x.payload["gender"].shape

torch.Size([1280, 60])

In [15]:
model_churn: CustomCoLES = instantiate(cfg_model["model"])

In [16]:
model_checkpoint: ModelCheckpoint = instantiate(
    cfg_model["trainer_coles"]["checkpoint_callback"],
    monitor=model_churn.metric_name,
    mode="max"
)
    
early_stopping: EarlyStopping = instantiate(
    cfg_model["trainer_coles"]["early_stopping"],
    monitor=model_churn.metric_name,
    mode="max"
)
    
logger: TensorBoardLogger = instantiate(cfg_model["trainer_coles"]["logger"])
    
trainer: Trainer = instantiate(
    cfg_model["trainer_coles"]["trainer"],
    callbacks=[model_checkpoint, early_stopping],
    logger=logger
)
    
trainer.fit(model_churn, train_datamodule)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name               | Type            | Params
-------------------------------------------------------
0 | _loss              | ContrastiveLoss | 0     
1 | _seq_encoder       | RnnSeqEncoder   | 12.9 K
2 | _validation_metric | BatchRecallTopK | 0     
3 | _head              | Head            | 0     
-------------------------------------------------------
12.9 K    Trainable params
0         Non-trainable params
12.9 K    Total params
0.052     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Metric recall_top_k improved. New best score: 0.164


Validation: 0it [00:00, ?it/s]

Metric recall_top_k improved by 0.058 >= min_delta = 0.01. New best score: 0.222


Validation: 0it [00:00, ?it/s]

Metric recall_top_k improved by 0.041 >= min_delta = 0.01. New best score: 0.262


Validation: 0it [00:00, ?it/s]

Metric recall_top_k improved by 0.049 >= min_delta = 0.01. New best score: 0.311


Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Metric recall_top_k improved by 0.053 >= min_delta = 0.01. New best score: 0.365


Validation: 0it [00:00, ?it/s]

Metric recall_top_k improved by 0.035 >= min_delta = 0.01. New best score: 0.400


Validation: 0it [00:00, ?it/s]

Metric recall_top_k improved by 0.023 >= min_delta = 0.01. New best score: 0.422


Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Metric recall_top_k improved by 0.021 >= min_delta = 0.01. New best score: 0.443


Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Metric recall_top_k improved by 0.011 >= min_delta = 0.01. New best score: 0.454


Validation: 0it [00:00, ?it/s]

Metric recall_top_k improved by 0.020 >= min_delta = 0.01. New best score: 0.474


Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Monitored metric recall_top_k did not improve in the last 5 records. Best score: 0.474. Signaling Trainer to stop.


In [17]:
torch.save(model_churn.state_dict(), "saved_models/coles_raif_default.pth")

In [18]:
model_churn.load_state_dict(torch.load("saved_models/coles_raif_default.pth"))

<All keys matched successfully>

In [23]:
cfg_model["validation_dataset"]

{'_target_': 'src.coles.CustomColesValidationDataset', 'min_len': 20, 'seq_len': 40, 'stride': 10, 'local_target_col': 'fake_local_label', 'col_time': 'event_time'}

In [24]:
# initialize custom datasets and datamodule for local validation
# use the same 'train' and 'val' preprocessed data
train_data_local: CustomColesValidationDataset = instantiate(cfg_model["validation_dataset"], data=train, local_target_col = "gender")
val_data_local: CustomColesValidationDataset = instantiate(cfg_model["validation_dataset"], data=val, local_target_col = "gender")
test_data_local: CustomColesValidationDataset = instantiate(cfg_model["validation_dataset"], data=test, local_target_col = "gender")

# keep batch_size = 1 (all slices of one user in one batch)
# or may use batch_size > 1 to speed-up LocalValidationModel training
val_datamodule: PtlsDataModule = instantiate(
    cfg_model["datamodule"],
    train_data=train_data_local,
    valid_data=val_data_local,
    test_data=test_data_local,
    train_batch_size=1, # ! for pooling_model
    valid_batch_size=1,
    test_batch_size=1
)

In [25]:
valid_batch, local_labels = next(iter(val_datamodule.val_dataloader()))
valid_batch.payload['event_time'].shape

torch.Size([16, 40])

In [41]:
valid_batch.payload["gender"]

tensor([[2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
         2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2],
        [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
         2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2],
        [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
         2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2],
        [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
         2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2],
        [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
         2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2],
        [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
         2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2],
        [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
         2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2],
        [2, 2, 2, 2, 2, 2, 

In [None]:
from typing import Optional, Tuple, Dict

import torch
import torch.nn as nn
import pytorch_lightning as pl

from torchmetrics.regression import MeanAbsoluteError, MeanSquaredError
from torchmetrics.classification import (
    BinaryF1Score,
    BinaryAUROC,
    BinaryAveragePrecision,
    BinaryAccuracy,
    Accuracy,
    AUROC,
    F1Score,
    AveragePrecision,
)

from ptls.data_load.padded_batch import PaddedBatch


class LocalValidationModel(pl.LightningModule):
    """
    PytorchLightningModule for local validation of backbone (e.g. CoLES) model of transactions representations.
    """

    def __init__(
        self,
        backbone: nn.Module,
        backbone_embd_size: int,
        hidden_size: int,
        val_mode: str,
        num_types: Optional[int] = None,
        learning_rate: float = 1e-3,
        backbone_output_type: str = "tensor",
        backbone_embd_mode: str = "seq2vec",
        seq_len: Optional[int] = None,
        mask_col: str = "mcc_code",
        local_label_col: Optional[str] = None,
        mcc_padd_value: int = 0,
    ) -> None:
        """Initialize LocalValidationModel with pretrained backbone model and 2-layer linear prediction head.

        Args:
            backbone (nn.Module) - backbone model for transactions representations
            backbone_embd_size (int) - size of embeddings produced by backbone model
            hidden_size (int) - hidden size for 2-layer linear prediction head
            val_mode (str) - local validation mode (options: 'donwnstream', 'return_time' and 'event_time')
            num_types (int) - number of possible event types (MCC-codes) for 'event_time' validation mode
            learning_rate (float) - learning rate for prediction head training
            backbone_output_type (str) - type of output of the backbone model
                                         (e.g. torch.Tensor -- for CoLES, PaddedBatch for BestClassifier)
            backbone_embd_mode (str) - type of backbone embeddings:
                                       'seq2vec', if backbone transforms (bs, seq_len) -> (bs, embd_dim),
                                       'seq2seq', if backbone transforms (bs, seq_len) -> (bs, seq_len, embd_dim),
            seq_len (int) - size of the sliding widwon
            mask_col (str) - name of columns containing zero-padded values for mask creation
            local_label_col (str) - name of the columns containing local targets for 'downstream' validation mode
            mcc_padd_value (int) - MCC-code corresponding to padding
        """
        super().__init__()

        assert backbone_output_type in [
            "tensor",
            "padded_batch",
        ], f"Unknown output type of the backbone model {backbone_output_type}"

        assert backbone_embd_mode in [
            "seq2vec",
            "seq2seq",
        ], f"Unknown backbone embeddings mode {backbone_embd_mode}"

        if backbone_embd_mode == "seq2vec":
            assert (
                seq_len is not None
            ), "Specify subsequence length for sampling sliding windows"

            self.seq_len = seq_len

        assert val_mode in [
            "downstream",
            "return_time",
            "event_type",
        ], f"Unknown validation mode {val_mode}"

        self.backbone = backbone

        # freeze backbone model
        for param in self.backbone.parameters():
            param.requires_grad = False

        if val_mode == "downstream":
            assert (
                local_label_col is not None
            ), "Specify local_label_col for downstream validation"

            self.pred_head = nn.Sequential(
                nn.Linear(backbone_embd_size, hidden_size),
                nn.ReLU(),
                nn.Linear(hidden_size, 1),
                nn.Sigmoid(),
            )
            # BCE loss for seq2seq binary classification
            self.loss = nn.BCELoss()

            # metrics for binary classification
            self.metrics = {
                "AUROC": BinaryAUROC(),
                "PR-AUC": BinaryAveragePrecision(),
                "Accuracy": BinaryAccuracy(),
                "F1Score": BinaryF1Score(),
            }
            self.local_label_col = local_label_col

        elif val_mode == "return_time":
            self.pred_head = nn.Sequential(
                nn.Linear(backbone_embd_size, hidden_size),
                nn.ReLU(),
                nn.Linear(hidden_size, 1),
            )
            # Custom LogCosh loss for return-time prediction
            self.loss = LogCoshLoss("mean")

            # regression metrics
            self.metrics = {"MAE": MeanAbsoluteError(), "MSE": MeanSquaredError()}
        else:
            assert (
                num_types is not None
            ), "Specify number of event types for next-event-type prediction"

            self.pred_head = nn.Sequential(
                nn.Linear(backbone_embd_size, hidden_size),
                nn.ReLU(),
                nn.Linear(hidden_size, num_types),
            )
            # CrossEntropyLoss for next-event-type prediction
            self.loss = torch.nn.CrossEntropyLoss(
                ignore_index=mcc_padd_value, reduction="mean"
            )

            self.metrics = {
                "AUROC": AUROC(
                    task="multiclass",
                    num_classes=num_types,
                    ignore_index=mcc_padd_value,
                ),
                "PR-AUC": AveragePrecision(
                    task="multiclass",
                    num_classes=num_types,
                    ignore_index=mcc_padd_value,
                ),
                "Accuracy": Accuracy(
                    task="multiclass",
                    num_classes=num_types,
                    ignore_index=mcc_padd_value,
                ),
                "F1Score": F1Score(
                    task="multiclass",
                    num_classes=num_types,
                    ignore_index=mcc_padd_value,
                ),
            }

            self.num_types = num_types

        self.lr = learning_rate

        self.backbone_output_type = backbone_output_type
        self.backbone_embd_mode = backbone_embd_mode
        self.val_mode = val_mode

        self.mask_col = mask_col
        self.mcc_padd_value = mcc_padd_value

    def _get_validation_labels(self, padded_batch: PaddedBatch) -> torch.Tensor:
        """Extract necessary target for local validation from the batch of data.

        Args:
            padded_batch (PaddedBatch) - container with zero-padded data (no sampling), shape of any feature = (batch_size, max_len)

        Returns:
            torch.Tensor containing targets
        """
        if self.val_mode == "downstream":
            # take column with prepared local targets (e.g. 'churn_target' for Churn local validation)
            target = padded_batch.payload[self.local_label_col]

        elif self.val_mode == "return_time":
            # extract event times for return time (next transaction time) prediction
            target = padded_batch.payload["event_time"]
        else:
            # extract event MCC-codes for next transaction types prediction
            target = padded_batch.payload["mcc_code"]

            # if MCC code > self.num_types than merge all these unpopular codes into 1 category
            target = torch.where(
                target >= self.num_types - 1, self.num_types - 1, target
            )

        if self.backbone_embd_mode == "seq2vec":
            # crop targets, delete first seq_len transactions as there are not history windows for them
            target = target[:, self.seq_len :]
        return target

    def _return_time_target_and_preds(
        self, preds: torch.Tensor, target: torch.Tensor, mask: torch.Tensor
    ) -> Tuple[torch.Tensor]:
        """Prepare targets for 'return_time' validation.

        Args:
            preds (torch.Tensor) - raw predictions (output of the pred_head)
            target (torch.Tensor) - raw targets (from the dataset with no sampling)
            mask (torch.Tensor) - raw mask indicating non-padding items

        Returns a tuple of:
            * preds (torch.Tensor) - modified predictions
            * target (torch.Tensor) - modified targets
            * mask (torch.Tensor) - modified mask
        """
        # get time differencies, do not take the last prediction, crop the mask
        target = target[:, 1:] - target[:, :-1]
        preds = preds.squeeze(-1)[:, :-1]
        mask = mask[:, 1:]
        return preds, target, mask

    def _event_type_target_and_preds(
        self, preds: torch.Tensor, target: torch.Tensor
    ) -> Tuple[torch.Tensor]:
        """Prepare targets for 'event_type' validation.

        Args:
            preds (torch.Tensor) - raw predictions (output of the pred_head)
            target (torch.Tensor) - raw targets (from the dataset with no sampling)

        Returns a tuple of:
            * preds (torch.Tensor) - modified predictions
            * target (torch.Tensor) - modified targets
        """
        # crop predictions and target
        target = target[:, 1:]
        # preds: (batch_size, max_len, num_types) -> (batch_size, num_types, max_len) for loss and metrics
        preds = preds[:, :-1, :].transpose(1, 2)
        return preds, target

    def forward(self, inputs: PaddedBatch) -> Tuple[torch.Tensor]:
        """Do forward pass through the local validation model.

        Args:
            inputs (PaddedBatch) - inputs if ptls format (no sampling)

        Returns a tuple of:
            * torch.Tensor of predicted targets
            * torch.Tensor with mask corresponding to non-padded times
        """
        bs = inputs.payload["event_time"].shape[0]

        if self.backbone_embd_mode == "seq2vec":
            collated_batch = SlidingWindowSampler(inputs, seq_len=self.seq_len)

            out = self.backbone(collated_batch)

            embd_size = out.shape[-1]
            out = out.reshape(bs, -1, embd_size)

            # shape of mask is (batch_size, max_seq_len - seq_len), zeros correspond to windows with padding
            mask = (
                collated_batch.payload[self.mask_col]
                .reshape(bs, -1, self.seq_len)
                .ne(0)
                .all(dim=2)
            )
        else:
            # shape is (batch_size, max_seq_len, embd_dim)
            out = self.backbone(inputs)
            # shape is (batch_size, max_seq_len)
            mask = inputs.payload[self.mask_col].ne(0)

        # in case of baseline models
        if self.backbone_output_type == "padded_batch":
            out = out.payload

        preds = self.pred_head(out)

        return preds, mask

    def shared_step(
        self, batch: Tuple[PaddedBatch, torch.Tensor], batch_idx: int
    ) -> Tuple[torch.Tensor]:
        """Shared step for training, valiation and testing.

        Args:
            batch (PaddedBatch) - inputs if ptls format (no sampling)

        Returns a tuple of:
            * preds (torch.Tensor) - model predictions
            * target (torch.Tensor) - true target values
            * mask (torch.Tensor) - binary mask indication non-padding transactions
        """
        inputs, _ = batch
        target = self._get_validation_labels(inputs)
        preds, mask = self.forward(inputs)

        return preds, target, mask

    def training_step(
        self, batch: Tuple[PaddedBatch, torch.Tensor], batch_idx: int
    ) -> Dict[str, float]:
        """Training step of the LocalValidationModel."""
        preds, target, mask = self.shared_step(batch, batch_idx)

        if self.val_mode == "downstream":
            train_loss = self.loss(preds[mask].squeeze(), target[mask].float())
            metric = ((preds[mask] > 0.5).long() == target[mask]).float().mean()
            metric_name = "acc"

        elif self.val_mode == "return_time":
            preds, target, mask = self._return_time_target_and_preds(
                preds, target, mask
            )

            train_loss = self.loss(target[mask], preds[mask])

            metric_name = "mae"
            metric = MeanAbsoluteError().to(preds.device)(
                preds[mask].squeeze(), target[mask]
            )
        else:
            preds, target = self._event_type_target_and_preds(preds, target)

            train_loss = self.loss(preds, target).mean()

            metric_name = "acc"
            metric = Accuracy(
                task="multiclass",
                num_classes=self.num_types,
                ignore_index=self.mcc_padd_value,
            ).to(preds.device)(preds, target)

        self.log("train_loss", train_loss, prog_bar=True)
        self.log("train_" + metric_name, metric, prog_bar=True)

        return {"loss": train_loss, metric_name: metric}

    def validation_step(
        self, batch: Tuple[PaddedBatch, torch.Tensor], batch_idx: int
    ) -> Dict[str, float]:
        """Validation step of the LocalValidationModel."""
        preds, target, mask = self.shared_step(batch, batch_idx)

        if self.val_mode == "downstream":
            val_loss = self.loss(preds[mask].squeeze(), target[mask].float())
            metric = ((preds[mask] > 0.5).long() == target[mask]).float().mean()
            metric_name = "acc"

        elif self.val_mode == "return_time":
            preds, target, mask = self._return_time_target_and_preds(
                preds, target, mask
            )

            val_loss = self.loss(target[mask], preds[mask])

            metric_name = "mae"
            metric = MeanAbsoluteError().to(preds.device)(
                preds[mask].squeeze(), target[mask]
            )
        else:
            preds, target = self._event_type_target_and_preds(preds, target)

            val_loss = self.loss(preds, target).mean()

            metric_name = "acc"
            metric = Accuracy(
                task="multiclass",
                num_classes=self.num_types,
                ignore_index=self.mcc_padd_value,
            ).to(preds.device)(preds, target)

        self.log("val_loss", val_loss, prog_bar=True)
        self.log("val_" + metric_name, metric, prog_bar=True)

        return {"loss": val_loss, metric_name: metric}

    def test_step(
        self, batch: Tuple[PaddedBatch, torch.Tensor], batch_idx: int
    ) -> Dict[str, float]:
        """Test step of the LocalValidationModel."""
        preds, target, mask = self.shared_step(batch, batch_idx)

        if self.val_mode == "downstream":
            preds = preds[mask].squeeze()
            target = target[mask]
        elif self.val_mode == "return_time":
            preds, target, mask = self._return_time_target_and_preds(
                preds, target, mask
            )
            preds = preds[mask]
            target = target[mask]
        else:
            preds, target = self._event_type_target_and_preds(preds, target)

        dict_out = {"preds": preds, "labels": target}

        for name, metric in self.metrics.items():
            metric.to(preds.device)
            metric.update(preds, target)
            dict_out[name] = metric.compute().item()

        return dict_out

    def on_test_epoch_end(self) -> Dict[str, float]:
        """Collect test_step outputs and compute test metrics for the whole test dataset."""
        results = {}
        for name, metric in self.metrics.items():
            results[name] = metric.compute()
            self.log(name, metric.compute())
        return results

    def configure_optimizers(self) -> torch.optim.Optimizer:
        """Initialize optimizer for the LocalValidationModel."""
        opt = torch.optim.Adam(self.pred_head.parameters(), lr=self.lr)
        return opt

In [None]:
class UsersValidationModel(pl.LightningModule):
    """
    PytorchLightningModule for local validation of backbone (e.g. CoLES) model of transactions representations.
    """

    def __init__(
        self,
        backbone: nn.Module,
        backbone_embd_size: int,
        hidden_size: int,
        learning_rate: float = 1e-3,
        backbone_output_type: str = "tensor",
        target_col: str = "gender"
    ) -> None:
        """Initialize LocalValidationModel with pretrained backbone model and 2-layer linear prediction head.

        Args:
            backbone (nn.Module) - backbone model for transactions representations
            backbone_embd_size (int) - size of embeddings produced by backbone model
            hidden_size (int) - hidden size for 2-layer linear prediction head
            learning_rate (float) - learning rate for prediction head training
            backbone_output_type (str) - type of output of the backbone model
                                         (e.g. torch.Tensor -- for CoLES, PaddedBatch for BestClassifier)
        """
        super().__init__()

        assert backbone_output_type in [
            "tensor",
            "padded_batch",
        ], "Unknown output type of the backbone model"

        assert target_col in [
            "gender",
            "married_", 
            "categorycode", 
            "residenttype",
            "age", 
        ], "Unknown target_col"



        self.backbone = backbone

        # freeze backbone model
        for param in self.backbone.parameters():
            param.requires_grad = False

        if
        self.pred_head = nn.Sequential(
            nn.Linear(
                lstm_hidden_size * (lstm_bidirectional + 1) + backbone_embd_size,
                hidden_size,
            ),
            nn.ReLU(),
            nn.Linear(hidden_size, 1),
            nn.Sigmoid(),
        )

        # BCE loss for seq2seq binary classification
        self.loss = nn.BCELoss()

        self.lr = learning_rate

        self.metrics = {
            "AUROC": BinaryAUROC(),
            "PR-AUC": BinaryAveragePrecision(),
            "Accuracy": BinaryAccuracy(),
            "F1Score": BinaryF1Score(),
        }

        self.backbone_output_type = backbone_output_type

    def forward(self, inputs: PaddedBatch) -> torch.Tensor:
        """Do forward pass through the global validation model.

        Args:
            inputs (PaddedBatch) - inputs if ptls format

        Returns:
            torch.Tensor of predicted local targets
        """
        out = self.backbone(inputs)
        if self.backbone_output_type == "padded_batch":
            out = out.payload
        if self.use_lstm:
            lstm_out = self.lstm(out)[0]
            out = torch.cat((out, lstm_out), dim=1)
        out = self.pred_head(out).squeeze(-1)
        return out

    def training_step(
        self, batch: tuple[PaddedBatch, torch.Tensor], batch_idx: int
    ) -> dict[str, float]:
        """Training step of the LocalValidationModel."""
        inputs, labels = batch
        preds = self.forward(inputs)

        train_loss = self.loss(preds, labels.float())
        train_accuracy = ((preds.squeeze() > 0.5).long() == labels).float().mean()

        self.log("train_loss", train_loss, prog_bar=True)
        self.log("train_acc", train_accuracy, prog_bar=True)

        return {"loss": train_loss, "acc": train_accuracy}

    def validation_step(
        self, batch: tuple[PaddedBatch, torch.Tensor], batch_idx: int
    ) -> dict[str, float]:
        """Validation step of the LocalValidationModel."""
        inputs, labels = batch

        preds = self.forward(inputs)

        val_loss = self.loss(preds, labels.float())
        val_accuracy = ((preds.squeeze() > 0.5).long() == labels).float().mean()

        self.log("val_loss", val_loss, prog_bar=True)
        self.log("val_acc", val_accuracy, prog_bar=True)

        return {"loss": val_loss, "acc": val_accuracy}

    def test_step(
        self, batch: tuple[PaddedBatch, torch.Tensor], batch_idx: int
    ) -> dict[str, float]:
        """Test step of the LocalValidationModel."""
        inputs, labels = batch
        preds = self.forward(inputs)

        dict_out = {"preds": preds, "labels": labels}
        for name, metric in self.metrics.items():
            metric.to(inputs.device)
            metric.update(preds, labels)

            dict_out[name] = metric.compute().item()

        return dict_out

    def on_test_epoch_end(self) -> dict[str, float]:
        """Collect test_step outputs and compute test metrics for the whole test dataset."""
        results = {}
        for name, metric in self.metrics.items():
            results[name] = metric.compute()
            self.log(name, metric.compute())
        return results

    def configure_optimizers(self) -> torch.optim.Optimizer:
        """Initialize optimizer for the LocalValidationModel."""
        opt = torch.optim.Adam(self.pred_head.parameters(), lr=self.lr)
        return opt


In [29]:
emb_dim = model_churn(valid_batch).shape[-1]
HIDDEN_SIZE = 32

val_model = LocalValidationModel(
                             model_churn, 
                             backbone_embd_size = emb_dim,
                             hidden_size = HIDDEN_SIZE
                             )

In [31]:
backbone_out = val_model.backbone(valid_batch)
print("Pooling COLES embeddings:", backbone_out.shape)

pred_out = val_model(valid_batch)
print("Predicted labels:", pred_out.shape)

print("True local labels:", local_labels.shape)

Pooling COLES embeddings: torch.Size([16, 32])
Predicted labels: torch.Size([16])
True local labels: torch.Size([16])


In [32]:
val_model

LocalValidationModel(
  (backbone): CustomCoLES(
    (_loss): ContrastiveLoss()
    (_seq_encoder): RnnSeqEncoder(
      (trx_encoder): TrxEncoder(
        (embeddings): ModuleDict(
          (mcc_code): NoisyEmbedding(
            202, 32, padding_idx=0
            (dropout): Dropout(p=0, inplace=False)
          )
        )
        (numeric_values): ModuleDict(
          (event_time): LogScaler()
        )
        (numerical_batch_norm): RBatchNorm(
          (bn): BatchNorm1d(1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (seq_encoder): RnnEncoder(
        (rnn): GRU(33, 32, batch_first=True)
        (reducer): LastStepEncoder()
      )
    )
    (_validation_metric): BatchRecallTopK()
    (_head): Head(
      (model): Sequential(
        (0): L2NormEncoder()
      )
    )
    (sequence_encoder_model): RnnSeqEncoder(
      (trx_encoder): TrxEncoder(
        (embeddings): ModuleDict(
          (mcc_code): NoisyEmbedding(
            202, 32

In [33]:
val_trainer = Trainer(
    accelerator="gpu",
    devices=1,
    max_epochs=5,
)
    
val_trainer.fit(val_model, val_datamodule)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type        | Params
------------------------------------------
0 | backbone  | CustomCoLES | 12.9 K
1 | pred_head | Sequential  | 1.1 K 
2 | loss      | BCELoss     | 0     
------------------------------------------
1.1 K     Trainable params
12.9 K    Non-trainable params
14.0 K    Total params
0.056     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

In [37]:
val_trainer.test(val_model, val_datamodule)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: 0it [00:00, ?it/s]

RuntimeError: Detected the following values in `target`: tensor([2], device='cuda:0') but expected only the following values [0, 1].

In [None]:
type(val_datamodule.train_dataloader())

torch.utils.data.dataloader.DataLoader