In [9]:
import pandas as pd
import numpy as np
import os

# Controls weights when combining predictions
# 0: equal average of all inputs; 
# 1: up to 50% of weight going to least correlated input
DENSITY_COEFF = 0.1
assert DENSITY_COEFF >= 0.0 and DENSITY_COEFF <= 1.0

# When merging 2 files with corr > OVER_CORR_CUTOFF 
# the result's density is the max instead of the sum of the merged files' densities
OVER_CORR_CUTOFF = 0.98
assert OVER_CORR_CUTOFF >= 0.0 and OVER_CORR_CUTOFF <= 1.0

INPUT_DIR = 'StackPreds/TopN_XGB/'#'NewRnnPreds/'

def load_submissions():
    files = os.listdir(INPUT_DIR)
    csv_files = []
    for f in files:
        try:
            if int(f[:4]) > 9867: #f.endswith(".csv"):
                csv_files.append(f)
                print(f)
        except ValueError:
            pass
    frames = {f:pd.read_csv(INPUT_DIR+f).sort_values('id') for f in csv_files}
    return frames

In [11]:
def get_corr_mat(col,frames):
    c = pd.DataFrame()
    for name,df in frames.items():
        c[name] = df[col]
    cor = c.corr()
    for name in cor.columns:
        cor.set_value(name,name,0.0)
    return cor


def highest_corr(mat):
    n_cor = np.array(mat.values)
    corr = np.max(n_cor)
    idx = np.unravel_index(np.argmax(n_cor, axis=None), n_cor.shape)
    f1 = mat.columns[idx[0]]
    f2 = mat.columns[idx[1]]
    return corr,f1,f2


def get_merge_weights(m1,m2,densities):
    d1 = densities[m1]
    d2 = densities[m2]
    d_tot = d1 + d2
    weights1 = 0.5*DENSITY_COEFF + (d1/d_tot)*(1-DENSITY_COEFF)
    weights2 = 0.5*DENSITY_COEFF + (d2/d_tot)*(1-DENSITY_COEFF)
    return weights1, weights2


def ensemble_col(col,frames,densities):
    if len(frames) == 1:
        _, fr = frames.popitem()
        return fr[col]

    mat = get_corr_mat(col,frames)
    corr,merge1,merge2 = highest_corr(mat)
    new_col_name = merge1 + '_' + merge2

    w1,w2 = get_merge_weights(merge1,merge2,densities)
    new_df = pd.DataFrame()
    new_df[col] = (frames[merge1][col]*w1) + (frames[merge2][col]*w2)
    del frames[merge1]
    del frames[merge2]
    frames[new_col_name] = new_df

    if corr >= OVER_CORR_CUTOFF:
        print('\t',merge1,merge2,'  (OVER CORR)')
        densities[new_col_name] = max(densities[merge1],densities[merge2])
    else:
        print('\t',merge1,merge2)
        densities[new_col_name] = densities[merge1] + densities[merge2]

    del densities[merge1]
    del densities[merge2]
    #print(densities)
    return ensemble_col(col,frames,densities)


ens_submission = pd.read_csv('~/data/toxic/data/sample_submission.csv').sort_values('id')
#print(get_corr_mat('toxic',load_submissions()))

for col in ["toxic", "severe_toxic", "obscene", "threat", "insult", "identity_hate"]:
    frames = load_submissions()
    print('\n\n',col)
    densities = {k:1.0 for k in frames.keys()}
    ens_submission[col] = ensemble_col(col,frames,densities)

print(ens_submission)    
ens_submission.to_csv(INPUT_DIR + 'All9868Above_lazy_ensemble_submission.csv.gz', index=False, compression='gzip')

9869_PureRNN_0991846_0958260_2_0993922.csv.gz
9872_PureRNN_w_layer2Logreg25_30_165_165_12_1520888020_1520572538_1520444908_1520438638_1520460207__5_1520963719.csv.gz
9870_MAD3_10_1520957514.csv.gz
9869_xgb_meta_ensemble_1520435392_1520460999__2_1520461403.csv.gz
9873_RNN_Normal_80_20_1081874_1081126_2_1082136.csv.gz
9876_NewRNN_w_9874_37blend_2_1398299.csv.gz
9870_RNN_CapNwei_85_15_1080470_0958260_2_1081126.csv.gz
9868_xgb_meta_ensemble_1520435392_1520435392_2_1520446721.csv.gz
9874_RNN_Normal_70_30_1081874_1081126_2_1082320.csv.gz
9875_CNN_w_9876_19lend_2_1574427.csv.gz
9869_RNN_and_capftc6_15_85_1055020_0958260_2_1055171.csv.gz
9870_MAD_al_0888020_0572538_0444908_0438638_0460207_6_0966327.csv.gz
9872_RNN_w_layer2Logregandxgb_0888020_0885607_0572538_0444908_0438638_0460207_6_0964560.csv.gz
9876_CapsultNet_w_9876_28blend_2_1413594.csv.gz
9873_MAD131330181808_ensemble_1520437955_1520005508_1520572538_1520444908_1520438638_1520460207_6_1520573203.csv.gz
9868_PureRNN_0991846_0958260_2_099



 severe_toxic
	 9873_RNN_Normal_80_20_1081874_1081126_2_1082136.csv.gz 9874_RNN_Normal_70_30_1081874_1081126_2_1082320.csv.gz   (OVER CORR)
	 9869_PureRNN_0991846_0958260_2_0993922.csv.gz 9868_PureRNN_0991846_0958260_2_0994171.csv.gz   (OVER CORR)
	 9873_RNN_Normal_50_50_1081874_1081126_2_1082929.csv.gz 9874_RNN_Normal_60_40_1081874_1081126_2_1082835.csv.gz   (OVER CORR)
	 9873_MAD131330181808_ensemble_1520437955_1520005508_1520572538_1520444908_1520438638_1520460207_6_1520573203.csv.gz 9873_MAD151524181810_ensemble_1520437955_1520005508_1520560823_1520444908_1520438638_1520460207_6_1520562334.csv.gz   (OVER CORR)
	 9875_CNN_w_9876_19lend_2_1574427.csv.gz 9876_CapsultNet_w_9876_28blend_2_1413594.csv.gz   (OVER CORR)
	 9869_PureRNN40222216_1520572538_1520444908_1520438638_1520460207_4_1520958260.csv.gz 9868_RNN_bagging_7_1520866595.csv.gz   (OVER CORR)
	 9873_RNN_Normal_80_20_1081874_1081126_2_1082136.csv.gz_9874_RNN_Normal_70_30_1081874_1081126_2_1082320.csv.gz 9873_RNN_Normal_50_50_



 obscene
	 9869_PureRNN_0991846_0958260_2_0993922.csv.gz 9868_PureRNN_0991846_0958260_2_0994171.csv.gz   (OVER CORR)
	 9873_MAD131330181808_ensemble_1520437955_1520005508_1520572538_1520444908_1520438638_1520460207_6_1520573203.csv.gz 9873_MAD151524181810_ensemble_1520437955_1520005508_1520560823_1520444908_1520438638_1520460207_6_1520562334.csv.gz   (OVER CORR)
	 9873_RNN_Normal_80_20_1081874_1081126_2_1082136.csv.gz 9874_RNN_Normal_70_30_1081874_1081126_2_1082320.csv.gz   (OVER CORR)
	 9873_RNN_Normal_50_50_1081874_1081126_2_1082929.csv.gz 9874_RNN_Normal_60_40_1081874_1081126_2_1082835.csv.gz   (OVER CORR)
	 9875_CNN_w_9876_19lend_2_1574427.csv.gz 9876_CapsultNet_w_9876_28blend_2_1413594.csv.gz   (OVER CORR)
	 9875_NewRNN_w_9874_55blend_2_1397748.csv.gz 9875_CNN_w_9876_19lend_2_1574427.csv.gz_9876_CapsultNet_w_9876_28blend_2_1413594.csv.gz   (OVER CORR)
	 9869_RNN_and_capftc6_15_85_1055020_0958260_2_1055171.csv.gz 9869_PureRNN40222216_1520572538_1520444908_1520438638_1520460207_4_



 threat
	 9873_MAD131330181808_ensemble_1520437955_1520005508_1520572538_1520444908_1520438638_1520460207_6_1520573203.csv.gz 9873_MAD151524181810_ensemble_1520437955_1520005508_1520560823_1520444908_1520438638_1520460207_6_1520562334.csv.gz   (OVER CORR)
	 9869_PureRNN_0991846_0958260_2_0993922.csv.gz 9868_PureRNN_0991846_0958260_2_0994171.csv.gz   (OVER CORR)
	 9873_RNN_Normal_80_20_1081874_1081126_2_1082136.csv.gz 9874_RNN_Normal_70_30_1081874_1081126_2_1082320.csv.gz   (OVER CORR)
	 9875_CNN_w_9876_19lend_2_1574427.csv.gz 9876_CapsultNet_w_9876_28blend_2_1413594.csv.gz   (OVER CORR)
	 9873_RNN_Normal_50_50_1081874_1081126_2_1082929.csv.gz 9874_RNN_Normal_60_40_1081874_1081126_2_1082835.csv.gz   (OVER CORR)
	 9870_RNN_CapNwei_85_15_1080470_0958260_2_1081126.csv.gz 9869_RNN_and_capftc6_15_85_1055020_0958260_2_1055171.csv.gz   (OVER CORR)
	 9869_PureRNN40222216_1520572538_1520444908_1520438638_1520460207_4_1520958260.csv.gz 9870_RNN_CapNwei_85_15_1080470_0958260_2_1081126.csv.gz_986



 insult
	 9869_PureRNN_0991846_0958260_2_0993922.csv.gz 9868_PureRNN_0991846_0958260_2_0994171.csv.gz   (OVER CORR)
	 9873_MAD131330181808_ensemble_1520437955_1520005508_1520572538_1520444908_1520438638_1520460207_6_1520573203.csv.gz 9873_MAD151524181810_ensemble_1520437955_1520005508_1520560823_1520444908_1520438638_1520460207_6_1520562334.csv.gz   (OVER CORR)
	 9873_RNN_Normal_80_20_1081874_1081126_2_1082136.csv.gz 9874_RNN_Normal_70_30_1081874_1081126_2_1082320.csv.gz   (OVER CORR)
	 9873_RNN_Normal_50_50_1081874_1081126_2_1082929.csv.gz 9874_RNN_Normal_60_40_1081874_1081126_2_1082835.csv.gz   (OVER CORR)
	 9875_CNN_w_9876_19lend_2_1574427.csv.gz 9876_CapsultNet_w_9876_28blend_2_1413594.csv.gz   (OVER CORR)
	 9870_RNN_CapNwei_85_15_1080470_0958260_2_1081126.csv.gz 9869_RNN_and_capftc6_15_85_1055020_0958260_2_1055171.csv.gz   (OVER CORR)
	 9872_PureRNN_w_layer2Logreg25_30_165_165_12_1520888020_1520572538_1520444908_1520438638_1520460207__5_1520963719.csv.gz 9872_RNN_w_layer2Logrega



 identity_hate
	 9869_PureRNN_0991846_0958260_2_0993922.csv.gz 9868_PureRNN_0991846_0958260_2_0994171.csv.gz   (OVER CORR)
	 9873_MAD131330181808_ensemble_1520437955_1520005508_1520572538_1520444908_1520438638_1520460207_6_1520573203.csv.gz 9873_MAD151524181810_ensemble_1520437955_1520005508_1520560823_1520444908_1520438638_1520460207_6_1520562334.csv.gz   (OVER CORR)
	 9873_RNN_Normal_80_20_1081874_1081126_2_1082136.csv.gz 9874_RNN_Normal_70_30_1081874_1081126_2_1082320.csv.gz   (OVER CORR)
	 9873_RNN_Normal_50_50_1081874_1081126_2_1082929.csv.gz 9874_RNN_Normal_60_40_1081874_1081126_2_1082835.csv.gz   (OVER CORR)
	 9875_CNN_w_9876_19lend_2_1574427.csv.gz 9876_CapsultNet_w_9876_28blend_2_1413594.csv.gz   (OVER CORR)
	 9870_RNN_CapNwei_85_15_1080470_0958260_2_1081126.csv.gz 9869_RNN_and_capftc6_15_85_1055020_0958260_2_1055171.csv.gz   (OVER CORR)
	 9875_NewRNN_w_9874_55blend_2_1397748.csv.gz 9875_CNN_w_9876_19lend_2_1574427.csv.gz_9876_CapsultNet_w_9876_28blend_2_1413594.csv.gz   (OV