# Evaluating the retrival pipeline

## On Top Level MITRE ATT&CK Techniques

In [None]:
import os
import sys

import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

import torch
from libs.pygaggle.data.relevance import RelevanceExample
from libs.pygaggle.model import StepEvaluator

from libs import resources as res, rank

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

if device == 'cuda':
    cuda_device = 7
    torch.cuda.set_device(cuda_device)

cache_dir = 'cache'

## Load data

In [None]:
sentences = res.load_annotated()
sentences = sentences[sentences.tech_id.str.len() == 5]

sentences.head()

In [None]:
dataset = res.load_mitre_kb()
dataset = dataset[dataset.tech_id.str.len() == 5]

corpus = pd.DataFrame([{
    'tech_id': tech_id,
    'text': name + ' ' + ' '.join(g['text'].values)
} for (tech_id, name), g in dataset.groupby(['tech_id', 'tech_name'])])

In [None]:
queries = sentences[sentences['tech_id'].isin(dataset.tech_id.values)]
len(queries['tech_id'].drop_duplicates())

## Rerank

In [None]:
metrics = ['recall@3', 'recall@5', 'recall@10', 'recall@20', 'recall@40', 'recall@50', 'recall@100', 'mrr']

### Generate Initial Examples

In [None]:
texts, label_map = rank.get_texts(corpus)
queries = rank.get_queries(queries, label_col='tech_id')

examples = [RelevanceExample(
    query,
    texts,
    [(True if label_map[query.metadata['label']] == i else False) for i in range(len(texts))]
) for query in queries]

### First Stage
#### BM25

In [None]:
def stage1_runner(examples):
    bm25_reranker = rank.construct_bm25()
    bm25_eval = StepEvaluator(bm25_reranker, metrics, n_hits=100)
    return bm25_eval.evaluate(examples)


ranker_suffix = f'bm25__100'
stage1_cache_file = f'../../data/cache/top_level__{ranker_suffix}.pkl'
bm25_exmp, bm25_metrics = rank.load_cache_or_run(stage1_cache_file, stage1_runner, examples=examples)

for metric in bm25_metrics:
    print(f'{metric.name} = {metric.value:.5}')

### Second Stage
#### SentSecBert

In [None]:
def stage2_runner(examples, seg_size, stride):
    sentsecbert_sim_reranker = rank.construct_sentsecbert('../../models/SentSecBert_10k_AllDataSplit')
    sentsecbert_eval = StepEvaluator(sentsecbert_sim_reranker, metrics, n_hits=50)
    return sentsecbert_eval.evaluate_by_segments(
        examples,
        seg_size=seg_size,
        stride=stride,
        aggregate_method='max'
    )


seg_size = 14
overlap = 0.25
stride = seg_size - int(seg_size * overlap)

ranker_suffix = f'sentsecbert__50__{seg_size}__{str(overlap).replace(".", "_")}.pkl'
stage2_cache_file = f'../../data/cache/top_level__{ranker_suffix}.pkl'
sentsecbert_exmp, sentsecbert_metrics = rank.load_cache_or_run(stage2_cache_file, stage2_runner,
                                                               examples=bm25_exmp, seg_size=seg_size, stride=stride)

for metric in sentsecbert_metrics:
    print(f'{metric.name} = {metric.value:.5}')

### Third Stage
#### MonoT5

In [None]:
def stage3_runner(examples, seg_size, stride):
    monot5_tram_reranker = rank.construct_monot5('../../models/monot5_AllDataSplit')
    monot5_tram_eval = StepEvaluator(monot5_tram_reranker, metrics, n_hits=10)
    return monot5_tram_eval.evaluate_by_segments(
        examples,
        seg_size=seg_size,
        stride=stride,
        aggregate_method='max'
    )


seg_size = 14
overlap = 0.25
stride = seg_size - int(seg_size * overlap)

ranker_suffix = f'monot5__10__{seg_size}__{str(overlap).replace(".", "_")}.pkl'
stage3_cache_file = f'../../data/cache/top_level__{ranker_suffix}.pkl'
_, monot5_tram_metrics = rank.load_cache_or_run(stage3_cache_file, stage3_runner,
                                                examples=sentsecbert_exmp, seg_size=seg_size, stride=stride)

for metric in monot5_tram_metrics:
    print(f'{metric.name} = {metric.value:.5}')