In [None]:
import torch
from ptls.data_load.padded_batch import PaddedBatch
from ptls.nn.trx_encoder.scalers import scaler_by_name


from nn.trx_encoder.trx_encoder_with_client_item_embeddings import TrxEncoder_WithCIEmbeddings
from nn.trx_encoder.client_item_encoder import DummyGNNClientItemEncoder

dummy_gnn_encoder = DummyGNNClientItemEncoder(output_size=10)

B, T = 5, 20
trx_encoder = TrxEncoder_WithCIEmbeddings(
    embeddings={'mcc_code': {'in': 100, 'out': 5}},
    numeric_values={'amount': 'log'},
    col_item_ids='mcc_code',
    client_item_embeddings=[dummy_gnn_encoder],
    use_batch_norm=False
)

feats_pb = PaddedBatch(
    payload={
        'mcc_code': torch.randint(0, 99, (B, T)),
        'amount': torch.randn(B, T),
    },
    length=torch.randint(10, 20, (B,)),
)

client_ids = torch.arange(0, B, dtype=torch.int64)

z = trx_encoder((feats_pb, client_ids))

assert z.payload.shape == (5, 20, 5+1+10)  # B, T, H



assert torch.allclose(
    z.payload[:,:,5],
    scaler_by_name('log')(feats_pb.payload['amount']).squeeze(-1)
)