In [1]:
import os
os.environ["OMP_NUM_THREADS"] = "4"
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "7"

import pandas as pd
import numpy as np
import torch
from functools import partial
import pytorch_lightning as pl
import warnings
warnings.filterwarnings("ignore")

from torch.utils.data import DataLoader

from ptls.data_load.datasets import MemoryMapDataset
from ptls.data_load.iterable_processing.iterable_seq_len_limit import ISeqLenLimit
from ptls.data_load.iterable_processing.to_torch_tensor import ToTorch
from ptls.data_load.iterable_processing.feature_filter import FeatureFilter
from ptls.nn import TrxEncoder, RnnSeqEncoder
from ptls.frames.coles import CoLESModule
from ptls.data_load.iterable_processing import SeqLenFilter
from ptls.frames.coles import ColesIterableDataset
from ptls.frames.coles.split_strategy import SampleSlices
from ptls.frames import PtlsDataModule
from ptls.preprocessing import PandasDataPreprocessor
from ptls.data_load.utils import collate_feature_dict
from ptls.data_load.iterable_processing_dataset import IterableProcessingDataset

from tqdm.auto import tqdm
import lightgbm as ltb

from datetime import datetime

pd.set_option('display.expand_frame_repr', False)

  from .autonotebook import tqdm as notebook_tqdm


# Dialogs aggregation train

In [64]:
dial_train = pd.read_parquet('dial_train.parquet')

In [65]:
dial_train = pd.concat(
    [
        dial_train.drop(columns=['embedding']), 
        pd.DataFrame(np.vstack(dial_train.embedding.to_list()), columns=[f'emb_{i}' for i in range(768)])
    ],
    axis=1
) 

In [66]:
dial_train['event_time'] = pd.to_datetime(dial_train['event_time']).dt.strftime('%Y-%m')

In [67]:
dial_train

Unnamed: 0,client_id,event_time,mon,emb_0,emb_1,emb_2,emb_3,emb_4,emb_5,emb_6,...,emb_758,emb_759,emb_760,emb_761,emb_762,emb_763,emb_764,emb_765,emb_766,emb_767
0,b27b9c54e72728e7bbfbe96ef2f3d49c14c9c5f0900033...,2022-01,1,0.341713,-0.052266,0.550964,-0.252805,-0.274406,0.521302,0.318203,...,0.519438,0.216861,0.259062,0.985674,0.538787,0.315250,0.328254,-0.243025,0.503692,0.394347
1,bff7260208097c052cea083ddc9e961a8d23c1b1c2b268...,2022-01,1,0.251929,-0.057982,0.561294,-0.246828,-0.353653,0.488312,0.205130,...,0.503745,0.385861,0.244470,0.932922,0.552699,0.346178,0.251520,-0.324281,0.527927,0.248990
2,c977ed2889aacd9aa35420cce5652274b1a1d347648d80...,2022-01,1,0.104948,0.128751,0.343332,0.123764,-0.109072,0.313176,0.101924,...,0.013213,0.196502,0.377637,0.326098,0.308402,0.017118,0.112048,-0.150991,0.281751,-0.081843
3,d2e003fda662d4362aed928dea8bdaffaaac002e6cc435...,2022-01,1,0.341845,-0.006655,0.416555,-0.361188,-0.218251,0.558154,0.288070,...,0.498797,0.249988,0.228737,0.956855,0.402395,0.206879,0.303140,-0.369564,0.407305,0.248160
4,d887ecc28f596b1ccf4d9758c1974d2c3058041082ed6a...,2022-01,1,0.208994,-0.203628,0.508068,-0.379986,-0.397421,0.473816,0.247349,...,0.442976,0.398096,0.553015,0.931851,0.529431,0.364997,0.252355,-0.461416,0.467727,0.316865
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
206238,3b28346c9687dc7b7f293f9a232b3372fa0564cacf520b...,2022-09,9,0.391926,-0.042448,0.540796,-0.203592,-0.274254,0.543089,0.356686,...,0.515867,0.200226,0.421370,0.980604,0.530302,0.321455,0.383613,-0.245814,0.509372,0.380903
206239,3b28346c9687dc7b7f293f9a232b3372fa0564cacf520b...,2022-09,9,0.494272,-0.377852,0.480565,-0.498725,-0.243277,0.454829,0.269279,...,0.436748,0.318930,0.549085,0.959765,0.425028,0.246801,0.342152,-0.535821,0.436366,0.330771
206240,3b28346c9687dc7b7f293f9a232b3372fa0564cacf520b...,2022-09,9,0.005900,-0.169089,0.312422,-0.169415,-0.245238,0.307090,0.044754,...,0.300388,0.355059,0.236188,0.982782,0.277640,0.058361,0.346777,-0.224037,0.281865,0.358377
206241,3b28346c9687dc7b7f293f9a232b3372fa0564cacf520b...,2022-09,9,0.286338,0.039379,0.497086,-0.255222,-0.193952,0.449693,0.036264,...,0.419630,0.330606,0.557355,0.948740,0.434777,-0.140887,0.283809,-0.076651,0.405551,0.114356


In [73]:
aggregation_functions = {f'emb_{i}': 'mean' for i in range(len(dial_train.columns)-3)}
aggregated_df = dial_train.groupby(['client_id', 'event_time']).agg(aggregation_functions)

In [74]:
aggregated_df

Unnamed: 0_level_0,Unnamed: 1_level_0,emb_0,emb_1,emb_2,emb_3,emb_4,emb_5,emb_6,emb_7,emb_8,emb_9,...,emb_758,emb_759,emb_760,emb_761,emb_762,emb_763,emb_764,emb_765,emb_766,emb_767
client_id,event_time,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1
00021a8c0dd1ffba57d5d690e1f97f11e7a771d1cef5e7b1452556a257e5f030,2022-07,0.499100,-0.215977,0.670235,-0.567517,-0.537667,0.709535,0.471731,0.462270,-0.462564,0.582690,...,0.669366,0.588496,0.593230,0.980690,0.491650,0.278866,0.509793,-0.431969,0.494451,0.542458
00021a8c0dd1ffba57d5d690e1f97f11e7a771d1cef5e7b1452556a257e5f030,2022-08,0.558679,-0.043249,0.535389,-0.431847,-0.314405,0.546827,0.298692,0.469526,-0.444612,0.415379,...,0.510143,0.219774,0.427975,0.985804,0.517400,0.264797,0.368739,-0.408416,0.523512,0.376396
00021a8c0dd1ffba57d5d690e1f97f11e7a771d1cef5e7b1452556a257e5f030,2022-09,0.323487,-0.276744,0.560727,-0.342046,-0.249254,0.475357,0.235547,0.502779,-0.534826,0.514150,...,0.542720,0.449572,0.465460,0.971408,0.439192,0.324697,0.327267,-0.552139,0.469697,0.315450
00021a8c0dd1ffba57d5d690e1f97f11e7a771d1cef5e7b1452556a257e5f030,2022-11,0.180928,0.077817,0.333774,0.061921,-0.194038,0.312977,0.137492,0.136401,-0.167976,0.210617,...,0.197331,0.278634,0.125254,0.880835,0.304740,0.164341,0.229710,-0.227223,0.222957,0.274799
00021a8c0dd1ffba57d5d690e1f97f11e7a771d1cef5e7b1452556a257e5f030,2022-12,0.147317,-0.055934,0.223596,-0.029845,-0.087291,0.310376,0.084504,0.235928,-0.238309,0.262380,...,0.184304,0.116904,0.073277,0.696958,0.243069,0.133785,0.105102,-0.124728,0.312680,0.114867
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
ffffab5f6ae1c8d04d83ef12e2ad803298737992698079e8b615bcf7410eb10d,2022-01,0.250947,-0.246172,0.437951,-0.319685,-0.264899,0.448308,0.225199,0.500423,-0.531635,0.289229,...,0.583931,0.284704,0.326064,0.963608,0.497223,0.342319,0.257071,-0.351097,0.505550,0.253633
ffffab5f6ae1c8d04d83ef12e2ad803298737992698079e8b615bcf7410eb10d,2022-02,0.211653,-0.118454,0.479317,-0.160438,-0.159876,0.377764,0.176856,0.287152,-0.317172,0.275343,...,0.305766,0.246688,0.167175,0.652105,0.448935,0.041336,0.221523,-0.227581,0.444711,0.119269
ffffab5f6ae1c8d04d83ef12e2ad803298737992698079e8b615bcf7410eb10d,2022-06,0.057221,-0.029483,0.395019,-0.197280,-0.131425,0.319580,0.032792,0.160288,-0.066244,0.382928,...,0.197924,0.014337,0.036678,0.763321,0.310197,0.170308,0.171753,-0.092938,0.385899,0.052934
ffffab5f6ae1c8d04d83ef12e2ad803298737992698079e8b615bcf7410eb10d,2022-10,0.334689,-0.207153,0.544601,-0.413990,-0.317309,0.450411,0.347005,0.379911,-0.520079,0.456280,...,0.385529,0.343866,0.292791,0.805848,0.465737,0.253829,0.240460,-0.334201,0.475294,0.316320


In [75]:
aggregated_df.reset_index(inplace = True)

In [76]:
aggregated_df['client_id'] = aggregated_df['client_id'] + '_month=' + pd.to_datetime(
    aggregated_df['event_time']).apply(lambda x: str(x.month))

In [79]:
aggregated_df = aggregated_df.drop(columns='event_time')
aggregated_df.to_parquet('dial_features_train.parquet', index=False)

# Dialogs aggregation test

In [23]:
dial_test = pd.read_parquet('dial_test.parquet')

In [24]:
dial_test = pd.concat(
    [
        dial_test.drop(columns=['embedding']), 
        pd.DataFrame(np.vstack(dial_test.embedding.to_list()), columns=[f'emb_{i}' for i in range(768)])
    ],
    axis=1
) 

In [25]:
aggregation_functions = {f'emb_{i}': 'mean' for i in range(len(dial_test.columns)-2)}
aggregated_df_test = dial_test.groupby(['client_id']).agg(aggregation_functions)

In [None]:
aggregated_df_test.reset_index(inplace = True)

Unnamed: 0_level_0,emb_0,emb_1,emb_2,emb_3,emb_4,emb_5,emb_6,emb_7,emb_8,emb_9,...,emb_758,emb_759,emb_760,emb_761,emb_762,emb_763,emb_764,emb_765,emb_766,emb_767
client_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
00006c6ed6d81e18051751b68f9cb0d4f31d13ef0ae7fbd48693f05e60c1a204,0.092142,-0.150681,0.324699,-0.014067,-0.041489,0.244475,0.106149,0.151189,-0.148194,0.131524,...,0.243261,0.097592,0.107901,0.668990,0.291956,0.038800,0.082139,-0.025007,0.194962,0.164325
00011c01bb22d8f62d9655f32d123dcca5ae55179f8266bdb8676e25321e8477,0.325750,-0.102299,0.584633,-0.301862,-0.321009,0.436904,0.244961,0.494768,-0.593855,0.324509,...,0.524059,0.348781,0.245613,0.862081,0.567360,0.378372,0.386632,-0.228029,0.495043,0.206350
0001ac6446bf223a094d6514a6c890d82e9aa92104dee0a8afc28b2002b95dac,0.313159,-0.161794,0.530575,-0.327430,-0.262520,0.493394,0.270701,0.418504,-0.436374,0.395218,...,0.469505,0.352366,0.368798,0.922958,0.490719,0.262944,0.299643,-0.330986,0.513865,0.332118
0003304a0f65d675ddfbc0691e0c564d26a4c9e08edf67e1823e833e8b05fa99,0.302135,-0.013960,0.471700,-0.386813,-0.267022,0.490684,0.302831,0.333394,-0.586620,0.584643,...,0.569554,0.334014,0.525866,0.839566,0.429467,0.397954,0.293862,-0.274147,0.401424,0.310477
00037813e71deead5685649d494c9a412391942fe771e2699bcc33029bd5c7dd,0.374263,-0.101210,0.618231,-0.364309,-0.288485,0.504267,0.319840,0.259852,-0.433514,0.395235,...,0.511705,0.316487,0.468458,0.907858,0.562500,0.166101,0.332280,-0.349127,0.519880,0.404450
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
fffb4dfad27856f24b9487f76cd9dac40210eea3de58e17f11d5776c28ca87f8,0.506066,-0.281439,0.640810,-0.535692,-0.268165,0.742327,0.345215,0.498198,-0.503286,0.475008,...,0.519946,0.542672,0.463693,0.984409,0.622174,0.228496,0.403753,-0.447912,0.534203,0.510189
fffd127b61250edd960e8751b25f372d164e5cbb89761adb3477e3947eaaad84,0.180644,-0.076154,0.391725,-0.153129,-0.260827,0.446096,0.180598,0.309986,-0.295848,0.298149,...,0.299031,0.205657,0.348035,0.634951,0.360963,0.149529,0.161550,-0.292872,0.345656,0.194827
fffd7b5a53179784d02b5a1e625322f8d36a838eba22f92c869555e6780f63b1,0.049251,-0.169683,0.326835,-0.106828,-0.111616,0.338683,0.163551,0.250411,-0.108024,0.256269,...,0.125575,0.263678,0.202163,0.820367,0.220768,0.019176,0.245175,-0.253359,0.308669,0.172629
fffe8ed2b0c1cdf0992f01cdd4d071edfa2cdf60279dcbf71729b29dfc5eae6f,0.392902,-0.045059,0.726924,-0.206968,-0.283472,0.519860,0.338938,0.247836,-0.286562,0.235624,...,0.536575,0.383584,0.429522,0.983889,0.507465,0.097660,0.377414,-0.234115,0.502277,0.363152


In [29]:
aggregated_df_test

Unnamed: 0,client_id,emb_0,emb_1,emb_2,emb_3,emb_4,emb_5,emb_6,emb_7,emb_8,...,emb_758,emb_759,emb_760,emb_761,emb_762,emb_763,emb_764,emb_765,emb_766,emb_767
0,00006c6ed6d81e18051751b68f9cb0d4f31d13ef0ae7fb...,0.092142,-0.150681,0.324699,-0.014067,-0.041489,0.244475,0.106149,0.151189,-0.148194,...,0.243261,0.097592,0.107901,0.668990,0.291956,0.038800,0.082139,-0.025007,0.194962,0.164325
1,00011c01bb22d8f62d9655f32d123dcca5ae55179f8266...,0.325750,-0.102299,0.584633,-0.301862,-0.321009,0.436904,0.244961,0.494768,-0.593855,...,0.524059,0.348781,0.245613,0.862081,0.567360,0.378372,0.386632,-0.228029,0.495043,0.206350
2,0001ac6446bf223a094d6514a6c890d82e9aa92104dee0...,0.313159,-0.161794,0.530575,-0.327430,-0.262520,0.493394,0.270701,0.418504,-0.436374,...,0.469505,0.352366,0.368798,0.922958,0.490719,0.262944,0.299643,-0.330986,0.513865,0.332118
3,0003304a0f65d675ddfbc0691e0c564d26a4c9e08edf67...,0.302135,-0.013960,0.471700,-0.386813,-0.267022,0.490684,0.302831,0.333394,-0.586620,...,0.569554,0.334014,0.525866,0.839566,0.429467,0.397954,0.293862,-0.274147,0.401424,0.310477
4,00037813e71deead5685649d494c9a412391942fe771e2...,0.374263,-0.101210,0.618231,-0.364309,-0.288485,0.504267,0.319840,0.259852,-0.433514,...,0.511705,0.316487,0.468458,0.907858,0.562500,0.166101,0.332280,-0.349127,0.519880,0.404450
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
81420,fffb4dfad27856f24b9487f76cd9dac40210eea3de58e1...,0.506066,-0.281439,0.640810,-0.535692,-0.268165,0.742327,0.345215,0.498198,-0.503286,...,0.519946,0.542672,0.463693,0.984409,0.622174,0.228496,0.403753,-0.447912,0.534203,0.510189
81421,fffd127b61250edd960e8751b25f372d164e5cbb89761a...,0.180644,-0.076154,0.391725,-0.153129,-0.260827,0.446096,0.180598,0.309986,-0.295848,...,0.299031,0.205657,0.348035,0.634951,0.360963,0.149529,0.161550,-0.292872,0.345656,0.194827
81422,fffd7b5a53179784d02b5a1e625322f8d36a838eba22f9...,0.049251,-0.169683,0.326835,-0.106828,-0.111616,0.338683,0.163551,0.250411,-0.108024,...,0.125575,0.263678,0.202163,0.820367,0.220768,0.019176,0.245175,-0.253359,0.308669,0.172629
81423,fffe8ed2b0c1cdf0992f01cdd4d071edfa2cdf60279dcb...,0.392902,-0.045059,0.726924,-0.206968,-0.283472,0.519860,0.338938,0.247836,-0.286562,...,0.536575,0.383584,0.429522,0.983889,0.507465,0.097660,0.377414,-0.234115,0.502277,0.363152


In [30]:
aggregated_df_test.to_parquet('dial_features_test.parquet', index=False)