In [1]:
import pandas as pd
import json
from pathlib import Path
from typing import List
from functools import reduce
import numpy as np
from sklearn.metrics import f1_score
import pickle
from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import make_classification
from xgboost import XGBClassifier


In [2]:
import joblib

In [3]:
def ensemble(true: np.array, preds: np.array, strategy='OR'):
    final = []
    clf = None
    
    if strategy == 'OR':
        for pred in preds:
            final.append(1 if any(pred) else 0)
    
    elif strategy == 'AND':
        for pred in preds:
            final.append(1 if all(pred) else 0)
            
    elif strategy == 'blend_rf':
        clf = RandomForestClassifier(max_depth=100, 
                                     n_estimators=3, 
                                     random_state=0)
        clf.fit(preds, true.ravel())
        final = clf.predict(preds)
    elif strategy == 'blend_xgb':
        clf = XGBClassifier(n_estimators=1000,
                                       max_depth=5,
                            learning_rate=0.1,
                                       verbosity=1)
        clf.fit(preds, true.ravel())
        final = clf.predict(preds)

    return final, clf

In [4]:
def merge_cnn_preds_with_df(cnn_preds_fp, test):
    pred_df_test = test.copy()
    print(pred_df_test.shape)
    with open(cnn_preds_fp, 'rb') as handle:
        preds = pickle.load(handle)
        print(len(preds))
        pred_df_test['Prediction'] = preds
        pred_df_test['Prediction'] = pred_df_test['Prediction'].astype(int)
    return pred_df_test

In [26]:
test = pd.read_csv(Path('/media/sarthak/HDD/data_science/fnp_resources/data/task1/all_combined/test.csv'))
dev = pd.read_csv(Path('/media/sarthak/HDD/data_science/fnp_resources/data/task1/all_combined/dev.csv'))
reload_path = Path('/media/sarthak/HDD/data_science/fnp_resources/fincausal_t1_models/submissions/metaclf_bertsandrulesandcnnandbertbaseuncased.joblib')

In [27]:
# code for transformers prediction
fnp_pred_dfs_test = []

file_paths = [Path('/media/sarthak/HDD/data_science/fnp_resources/fincausal_t1_models/101_1/output/best_model/inference/predictions.csv'),
              Path('/media/sarthak/HDD/data_science/fnp_resources/fincausal_t1_models/102_1/output/best_model/inference/predictions.csv'),
             Path('/media/sarthak/HDD/data_science/fnp_resources/fincausal_t1_models/103_1/output/best_model/inference/predictions.csv'),
             Path('/media/sarthak/HDD/data_science/fnp_resources/fincausal_t1_models/105_1/output/best_model/inference/predictions.csv'),
             Path('/media/sarthak/HDD/data_science/fnp_resources/fincausal_t1_models/99_1/output/best_model/inference/predictions.csv'),
             Path('/media/sarthak/HDD/data_science/fnp_resources/fincausal_t1_models/104_1/output/best_model/inference/predictions.csv'),
             Path('/media/sarthak/HDD/data_science/fnp_resources/fincausal_t1_models/105_2/output/best_model/inference/predictions.csv')]
"""
file_paths = [ [Path('/media/sarthak/HDD/data_science/fnp_resources/fincausal_t1_models/61_1/output/best_model/inference_dev/predictions.csv'),
             Path('/media/sarthak/HDD/data_science/fnp_resources/fincausal_t1_models/61_1/output/best_model/inference/predictions.csv')],
    [Path('/media/sarthak/HDD/data_science/fnp_resources/fincausal_t1_models/62_1/output/best_model/inference_dev/predictions.csv'),
             Path('/media/sarthak/HDD/data_science/fnp_resources/fincausal_t1_models/62_1/output/best_model/inference/predictions.csv')],
    
    [Path('/media/sarthak/HDD/data_science/fnp_resources/fincausal_t1_models/64_1/output/best_model/inference_dev/predictions.csv'),
             Path('/media/sarthak/HDD/data_science/fnp_resources/fincausal_t1_models/64_1/output/best_model/inference/predictions.csv')],
    [Path('/media/sarthak/HDD/data_science/fnp_resources/fincausal_t1_models/67_1/output/best_model/inference_dev/predictions.csv'),
             Path('/media/sarthak/HDD/data_science/fnp_resources/fincausal_t1_models/67_1/output/best_model/inference/predictions.csv')],
              [Path('/media/sarthak/HDD/data_science/fnp_resources/fincausal_t1_models/69_1/output/best_model/inference_dev/predictions.csv'),
             Path('/media/sarthak/HDD/data_science/fnp_resources/fincausal_t1_models/69_1/output/best_model/inference/predictions.csv')],
             [Path('/media/sarthak/HDD/data_science/fnp_resources/fincausal_t1_models/87_1/output/best_model/inference_dev/predictions.csv'),
             Path('/media/sarthak/HDD/data_science/fnp_resources/fincausal_t1_models/87_1/output/best_model/inference/predictions.csv')]]
"""
for fp in file_paths:
    fnp_pred_dfs_test.append(pd.read_csv(fp))

In [28]:
# code for CNN predictions
cnn_pred_dfs_test = []
"""
file_paths = [['/media/sarthak/HDD/TUM/Thesis/thesis-sarthak/src/tc/experiments/fincausal_all_combined/preds_val.pkl', 
               '/media/sarthak/HDD/TUM/Thesis/thesis-sarthak/src/tc/experiments/fincausal_all_combined/preds.pkl']]
"""
file_paths = ['/media/sarthak/HDD/TUM/Thesis/thesis-sarthak/src/tc/experiments/fincausal_allcombined_traindev/preds_test.pkl']

for fp in file_paths:
    cnn_pred_df_test = merge_cnn_preds_with_df(fp, test)
    cnn_pred_dfs_test.append(cnn_pred_df_test)

(2206, 4)
2206


In [29]:
data_frames_test = fnp_pred_dfs_test + cnn_pred_dfs_test
k = np.arange(len(data_frames_test)).astype(str)
df_merged_test = pd.concat([x.set_index('unique_id') for x in data_frames_test], axis=1, join='inner', keys=k)
df_merged_test.columns = df_merged_test.columns.map('_'.join)

cols_gold = ['0_Gold'] 
cols_pred = [i+'_Prediction' for i in k]
df_merged_test = df_merged_test[cols_gold + cols_pred]

In [30]:
test.head()

Unnamed: 0,Index,Text,Gold,unique_id
0,352.00039,One bit of advice Orton had for young finance ...,0,f56182ea-5095-4934-889c-f18c1a26134c
1,8.00001,President Muhammadu Buhari has disclosed that ...,1,06cc12f1-1df2-43ef-898a-85ebb29bdede
2,53.0001,Advent of compressed exhaust systems based on ...,0,a2705c77-271b-4870-b948-8b688657f39c
3,533.00022,Catholic leaders have also suggested the UK go...,0,7966c427-9695-4e2f-a098-a5d385d6e0c4
4,316.0004,(NasdaqGS:CME) is 66. A company with a value o...,0,14e3e218-9f6b-4949-96c7-62ca68598d55


In [31]:
df_merged_test.head()

Unnamed: 0_level_0,0_Gold,0_Prediction,1_Prediction,2_Prediction,3_Prediction,4_Prediction,5_Prediction,6_Prediction,7_Prediction
unique_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1
f56182ea-5095-4934-889c-f18c1a26134c,0,0,0,0,0,0,0,0,0
06cc12f1-1df2-43ef-898a-85ebb29bdede,1,0,0,0,0,0,1,0,0
a2705c77-271b-4870-b948-8b688657f39c,0,0,0,0,0,0,0,0,0
7966c427-9695-4e2f-a098-a5d385d6e0c4,0,0,0,0,0,0,0,0,0
14e3e218-9f6b-4949-96c7-62ca68598d55,0,0,0,0,0,0,0,0,0


In [32]:
meta_clf = joblib.load(reload_path)

df_merged_test['Prediction'] = meta_clf.predict(df_merged_test[cols_pred].values)

print('test: ', f1_score(df_merged_test['0_Gold'].tolist(), df_merged_test['Prediction'], average='weighted'))

test:  0.9592589640517364
