In [1]:
import elasticsearch
import elasticsearch.helpers
from elasticsearch_dsl import Search
import json
import os
from scipy import stats

## TREC run

### Basic Functions

In [2]:
es = elasticsearch.Elasticsearch(host='localhost')  # in case you use Docker, the host is 'elasticsearch'

def read_documents(file_name):
    """
    Returns a generator of documents to be indexed by elastic, read from file_name
    """
    with open(file_name, 'r') as documents:
        for line in documents:
            doc_line = json.loads(line)
            if ('index' in doc_line):
                id = doc_line['index']['_id']
            elif ('PMID' in doc_line):
                doc_line['_id'] = id
                yield doc_line
            else:
                raise ValueError('Woops, error in index file')

def create_index(es, index_name, body={}):
    # delete index when it already exists
    es.indices.delete(index=index_name, ignore=[400, 404])
    # create the index 
    es.indices.create(index=index_name, body=body)
                
def index_documents(es, file_name, index_name, body={}):
    create_index(es, index_name, body)
    # bulk index the documents from file_name
    return elasticsearch.helpers.bulk(
        es, 
        read_documents(file_name),
        index=index_name,
        chunk_size=2000,
        request_timeout=30
    )

In [3]:
def read_qrels_file(qrels_file):
    trec_relevant = dict()
    with open(qrels_file, 'r') as qrels:
        for line in qrels:
            (qid, q0, doc_id, rel) = line.strip().split()
            if qid not in trec_relevant:
                trec_relevant[qid] = set()
            if (rel == "1"):
                trec_relevant[qid].add(doc_id)
    return trec_relevant


def read_run_file(run_file):  
    trec_retrieved = dict()
    with open(run_file, 'r') as run:
        for line in run:
            (qid, q0, doc_id, rank, score, tag) = line.strip().split()
            if qid not in trec_retrieved:
                trec_retrieved[qid] = []
            trec_retrieved[qid].append(doc_id) 
    return trec_retrieved
    

def read_eval_files(qrels_file, run_file):
    return read_qrels_file(qrels_file), read_run_file(run_file)

In [4]:
def precision_at_k(relevant, retrieved, k):
    if k==0:
        return 1
    elif (k > 0) and (type(k) == int):
        tp_at_k = [doc for doc in retrieved[:k] if doc in relevant]
        return len(tp_at_k) / len(retrieved[:k])
    else:
        print("k has a wrong value")
        
def interpolated_precision_at_recall_X (relevant, retrieved, X):
    precisions = []
    for i in range(len(retrieved) + 1):
        tp_at_i = [doc for doc in retrieved[:i] if doc in relevant]
        recall_at_i = len(tp_at_i) / len(relevant)
        if recall_at_i >= X:
            precision_at_i = precision_at_k(relevant, retrieved, i)
            precisions.append(precision_at_i)
    if (len(precisions) == 0):
        return 0
    return max(precisions)


def average_precision(relevant, retrieved):
    eleven_points = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
    precisions = []
    for recall in eleven_points:
        precisions.append(interpolated_precision_at_recall_X(relevant, retrieved, recall))
    return sum(precisions) / len(precisions)


def mean_metric(measure, all_relevant, all_retrieved):
    total = 0
    count = 0
    for qid in all_relevant:
        relevant  = all_relevant[qid]
        retrieved = all_retrieved.get(qid, [])
        value = measure(relevant, retrieved)
        total += value
        count += 1
    return "mean " + measure.__name__, total / count

In [5]:
def make_trec_run(es, topics_file_name, run_file_name, run_name="test"):
    with open(run_file_name, 'w') as run_file:
        with open(topics_file_name, 'r') as test_queries:
            for line in test_queries:
                (qid, query) = line.strip().split('\t')
                s = Search(using=es, index='genomics')[:10000].query("multi_match", query=query)
                response = s.execute()
                s = ""
                for rank, hit in enumerate(response.to_dict()['hits']['hits']):
                    line = [str(qid), "Q0", str(hit['_source']['PMID']), str(rank + 1), str(hit['_score']), run_name]
                    run_file.write(" ".join(line))
                    run_file.write("\n")

                    
def trec_eval(qrels_file, run_file):
    def precision_at_1(rel, ret): return precision_at_k(rel, ret, k=1)
    def precision_at_5(rel, ret): return precision_at_k(rel, ret, k=5)
    def precision_at_10(rel, ret): return precision_at_k(rel, ret, k=10)
    def precision_at_50(rel, ret): return precision_at_k(rel, ret, k=50)
    def precision_at_100(rel, ret): return precision_at_k(rel, ret, k=100)
    def precision_at_recall_00(rel, ret): return interpolated_precision_at_recall_X(rel, ret, X=0.0)
    def precision_at_recall_01(rel, ret): return interpolated_precision_at_recall_X(rel, ret, X=0.1)
    def precision_at_recall_02(rel, ret): return interpolated_precision_at_recall_X(rel, ret, X=0.2)
    def precision_at_recall_03(rel, ret): return interpolated_precision_at_recall_X(rel, ret, X=0.3)
    def precision_at_recall_04(rel, ret): return interpolated_precision_at_recall_X(rel, ret, X=0.4)
    def precision_at_recall_05(rel, ret): return interpolated_precision_at_recall_X(rel, ret, X=0.5)
    def precision_at_recall_06(rel, ret): return interpolated_precision_at_recall_X(rel, ret, X=0.6)
    def precision_at_recall_07(rel, ret): return interpolated_precision_at_recall_X(rel, ret, X=0.7)
    def precision_at_recall_08(rel, ret): return interpolated_precision_at_recall_X(rel, ret, X=0.8)
    def precision_at_recall_09(rel, ret): return interpolated_precision_at_recall_X(rel, ret, X=0.9)
    def precision_at_recall_10(rel, ret): return interpolated_precision_at_recall_X(rel, ret, X=1.0)
    def average_precision(rel, ret):
        temp_scores = [x for (_, x) in metrics_scores]
        return sum(temp_scores) / len(temp_scores)

    (all_relevant, all_retrieved) = read_eval_files(qrels_file, run_file)
    
    unknown_qids = set(all_retrieved.keys()).difference(all_relevant.keys())
    if len(unknown_qids) > 0:
        raise ValueError("Unknown qids in run: {}".format(sorted(list(unknown_qids))))

    metrics = [
        precision_at_recall_00,
        precision_at_recall_01,
        precision_at_recall_02,
        precision_at_recall_03,
        precision_at_recall_04,
        precision_at_recall_05,
        precision_at_recall_06,
        precision_at_recall_07,
        precision_at_recall_08,
        precision_at_recall_09,
        precision_at_recall_10,
        average_precision,
        precision_at_1,
        precision_at_5,
        precision_at_10,
        precision_at_50,
        precision_at_100
    ]   
    
    metrics_scores = []
    for metric in metrics:
        metrics_scores.append(mean_metric(metric, all_relevant, all_retrieved))

    return metrics_scores


def print_trec_eval(qrels_file, run_file):
    results = trec_eval(qrels_file, run_file)
    print("Results for {}".format(run_file))
    for (metric, score) in results:
        print("{:<30} {:.4}".format(metric, score))

### Rebuilt BM25

In [6]:
rebuilt_bm25 = {
    "settings" : {
        "number_of_shards" : 1,
        "index" : {
            "similarity" : {
                "my_similarity" : {
                    "type": "BM25",
                    "k1": 1.2,
                    "b": 0.75
                }
            }
        },
        "analysis": {
            "filter": {
                "english_stop": {
                    "type": "stop",
                    "stopwords": "_english_" 
                },
                "english_stemmer": {
                    "type": "stemmer",
                    "language": "english"
                },
            },
            "analyzer": {
                "my_analyzer": {
                    "tokenizer": "standard",
                    "filter": [
                        "lowercase",
                        "english_stop",
                        "english_stemmer"
                    ]
                }
            },
        },
    },
    "mappings": {
        "properties": {
            "AB": {
                "type": "text",
                "copy_to": "all"
            },
            "TI": {
                "type": "text",
                "copy_to": "all"
            },
            "all": {
                "type": "text",
                "similarity": "my_similarity",
                "analyzer":"my_analyzer"
            }
        }
    }
}

index_documents(es, 'data/trec-medline.json', 'genomics', body=rebuilt_bm25)
make_trec_run(es, 'data/training-queries-simple.txt', 'project_rebuilt_bm25.run', run_name='project01')

### Simple BM25

In [7]:
bm25 = {
    "settings" : {
        "number_of_shards" : 1,
        "index" : {
            "similarity" : {
                "my_similarity" : {
                    "type": "BM25",
                    "k1": 1.2,
                    "b": 0.75
                }
            }
        },
    },
    "mappings": {
        "properties": {
            "AB": {
                "type": "text",
                "copy_to": "all"
            },
            "TI": {
                "type": "text",
                "copy_to": "all"
            },
            "all": {
                "type": "text",
                "similarity": "my_similarity"
            }
        }
    }
}

index_documents(es, 'data/trec-medline.json', 'genomics', body=bm25)
make_trec_run(es, 'data/training-queries-simple.txt', 'project_bm25.run', run_name='project01')

## Comparison

In [12]:
def is_signiﬁcance(array_model1, array_model2):
    if len(array_model1) != 0 and len(array_model2) != 0:
        alpha = 0.05
        t, p_2tailed = stats.ttest_rel(array_model1, array_model2)
        p_1tailed = p_2tailed / 2
        print("n=", len(array_model1), " t=", format(t, '0.4f'), " p=", format(p_1tailed, '0.8f'))
        if p_1tailed < alpha and t < 0:
            print("significant，model 2 better")
        elif p_1tailed < alpha and t > 0:
            print("significant，model 1 better")
        elif p_1tailed >= alpha and t < 0:
            print("not significant, model 2 seems better")
        else:
            print("not significant, model 1 seems better")

            
def get_scores(results):
    return [x for (_, x) in results]


if not os.path.exists("scores.json"):
    # save the results of trec_eval because it's really time consuming.
    project_rebuilt_bm25_results = trec_eval('data/training-qrels.txt', 'project_rebuilt_bm25.run')
    project_bm25_results = trec_eval('data/training-qrels.txt', 'project_bm25.run')
    
    project_rebuilt_bm25_scores = get_scores(project_rebuilt_bm25_results)
    project_bm25_scores = get_scores(project_bm25_results)
    scores = {
        "rebuilt_bm25": project_rebuilt_bm25_scores,
        "bm25": project_bm25_scores,
    }
    jsObj = json.dumps(scores)
    with open("scores.json", "w") as f:
        f.write(jsObj)
else:
    with open('scores.json', 'r') as f:
        scores = json.load(f)
        project_rebuilt_bm25_scores = scores['rebuilt_bm25']
        project_bm25_scores = scores['bm25']
        
print("Let model 1 is rebuilt bm25, model 2 is simple bm25:")
is_signiﬁcance(project_rebuilt_bm25_scores, project_bm25_scores)

Let model 1 is rebuilt bm25, model 2 is simple bm25:
n= 17  t= 3.2491  p= 0.00251548
significant，model 1 better
