In [1]:
import warnings
warnings.filterwarnings("ignore")

import os
import sys

dir2 = os.path.abspath('')
dir1 = os.path.dirname(dir2)
if not dir1 in sys.path:
    sys.path.append(dir1)

os.chdir('..')

%load_ext autoreload
%autoreload

In [2]:
from pathlib import Path

import pandas as pd
import numpy as np

import torch

from hydra import initialize, compose
from hydra.utils import instantiate

from ptls.preprocessing import PandasDataPreprocessor
from ptls.frames import PtlsDataModule

from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.loggers import TensorBoardLogger

from sklearn.model_selection import train_test_split

from src.coles import CustomColesDataset, CustomColesValidationDataset, CustomCoLES
from src.local_validation import LocalValidationModel

In [3]:
from src.pooling import PoolingModel

# Example of usage with churn dataset


In [4]:
DATASET = "churn"

with initialize(config_path="../config", version_base=None):
    cfg = compose(config_name="config_" + DATASET)
    
cfg_preprop = cfg["dataset"]
cfg_model = cfg["model"]

In [5]:
df = pd.read_parquet(Path(cfg["dataset"]["dir_path"]).joinpath(cfg["dataset"]["train_file_name"]))
df.head()

Unnamed: 0,user_id,mcc_code,timestamp,amount,global_target,holiday_target,weekend_target,churn_target
0,0,5200,2017-10-21 00:00:00,5023.0,0,0,1,0
1,0,6011,2017-10-12 12:24:07,20000.0,0,0,0,0
2,0,5921,2017-12-05 00:00:00,767.0,0,0,0,1
3,0,5411,2017-10-21 00:00:00,2031.0,0,0,1,0
4,0,6012,2017-10-24 13:14:24,36562.0,0,0,0,0


In [6]:
local_target = cfg_model["validation_dataset"]["local_target_col"]

preprocessor = PandasDataPreprocessor(
    col_id="user_id",
    col_event_time="timestamp",
    event_time_transformation="dt_to_timestamp", # no time preprocessing
    cols_category=["mcc_code"],
    cols_numerical=["amount", local_target], # keep column with fake local targets
    return_records=True
)

dataset = preprocessor.fit_transform(df)

train, val_test = train_test_split(dataset, test_size=.2, random_state=42)
val, test = train_test_split(val_test, test_size=.5, random_state=42)

In [7]:
# initialize original CoLES datasest - for CoLES training
train_data: CustomColesDataset = instantiate(cfg_model["dataset"], data=train)
val_data: CustomColesDataset = instantiate(cfg_model["dataset"], data=val)
    
train_datamodule: PtlsDataModule = instantiate(
    cfg_model["datamodule"],
    train_data=train_data,
    valid_data=val_data
)

In [8]:
model_churn: CustomCoLES = instantiate(cfg_model["model"])

In [9]:
# model_checkpoint: ModelCheckpoint = instantiate(
#     cfg_model["trainer_coles"]["checkpoint_callback"],
#     monitor=model_churn.metric_name,
#     mode="max"
# )
    
# early_stopping: EarlyStopping = instantiate(
#     cfg_model["trainer_coles"]["early_stopping"],
#     monitor=model_churn.metric_name,
#     mode="max"
# )
    
# logger: TensorBoardLogger = instantiate(cfg_model["trainer_coles"]["logger"])
    
# trainer: Trainer = instantiate(
#     cfg_model["trainer_coles"]["trainer"],
#     callbacks=[model_checkpoint, early_stopping],
#     logger=logger
# )
    
# trainer.fit(model_churn, train_datamodule)

In [10]:
# torch.save(model_churn.state_dict(), "saved_models/coles_churn_default.pth")

In [11]:
model_churn.load_state_dict(torch.load("saved_models/coles_churn_default.pth"))

<All keys matched successfully>

In [12]:
# initialize custom datasets and datamodule for local validation
# use the same 'train' and 'val' preprocessed data
train_data_local: CustomColesValidationDataset = instantiate(cfg_model["validation_dataset"], data=train)
val_data_local: CustomColesValidationDataset = instantiate(cfg_model["validation_dataset"], data=val)
test_data_local: CustomColesValidationDataset = instantiate(cfg_model["validation_dataset"], data=test)

# keep batch_size = 1 (all slices of one user in one batch)
# or may use batch_size > 1 to speed-up LocalValidationModel training
val_datamodule: PtlsDataModule = instantiate(
    cfg_model["datamodule"],
    train_data=train_data_local,
    valid_data=val_data_local,
    test_data=test_data_local,
    train_batch_size=1, # ! for pooling_model
    valid_batch_size=1,
    test_batch_size=1
)

In [13]:
valid_batch, local_labels = next(iter(val_datamodule.val_dataloader()))
valid_batch.payload['event_time'].shape

torch.Size([25, 40])

In [14]:
emb_dim = model_churn(valid_batch).shape[-1]
HIDDEN_SIZE = 32

pooling_model = PoolingModel(val_datamodule.train_dataloader(), 
                             model_churn, 
                             backbone_embd_size = emb_dim,
                             agregating_model = None,
                             pooling_type = "mean",
                             hidden_size = HIDDEN_SIZE,
                             )

In [15]:
backbone_out = pooling_model.backbone(valid_batch)
print("Pooling COLES embeddings:", backbone_out.shape)

pred_out = pooling_model(valid_batch)
print("Predicted labels:", pred_out.shape)

print("True local labels:", local_labels.shape)

Pooling COLES embeddings: torch.Size([25, 1024])
Predicted labels: torch.Size([25])
True local labels: torch.Size([25])


In [16]:
pooling_model

PoolingModel(
  (backbone): CustomCoLES(
    (_loss): ContrastiveLoss()
    (_seq_encoder): RnnSeqEncoder(
      (trx_encoder): TrxEncoder(
        (embeddings): ModuleDict(
          (mcc_code): NoisyEmbedding(
            344, 32, padding_idx=0
            (dropout): Dropout(p=0, inplace=False)
          )
        )
        (numeric_values): ModuleDict(
          (event_time): LogScaler()
        )
        (numerical_batch_norm): RBatchNorm(
          (bn): BatchNorm1d(1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (seq_encoder): RnnEncoder(
        (rnn): LSTM(33, 1024, batch_first=True)
        (reducer): LastStepEncoder()
      )
    )
    (_validation_metric): BatchRecallTopK()
    (_head): Head(
      (model): Sequential(
        (0): L2NormEncoder()
      )
    )
    (sequence_encoder_model): RnnSeqEncoder(
      (trx_encoder): TrxEncoder(
        (embeddings): ModuleDict(
          (mcc_code): NoisyEmbedding(
            344, 32, pad

In [17]:
val_trainer = Trainer(
    accelerator="gpu",
    devices=1,
    max_epochs=5,
)
    
val_trainer.fit(pooling_model, val_datamodule)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type        | Params
------------------------------------------
0 | backbone  | CustomCoLES | 4.3 M 
1 | pred_head | Sequential  | 65.6 K
2 | loss      | BCELoss     | 0     
------------------------------------------
65.6 K    Trainable params
4.3 M     Non-trainable params
4.4 M     Total params
17.661    Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

In [18]:
torch.save(pooling_model.state_dict(), "saved_models/validation_coles_pooling_mean_churn.pth")

In [19]:
pooling_model.load_state_dict(torch.load("saved_models/validation_coles_pooling_mean_churn.pth"))

<All keys matched successfully>

In [20]:
val_trainer.test(pooling_model, val_datamodule)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: 0it [00:00, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
          AUROC             0.6034145355224609
        Accuracy            0.5520769357681274
         F1Score            0.5976631045341492
         PR-AUC             0.6007697582244873
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'AUROC': 0.6034145355224609,
  'PR-AUC': 0.6007697582244873,
  'Accuracy': 0.5520769357681274,
  'F1Score': 0.5976631045341492}]