In [1]:
import parser
import sparse_matrix
import predict
import rocchio
from tqdm.auto import tqdm

configs = {
    "k": 1.5,
    "b": 0.5,
    "ka": 100,
    "alpha": 1,
    "beta": 0.75,
    "gamma": 0.15,
    "target": 100,
    "use_rocchio": True,
    "output_path": "../prediction.csv",
    "model_path": "../model",
    "corpus_path": "../CIRB010",
    "title_weight": 1,
    "question_weight": 1,
    "concepts_weight": 1,
    "narrative_weight": 1,
    "query_path": "../queries/query-train.xml",
    "cdn": 1,
    "ctc": 1,
    "cte": 1,
    "cts": 1,
    "unigram_weight": 1,
    "bigram_weight": 1,
    "rocchio_iters": 1,
    "use_cosine": False
}

if __name__ == '__main__':
    # parser.parse_arg(configs)
    fname_to_id, id_to_fname = parser.parse_file_list(configs)
    vocab_to_id, id_to_vocab = parser.parse_vocab_list(configs)
    doc_count = len(fname_to_id)
    inverted_files, gram_to_id, gram_count, id_to_doclen = parser.parse_inverted_file(configs, doc_count)
    configs["gram_count"] = gram_count
    configs["doc_count"] = doc_count
    # Save checkpoint for notebook
    avdl = sum(id_to_doclen.values()) / len(id_to_doclen)
    corpus = {
        "fname_to_id": fname_to_id,
        "id_to_doclen": id_to_doclen,
        "id_to_fname": id_to_fname,
        "vocab_to_id": vocab_to_id,
        "id_to_vocab": id_to_vocab,
        "inverted_files": inverted_files,
        "gram_to_id": gram_to_id,
        "avdl": avdl,
    }
    corpus["sparse"] = sparse_matrix.gen_matrix(corpus, configs)

Reading Inverted Files: 


HBox(children=(FloatProgress(value=0.0, max=1193467.0), HTML(value='')))

In [22]:
import predict
import rocchio

%load_ext autoreload
%autoreload 2


configs["rocchio_iters"] = 1
configs["use_rocchio"] = True
configs["gamma"] = 0
configs["query_path"] = "../queries/query-test.xml"
print("Processing Query")
queries = parser.parse_queries(corpus, configs, configs["query_path"])
sparse_queries = []
for query in queries:
    sparse_queries.append( sparse_matrix.gen_query_vector(query, corpus, configs) )
query_responses = []
for sparse_query in tqdm(sparse_queries):
    query_responses.append( predict.predict_query(sparse_query, corpus, configs) )
print("Rocchio Feedback~~~")
for query in query_responses:
    print(query[:10])
# print("query_responses: ", query_responses)
query_responses = load_csv_response()
if configs["use_rocchio"]:
    print("in rocchio")
    for _ in tqdm(range(configs["rocchio_iters"])):
        print("iterations")
        for i in tqdm(range(len(query_responses))):
            feedback_vec = rocchio.rocchio_feedback(query_responses[i], sparse_queries[i],  corpus, configs)
            print((sparse_queries[i] + feedback_vec).sum())
            response = predict.predict_query(sparse_queries[i] + feedback_vec, corpus, configs) 
            query_responses[i] = response
    print("done rocchio")
print("something done~~~")
predict.process_predictions(query_responses, configs, corpus)
predict.write_predictions(query_responses, queries, configs)
predict.calc_MAP(query_responses, configs)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
Processing Query


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


Rocchio Feedback~~~
[24656, 8301, 8984, 11058, 5926, 5690, 26811, 46598, 45497, 27283]
[25646, 3607, 26747, 21394, 25844, 12327, 3337, 27895, 45768, 21636]
[13591, 13629, 13330, 13189, 21972, 13542, 6794, 10052, 26582, 9506]
[25086, 45822, 28205, 5359, 25953, 10895, 28194, 11520, 22036, 26513]
[28534, 18206, 9402, 17980, 7353, 38893, 18734, 24060, 18202, 28529]
[25240, 26983, 21995, 26347, 27682, 20583, 10942, 45619, 46549, 21691]
[17891, 7851, 18101, 23723, 17745, 20799, 6889, 18551, 24388, 17977]
[13229, 13680, 21581, 13290, 13979, 13203, 13772, 13176, 21802, 13402]
[18566, 18846, 7466, 18166, 19157, 7248, 8739, 8001, 18162, 19386]
[16329, 45253, 16231, 17237, 17196, 7154, 17352, 22608, 17565, 8384]
[15109, 32927, 33011, 21204, 36506, 22058, 28275, 45462, 46398, 12162]
[39119, 43602, 6976, 38806, 15802, 11032, 7366, 43917, 2841, 6480]
[46911, 13402, 25845, 13560, 21722, 29808, 43429, 9368, 31359, 43680]
[25855, 45964, 25488, 45873, 20992, 20943, 28035, 10391, 20833, 20907]
[1518, 23

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

iterations


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))

3826.3433656341313
3874.12464945282
3675.389864767482
3875.4290760463055
3249.7656580642774
3901.1365195109893
3520.182294705661
3624.684188017917
3826.369853066775
3963.3844913012013
3830.334671549326
3656.0778049339906
3599.790628456651
4088.7482356325613
3544.920853433383
3461.1596132384234
3184.5169989830274
3884.4820147020187
3776.0436792910814
3456.391616061106


done rocchio
something done~~~
[0.0, 0.006253006253006252, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
0.0006253006253006253


In [21]:
def load_csv_response():
    fake_fname_to_id = fname_to_id.copy()
    for key in fname_to_id:
        processed_name = key.lower().split("/")[-1]
        fake_fname_to_id[processed_name] = fake_fname_to_id[key]
    
    with open("../predictions/ensemble-0.79728.csv", "r") as f:
        ans = []
        for line_no, line in enumerate(f):
            if line_no == 0 or line_no == 21:
                continue
            ranked_docs = line.split(",")[1].split()
            ans.append([fake_fname_to_id[i] for i in ranked_docs])
    print(ans)
    return ans
load_csv_response()

[[24656, 8984, 8301, 46598, 45497, 5926, 26811, 11058, 27283, 9837, 5690, 24989, 27615, 21930, 23807, 9520, 27839, 46205, 27249, 26288, 8204, 12339, 23279, 10687, 27988, 23839, 6692, 21023, 21016, 46681, 22175, 26016, 27090, 24079, 22781, 25227, 27664, 45740, 20474, 25092, 23487, 23988, 26787, 25116, 20707, 26549, 27450, 45233, 46859, 26808, 27140, 25459, 46948, 24971, 27298, 11504, 25093, 26742, 28093, 26451, 4967, 9997, 27291, 5730, 30147, 46184, 25171, 25419, 46288, 21631, 24036, 45300, 45446, 8635, 25428, 45721, 25965, 22436, 26464, 4396, 22132, 26215, 26913, 17720, 27256, 28265, 10271, 26201, 18313, 27267, 21339, 26003, 20510, 46188, 24516, 20793, 6009, 10821, 26457, 25955], [25646, 26747, 3607, 25844, 21394, 12327, 45768, 11984, 21636, 21834, 25578, 3337, 27125, 27895, 27481, 20950, 27905, 20990, 10125, 10556, 25914, 26431, 21125, 21234, 21623, 21829, 3562, 21037, 26194, 21220, 21200, 26953, 667, 13596, 33923, 25977, 32679, 3532, 28223, 32640, 20916, 33593, 30793, 21533, 3707, 30

[[24656,
  8984,
  8301,
  46598,
  45497,
  5926,
  26811,
  11058,
  27283,
  9837,
  5690,
  24989,
  27615,
  21930,
  23807,
  9520,
  27839,
  46205,
  27249,
  26288,
  8204,
  12339,
  23279,
  10687,
  27988,
  23839,
  6692,
  21023,
  21016,
  46681,
  22175,
  26016,
  27090,
  24079,
  22781,
  25227,
  27664,
  45740,
  20474,
  25092,
  23487,
  23988,
  26787,
  25116,
  20707,
  26549,
  27450,
  45233,
  46859,
  26808,
  27140,
  25459,
  46948,
  24971,
  27298,
  11504,
  25093,
  26742,
  28093,
  26451,
  4967,
  9997,
  27291,
  5730,
  30147,
  46184,
  25171,
  25419,
  46288,
  21631,
  24036,
  45300,
  45446,
  8635,
  25428,
  45721,
  25965,
  22436,
  26464,
  4396,
  22132,
  26215,
  26913,
  17720,
  27256,
  28265,
  10271,
  26201,
  18313,
  27267,
  21339,
  26003,
  20510,
  46188,
  24516,
  20793,
  6009,
  10821,
  26457,
  25955],
 [25646,
  26747,
  3607,
  25844,
  21394,
  12327,
  45768,
  11984,
  21636,
  21834,
  25578,
  3337,
  27125

In [None]:
    [0.7952003565194296, 0.8277743643256356, 0.8127736254964418, 0.8609187109187109, 0.8659613394461879, 0.6254554564545596, 0.7427256030958908, 0.7967283585762662, 0.7071669071669072, 0.6806975827134235]
0.7715402304713453