In [1]:
import pandas as pd
import json
from pathlib import Path
from typing import List
from functools import reduce
import numpy as np
from sklearn.metrics import f1_score
import pickle
from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import make_classification
from xgboost import XGBClassifier


In [33]:
def ensemble(true: np.array, preds: np.array, strategy='OR'):
    final = []
    clf = None
    
    if strategy == 'OR':
        for pred in preds:
            final.append(1 if any(pred) else 0)
    
    elif strategy == 'AND':
        for pred in preds:
            final.append(1 if all(pred) else 0)
            
    elif strategy == 'blend_rf':
        clf = RandomForestClassifier(max_depth=100, 
                                     n_estimators=3, 
                                     random_state=0)
        clf.fit(preds, true.ravel())
        final = clf.predict(preds)
    elif strategy == 'blend_xgb':
        clf = XGBClassifier(n_estimators=200,
                                       max_depth=1000,
                            learning_rate=0.3,
                                       verbosity=1)
        clf.fit(preds, true.ravel())
        final = clf.predict(preds)

    return final, clf

In [34]:
def merge_cnn_preds_with_df(cnn_preds_fp, dev, test):
    pred_df_dev = dev.copy()
    pred_df_test = test.copy()
    with open(cnn_preds_fp[0], 'rb') as handle:
        pred_df_dev['Prediction'] = pickle.load(handle)
        pred_df_dev['Prediction'] = pred_df_dev['Prediction'].astype(int)
    with open(cnn_preds_fp[1], 'rb') as handle:
        pred_df_test['Prediction'] = pickle.load(handle)
        pred_df_test['Prediction'] = pred_df_test['Prediction'].astype(int)
    return pred_df_dev, pred_df_test

In [35]:
test = pd.read_csv(Path('/media/sarthak/HDD/data_science/fnp_resources/data/task1/train_on_trial_test_on_practice_v2/test.csv'))
dev = pd.read_csv(Path('/media/sarthak/HDD/data_science/fnp_resources/data/task1/train_on_trial_test_on_practice_v2/dev.csv'))

In [51]:
# code for transformers prediction
fnp_pred_dfs_dev = []
fnp_pred_dfs_test = []
"""
file_paths = [[Path('/media/sarthak/HDD/data_science/fnp_resources/fincausal_t1_models/61_1/output/best_model/inference_dev/predictions.csv'),
             Path('/media/sarthak/HDD/data_science/fnp_resources/fincausal_t1_models/61_1/output/best_model/inference/predictions.csv')],
    [Path('/media/sarthak/HDD/data_science/fnp_resources/fincausal_t1_models/62_1/output/best_model/inference_dev/predictions.csv'),
             Path('/media/sarthak/HDD/data_science/fnp_resources/fincausal_t1_models/62_1/output/best_model/inference/predictions.csv')],
    [Path('/media/sarthak/HDD/data_science/fnp_resources/fincausal_t1_models/57_1/output/best_model/inference_dev/predictions.csv'),
             Path('/media/sarthak/HDD/data_science/fnp_resources/fincausal_t1_models/57_1/output/best_model/inference/predictions.csv')],
                [Path('/media/sarthak/HDD/data_science/fnp_resources/fincausal_t1_models/67_1/output/best_model/inference_dev/predictions.csv'),
             Path('/media/sarthak/HDD/data_science/fnp_resources/fincausal_t1_models/67_1/output/best_model/inference/predictions.csv')],
              [Path('/media/sarthak/HDD/data_science/fnp_resources/fincausal_t1_models/64_1/output/best_model/inference_dev/predictions.csv'),
             Path('/media/sarthak/HDD/data_science/fnp_resources/fincausal_t1_models/64_1/output/best_model/inference/predictions.csv')],
             [Path('/media/sarthak/HDD/data_science/fnp_resources/fincausal_t1_models/69_1/output/best_model/inference_dev/predictions.csv'),
             Path('/media/sarthak/HDD/data_science/fnp_resources/fincausal_t1_models/69_1/output/best_model/inference/predictions.csv')]]
"""
file_paths = [ [Path('/media/sarthak/HDD/data_science/fnp_resources/fincausal_t1_models/92_1/output/best_model/inference_dev/predictions.csv'),
             Path('/media/sarthak/HDD/data_science/fnp_resources/fincausal_t1_models/92_1/output/best_model/inference/predictions.csv')],
    [Path('/media/sarthak/HDD/data_science/fnp_resources/fincausal_t1_models/93_1/output/best_model/inference_dev/predictions.csv'),
             Path('/media/sarthak/HDD/data_science/fnp_resources/fincausal_t1_models/93_1/output/best_model/inference/predictions.csv')],
    
    [Path('/media/sarthak/HDD/data_science/fnp_resources/fincausal_t1_models/84_1/output/best_model/inference_dev/predictions.csv'),
             Path('/media/sarthak/HDD/data_science/fnp_resources/fincausal_t1_models/84_1/output/best_model/inference/predictions.csv')],
    [Path('/media/sarthak/HDD/data_science/fnp_resources/fincausal_t1_models/85_1/output/best_model/inference_dev/predictions.csv'),
             Path('/media/sarthak/HDD/data_science/fnp_resources/fincausal_t1_models/85_1/output/best_model/inference/predictions.csv')],
              [Path('/media/sarthak/HDD/data_science/fnp_resources/fincausal_t1_models/86_1/output/best_model/inference_dev/predictions.csv'),
             Path('/media/sarthak/HDD/data_science/fnp_resources/fincausal_t1_models/86_1/output/best_model/inference/predictions.csv')],
             [Path('/media/sarthak/HDD/data_science/fnp_resources/fincausal_t1_models/87_1/output/best_model/inference_dev/predictions.csv'),
             Path('/media/sarthak/HDD/data_science/fnp_resources/fincausal_t1_models/87_1/output/best_model/inference/predictions.csv')]]
for fp in file_paths:
    fnp_pred_dfs_dev.append(pd.read_csv(fp[0]))
    fnp_pred_dfs_test.append(pd.read_csv(fp[1]))

In [52]:
# code for CNN predictions
cnn_pred_dfs_dev = []
cnn_pred_dfs_test = []
"""
file_paths = [['/media/sarthak/HDD/TUM/Thesis/thesis-sarthak/src/tc/experiments/fincausal_all_combined/preds_val.pkl', 
               '/media/sarthak/HDD/TUM/Thesis/thesis-sarthak/src/tc/experiments/fincausal_all_combined/preds.pkl']]
"""
file_paths = [['/media/sarthak/HDD/TUM/Thesis/thesis-sarthak/src/tc/experiments/fincausal_train_trial_test_prac/preds_val.pkl', 
               '/media/sarthak/HDD/TUM/Thesis/thesis-sarthak/src/tc/experiments/fincausal_train_trial_test_prac/preds_test.pkl']]

for fp in file_paths:
    cnn_pred_df_dev, cnn_pred_df_test = merge_cnn_preds_with_df(fp, dev, test)
    cnn_pred_dfs_dev.append(cnn_pred_df_dev)
    cnn_pred_dfs_test.append(cnn_pred_df_test)

In [53]:
data_frames_test = fnp_pred_dfs_test + cnn_pred_dfs_test
data_frames = fnp_pred_dfs_dev + cnn_pred_dfs_dev
k = np.arange(len(data_frames)).astype(str)
df_merged = pd.concat([x.set_index('unique_id') for x in data_frames], axis=1, join='inner', keys=k)
df_merged_test = pd.concat([x.set_index('unique_id') for x in data_frames_test], axis=1, join='inner', keys=k)
df_merged.columns = df_merged.columns.map('_'.join)
df_merged_test.columns = df_merged_test.columns.map('_'.join)

cols_gold = ['0_Gold'] 
cols_pred = [i+'_Prediction' for i in k]
df_merged = df_merged[cols_gold + cols_pred]
df_merged_test = df_merged_test[cols_gold + cols_pred]

In [54]:
test.head()

Unnamed: 0,Index,Text,Gold,unique_id
0,1.00001,Florida raking in billions as Americans abando...,0,f3fb9fd3-58d1-4321-a877-3b313989c2a3
1,1.00002,"Recently, changes to the U.S. tax code have en...",0,cefc6481-f2c7-4f7c-8974-faa3ed310fc4
2,1.00003,"MORE FROM FOXBUSINESS.COM... As it turns out, ...",0,269f9c6b-0552-460b-8248-f1318efa6435
3,1.00004,"According to a new study from LendingTree, whi...",0,2269c204-d93b-44a6-93f3-e510f71bc7c0
4,1.00005,The Sunshine State drew in a net influx of abo...,1,7ed4375d-1e1a-4ac6-a3bf-4dccdc69b497


In [55]:
df_merged.head()

Unnamed: 0_level_0,0_Gold,0_Prediction,1_Prediction,2_Prediction,3_Prediction,4_Prediction,5_Prediction,6_Prediction
unique_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
6d85e627-d23c-4b6c-9021-e412ef2027ac,0,0,0,0,0,0,0,0
5bcbe512-80b4-4f76-9d52-211e5acf7a0e,0,0,0,0,0,0,0,0
76b7e135-55e5-43fb-b2b0-8252de67ebbd,0,0,0,0,0,0,0,0
505b0aa1-edd5-4dab-9c23-e4d15206edaa,0,0,0,0,0,0,0,0
c5c1deb0-1d50-4de8-8c65-8056434c701e,0,0,0,0,0,0,0,0


In [56]:
df_merged_test.head()

Unnamed: 0_level_0,0_Gold,0_Prediction,1_Prediction,2_Prediction,3_Prediction,4_Prediction,5_Prediction,6_Prediction
unique_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
f3fb9fd3-58d1-4321-a877-3b313989c2a3,0,0,0,0,0,0,0,0
cefc6481-f2c7-4f7c-8974-faa3ed310fc4,0,0,0,0,0,0,0,0
269f9c6b-0552-460b-8248-f1318efa6435,0,0,0,0,0,0,0,0
2269c204-d93b-44a6-93f3-e510f71bc7c0,0,0,0,0,0,0,0,0
7ed4375d-1e1a-4ac6-a3bf-4dccdc69b497,1,0,0,1,0,0,1,1


In [57]:
df_merged['Prediction'], meta_clf = ensemble(true=df_merged[cols_gold].values, 
                                             preds=df_merged[cols_pred].values,
                                             strategy='blend_xgb')
df_merged_test['Prediction'] = meta_clf.predict(df_merged_test[cols_pred].values)

print('val: ', f1_score(df_merged['0_Gold'].tolist(), df_merged['Prediction'], average='weighted'))
print('test: ', f1_score(df_merged_test['0_Gold'].tolist(), df_merged_test['Prediction'], average='weighted'))

val:  0.9627388700408501
test:  0.9443319036049485


In [43]:
df_merged['Prediction'], meta_clf = ensemble(true=df_merged[cols_gold].values, 
                                             preds=df_merged[cols_pred].values,
                                             strategy='OR')
df_merged_test['Prediction'], meta_clf = ensemble(true=df_merged_test[cols_gold].values, 
                                             preds=df_merged_test[cols_pred].values,
                                             strategy='OR')

print('val: ', f1_score(df_merged['0_Gold'].tolist(), df_merged['Prediction'], average='weighted'))
print('test: ', f1_score(df_merged_test['0_Gold'].tolist(), df_merged_test['Prediction'], average='weighted'))

val:  0.9020318468745557
test:  0.9241445688396995


In [236]:
meta_clf.feature_importances_, 

(array([0.02110144, 0.6654096 , 0.3067513 , 0.00673764], dtype=float32),)

In [361]:
# generate final vals and tests
ensembled_dev = dev.set_index('unique_id').join(df_merged[['Prediction']])
# print(f1_score(ensembled_dev['Gold'].tolist(), ensembled_dev['Prediction'], average='weighted'))
ensembled_test = test.set_index('unique_id').join(df_merged_test[['Prediction']])
# print(f1_score(ensembled_test['Gold'].tolist(), ensembled_test['Prediction'], average='weighted'))

In [363]:
ensembled_dev.reset_index(level=0, inplace=False).to_csv('/media/sarthak/HDD/data_science/fnp_resources/fincausal_t1_models/74_1/output/best_model/inference_dev/predictions.csv', index=False)
ensembled_test.reset_index(level=0, inplace=False).to_csv('/media/sarthak/HDD/data_science/fnp_resources/fincausal_t1_models/74_1/output/best_model/inference/predictions.csv', index=False)

# Generate Prediction

In [163]:
predict_df_path = Path('/media/sarthak/HDD/data_science/fnp_resources/data/task1/practice/all.csv')
predict_df = pd.read_csv(predict_df_path)
print('shape of data is: {}'.format(predict_df.shape))
predict_df.head()

shape of data is: (13547, 4)


Unnamed: 0,Index,Text,Gold,unique_id
0,1.00001,Florida raking in billions as Americans aband...,0,94bdc788-8aaa-49c1-9473-8222392e0c6e
1,1.00002,"Recently, changes to the U.S. tax code have e...",0,0af1d3f6-6aa9-461c-b921-0f8b90c83328
2,1.00003,"MORE FROM FOXBUSINESS.COM... As it turns out,...",0,5166917b-fe82-4e82-ba73-c67db70d957b
3,1.00004,"According to a new study from LendingTree, wh...",0,a52522d7-e25e-4b43-a510-44446d500443
4,1.00005,The Sunshine State drew in a net influx of ab...,1,517e2ae2-7d16-40af-a2a7-eb567841f9f8


In [165]:
predict_df['Prediction'] = merged['ensemble_preds'].tolist()

In [166]:
predict_df

Unnamed: 0,Index,Text,Gold,unique_id,Prediction
0,1.00001,Florida raking in billions as Americans aband...,0,94bdc788-8aaa-49c1-9473-8222392e0c6e,0
1,1.00002,"Recently, changes to the U.S. tax code have e...",0,0af1d3f6-6aa9-461c-b921-0f8b90c83328,0
2,1.00003,"MORE FROM FOXBUSINESS.COM... As it turns out,...",0,5166917b-fe82-4e82-ba73-c67db70d957b,0
3,1.00004,"According to a new study from LendingTree, wh...",0,a52522d7-e25e-4b43-a510-44446d500443,0
4,1.00005,The Sunshine State drew in a net influx of ab...,1,517e2ae2-7d16-40af-a2a7-eb567841f9f8,1
...,...,...,...,...,...
13542,590.00061,Contact transpo@gmu.edu with questions.,0,0294a30b-c9a5-4f89-ad3e-1188daa17543,0
13543,590.00062,Campus Fire Safety Month September is Campus...,0,77c8d28b-f1cf-4d04-b1fe-5ce06408d212,0
13544,590.00063,"Review the university's Fire Safety Plan, whi...",0,2e56e743-d171-43fd-a18c-1b4a4eca2aae,0
13545,590.00064,Contact Meredith Muckerman at 703-993-9715 or...,0,c1357ffa-d742-4f14-9f29-1d76808ba3c1,0


In [181]:
predict_df[['Index', 'Text', 'Prediction']].to_csv(Path('/media/sarthak/HDD/data_science/fnp_resources/fincausal_t1_models/37/output/best_model/inference/predictions.csv'), index=False, sep=';')