In [None]:
import pickle
from matplotlib import pyplot as plt

from TextRepresenter import PorterStemmer
from Index import Index
from Weighter import *
from IRModel import *
from EvalIRModel import *
from GridSearch import GridSearch

In [None]:
index = Index("cisi", "cisi/cisi.txt")
#index.indexation()
weighter = WeighterVector(index)
weighter.calculeNorms()

prepend = '_cisi_w2.pickle'

In [None]:
query_text = 'graph exploration'
ps = PorterStemmer()
query = ps.getTextRepresentation(query_text)
eval_irm = EvalIRModel('cisi/cisi.qry', 'cisi/cisi.rel')

# 1. Modèle de langue

In [None]:
language_model = LanguageModel(weighter)

In [None]:
scores, score_absent = language_model.getScores(query, 0.01)
print(scores, score_absent)

In [None]:
language_model.getRanking(query_text, lambd=0.009)[:10]

In [None]:
results = eval_irm.evalModel(language_model, ranking_call = lambda m, text: m.getRanking(text, lambd=1))
results

In [None]:
results = eval_irm.evalModel(language_model, ranking_call = lambda m, text: m.getRanking(text, lambd=0.01))
results

In [None]:
results = eval_irm.evalModel(language_model, ranking_call = lambda m, text: m.getRanking(text, lambd=0))
results

# 2. Modèle BM25

In [None]:
bm25_model = BM25Model(weighter)

In [None]:
scores = bm25_model.getScores(query)
print(scores)

In [None]:
bm25_model.getRanking(query_text)[:10]

In [None]:
results = eval_irm.evalModel(bm25_model, ranking_call = lambda m, text: m.getRanking(text))
results

# 3. Optimisation des paramètres
## 3.1 Language Model

In [None]:
lambda_values = np.geomspace(1e-5, 1, 100)

In [None]:
search = GridSearch(param_a_name="lambd", param_a_values=lambda_values)
results = search.search(language_model, train_prop=0.8, seed=42)
lang_results = np.array(results)

In [None]:
with open('models/lang_results'+prepend, 'wb') as f:
    pickle.dump(lang_results, f, protocol=4)

In [None]:
with open('models/lang_results'+prepend, 'rb') as f:
    lang_results = pickle.load(f)

In [None]:
#lambda_values = np.linspace(0,1,100)
lambda_values = np.geomspace(1e-5, 1, 100)
plt.title("Modéle de langue - GridSearch")
plt.plot(lambda_values, lang_results[:,1])
plt.xlabel("Lambda")
plt.xticks(lambda_values, rotation='vertical');

## 3.2 BM25 Model

In [None]:
param_a_values = np.linspace(1, 2, 20)
param_b_values = list(np.linspace(0.5, 1, 20))

In [None]:
search = GridSearch(param_a_name="k1", param_a_values=param_a_values, param_b_name="b", param_b_values=param_b_values)
results = search.search(bm25_model, train_prop=0.8, seed=42)
bm25_results = np.array(results)
bm25_results

In [None]:
with open('models/bm25_results'+prepend, 'wb') as f:
    pickle.dump(bm25_results, f, protocol=4)

In [None]:
with open('models/bm25_results'+prepend, 'rb') as f:
    bm25_results = pickle.load(f)

In [None]:
bm25_matrix = np.array(bm25_results[:, 1].reshape(20, 20), dtype=np.float)
# k1 : ligne, b : colonnes
idx = np.argmax(bm25_matrix)
print(param_a_values[idx//20], param_b_values[idx%20], np.max(bm25_matrix))

In [None]:
fig, ax = plt.subplots()
im = ax.imshow(bm25_matrix)
# We want to show all ticks...
ax.set_xticks(np.arange(len(param_b_values)))
ax.set_yticks(np.arange(len(param_a_values)))
# ... and label them with the respective list entries
ax.set_xticklabels(['{:.2f}'.format(k1v) for k1v in param_b_values])
ax.set_yticklabels(['{:.2f}'.format(k1v) for k1v in param_a_values])

# Rotate the tick labels and set their alignment.
#plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
#         rotation_mode="anchor")

# Loop over data dimensions and create text annotations.
#for i in range(len(param_a_values)):
#    for j in range(len(param_b_values)):
#        text = ax.text(j, i, bm25_matrix[i, j],
#                       ha="center", va="center", color="w")

ax.set_title("bm25 accuracy")
#fmt = StrMethodFormatter('{x}')
#ax.yaxis.set_major_formatter(fmt)

fig.tight_layout()
plt.show()