#  Вступление 

В этом домашнем задании вы поработаете с датасетом MBD [ссылка](https://huggingface.co/datasets/ai-lab/MBD-mini), описание датасета дано по ссылке. Вашей задачей будет предсказать метки классов, которые соответствуют транзакционная активности клиентов. Вам доступны 3 датасета признаков: геоданные, данные о транзакциях, данные о предыдущих покупках и диалогов (эмбеддинги диалогов). Вам нужно будет предсказать метки для одного выбранного дня, признаки для тестового датасета доступны в `test.csv`.

Ниже в разделе `Обучение`приведен бейзлайн, где задействован один датасет с эмбеддингами, для лучших результатов рекомендуется скачать и использовать часть данных из оригинального датасета по ссылке, а также учесть признаки с разных дат,однако учтите, что не для всех клиентов доступны все 3 категории признаков. Вам придется придумать как грамотно заполнить пропуски.
Вы можете использовать любую библиотеку и дополнять train выборку данными из полного датасета.

In [1]:
from huggingface_hub import hf_hub_download
import pandas as pd
import os
import numpy as np
from sklearn.model_selection import train_test_split

from catboost import CatBoostClassifier, Pool
# раскомментите чтобы скачать полный датасет, он большой, будьте готовы подождать
#hf_hub_download(repo_id="ai-lab/MBD-mini", filename="detail.tar.gz", repo_type="dataset", local_dir='data')
#hf_hub_download(repo_id="ai-lab/MBD-mini", filename="targets.tar.gz", repo_type="dataset", local_dir='data')
#hf_hub_download(repo_id="ai-lab/MBD-mini", filename="ptls.tar.gz", repo_type="dataset", local_dir='data')


In [2]:
# Расскоментите, чтобы разархировать датасеты в директории
#!tar -xvzf  data/detail.tar.gz 
#!tar -xvzf  data/ptls.tar.gz 
#!tar -xvzf  data/targets.tar.gz 

## Загрузка датасета
В этом разделе приведен пример сэмплирования данных из датасета, если у вас не хватает оперативной памяти, можете модернизировать код и скачать его полностью.  

In [2]:
def load_data_from_directory(directory_path):
    all_files = [os.path.join(directory_path, f) for f in os.listdir(directory_path)]
    df_list = []
    for file in all_files:
        if os.path.getsize(file) > 0:  # Check if the file is not empty
            df_list.append(pd.read_parquet(file))
        else:
            print(f"Skipping empty file: {file}")
            
    return pd.concat(df_list, ignore_index=True)


In [4]:
geo_data = load_data_from_directory('ptls/geo/')
geo_data['event_time_geo'] = geo_data['event_time']
del  geo_data['event_time']

Skipping empty file: ptls/geo/_SUCCESS


In [3]:
#dialog_data = load_data_from_directory('ptls/dialog/')
trx_data = load_data_from_directory('ptls/trx/')


Skipping empty file: ptls/trx/_SUCCESS


In [4]:
trx_data.shape

(98721, 14)

In [7]:
trx_data.iloc[0]

client_id        145c6b726a2d62545263742c78bcd6082a46092a1403cf...
event_time       [1609492496, 1609503804, 1609587358, 160975868...
amount           [66605.52, 30236.105, 33207.33, 30632.889, 819...
event_type       [2, 2, 2, 2, 4, 4, 2, 2, 2, 1, 4, 2, 2, 2, 4, ...
event_subtype    [2, 2, 2, 2, 4, 4, 2, 2, 2, 1, 4, 2, 2, 2, 4, ...
currency         [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...
src_type11       [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...
src_type12       [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...
dst_type11       [2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2, 2, ...
dst_type12       [2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2, 2, ...
src_type21       [170, 170, 170, 170, 170, 170, 170, 170, 170, ...
src_type22       [45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 4...
src_type31       [144, 144, 144, 144, 144, 144, 144, 144, 144, ...
src_type32       [54, 54, 54, 54, 54, 54, 54, 54, 54, 54, 54, 5...
Name: 0, dtype: object

In [8]:
trx_data.iloc[0]
trx_data['event_time_trx'] = trx_data['event_time']

In [9]:
trx_data.iloc[0]

client_id         145c6b726a2d62545263742c78bcd6082a46092a1403cf...
event_time        [1609492496, 1609503804, 1609587358, 160975868...
amount            [66605.52, 30236.105, 33207.33, 30632.889, 819...
event_type        [2, 2, 2, 2, 4, 4, 2, 2, 2, 1, 4, 2, 2, 2, 4, ...
event_subtype     [2, 2, 2, 2, 4, 4, 2, 2, 2, 1, 4, 2, 2, 2, 4, ...
currency          [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...
src_type11        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...
src_type12        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...
dst_type11        [2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2, 2, ...
dst_type12        [2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2, 2, ...
src_type21        [170, 170, 170, 170, 170, 170, 170, 170, 170, ...
src_type22        [45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 4...
src_type31        [144, 144, 144, 144, 144, 144, 144, 144, 144, ...
src_type32        [54, 54, 54, 54, 54, 54, 54, 54, 54, 54, 54, 5...
event_time_trx    [1609492496, 1609503804, 16095

In [10]:
del  trx_data['event_time']
targets = load_data_from_directory('targets/')

Skipping empty file: targets/_SUCCESS


In [11]:
merged_data = geo_data.merge(dialog_data, on='client_id', how='outer')
merged_data = merged_data.merge(trx_data, on='client_id', how='outer')

In [20]:
merged_data.iloc[0]
merged_data.shape

(99756, 20)

In [13]:
exclusive_merged_data = merged_data[merged_data['client_id'].isin(targets['client_id'])]
exclusive_targets = targets[targets['client_id'].isin(exclusive_merged_data['client_id'])]

In [14]:
exclusive_targets.iloc[0]

client_id          01392f49a65f10a6ed7afb0ab7f9405148d9f1ae5eb61b...
mon                                                       2022-02-28
target_1                                                           0
target_2                                                           0
target_3                                                           0
target_4                                                           0
trans_count                                                        1
diff_trans_date                                                  6.0
Name: 0, dtype: object

In [15]:
# Проверяем что количество уникальных айди совпадает в таргете и в признаках
assert len(exclusive_merged_data['client_id'].unique()) == len(exclusive_targets['client_id'].unique())

In [17]:
exclusive_merged_data.shape

(99756, 20)

In [18]:
len(exclusive_merged_data['client_id'].unique())

99756

In [16]:
exclusive_targets.shape

(1197072, 8)

# Ниже приведен пример с учетом диалогов 

Берем только колонку с диалогами и составляем обучающую выборку

In [8]:
exclusive_merged_data.head()

Unnamed: 0,client_id,geohash_4,geohash_5,geohash_6,event_time_geo,event_time,embedding,amount,event_type,event_subtype,currency,src_type11,src_type12,dst_type11,dst_type12,src_type21,src_type22,src_type31,src_type32,event_time_trx
0,000032cc38caee45fe031778bcf6af05aa2aabe476acb8...,"[3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, ...","[312, 312, 2, 312, 312, 2, 312, 312, 2, 312, 3...","[1018, 1018, 106, 1018, 1018, 106, 1018, 1018,...","[1641574982, 1641582278, 1641701700, 164176617...","[1656940377, 1657195278, 1657356462, 165796531...","[[0.5291204, -0.3842432, 0.49353755, -0.520477...","[6.668097, 737.65106, 1.3568845, 181.28973, 74...","[3, 20, 3, 14, 3, 1, 12, 3, 14, 3, 23, 12, 1, ...","[3, 21, 3, 13, 3, 1, 10, 3, 13, 3, 28, 10, 1, ...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[1, 1, 3, 5, 1, 1, 3, 1, 5, 1, 12, 3, 1, 1, 3,...","[1, 1, 11, 10, 1, 1, 16, 1, 10, 1, 26, 16, 1, ...","[6, 4, 3, 8, 6, 1, 3, 3, 8, 6, 6, 3, 1, 2, 3, ...","[7, 5, 3, 9, 7, 1, 3, 3, 9, 15, 15, 3, 1, 4, 3...","[297, 297, 297, 297, 297, 297, 297, 297, 297, ...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[148, 148, 148, 148, 148, 148, 148, 148, 148, ...","[36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 3...","[1631789510, 1636729704, 1639685826, 164224148..."
1,00005e39da5fb5968996cd49130281c8bd074a42e0e17a...,,,,,[1652963906],"[[0.03743099, -0.11977149, 0.3999211, 0.105271...","[36974.793, 35930.902, 41079.63, 39017.137, 40...","[2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 7]","[2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 7]","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 4]","[4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 42]","[2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 4]","[2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 5]","[10001, 10001, 10001, 10001, 10001, 10001, 100...","[30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 3...","[69, 69, 69, 69, 69, 69, 69, 69, 69, 69, 69, 6...","[24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 2...","[1660119344, 1661411255, 1662020112, 166227544..."
2,00012766278a7b25ddd3d6de329a44dec4c278d5807fae...,,,,,,,"[1138885.9, 2362.8396, 502637.16, 376197.97, 1...","[1, 3, 1, 1, 1, 3, 14, 23, 4, 15, 3, 3, 1, 3, ...","[1, 3, 1, 1, 1, 3, 13, 3, 4, 15, 3, 3, 1, 3, 4...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[1, 1, 1, 1, 1, 1, 5, 12, 1, 1, 1, 1, 1, 1, 1,...","[1, 1, 1, 1, 1, 1, 10, 26, 1, 1, 1, 1, 1, 1, 1...","[2, 3, 2, 2, 2, 3, 8, 6, 5, 1, 3, 6, 2, 3, 5, ...","[4, 3, 4, 4, 2, 3, 9, 15, 6, 1, 3, 15, 2, 3, 6...","[320, 320, 320, 320, 320, 320, 320, 320, 320, ...","[15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 1...","[85, 85, 85, 85, 85, 85, 85, 85, 85, 85, 85, 8...","[11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 1...","[1662296374, 1662303399, 1662555762, 166280000..."
3,0001edbc5ab720f70a615ed9e8429df9b6c3f3c3999a51...,"[504, 504, 504, 504, 504, 504, 504, 504, 504, ...","[3387, 7621, 3387, 3387, 2277, 2277, 2277, 338...","[10001, 10001, 10001, 10001, 10001, 10001, 100...","[1641630233, 1641709243, 1641795445, 164191149...",,,"[120123.3, 4095.0522, 4050.3926, 51119.47, 347...","[2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...","[2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...","[9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, ...","[6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, ...","[7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, ...","[7907, 7907, 7907, 7907, 7907, 7907, 7907, 790...","[15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 1...","[111, 111, 111, 111, 111, 111, 111, 111, 111, ...","[25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 2...","[1611730958, 1611814603, 1612068485, 161224435..."
4,0002ddd816198d32474486d54f4bfe4f7b361119b5dc45...,"[60, 149, 149, 149, 149, 60, 149, 149, 149, 14...","[839, 1065, 1065, 1065, 1377, 839, 1065, 1377,...","[10001, 3659, 3659, 3659, 3926, 10001, 3659, 3...","[1643724003, 1643804335, 1643872253, 164405751...",,,"[14360.89, 112353.516, 791.74634, 1094.8862, 4...","[2, 1, 1, 1, 1, 2, 2, 1, 1, 1, 1, 1, 2, 1, 1, ...","[2, 1, 1, 1, 1, 2, 2, 1, 1, 1, 1, 1, 2, 1, 1, ...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, ...","[15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 1...","[2, 1, 1, 1, 1, 2, 2, 1, 1, 1, 1, 1, 2, 1, 1, ...","[2, 1, 1, 1, 1, 2, 2, 1, 1, 1, 1, 1, 2, 1, 1, ...","[215, 215, 215, 215, 215, 215, 215, 215, 215, ...","[8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, ...","[286, 286, 286, 286, 286, 286, 286, 286, 286, ...","[59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 5...","[1609475093, 1609754967, 1609756326, 160983599..."


In [9]:
exclusive_targets.head()

Unnamed: 0,client_id,mon,target_1,target_2,target_3,target_4,trans_count,diff_trans_date
0,01392f49a65f10a6ed7afb0ab7f9405148d9f1ae5eb61b...,2022-02-28,0,0,0,0,1,6.0
1,01392f49a65f10a6ed7afb0ab7f9405148d9f1ae5eb61b...,2022-03-31,0,0,0,0,1,34.0
2,01392f49a65f10a6ed7afb0ab7f9405148d9f1ae5eb61b...,2022-04-30,0,0,0,0,3,16.0
3,01392f49a65f10a6ed7afb0ab7f9405148d9f1ae5eb61b...,2022-05-31,0,0,0,0,4,26.0
4,01392f49a65f10a6ed7afb0ab7f9405148d9f1ae5eb61b...,2022-06-30,0,0,0,0,6,5.0


In [10]:
# Отбираем выборку для обучения 
df_unique_date = exclusive_targets.groupby(['client_id']).agg({'mon': 'min'}).reset_index()
df_result = pd.merge(df_unique_date,  exclusive_targets, on=['client_id', 'mon'], how='left')

In [11]:
df_result.head()

Unnamed: 0,client_id,mon,target_1,target_2,target_3,target_4,trans_count,diff_trans_date
0,000032cc38caee45fe031778bcf6af05aa2aabe476acb8...,2022-02-28,0,0,0,0,0,
1,00005e39da5fb5968996cd49130281c8bd074a42e0e17a...,2022-02-28,0,0,0,0,0,
2,00012766278a7b25ddd3d6de329a44dec4c278d5807fae...,2022-02-28,0,0,0,0,0,
3,0001edbc5ab720f70a615ed9e8429df9b6c3f3c3999a51...,2022-02-28,0,0,0,0,7,0.0
4,0002ddd816198d32474486d54f4bfe4f7b361119b5dc45...,2022-02-28,0,0,0,0,47,0.0


In [12]:
data = pd.merge(exclusive_merged_data, df_result, on=['client_id'])

In [14]:
data.shape

(99756, 27)

# Обучение
В этом разделе приведен бейзлайн с доступными `train.csv` данными, вы можете использовать дополнительные данные из оригинального датасета.

In [3]:
df_filtered = pd.read_csv('train.csv')

In [4]:
df_filtered.head(10)

Unnamed: 0,client_id,geohash_4,geohash_5,geohash_6,event_time_geo,event_time,embedding,amount,event_type,event_subtype,...,src_type31,src_type32,event_time_trx,mon,target_1,target_2,target_3,target_4,trans_count,diff_trans_date
0,000032cc38caee45fe031778bcf6af05aa2aabe476acb8...,[3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3...,[312 312 2 312 312 2 312 312 2 312 312 3...,[ 1018 1018 106 1018 1018 106 1018 10...,[1641574982 1641582278 1641701700 1641766170 1...,[1656940377 1657195278 1657356462 1657965315 1...,[ 0.5291204 -0.3842432 0.49353755 -0.520477...,[6.66809702e+00 7.37651062e+02 1.35688448e+00 ...,[ 3 20 3 14 3 1 12 3 14 3 23 12 1 1 12 ...,[ 3 21 3 13 3 1 10 3 13 3 28 10 1 1 10 ...,...,[148 148 148 148 148 148 148 148 148 148 148 1...,[36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 ...,[1631789510 1636729704 1639685826 1642241482 1...,2022-02-28,0,0,0,0,0,
1,00005e39da5fb5968996cd49130281c8bd074a42e0e17a...,,,,,[1652963906],[ 3.74309905e-02 -1.19771488e-01 3.99921089e-...,[36974.793 35930.902 41079.63 39017.137...,[2 2 2 2 2 2 2 2 2 2 2 2 7],[2 2 2 2 2 2 2 2 2 2 2 2 7],...,[69 69 69 69 69 69 69 69 69 69 69 69 69],[24 24 24 24 24 24 24 24 24 24 24 24 24],[1660119344 1661411255 1662020112 1662275446 1...,2022-02-28,0,0,0,0,0,
2,00012766278a7b25ddd3d6de329a44dec4c278d5807fae...,,,,,,,[1.13888588e+06 2.36283960e+03 5.02637156e+05 ...,[ 1 3 1 1 1 3 14 23 4 15 3 3 1 3 4 ...,[ 1 3 1 1 1 3 13 3 4 15 3 3 1 3 4 ...,...,[85 85 85 85 85 85 85 85 85 85 85 85 85 85 85 ...,[11 11 11 11 11 11 11 11 11 11 11 11 11 11 11 ...,[1662296374 1662303399 1662555762 1662800003 1...,2022-02-28,0,0,0,0,0,
3,0001edbc5ab720f70a615ed9e8429df9b6c3f3c3999a51...,[ 504 504 504 504 504 504 504 504 504 ...,[ 3387 7621 3387 3387 2277 2277 2277 33...,[10001 10001 10001 10001 10001 10001 10001 100...,[1641630233 1641709243 1641795445 1641911490 1...,,,[1.20123297e+05 4.09505225e+03 4.05039258e+03 ...,[2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2...,[2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2...,...,[111 111 111 111 111 111 111 111 111 111 111 1...,[25 25 25 25 25 25 25 25 25 25 25 25 25 25 25 ...,[1611730958 1611814603 1612068485 1612244354 1...,2022-02-28,0,0,0,0,7,0.0
4,0002ddd816198d32474486d54f4bfe4f7b361119b5dc45...,[ 60 149 149 149 149 60 149 149 149 ...,[ 839 1065 1065 1065 1377 839 1065 13...,[10001 3659 3659 3659 3926 10001 3659 39...,[1643724003 1643804335 1643872253 1644057515 1...,,,[1.43608896e+04 1.12353516e+05 7.91746338e+02 ...,[2 1 1 1 1 2 2 1 1 1 1 1 2 1 1 1 1 2 1 1 2 2 2...,[2 1 1 1 1 2 2 1 1 1 1 1 2 1 1 1 1 2 1 1 2 2 2...,...,[286 286 286 286 286 286 286 286 286 286 286 2...,[59 59 59 59 59 59 59 59 59 59 59 59 59 59 59 ...,[1609475093 1609754967 1609756326 1609835992 1...,2022-02-28,0,0,0,0,47,0.0
5,0003d9e89a9a18c15845518e203080761a49de1367b265...,[ 129 2647 129 129 129 129 129 2647 129 ...,[ 1268 10001 361 5487 361 2570 798 100...,[10001 10001 3027 10001 6624 10001 10001 100...,[1640239585 1641098116 1641207506 1641213402 1...,,,[3.67492773e+04 2.35645137e+04 8.96270020e+03 ...,[ 1 1 1 1 1 1 2 1 1 1 1 7 1 2 2 ...,[ 1 1 1 1 1 1 2 1 1 1 1 1 1 2 2 ...,...,[19 19 19 19 19 19 19 19 19 19 19 19 19 19 19 ...,[6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6...,[1609486696 1609576897 1609594877 1609842779 1...,2022-02-28,0,0,0,0,16,0.0
6,00041153285f5849bf5014a9b4aa7bcc8912ecf51ea39c...,[ 3 3 3 3 3 3 349 3 3 3 349 3...,[ 96 96 96 96 96 96 5096 96 96 ...,[ 2686 4827 4827 1526 4827 4827 10001 24...,[1650231715 1650295630 1651993397 1652377025 1...,[1615613357],[ 0.36896434 -0.1924966 0.47202018 -0.311184...,[6.06513125e+05 2.32826312e+05 2.85645625e+05 ...,[ 1 1 1 1 4 3 1 1 1 1 1 30 3 1 3 ...,[ 1 1 1 1 4 3 1 1 1 1 1 32 3 1 3 ...,...,[223 223 223 223 223 223 223 223 223 223 223 2...,[2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2...,[1614673598 1615460709 1615462333 1615722044 1...,2022-02-28,0,0,0,0,9,6.0
7,0004b2962fc621e5de79f92043e48e3815998158c16f9e...,,,,,[1624423185 1624950735 1625205517],[ 0.5406315 -0.2105584 0.51350933 -0.564700...,[5.85182578e+04 6.78135781e+04 3.72411670e+03 ...,[25 1 1 3 4 1 1 1 4 25 1 25 1 25 1 ...,[24 1 1 1 4 1 1 1 4 24 1 24 1 24 1 ...,...,[90 90 90 90 90 90 90 90 90 90 90 90 90 90 90 ...,[6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6...,[1609480509 1609496644 1609578618 1609609622 1...,2022-02-28,0,0,0,0,7,3.0
8,0004dc944b7ab7d62e5732ce56d08a4ed2b0171283ad19...,[1919 425 425 425 1726 1919 1919 1726 1726 ...,[10001 1970 1970 1970 10001 10001 10001 67...,[10001 10001 10001 10001 10001 10001 10001 100...,[1639812075 1639823841 1639826451 1639904049 1...,,,[3.3061306e+05 2.0242895e+04 2.7054509e+03 6.1...,[1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1...,[1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1...,...,[20 20 20 20 20 20 20 20 20 20 20 20 20 20 20 ...,[5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5...,[1610193040 1611061940 1612777120 1616670407 1...,2022-02-28,0,0,0,0,2,5.0
9,000700766e883532e040c0109ec12d3f64109b599dd364...,[ 12 212 212 ... 12 143 212],[ 244 2001 2001 ... 466 813 2001],[ 1625 10001 10001 ... 4258 3808 10001],[1639477458 1639653352 1639731068 ... 16724595...,,,[4.87791758e+04 3.64662656e+04 6.09522562e+05 ...,[ 1 2 1 1 1 1 2 1 2 1 1 1 11 1 1 ...,[ 1 2 1 1 1 1 2 1 2 1 1 1 12 1 1 ...,...,[86 86 86 86 86 86 86 86 86 86 86 86 86 86 86 ...,[11 11 11 11 11 11 11 11 11 11 11 11 11 11 11 ...,[1609736069 1610098321 1610256361 1610574549 1...,2022-02-28,0,0,0,0,9,6.0


In [5]:
def string_to_array(string):
    if pd.isna(string):
        return np.nan  # Если значение NaN, возвращаем NaN
    clean_str = string.replace('[', '').replace(']', '').replace('\n', '')
    return np.fromstring(clean_str, sep=' ')

# Применение функции для измения типа колонки
df_filtered['embedding'] = df_filtered['embedding'].apply(string_to_array)



In [6]:
new_data = df_filtered[['client_id', 'target_1', 'target_2', 'target_3', 'target_4', 'embedding']]
new_data= new_data.dropna()

In [7]:
import pandas as pd

def add_all_embedding_components(data):
    """
    Добавляет колонки embedding_0, embedding_1, ..., embedding_767 в DataFrame,
    каждая из которых содержит соответствующую компоненту эмбеддинга из колонки 'embedding'.

    :param data: DataFrame, содержащий колонку 'embedding' с эмбеддингами
    """
    embeddings_df = pd.DataFrame(data['embedding'].tolist(), index=data.index)
    
    embeddings_df = embeddings_df.rename(columns=lambda x: f'embedding_{x}')
    data = pd.concat([data, embeddings_df], axis=1)
    data.drop(columns=['embedding'], inplace=True)
    return data


# Добавление всех компонент эмбеддинга как отдельные колонки
new_data = add_all_embedding_components(new_data)




In [8]:
train, val = train_test_split(new_data, test_size=0.2, stratify=new_data[['target_1', 'target_3', 'target_4']], random_state=42)


In [9]:
train.head(2)

Unnamed: 0,client_id,target_1,target_2,target_3,target_4,embedding_0,embedding_1,embedding_2,embedding_3,embedding_4,...,embedding_758,embedding_759,embedding_760,embedding_761,embedding_762,embedding_763,embedding_764,embedding_765,embedding_766,embedding_767
65234,c2a4dd076863ad647723184ae125304d66f305f2380b27...,0,0,0,0,0.039867,-0.071968,0.363047,0.121511,-0.129254,...,0.06733,0.178045,0.130703,0.665849,0.328491,0.166453,0.121372,-0.128048,0.250428,0.119112
10571,1fe8e07c39916c584b1975b8893ecb8eef433995a8ef38...,0,0,0,0,0.283553,-0.354141,0.439486,-0.272576,-0.212807,...,0.414412,0.302377,0.522225,0.946563,0.598065,0.245536,0.281303,-0.484948,0.407955,0.267245


In [11]:
X_train = train.drop(columns= ['client_id','target_1', 'target_2','target_3', 'target_4'])
train_client_id = train['client_id']
y_train = train[['target_1', 'target_2','target_3', 'target_4']]
X_val = val.drop(columns= ['client_id','target_1', 'target_2','target_3', 'target_4'])
val_client_id = val['client_id']
y_val = val[['target_1', 'target_2','target_3', 'target_4']]
train_pool = Pool(X_train, y_train)
test_pool = Pool(X_val, y_val)

In [16]:
for a in X_train.columns:
    if 'embed' not in a:
        print(a)

In [None]:
clf = CatBoostClassifier(
    loss_function='MultiLogloss',
    eval_metric='HammingLoss',
    iterations=100,
    class_names=['target_1', 'target_2','target_3', 'target_4'],
)

In [None]:
clf.fit(train_pool, eval_set=test_pool, metric_period=10, verbose=20)

Learning rate set to 0.192437
0:	learn: 0.0046832	test: 0.0046635	best: 0.0046635 (0)	total: 182ms	remaining: 18s
20:	learn: 0.0046832	test: 0.0046635	best: 0.0046635 (0)	total: 2.41s	remaining: 9.06s
40:	learn: 0.0046638	test: 0.0046635	best: 0.0046635 (0)	total: 4.53s	remaining: 6.51s
60:	learn: 0.0046638	test: 0.0046635	best: 0.0046635 (0)	total: 6.73s	remaining: 4.3s
80:	learn: 0.0046444	test: 0.0046635	best: 0.0046635 (0)	total: 8.88s	remaining: 2.08s
99:	learn: 0.0046347	test: 0.0046635	best: 0.0046635 (0)	total: 10.9s	remaining: 0us

bestTest = 0.004663454065
bestIteration = 0

Shrink model to first 1 iterations.


<catboost.core.CatBoostClassifier at 0x74ce20f45f00>

In [None]:
clf.predict(X_val)

array([[0, 0, 0, 0],
       [0, 0, 0, 0],
       [0, 0, 0, 0],
       ...,
       [0, 0, 0, 0],
       [0, 0, 0, 0],
       [0, 0, 0, 0]])

# Пример сабмита

Сабмит представляет из себя датафрейм с 4 колонками: 'target_1', 'target_2', 'target_3', 'target_4', не меняйте порядок айди клиентов в тестовой выборке.

In [None]:
test = pd.read_csv('test.csv')
test['embedding'] = test['embedding'].apply(string_to_array)
X_test = test[['embedding']]
y_test = clf.predict(X_test)

In [None]:
ans = pd.DataFrame(y_test, columns=['target_1', 'target_2', 'target_3', 'target_4'])

In [None]:
ans.to_csv('submit.csv', index=False)