In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import optuna

def calc_overlap3(set_pred, set_gt):
    """
    Calculates if the overlap between prediction and
    ground truth is enough fora potential True positive
    """
    # Length of each and intersection
    try:
        len_gt = len(set_gt)
        len_pred = len(set_pred)
        inter = len(set_gt & set_pred)
        overlap_1 = inter / len_gt
        overlap_2 = inter/ len_pred
        return overlap_1 >= 0.5 and overlap_2 >= 0.5
    except:  # at least one of the input is NaN
        return False

def score_feedback_comp_micro3(pred_df, gt_df, discourse_type):
    """
    A function that scores for the kaggle
        Student Writing Competition
        
    Uses the steps in the evaluation page here:
        https://www.kaggle.com/c/feedback-prize-2021/overview/evaluation
    """
    gt_df = gt_df.loc[gt_df['discourse_type'] == discourse_type, 
                      ['id', 'predictionstring']].reset_index(drop=True)
    pred_df = pred_df.loc[pred_df['class'] == discourse_type,
                      ['id', 'predictionstring']].reset_index(drop=True)
    pred_df['pred_id'] = pred_df.index
    gt_df['gt_id'] = gt_df.index
    pred_df['predictionstring'] = [set(pred.split(' ')) for pred in pred_df['predictionstring']]
    gt_df['predictionstring'] = [set(pred.split(' ')) for pred in gt_df['predictionstring']]
    
    # Step 1. all ground truths and predictions for a given class are compared.
    joined = pred_df.merge(gt_df,
                           left_on='id',
                           right_on='id',
                           how='outer',
                           suffixes=('_pred','_gt')
                          )
    overlaps = [calc_overlap3(*args) for args in zip(joined.predictionstring_pred, 
                                                     joined.predictionstring_gt)]
    
    # 2. If the overlap between the ground truth and prediction is >= 0.5, 
    # and the overlap between the prediction and the ground truth >= 0.5,
    # the prediction is a match and considered a true positive.
    # If multiple matches exist, the match with the highest pair of overlaps is taken.
    # we don't need to compute the match to compute the score
    TP = joined.loc[overlaps]['gt_id'].nunique()

    # 3. Any unmatched ground truths are false negatives
    # and any unmatched predictions are false positives.
    TPandFP = len(pred_df)
    TPandFN = len(gt_df)
    
    #calc microf1
    my_f1_score = 2*TP / (TPandFP + TPandFN)
    return my_f1_score

def score_feedback_comp3(pred_df, gt_df, return_class_scores=False):
    class_scores = {}
    for discourse_type in gt_df.discourse_type.unique():
        class_score = score_feedback_comp_micro3(pred_df, gt_df, discourse_type)
        class_scores[discourse_type] = class_score
    f1 = np.mean([v for v in class_scores.values()])
    if return_class_scores:
        return f1, class_scores
    return f1


def add_pred(predictstring, weight, threshold):
    predictstring = predictstring.split()
    if len(predictstring) > threshold:
        predictstring = predictstring[:-1*int(len(predictstring)*weight)]
    return " ".join(predictstring)


def opt_weight_threshold(trial, name):
    w = trial.suggest_float('weight_'+name, 0, 1) 
    t = trial.suggest_int('threshold_'+name, 0, 200)
    func = lambda x: add_pred(x, w, t)
    
    score_df = oof_df.copy()
    index = (score_df['class']==name)
    score_df.loc[index, 'predictionstring'] = score_df.loc[index, 'predictionstring'].apply(func)
    score = score_feedback_comp_micro3(score_df, train, name)
    return score



classes = [
    "Lead",
    "Claim",
    "Position",
    "Evidence",
    "Counterclaim",
    "Concluding Statement",
    "Rebuttal"
]

train = pd.read_csv('../input/feedback-prize-2021/train.csv')
oof_df = pd.read_csv('../input/expv2-en-038-xgb-mlp-lstm-fe-fix/preds/df_all.csv')
f1, scores = score_feedback_comp3(oof_df, train, True)
print(f1)
display(scores)
print()

weightclass_dict = {}
for name in classes:
    print(name)
    val = lambda x: opt_weight_threshold(x, name=name)
    study = optuna.create_study(
        direction='maximize', 
        sampler=optuna.samplers.TPESampler(seed=42)
    )
    optuna.logging.disable_default_handler()
    study.optimize(val, n_trials=100)
    display(study.best_trial.number)
    display(study.best_trial.values)
    display(study.best_trial.params)
    print()
    for k, v in study.best_trial.params.items():
        weightclass_dict[k] = v

print()
display(weightclass_dict)

score_df = oof_df.copy()
for name in classes:
    func = lambda x: add_pred(
        x, 
        weightclass_dict[f'weight_{name}'], 
        weightclass_dict[f'threshold_{name}']
    )
    index = (score_df['class']==name)
    score_df.loc[index, 'predictionstring'] = score_df.loc[index, 'predictionstring'].apply(func)

f1, scores = score_feedback_comp3(score_df, train, True)
print()
print(f1)
display(scores)