In [1]:
import json
import nltk
import math
import numpy as np

### 将文本加载为json格式

In [2]:
input_file = 'data/tweets.txt'
with open(input_file, 'r') as f:
    lines = f.readlines()
items = [json.loads(x) for x in lines]
tweets = [x['text'] for x in items]
idx2id = [x['tweetId'] for x in items]
N = len(tweets)

### 定义预处理类
* 大写转小写
* 分词
* 去除标点符号和停用词

In [3]:
class Preprocess:
    def __init__(self):
        self.punctuations = [',',':','_','!','\"','*','>','<','@','~','-','(',')','%','=','\\','^','&','|','#','$','[',']','+',':','#','|'] 
        self.stop_words = set(nltk.corpus.stopwords.words('english'))
    def __call__(self,text,query=False):
        text = text.lower()
        text = nltk.word_tokenize(text)
        text = [x for x in text if x not in self.punctuations and x not in self.stop_words]
        return text

In [4]:
preprocess = Preprocess()
tokens = [preprocess(x) for x in tweets] # 每个doc的分词
length = [nltk.FreqDist(x) for x in tokens] # 每个doc的词频统计信息
length = [list(x.values()) for x in length] # 每个doc的tf向量
length = [math.sqrt(sum([tf*tf for tf in x])) for x in length] # 每个doc的tf向量的L2范数

### 统计词频

In [5]:
dictionary = {}
# posting 包括文档编号和tf
for i, token in enumerate(tokens):
    term_freq = nltk.FreqDist(token)
    for term, freq in term_freq.items():
        if term in dictionary:
            dictionary[term].add((i,freq))
        else:
            dictionary[term] = {(i,freq)}
# 按tf由大到小排序
for k,v in dictionary.items():
    dictionary[k] = sorted(list(dictionary[k]),key=lambda x:(x[1],x[0]),reverse=True)

def get_postings(term):
    if term in dictionary:
        return dictionary[term]
    else:
        return []

### 定义查询运算

In [6]:
def query_parse(query):
    query = query.lower()
    tokens = nltk.word_tokenize(query)
    postings = set()
    for token in tokens:
        postings.update(get_postings(token))
    postings = list(postings)
    postings.sort(key=lambda x:(x[1],x[0]),reverse=True)
    return postings

# 带红色强调字体输出
def toRed( s ):
    return "%c[31;2m%s%c[0m"%('\033', s, '\033')
def print_with_emphasize(line, Q):
    Q = Q.lower()
    line = nltk.word_tokenize(line)
    s_line = nltk.word_tokenize(Q)
    to_be_print = ''
    for l in line:
        if l.lower() in s_line:
            to_be_print += toRed(l) + ' '
        else:
            to_be_print += l + ' '
    print(to_be_print+'\n')

### 定义top K运算

In [7]:
TF = {
    'n':lambda tf:tf,
    'l':lambda tf:[1+math.log(x) if x>=1 else 0 for x in tf],
    'a':lambda tf:[0.5+0.5*x/max(tf) if max(tf)>0 else 0 for x in tf],
    'b':lambda tf:[1 if x>0 else 0 for x in tf],
    # TODO 未实现
#     'L':lambda tf:[(1+math.log(x))/(1+math.log(sum(tf)/len(tf))) for x in tf]
}

DF = {
    'n':lambda df:[1]*len(df),
    't':lambda df:[math.log(N/x) if x>=1 else 0 for x in df],
    'p':lambda df:[max(0,math.log((N-x)/x)) if x>=1 else 0 for x in df]
}

NORM = {
    'n':lambda tfdf:[1]*len(tfdf),
    'c':lambda tfdf:[1/math.sqrt(sum([w*w for w in tfdf])) if sum(tfdf)>0 else 0]*len(tfdf),
    
    # TODO 未实现
#     'u':lambda tfdf:1,
#     'b':lambda tfdf:1
}

def compute_wtq(terms, query_term_freq, notation='ltn'):
    tf = [query_term_freq[term] for term in terms]
    tf = TF[notation[0]](tf)
    df = [len(get_postings(term)) for term in terms]
    df = DF[notation[1]](df)
    tfdf = [tf[i]*df[i] for i in range(len(tf))]
    norm = NORM[notation[2]](tfdf)
    w = [tfdf[i] * norm[i] for i in range(len(tfdf))]
    return w

def compute_wtd(tf,df,notation='lnc'):
    tf = TF[notation[0]]([tf])[0]
    df = DF[notation[1]]([df])[0]
    return tf*df

def top_k(query,k,notation='lnc.ltn'):
    notation = notation.split('.')
    query_tokens = preprocess(query)
    term_freq = nltk.FreqDist(query_tokens)
    query_terms = list(term_freq.keys())
    score = [0]*N
    wtq = compute_wtq(query_terms, term_freq, notation[1])
    for i in range(len(query_terms)):
        postings = get_postings(query_terms[i])
        for posting in postings:
            # wtd未normalize，在循环外normalize
            wtd = compute_wtd(posting[1],len(postings),notation[0])
            score[posting[0]] += wtq[i]*wtd
    
    # document normalization 只实现n和c
    if notation[0][2] == 'c':
        score = [score[i]/length[i] for i in range(N)]
    
    # 修正k，只返回相关结果
    k_correct = len([x for x in score if x > 0])
    if k > k_correct:
        k = k_correct
        
    score = np.array(score)
    order = list(np.argsort(score))
    order.reverse()
    results = [tweets[i] for i in order[:k]]
    ids = [idx2id[i] for i in order[:k]]
    return results, ids

In [8]:
def display_query_result(query, k, notation='lnc.ltn'):
    results, ids = top_k(query,k,notation)
    print("查询%s, 返回前%s条结果(%s)：\n"%(toRed(query), toRed(str(len(results))), toRed(notation)))
    for idx, result in enumerate(results):
        print("tweet id: {}".format(ids[idx]))
        print_with_emphasize(result,query)

## metrics

In [9]:
def generate_tweetid_gain(file_name):
    qrels_dict = {}
    with open(file_name, 'r', errors='ignore') as f:
        for line in f:
            ele = line.strip().split(' ')
            if ele[0] not in qrels_dict:
                qrels_dict[ele[0]] = {}
            # here we want the gain of doc_id in qrels_dict > 0,
            # so it's sorted values can be IDCG groundtruth
            if int(ele[3]) > 0:
                qrels_dict[ele[0]][ele[2]] = int(ele[3])
    return qrels_dict


def read_tweetid_test(file_name):
    # input file format
    # query_id doc_id
    # query_id doc_id
    # query_id doc_id
    # ...
    test_dict = {}
    with open(file_name, 'r', errors='ignore') as f:
        for line in f:
            ele = line.strip().split(' ')
            if ele[0] not in test_dict:
                test_dict[ele[0]] = []
            test_dict[ele[0]].append(ele[1])
    return test_dict


def MAP_eval(qrels_dict, test_dict, k=100):
    AP_result = []
    for query in qrels_dict:
        test_result = test_dict[query]
        true_list = set(qrels_dict[query].keys())
        # print(len(true_list))
        # length_use = min(k, len(test_result), len(true_list))
        length_use = min(k, len(test_result))
        if length_use <= 0:
            print('query ', query, ' not found test list')
            return []
        P_result = []
        i = 0
        i_retrieval_true = 0
        for doc_id in test_result[0: length_use]:
            i += 1
            if doc_id in true_list:
                i_retrieval_true += 1
                P_result.append(i_retrieval_true / i)
                # print(i_retrieval_true / i)
        if P_result:
            AP = np.sum(P_result) / len(true_list)
            # print('query:', query, ',AP:', AP) # 打印每个样本的分数
            AP_result.append(AP)
        else:
            # print('query:', query, ' not found a true value') # 打印每个样本的分数
            AP_result.append(0)
    return np.mean(AP_result)


def NDCG_eval(qrels_dict, test_dict, k=100):
    NDCG_result = []
    for query in qrels_dict:
        test_result = test_dict[query]
        # calculate DCG just need to know the gains of groundtruth
        # that is [2,2,2,1,1,1]
        true_list = list(qrels_dict[query].values())
        true_list = sorted(true_list, reverse=True)
        i = 1
        DCG = 0.0
        IDCG = 0.0
        # maybe k is bigger than arr length
        length_use = min(k, len(test_result), len(true_list))
        if length_use <= 0:
            print('query ', query, ' not found test list')
            return []
        for doc_id in test_result[0: length_use]:
            i += 1
            rel = qrels_dict[query].get(doc_id, 0)
            DCG += (pow(2, rel) - 1) / math.log(i, 2)
            IDCG += (pow(2, true_list[i - 2]) - 1) / math.log(i, 2)
        NDCG = DCG / IDCG
        # print('query', query, ', NDCG: ', NDCG) # 打印每个样本的分数
        NDCG_result.append(NDCG)
    return np.mean(NDCG_result)


def MRR_eval(qrels_dict, test_dict, k=100):
    RR_result = []
    for query in qrels_dict:
        true_list = sorted(list(qrels_dict[query].items()), key=lambda x: x[1], reverse=True)
        true_list = [x[0] for x in true_list]
        test_result = test_dict[query]
        length_use = min(k, len(test_result))
        if length_use <= 0:
            print('query ', query, ' not found test list')
            return []
        else:
            k_rank = 0
            for idx, doc_id in enumerate(test_result[:length_use]):
                if doc_id in true_list:
                    k_rank = idx + 1
                    break
            if k_rank > 0:
                RR_result.append(1 / k_rank)
            else:
                RR_result.append(0)
        # print('query', query, ', RR: ', RR_result[-1]) # 打印每个样本的分数
    return np.mean(RR_result)

class Evaluation:
    def __init__(self, k=100):
        self.k = k
        # query relevance file
        self.file_qrels_path = 'data/qrels.txt'
        # qrels_dict = {query_id:{doc_id:gain, doc_id:gain, ...}, ...}
        self.qrels_dict = generate_tweetid_gain(self.file_qrels_path)
        # ur result, format is in function read_tweetid_test, or u can write by ur own
    
    def __call__(self,file_test_path):
        test_dict = read_tweetid_test(file_test_path)
        print(file_test_path[-11:-4]+': ')
        MAP = MAP_eval(self.qrels_dict, test_dict,self.k)
        print('MAP', ' = ', MAP, sep='')
        MRR = MRR_eval(self.qrels_dict, test_dict, self.k)
        print('MRR', ' = ', MRR, sep='')
        NDCG = NDCG_eval(self.qrels_dict, test_dict, self.k)
        print('NDCG', ' = ', NDCG, sep='')
        return MAP, MRR, NDCG

## 加载问题关键词

In [10]:
queries = dict()
with open('data/topics.desc.MB171-225.txt','r') as f:
    lines = f.readlines()
    for idx,line in enumerate(lines):
        if line[:5] == '<num>':
            num = line.split()[2][-3:]
            query = lines[idx+1][8:-10]
            queries[num] = query

## 组合出所有已经实现的SMART notation

In [11]:
half_notations = []
for tf in TF:
    for df in DF:
        for norm in NORM:
            half_notations.append(tf+df+norm)
notations = []
for doc_notation in half_notations:
    for query_notation in half_notations:
        notations.append(doc_notation+'.'+query_notation)

## 输出所有SMART notation的结果

In [12]:
k = 100
for notation in notations:
    with open('result/my_result_'+notation+'.txt', 'w') as f:
        for query_id in queries:
            results, ids = top_k(queries[query_id], k, notation)
            for doc_id in ids:
                f.write('{} {}\n'.format(query_id, doc_id))

## 评测所有SMART notation的结果，并分别保存三项指标最高的SMART notation

In [13]:
metric = Evaluation()

In [14]:
MAP_best = 0
MAP_best_notation = ''
MRR_best = 0
MRR_best_notation = ''
NDCG_best = 0
NDCG_best_notation = ''
for notation in notations:
    f = 'result/my_result_'+notation+'.txt'
    MAP,MRR,NDCG = metric(f)
    if MAP > MAP_best:
        MAP_best = MAP
        MAP_best_notation = notation
    if MRR > MRR_best:
        MRR_best = MRR
        MRR_best_notation = notation
    if NDCG > NDCG_best:
        NDCG_best = NDCG
        NDCG_best_notation = notation

nnn.nnn: 
MAP = 0.462935118027406
MRR = 0.9312801484230055
NDCG = 0.6591187301857276
nnn.nnc: 
MAP = 0.462935118027406
MRR = 0.9312801484230055
NDCG = 0.6591187301857276
nnn.ntn: 
MAP = 0.5193173906812537
MRR = 0.956060606060606
NDCG = 0.7035308065001947
nnn.ntc: 
MAP = 0.5193173906812537
MRR = 0.956060606060606
NDCG = 0.7035308065001947
nnn.npn: 
MAP = 0.5202911054581345
MRR = 0.956060606060606
NDCG = 0.7043657469797785
nnn.npc: 
MAP = 0.5202911054581345
MRR = 0.956060606060606
NDCG = 0.7043657469797785
nnn.lnn: 
MAP = 0.462935118027406
MRR = 0.9312801484230055
NDCG = 0.6591187301857276
nnn.lnc: 
MAP = 0.462935118027406
MRR = 0.9312801484230055
NDCG = 0.6591187301857276
nnn.ltn: 
MAP = 0.5193173906812537
MRR = 0.956060606060606
NDCG = 0.7035308065001947
nnn.ltc: 
MAP = 0.5193173906812537
MRR = 0.956060606060606
NDCG = 0.7035308065001947
nnn.lpn: 
MAP = 0.5202911054581345
MRR = 0.956060606060606
NDCG = 0.7043657469797785
nnn.lpc: 
MAP = 0.5202911054581345
MRR = 0.956060606060606
NDCG =

NDCG = 0.7027292854818217
npn.atc: 
MAP = 0.5213132392375933
MRR = 0.92987012987013
NDCG = 0.7027292854818217
npn.apn: 
MAP = 0.5212432319844882
MRR = 0.92987012987013
NDCG = 0.7026174066539839
npn.apc: 
MAP = 0.5212432319844882
MRR = 0.92987012987013
NDCG = 0.7026174066539839
npn.bnn: 
MAP = 0.5202911054581345
MRR = 0.956060606060606
NDCG = 0.7043657469797785
npn.bnc: 
MAP = 0.5202911054581345
MRR = 0.956060606060606
NDCG = 0.7043657469797785
npn.btn: 
MAP = 0.5213132392375933
MRR = 0.92987012987013
NDCG = 0.7027292854818217
npn.btc: 
MAP = 0.5213132392375933
MRR = 0.92987012987013
NDCG = 0.7027292854818217
npn.bpn: 
MAP = 0.5212432319844882
MRR = 0.92987012987013
NDCG = 0.7026174066539839
npn.bpc: 
MAP = 0.5212432319844882
MRR = 0.92987012987013
NDCG = 0.7026174066539839
npc.nnn: 
MAP = 0.5401921145492179
MRR = 0.9878787878787879
NDCG = 0.7178674548707963
npc.nnc: 
MAP = 0.5401921145492179
MRR = 0.9878787878787879
NDCG = 0.7178674548707963
npc.ntn: 
MAP = 0.5424076574915945
MRR = 0.9

ltc.ntn: 
MAP = 0.5417580318835811
MRR = 0.9662337662337661
NDCG = 0.7141108362401754
ltc.ntc: 
MAP = 0.5417580318835811
MRR = 0.9662337662337661
NDCG = 0.7141108362401754
ltc.npn: 
MAP = 0.541831824467046
MRR = 0.9662337662337661
NDCG = 0.714045905742791
ltc.npc: 
MAP = 0.541831824467046
MRR = 0.9662337662337661
NDCG = 0.714045905742791
ltc.lnn: 
MAP = 0.5421879440165703
MRR = 0.990909090909091
NDCG = 0.7178485955905279
ltc.lnc: 
MAP = 0.5421879440165703
MRR = 0.990909090909091
NDCG = 0.7178485955905279
ltc.ltn: 
MAP = 0.5417580318835811
MRR = 0.9662337662337661
NDCG = 0.7141108362401754
ltc.ltc: 
MAP = 0.5417580318835811
MRR = 0.9662337662337661
NDCG = 0.7141108362401754
ltc.lpn: 
MAP = 0.541831824467046
MRR = 0.9662337662337661
NDCG = 0.714045905742791
ltc.lpc: 
MAP = 0.541831824467046
MRR = 0.9662337662337661
NDCG = 0.714045905742791
ltc.ann: 
MAP = 0.5421879440165703
MRR = 0.990909090909091
NDCG = 0.7178485955905279
ltc.anc: 
MAP = 0.5421879440165703
MRR = 0.990909090909091
NDCG =

NDCG = 0.7056515798227759
anc.atc: 
MAP = 0.5319334258195053
MRR = 0.9696969696969697
NDCG = 0.7056515798227759
anc.apn: 
MAP = 0.5320692017890102
MRR = 0.9696969696969697
NDCG = 0.7057274284283247
anc.apc: 
MAP = 0.5320692017890102
MRR = 0.9696969696969697
NDCG = 0.7057274284283247
anc.bnn: 
MAP = 0.5101195353532912
MRR = 0.9606060606060607
NDCG = 0.6866171409596172
anc.bnc: 
MAP = 0.5097462775227387
MRR = 0.9606060606060607
NDCG = 0.6863937845067437
anc.btn: 
MAP = 0.5319334258195053
MRR = 0.9696969696969697
NDCG = 0.7056515798227759
anc.btc: 
MAP = 0.5319334258195053
MRR = 0.9696969696969697
NDCG = 0.7056515798227759
anc.bpn: 
MAP = 0.5320692017890102
MRR = 0.9696969696969697
NDCG = 0.7057274284283247
anc.bpc: 
MAP = 0.5320692017890102
MRR = 0.9696969696969697
NDCG = 0.7057274284283247
atn.nnn: 
MAP = 0.5444353423293059
MRR = 0.9698701298701299
NDCG = 0.7246718992202057
atn.nnc: 
MAP = 0.5444353423293059
MRR = 0.9698701298701299
NDCG = 0.7246718992202057
atn.ntn: 
MAP = 0.5314732472

MRR = 0.9698701298701299
NDCG = 0.7246718992202057
bnn.npn: 
MAP = 0.5450475431272652
MRR = 0.9698701298701299
NDCG = 0.7246718992202057
bnn.npc: 
MAP = 0.5450475431272652
MRR = 0.9698701298701299
NDCG = 0.7246718992202057
bnn.lnn: 
MAP = 0.4959948379192071
MRR = 0.9627272727272728
NDCG = 0.6950880411154406
bnn.lnc: 
MAP = 0.4959948379192071
MRR = 0.9627272727272728
NDCG = 0.6950880411154406
bnn.ltn: 
MAP = 0.5444353423293059
MRR = 0.9698701298701299
NDCG = 0.7246718992202057
bnn.ltc: 
MAP = 0.5444353423293059
MRR = 0.9698701298701299
NDCG = 0.7246718992202057
bnn.lpn: 
MAP = 0.5450475431272652
MRR = 0.9698701298701299
NDCG = 0.7246718992202057
bnn.lpc: 
MAP = 0.5450475431272652
MRR = 0.9698701298701299
NDCG = 0.7246718992202057
bnn.ann: 
MAP = 0.4959948379192071
MRR = 0.9627272727272728
NDCG = 0.6950880411154406
bnn.anc: 
MAP = 0.4959948379192071
MRR = 0.9627272727272728
NDCG = 0.6950880411154406
bnn.atn: 
MAP = 0.5444353423293059
MRR = 0.9698701298701299
NDCG = 0.7246718992202057
bnn

NDCG = 0.7084684247062444
bpn.atc: 
MAP = 0.5314732472070446
MRR = 0.9698701298701299
NDCG = 0.7084684247062444
bpn.apn: 
MAP = 0.53134019244983
MRR = 0.9698701298701299
NDCG = 0.7077008661245426
bpn.apc: 
MAP = 0.53134019244983
MRR = 0.9698701298701299
NDCG = 0.7077008661245426
bpn.bnn: 
MAP = 0.5450475431272652
MRR = 0.9698701298701299
NDCG = 0.7246718992202057
bpn.bnc: 
MAP = 0.5450475431272652
MRR = 0.9698701298701299
NDCG = 0.7246718992202057
bpn.btn: 
MAP = 0.5314732472070446
MRR = 0.9698701298701299
NDCG = 0.7084684247062444
bpn.btc: 
MAP = 0.5314732472070446
MRR = 0.9698701298701299
NDCG = 0.7084684247062444
bpn.bpn: 
MAP = 0.53134019244983
MRR = 0.9698701298701299
NDCG = 0.7077008661245426
bpn.bpc: 
MAP = 0.53134019244983
MRR = 0.9698701298701299
NDCG = 0.7077008661245426
bpc.nnn: 
MAP = 0.5320692017890102
MRR = 0.9696969696969697
NDCG = 0.7057274284283247
bpc.nnc: 
MAP = 0.5320692017890102
MRR = 0.9696969696969697
NDCG = 0.7057274284283247
bpc.ntn: 
MAP = 0.529056839925935
MR

In [16]:
for best_notation in [MAP_best_notation,MRR_best_notation,NDCG_best_notation]:
    metric('result/my_result_'+best_notation+'.txt')

ann.npn: 
MAP = 0.5450475431272652
MRR = 0.9698701298701299
NDCG = 0.7246718992202057
lnc.ntn: 
MAP = 0.5421879440165703
MRR = 0.990909090909091
NDCG = 0.7178485955905279
ann.ntn: 
MAP = 0.5444353423293059
MRR = 0.9698701298701299
NDCG = 0.7246718992202057
