In [1]:
import pandas as pd
import numpy as np
import nltk
from nltk.tokenize import RegexpTokenizer
from nltk.stem import WordNetLemmatizer
from nltk.corpus import stopwords
from math import log10, sqrt,log

In [98]:
#import dataset
combined_results_T5 = pd.read_csv('WikiHow_sample_all_withsummary.csv')
combined_results_BERT = pd.read_csv('combined_results_BERT_bm25.csv')

In [4]:
def preprocess(documents):
    tokenizer = RegexpTokenizer(r'\w+')
    tokens = tokenizer.tokenize(documents)
    tokens =  [token.lower() for token in tokens if token.isalpha()]
    tokens = remove_stopwords(tokens)
    lemmatizer = WordNetLemmatizer()
    tokens = [lemmatizer.lemmatize(token, pos='v') for token in tokens]
    return tokens

In [5]:
def remove_stopwords(tokens):
    new_tokens = []
    stop_words = set(stopwords.words("english"))
    for token in tokens:
        if token not in stop_words:
            new_tokens.append(token)
    return new_tokens

In [8]:
def get_inverted_index(data):
    n = len(data[data['summary']!='empty'])
    df = data.loc[:n,:]
    inverted_index = {}
    for i in range(len(df)):
        if isinstance(df.loc[i,'summary'], str) == False:
            continue
        tokens = preprocess(df.loc[i,'summary'])
        tokens_dist = nltk.FreqDist(tokens)
        for voc in tokens_dist.keys():
            if voc not in inverted_index.keys():
                inverted_index[voc] = [1, tokens_dist[voc]]
            else:
                inverted_index[voc][0] += 1
                inverted_index[voc][1] += tokens_dist[voc]
    return inverted_index

In [None]:
inverted_index_T5 = get_inverted_index(combined_results_T5)
inverted_index_BERT = get_inverted_index(combined_results_BERT)

In [14]:
def length(data):
    total_len = 0
    for i in range(len(data)):
        if isinstance(data.loc[i,'summary'], str) == False:
            continue
        passage_len = len(preprocess(data.loc[i,'summary']))
        total_len += passage_len
    total_pa = len(data)
    return total_len/total_pa, total_pa

avdl_t5, N_t5 = length(combined_results_T5)
avdl_bert, N_bert = length(combined_results_BERT)

In [16]:
def BM25(data, inverted_index,avdl,N, k1 = 1.2, k2 = 100,b = 0.75):
    scores = np.zeros((len(data),3))
    for i in tqdm(range(len(data))):
        if isinstance(data.loc[i,'summary'], str) == False:
            continue
        tokens_p = preprocess(data.loc[i,'summary'])
        tokens_q = preprocess(data.loc[i,'title'])
        f_p = nltk.FreqDist(tokens_p)
        f_q = nltk.FreqDist(tokens_q)
        dl = len(tokens_p)
        K = k1*((1-b)+b*(dl/avdl))
        bm25 = 0
        for token in f_q.keys():
            if token in inverted_index.keys():
                term1 =  log((N-inverted_index[token][0]+0.5)/(inverted_index[token][0]+0.5))
                term2 = (k1+1)*f_p[token]/(K+f_p[token])
                term3 = (k2+1)*f_q[token]/(k2+f_q[token])
                bm25 += term1 *term2 *term3
        data.loc[i,'bm25'] = bm25
    return data

In [22]:
combined_results_T5 = BM25(combined_results_T5, inverted_index,avdl_t5,N_t5, k1 = 1.2, k2 = 100,b = 0.75)

100%|███████████████████████████████████| 49642/49642 [00:36<00:00, 1374.88it/s]


In [17]:
combined_results_BERT = BM25(combined_results_BERT, inverted_index_BERT,avdl_bert,N_bert, k1 = 1.2, k2 = 100,b = 0.75)

100%|████████████████████████████████████| 44400/44400 [00:52<00:00, 843.13it/s]


In [36]:
#save BM25 scores
combined_results_T5.to_csv('WikiHow_sample_all_withsummary.csv')
combined_results_BERT.to_csv('combined_results_BERT.csv')

In [99]:
#merge BERT and T5 results together
bm25_compare = pd.merge(WikiHow_sample_all, combined_results_BERT, on='title')
bm25_compare = bm25_compare[bm25_compare['summary_length'] != 0]
bm25_compare = bm25_compare.reset_index()

In [81]:
#compare T5 and BERT performance
count_t5 = 0
count_bert = 0
equal = 0
test_count = 0 

for i,  (x,y) in enumerate(zip(bm25_compare['bm25_x'],bm25_compare['bm25_y'])):
    test_count += 1
    if x>y:
        count_t5 +=1
    elif x<y:
        count_bert +=1
    else:
        equal +=1  

In [84]:
equal/len(bm25_compare)

0.24493951371421727

In [86]:
count_t5/len(bm25_compare)

0.31596598395017367

In [87]:
count_bert/len(bm25_compare)

0.43909450233560904