In [1]:
import pandas as pd
import json
from pathlib import Path
from typing import List
from functools import reduce
import numpy as np
from sklearn.metrics import f1_score
import pickle
from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import make_classification
from xgboost import XGBClassifier


In [182]:
def ensemble(true: np.array, preds: np.array, strategy='OR'):
    final = []
    clf = None
    
    if strategy == 'OR':
        for pred in preds:
            final.append(1 if any(pred) else 0)
    
    elif strategy == 'AND':
        for pred in preds:
            final.append(1 if all(pred) else 0)
            
    elif strategy == 'blend_rf':
        clf = RandomForestClassifier(max_depth=100, 
                                     n_estimators=100, 
                                     random_state=0)
        clf.fit(preds, true.ravel())
        final = clf.predict(preds)
    elif strategy == 'blend_xgb':
        clf = XGBClassifier(n_estimators=200,
                                       max_depth=1000,
                            learning_rate=0.3,
                                       verbosity=1)
        clf.fit(preds, true.ravel())
        final = clf.predict(preds)

    return final, clf

In [183]:
def merge_cnn_preds_with_df(cnn_preds_fp, dev, test):
    pred_df_dev = dev.copy()
    pred_df_test = test.copy()
    with open(cnn_preds_fp[0], 'rb') as handle:
        pred_df_dev['Prediction'] = pickle.load(handle)
        pred_df_dev['Prediction'] = pred_df_dev['Prediction'].astype(int)
    with open(cnn_preds_fp[1], 'rb') as handle:
        pred_df_test['Prediction'] = pickle.load(handle)
        pred_df_test['Prediction'] = pred_df_test['Prediction'].astype(int)
    return pred_df_dev, pred_df_test

In [205]:
test = pd.read_csv(Path('/media/sarthak/HDD/data_science/fnp_resources/data/task1/train_on_practice_test_on_trial_v2/test.csv'))
dev = pd.read_csv(Path('/media/sarthak/HDD/data_science/fnp_resources/data/task1/train_on_practice_test_on_trial_v2/dev.csv'))

In [229]:
# code for transformers prediction
fnp_pred_dfs_dev = []
fnp_pred_dfs_test = []
"""
file_paths = [[Path('/media/sarthak/HDD/data_science/fnp_resources/fincausal_t1_models/61_1/output/best_model/inference_dev/predictions.csv'),
             Path('/media/sarthak/HDD/data_science/fnp_resources/fincausal_t1_models/61_1/output/best_model/inference/predictions.csv')],
    [Path('/media/sarthak/HDD/data_science/fnp_resources/fincausal_t1_models/62_1/output/best_model/inference_dev/predictions.csv'),
             Path('/media/sarthak/HDD/data_science/fnp_resources/fincausal_t1_models/62_1/output/best_model/inference/predictions.csv')],
    [Path('/media/sarthak/HDD/data_science/fnp_resources/fincausal_t1_models/57_1/output/best_model/inference_dev/predictions.csv'),
             Path('/media/sarthak/HDD/data_science/fnp_resources/fincausal_t1_models/57_1/output/best_model/inference/predictions.csv')],
                [Path('/media/sarthak/HDD/data_science/fnp_resources/fincausal_t1_models/67_1/output/best_model/inference_dev/predictions.csv'),
             Path('/media/sarthak/HDD/data_science/fnp_resources/fincausal_t1_models/67_1/output/best_model/inference/predictions.csv')],
              [Path('/media/sarthak/HDD/data_science/fnp_resources/fincausal_t1_models/64_1/output/best_model/inference_dev/predictions.csv'),
             Path('/media/sarthak/HDD/data_science/fnp_resources/fincausal_t1_models/64_1/output/best_model/inference/predictions.csv')],
             [Path('/media/sarthak/HDD/data_science/fnp_resources/fincausal_t1_models/69_1/output/best_model/inference_dev/predictions.csv'),
             Path('/media/sarthak/HDD/data_science/fnp_resources/fincausal_t1_models/69_1/output/best_model/inference/predictions.csv')]]
"""
file_paths = [ [Path('/media/sarthak/HDD/data_science/fnp_resources/fincausal_t1_models/82_1/output/best_model/inference_dev/predictions.csv'),
             Path('/media/sarthak/HDD/data_science/fnp_resources/fincausal_t1_models/82_1/output/best_model/inference/predictions.csv')],
    [Path('/media/sarthak/HDD/data_science/fnp_resources/fincausal_t1_models/76_1/output/best_model/inference_dev/predictions.csv'),
             Path('/media/sarthak/HDD/data_science/fnp_resources/fincausal_t1_models/76_1/output/best_model/inference/predictions.csv')],
              [Path('/media/sarthak/HDD/data_science/fnp_resources/fincausal_t1_models/77_1/output/best_model/inference_dev/predictions.csv'),
             Path('/media/sarthak/HDD/data_science/fnp_resources/fincausal_t1_models/77_1/output/best_model/inference/predictions.csv')]]
for fp in file_paths:
    fnp_pred_dfs_dev.append(pd.read_csv(fp[0]))
    fnp_pred_dfs_test.append(pd.read_csv(fp[1]))

In [230]:
# code for CNN predictions
cnn_pred_dfs_dev = []
cnn_pred_dfs_test = []
"""
file_paths = [['/media/sarthak/HDD/TUM/Thesis/thesis-sarthak/src/tc/experiments/fincausal_all_combined/preds_val.pkl', 
               '/media/sarthak/HDD/TUM/Thesis/thesis-sarthak/src/tc/experiments/fincausal_all_combined/preds.pkl']]
"""
file_paths = [['/media/sarthak/HDD/TUM/Thesis/thesis-sarthak/src/tc/experiments/fincausal_train_prac_test_trial/preds_val.pkl', 
               '/media/sarthak/HDD/TUM/Thesis/thesis-sarthak/src/tc/experiments/fincausal_train_prac_test_trial/preds_test.pkl']]

for fp in file_paths:
    cnn_pred_df_dev, cnn_pred_df_test = merge_cnn_preds_with_df(fp, dev, test)
    cnn_pred_dfs_dev.append(cnn_pred_df_dev)
    cnn_pred_dfs_test.append(cnn_pred_df_test)

In [231]:
data_frames_test = fnp_pred_dfs_test + cnn_pred_dfs_test
data_frames = fnp_pred_dfs_dev + cnn_pred_dfs_dev
k = np.arange(len(data_frames)).astype(str)
df_merged = pd.concat([x.set_index('unique_id') for x in data_frames], axis=1, join='inner', keys=k)
df_merged_test = pd.concat([x.set_index('unique_id') for x in data_frames_test], axis=1, join='inner', keys=k)
df_merged.columns = df_merged.columns.map('_'.join)
df_merged_test.columns = df_merged_test.columns.map('_'.join)

cols_gold = ['0_Gold'] 
cols_pred = [i+'_Prediction' for i in k]
df_merged = df_merged[cols_gold + cols_pred]
df_merged_test = df_merged_test[cols_gold + cols_pred]

In [232]:
test.head()

Unnamed: 0,Index,Text,Gold,unique_id
0,1.00001,Third Democratic presidential debate Septembe...,0,70a9c0cd-25ee-45a9-8910-4cb8b57bb216
1,1.00002,"On the policy front, Bernie Sanders claimed hi...",0,44f94352-79bc-4b0a-b511-ad4964ee167c
2,1.00003,Joe Biden misrepresented recent history when h...,0,e02319d3-4a80-4b45-98f1-abe121bb02a1
3,1.00004,Here's a look at some of the assertions in the...,0,992fe053-512e-4fa3-8dbb-c132a94f4fa0
4,1.00005,"It killed 22 people, and injured many more, we...",0,28fd05dc-fbb4-4f38-b6f8-a98fab37e619


In [233]:
df_merged.head()

Unnamed: 0_level_0,0_Gold,0_Prediction,1_Prediction,2_Prediction,3_Prediction
unique_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
f5aaf044-daea-4f23-846e-cb38ac774fe1,0,0,0,0,0
85aba382-046a-4085-9bd5-fd383bdde338,1,1,1,1,1
22474aef-9aec-4ed1-8676-e7ce19ae380b,0,0,0,0,0
f547ea59-3f92-4290-ad8d-f9118c69185b,0,0,0,0,0
544dd8f2-13f6-49a1-9911-5db730eb23be,0,0,0,0,0


In [234]:
df_merged_test.head()

Unnamed: 0_level_0,0_Gold,0_Prediction,1_Prediction,2_Prediction,3_Prediction
unique_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
70a9c0cd-25ee-45a9-8910-4cb8b57bb216,0,0,0,0,0
44f94352-79bc-4b0a-b511-ad4964ee167c,0,0,0,0,0
e02319d3-4a80-4b45-98f1-abe121bb02a1,0,0,0,0,0
992fe053-512e-4fa3-8dbb-c132a94f4fa0,0,0,0,0,0
28fd05dc-fbb4-4f38-b6f8-a98fab37e619,0,0,0,0,0


In [235]:
df_merged['Prediction'], meta_clf = ensemble(true=df_merged[cols_gold].values, 
                                             preds=df_merged[cols_pred].values,
                                             strategy='blend_xgb')
df_merged_test['Prediction'] = meta_clf.predict(df_merged_test[cols_pred].values)

print('val: ', f1_score(df_merged['0_Gold'].tolist(), df_merged['Prediction'], average='weighted'))
print('test: ', f1_score(df_merged_test['0_Gold'].tolist(), df_merged_test['Prediction'], average='weighted'))

val:  0.9643797057901942
test:  0.9522721022194831


In [220]:
meta_clf.feature_importances_, 

(array([0.00886776, 0.62589085, 0.3276314 , 0.00484754, 0.02881434,
        0.00394809], dtype=float32),)

In [361]:
# generate final vals and tests
ensembled_dev = dev.set_index('unique_id').join(df_merged[['Prediction']])
# print(f1_score(ensembled_dev['Gold'].tolist(), ensembled_dev['Prediction'], average='weighted'))
ensembled_test = test.set_index('unique_id').join(df_merged_test[['Prediction']])
# print(f1_score(ensembled_test['Gold'].tolist(), ensembled_test['Prediction'], average='weighted'))

In [363]:
ensembled_dev.reset_index(level=0, inplace=False).to_csv('/media/sarthak/HDD/data_science/fnp_resources/fincausal_t1_models/74_1/output/best_model/inference_dev/predictions.csv', index=False)
ensembled_test.reset_index(level=0, inplace=False).to_csv('/media/sarthak/HDD/data_science/fnp_resources/fincausal_t1_models/74_1/output/best_model/inference/predictions.csv', index=False)

# Generate Prediction

In [163]:
predict_df_path = Path('/media/sarthak/HDD/data_science/fnp_resources/data/task1/practice/all.csv')
predict_df = pd.read_csv(predict_df_path)
print('shape of data is: {}'.format(predict_df.shape))
predict_df.head()

shape of data is: (13547, 4)


Unnamed: 0,Index,Text,Gold,unique_id
0,1.00001,Florida raking in billions as Americans aband...,0,94bdc788-8aaa-49c1-9473-8222392e0c6e
1,1.00002,"Recently, changes to the U.S. tax code have e...",0,0af1d3f6-6aa9-461c-b921-0f8b90c83328
2,1.00003,"MORE FROM FOXBUSINESS.COM... As it turns out,...",0,5166917b-fe82-4e82-ba73-c67db70d957b
3,1.00004,"According to a new study from LendingTree, wh...",0,a52522d7-e25e-4b43-a510-44446d500443
4,1.00005,The Sunshine State drew in a net influx of ab...,1,517e2ae2-7d16-40af-a2a7-eb567841f9f8


In [165]:
predict_df['Prediction'] = merged['ensemble_preds'].tolist()

In [166]:
predict_df

Unnamed: 0,Index,Text,Gold,unique_id,Prediction
0,1.00001,Florida raking in billions as Americans aband...,0,94bdc788-8aaa-49c1-9473-8222392e0c6e,0
1,1.00002,"Recently, changes to the U.S. tax code have e...",0,0af1d3f6-6aa9-461c-b921-0f8b90c83328,0
2,1.00003,"MORE FROM FOXBUSINESS.COM... As it turns out,...",0,5166917b-fe82-4e82-ba73-c67db70d957b,0
3,1.00004,"According to a new study from LendingTree, wh...",0,a52522d7-e25e-4b43-a510-44446d500443,0
4,1.00005,The Sunshine State drew in a net influx of ab...,1,517e2ae2-7d16-40af-a2a7-eb567841f9f8,1
...,...,...,...,...,...
13542,590.00061,Contact transpo@gmu.edu with questions.,0,0294a30b-c9a5-4f89-ad3e-1188daa17543,0
13543,590.00062,Campus Fire Safety Month September is Campus...,0,77c8d28b-f1cf-4d04-b1fe-5ce06408d212,0
13544,590.00063,"Review the university's Fire Safety Plan, whi...",0,2e56e743-d171-43fd-a18c-1b4a4eca2aae,0
13545,590.00064,Contact Meredith Muckerman at 703-993-9715 or...,0,c1357ffa-d742-4f14-9f29-1d76808ba3c1,0


In [181]:
predict_df[['Index', 'Text', 'Prediction']].to_csv(Path('/media/sarthak/HDD/data_science/fnp_resources/fincausal_t1_models/37/output/best_model/inference/predictions.csv'), index=False, sep=';')