In [None]:
def full_evaluate(args, corpus, tokenized_dataset):
    """
    INPUT:
        saved_model_path: models are saved after each epoch.
        tokenized_dataset:
            tokenized_dataset['val_for_tuning']
            tokenized_dataset['test']
    """
    label_metadata = "conversation_has_personal_attack" if args.corpus_name == "wikiconv" else "has_removed_comment"
    utt_label_metadata = "comment_has_personal_attack" if args.corpus_name == "wikiconv" else None
    
    # Loop through all saved models to find the best model on val_for_tuning
    # Evaluate the best model on test
    config = AutoConfig.from_pretrained(args.model_name_or_path)

    checkpoints = os.listdir(args.output_dir)
    best_val_accuracy = 0
    for cp in checkpoints:
        full_model_path = os.path.join(args.output_dir, cp)
        finetuned_model = AutoModelForSequenceClassification.from_pretrained(full_model_path)
        val_scores = evaluateDataset(tokenized_dataset["val_for_tuning"], finetuned_model, "cuda")
        # for each CONVERSATION, whether or not it triggers will be effectively determined by what the highest score it ever got was
        highest_convo_scores = {c.id: -1 for c in corpus.iter_conversations(lambda convo: convo.meta['split']=="val")}
        for utt_id in val_scores.index:
            parent_convo = corpus.get_utterance(utt_id).get_conversation()
            utt_score = val_scores.loc[utt_id].score
            if utt_score > highest_convo_scores[parent_convo.id]:
                highest_convo_scores[parent_convo.id] = utt_score
        val_convo_ids = [c.id for c in corpus.iter_conversations(lambda convo: convo.meta['split'] == 'val')]
        val_labels = np.asarray([int(corpus.get_conversation(c).meta[label_metadata]) for c in val_convo_ids])
        val_scores = np.asarray([highest_convo_scores[c] for c in val_convo_ids])
        
        # use scikit learn to find candidate threshold cutoffs
        _, _, thresholds = roc_curve(val_labels, val_scores)
        accs = [acc_with_threshold(val_labels, val_scores, t) for t in thresholds]
        
        best_acc_idx = np.argmax(accs)
        if accs[best_acc_idx] > best_val_accuracy:
            best_val_accuracy = accs[best_acc_idx]
            best_threshold = thresholds[best_acc_idx]
            best_model = finetuned_model
    
    forecasts_df = evaluateDataset(tokenized_dataset["val_for_tuning"], best_model, "cuda", threshold=best_threshold)
    prediction_file = os.path.join(args.output_dir, "pred_val.csv")
    forecasts_df.to_csv(prediction_file)
    
    forecasts_df = evaluateDataset(tokenized_dataset["test"], best_model, "cuda", threshold=best_threshold)
    prediction_file = os.path.join(args.output_dir, "pred_test.csv")
    forecasts_df.to_csv(prediction_file)
    # We will add a metadata entry to each test-set utterance signifying whether, at the time
    # that CRAFT saw the context *up to and including* that utterance, CRAFT forecasted the
    # conversation would derail. Note that in datasets where the actual toxic comment is
    # included (such as wikiconv), we explicitly do not show that comment to CRAFT (since
    # that would be cheating!), so that comment will not have an associated forecast.
    for convo in corpus.iter_conversations():
        # only consider test set conversations (we did not make predictions for the other ones)
        if convo.meta['split'] == "test":
            for utt in convo.iter_utterances():
                if utt.id in forecasts_df.index:
                    utt.meta['forecast_score'] = forecasts_df.loc[utt.id].score
    
    conversational_forecasts_df = {
        "convo_id": [],
        "label": [],
        "score": [],
        "prediction": []
    }

    for convo in corpus.iter_conversations():
        if convo.meta['split'] == "test":
            conversational_forecasts_df['convo_id'].append(convo.id)
            conversational_forecasts_df['label'].append(int(convo.meta[label_metadata]))
            forecast_scores = [utt.meta['forecast_score'] for utt in convo.iter_utterances() if 'forecast_score' in utt.meta]
            conversational_forecasts_df['score'] = np.max(forecast_scores)
            conversational_forecasts_df['prediction'].append(int(np.max(forecast_scores) > best_threshold))

    conversational_forecasts_df = pd.DataFrame(conversational_forecasts_df).set_index("convo_id")
    test_labels = conversational_forecasts_df.label
    test_preds = conversational_forecasts_df.prediction
    test_acc = (test_labels == test_preds).mean()
    

    tp = ((test_labels==1)&(test_preds==1)).sum()
    fp = ((test_labels==0)&(test_preds==1)).sum()
    tn = ((test_labels==0)&(test_preds==0)).sum()
    fn = ((test_labels==1)&(test_preds==0)).sum()
    test_precision = tp / (tp + fp)
    test_recall = tp / (tp + fn)
    test_fpr = fp / (fp + tn)
    test_f1 = 2 / (((tp + fp) / tp) + ((tp + fn) / tp))

    result_dict = {'accuracy': test_acc, 
                    'precision': test_precision,
                    'recall': test_recall,
                    'f1': test_f1, 
                    'false positive rate': test_fpr}
    result_file = os.path.join(args.output_dir, "result.json")
    with open(result_file, 'w') as f:
        json.dump(result_dict, f)
    return