# Notebook for training information retrieval models

### Import packages

In [37]:
# Import packages
from pyspark.sql import SparkSession
from pyspark import SQLContext
from pyspark.sql.functions import size, explode, col
from pyspark.ml import Pipeline

from nltk.corpus import stopwords
from gensim.parsing.preprocessing import STOPWORDS as gensim_words
import spacy
sp = spacy.load('en_core_web_sm')

from sparknlp.base import Finisher, DocumentAssembler
from sparknlp.annotator import Tokenizer, Normalizer, LemmatizerModel, StopWordsCleaner

import matplotlib.pyplot as plt

In [38]:
nltk_stopwords = set(stopwords.words('english')) \
                    .union(set(stopwords.words('german'))) \
                    .union(set(stopwords.words('french')))
gensim_stopwords = set(gensim_words)
spacy_stopwords = sp.Defaults.stop_words
# https://countwordsfree.com/stopwords
cwf_stopwords = set(line.strip() for line in open('stop_words.txt'))

all_stopwords = list( nltk_stopwords \
                        .union(gensim_stopwords) \
                        .union(spacy_stopwords) \
                        .union(cwf_stopwords) )

### Create Spark Context and SQL Context

In [39]:
# Start spark session configured for spark nlp
spark = SparkSession.builder \
        .master('local[*]') \
        .appName('SDDM') \
        .config('spark.driver.memory', '64g') \
        .config('spark.executor.memory', '32g') \
        .config('spark.executor.cores', '8') \
        .config('spark.jars.packages', 'com.johnsnowlabs.nlp:spark-nlp_2.11:2.5.0') \
        .getOrCreate()
print("Created a SparkSession")
sc = spark.sparkContext
print("Created a SparkContext")
sqlContext = SQLContext(sc)
print("Created a SQLContext")

# .config('spark.memory.fraction', '0.8') \

Created a SparkSession
Created a SparkContext
Created a SQLContext


### Load the data into a SQLContext Dataframe

In [40]:
df = sqlContext.read.format('csv').options(header='true', maxColumns=2000000) \
      .load('/data/s1847503/SDDM/newdata/data.csv')
df.show()

+--------------------+--------------------+--------------------+--------------------+--------------------+
|            paper_id|               title|        list_authors|           full_text|            sections|
+--------------------+--------------------+--------------------+--------------------+--------------------+
|1329bb2f949e74925...|Generation of pre...|['Xue Wu Zhang', ...|"The infection of...| 30 drugs were se...|
|dc079a2e9cf98fad0...|Zoonotic disease ...|['Charlotte Robin...|"Veterinary profe...| based on the par...|
|75af9aa0e63889abd...|Current and Novel...|['Erasmus Kotey',...|"Influenza viruse...| although LAIVs a...|
|1755c4785f87bca19...|MERS: Progress on...|['*', 'Ryan Aguan...|Since its identif...|['Since its ident...|
|cc829c0f2ab2e110b...|Hepatologie Akute...|['Karoline Rutter...|"Das akute Leberv...| nach Ausschluss ...|
|ece3d68d9b996c917...|Novel approach to...|['Ivan Timokhin',...|"Introduction | T...|      diameter 12 mm|
|9cd0f74020b0db181...|On the electrif

### Initialize Annotators

In [41]:
# Pipeline for text
document_assembler = DocumentAssembler() \
                        .setInputCol('full_text') \
                        .setOutputCol('document')

# Tokenizer divides the text into tokens
tokenizer = Tokenizer() \
                .setInputCols(['document']) \
                .setOutputCol('tokens')

# Finisher converts tokens to human-readable output (we need the tokens for determining the text lengths)
finisher_tokens = Finisher() \
                        .setInputCols(['tokens']) \
                        .setCleanAnnotations(False)

# Normalizer removes punctuation, numbers etc.
normalizer = Normalizer() \
                .setInputCols(['tokens']) \
                .setOutputCol('normalized') \
                .setLowercase(True)

# Lemmatizer changes each word to its lemma
lemmatizer = LemmatizerModel.pretrained() \
                .setInputCols(['normalized']) \
                .setOutputCol('lemma')

# StopWordsCleaner removes stop words    
stopwords_cleaner = StopWordsCleaner() \
                        .setInputCols(['lemma']) \
                        .setOutputCol('clean_lemma') \
                        .setCaseSensitive(False).setStopWords(all_stopwords)

# Finisher converts clean tokens to human-readable output
finisher = Finisher() \
            .setInputCols(['clean_lemma']) \
            .setCleanAnnotations(False)

lemma_antbnc download started this may take some time.
Approximate size to download 907.6 KB
[OK!]


### Create Pipeline

In [42]:
# Pipeline for fully preprocessing the text
pipeline = Pipeline() \
            .setStages([
                document_assembler,
                tokenizer,
                normalizer,
                lemmatizer,
                stopwords_cleaner,
                finisher_tokens,
                finisher
             ])

### Preprocess questions

In [43]:
questions = sqlContext.read.format('csv').options(header='true').load('/data/s1847503/SDDM/newdata/questions.csv')
questions_clean = pipeline.fit(questions).transform(questions)
questions_clean = questions_clean.select(col('finished_clean_lemma').alias('clean_question'))

questions = [q.full_text for q in questions.collect()]
questions_clean = [q.clean_question for q in questions_clean.collect()]
print(questions)
print()
print(questions_clean)

['Effectiveness of inter_inner travel restriction ', 'Methods to understand and regulate the spread in communities', 'Evidence that domesticated_farm animals can be infected and maintain transmissibility of the disease', 'Effectiveness of school distancing', 'Effectiveness of workplace distancing to prevent secondary transmission', 'Effectiveness of community contact reduction', 'Effectiveness of case isolation_isolation of exposed individuals to prevent secondary transmission', 'Effectiveness of personal protective equipment (PPE)', 'Effectiveness of a multifactorial strategy to prevent secondary transmission', 'Seasonality of transmission ', 'How does temperature and humidity affect the transmission of 2019-nCoV', 'What is the likelihood of significant changes in transmissibility in changing seasons']

[['effectiveness', 'interinner', 'travel', 'restriction'], ['method', 'understand', 'regulate', 'spread', 'community'], ['evidence', 'domesticatedfarm', 'animal', 'infect', 'maintain',

### Preprocess text

In [44]:
# Peprocess the data
df = pipeline.fit(df).transform(df)
df = df.select('*', size('finished_tokens').alias('text_length'))

# Keep only papers with a text length of greater than 10
print("Before removing empty papers: {} rows.".format(df.count()))
df = df.dropna(subset='full_text')
# df = df.dropduplicates(subset='title')
# print("Removed duplicates")
# df = df.filter(df['text_length'] > 10)
print("After removing empty papers: {} rows.".format(df.count()))
print()

df = df.select(
                'paper_id',
                'title',
                'full_text',
                'text_length',
                col('finished_clean_lemma').alias('preprocessed')
            )

df.show()

Before removing empty papers: 1329677 rows.
After removing empty papers: 406784 rows.

+--------------------+--------------------+--------------------+-----------+--------------------+
|            paper_id|               title|           full_text|text_length|        preprocessed|
+--------------------+--------------------+--------------------+-----------+--------------------+
|1329bb2f949e74925...|Generation of pre...|"The infection of...|        723|[infection, newly...|
|dc079a2e9cf98fad0...|Zoonotic disease ...|"Veterinary profe...|       1756|[veterinary, prof...|
|75af9aa0e63889abd...|Current and Novel...|"Influenza viruse...|        919|[influenza, virus...|
|1755c4785f87bca19...|MERS: Progress on...|Since its identif...|       3942|[identification, ...|
|cc829c0f2ab2e110b...|Hepatologie Akute...|"Das akute Leberv...|       1832|[akute, lebervers...|
|ece3d68d9b996c917...|Novel approach to...|"Introduction | T...|        448|[introduction, qu...|
|9cd0f74020b0db181...|On the el

### TF-IDF

In [45]:
# Explode text
exploded = df.withColumn('token', explode(col('preprocessed')))
exploded.show()

+--------------------+--------------------+--------------------+-----------+--------------------+------------+
|            paper_id|               title|           full_text|text_length|        preprocessed|       token|
+--------------------+--------------------+--------------------+-----------+--------------------+------------+
|1329bb2f949e74925...|Generation of pre...|"The infection of...|        723|[infection, newly...|   infection|
|1329bb2f949e74925...|Generation of pre...|"The infection of...|        723|[infection, newly...|       newly|
|1329bb2f949e74925...|Generation of pre...|"The infection of...|        723|[infection, newly...|      emerge|
|1329bb2f949e74925...|Generation of pre...|"The infection of...|        723|[infection, newly...|      severe|
|1329bb2f949e74925...|Generation of pre...|"The infection of...|        723|[infection, newly...|       acute|
|1329bb2f949e74925...|Generation of pre...|"The infection of...|        723|[infection, newly...| respiratory|
|

In [None]:
# Get term frequencies
tf = exploded \
        .groupBy('paper_id', 'token') \
        .sum() \
        .alias('tf')
tf.show()

In [None]:
# Get inverse document frequencies
idf = exploded \
        .groupBy('token') \
        .agg(col('paper_id')) \
        .countDistinct(col('paper_id')) \
        .alias('df')
idf.show()

### Similarity (copied, taking inspiration from it)

In [None]:
# Similarity ----------------------------------------------------------------------
def calc_simlarity_score(question_list, text_list,threshold=None, top=None):
    if (threshold==None)  and  (top==None):
        raise ValueError("Parameter `threshold` and `top` cannot both be None")
    dic = {}
    tfidf = TfidfVectorizer()
    corpus_tfidf_matrix = tfidf.fit_transform(text_list)
    ques_tfidf_matrix = tfidf.transform(question_list)
    sim_matrix = cosine_similarity(corpus_tfidf_matrix, ques_tfidf_matrix)
    for ques_idx in range(sim_matrix.shape[1]):
        dic[ques_idx] = []
        if threshold != None:
            if (threshold>1) or (threshold <0):
                raise ValueError("Please enter a value from 0 to 1 for parameter `threshold`")
            for paper_idx in range(sim_matrix.shape[0]):
                score = sim_matrix[paper_idx, ques_idx]
                if score >= threshold:
                    dic[ques_idx].append((paper_idx, score))
            dic[ques_idx]=sorted(dic[ques_idx], key=lambda i: i[1], reverse=True)
        elif top != None:
            top_paper_idx_list = sorted(range(len(sim_matrix[:, ques_idx])), key=lambda i: sim_matrix[:,0][i], reverse=True)[:top]
            dic[ques_idx] = [(top_idx, sim_matrix[top_idx, ques_idx]) for top_idx in top_paper_idx_list]
    return dic, sim_matrix

# Retrieve relevant paper----------------------------------------------------------------------
def retrieve_paper(df, dic):
    df_dic={}
    for ques_idx in dic:
        new_df = df.iloc[[item[0] for item in dic[ques_idx]], :]
        new_df['score'] = [item[1] for item in dic[ques_idx]]
        new_df['question'] = questions[ques_idx]
        df_dic[ques_idx]=new_df.copy()
    return df_dic

### Close Spark Context when done

In [None]:
sc.stop()