# Notebook for the Word2Vec model

### Import packages

In [1]:
# Import packages
import numpy as np

from pyspark.sql import SparkSession
from pyspark import SQLContext
from pyspark.sql.functions import udf, size, col
from pyspark.ml import Pipeline
from pyspark.ml.feature import Word2Vec

from nltk.corpus import stopwords
from gensim.parsing.preprocessing import STOPWORDS as gensim_words
import spacy
sp = spacy.load('en_core_web_sm')

import os

from sparknlp.base import Finisher, DocumentAssembler
from sparknlp.annotator import Tokenizer, Normalizer, LemmatizerModel, StopWordsCleaner

import time

In [2]:
# Import stop words
nltk_stopwords = set(stopwords.words('english')) \
                    .union(set(stopwords.words('german'))) \
                    .union(set(stopwords.words('french')))
gensim_stopwords = set(gensim_words)
spacy_stopwords = sp.Defaults.stop_words
# https://countwordsfree.com/stopwords
cwf_stopwords = set(line.strip() for line in open('stop_words.txt'))

all_stopwords = list( nltk_stopwords \
                        .union(gensim_stopwords) \
                        .union(spacy_stopwords) \
                        .union(cwf_stopwords) )

### Create Spark Context and SQL Context

In [3]:
# Get the right paths on local machine
os.environ['JAVA_HOME'] = '/usr/lib/jvm/java-8-openjdk-amd64'
os.environ["PYSPARK_PYTHON"] = '/usr/bin/python3.7'
os.environ["PYSPARK_DRIVER_PYTHON"] = '/usr/bin/python3.7'

In [4]:
# Start spark session configured for SparkNLP
spark = SparkSession.builder \
        .master('local[*]') \
        .appName('SDDM') \
        .config('spark.driver.memory', '8g') \
        .config('spark.executor.memory', '8g') \
        .config('spark.memory.fraction', '0.8') \
        .config('spark.executor.cores', '8') \
        .config('spark.local.dir', '/home/rikz/Documents/Master/Semester2/SDDM/data/tmp') \
        .config('spark.jars.packages', 'com.johnsnowlabs.nlp:spark-nlp_2.11:2.5.0') \
        .getOrCreate()
print("Created a SparkSession")
sc = spark.sparkContext
print("Created a SparkContext")
sqlContext = SQLContext(sc)
print("Created a SQLContext")

Created a SparkSession
Created a SparkContext
Created a SQLContext


### Load the data into a SQLContext Dataframe

In [5]:
# Load data
df = sqlContext.read.format('csv').options(header='true', maxColumns=2000000) \
        .load('/home/rikz/Documents/Master/Semester2/SDDM/data/data.csv')
#       .load('/data/s1847503/SDDM/newdata/data.csv')

df.show()

+---+--------------------+--------------------+--------------------+--------------------+
|_c0|            paper_id|               title|        list_authors|           full_text|
+---+--------------------+--------------------+--------------------+--------------------+
|  0|           question0|                   -|                   -|How does temperat...|
|  1|           question1|                   -|                   -|Seasonality of tr...|
|  2|           question2|                   -|                   -|Effectiveness of ...|
|  3|           question3|                   -|                   -|Effectiveness of ...|
|  4|           question4|                   -|                   -|Effectiveness of ...|
|  5|           question5|                   -|                   -|Effectiveness of ...|
|  6|           question6|                   -|                   -|Effectiveness of ...|
|  7|           question7|                   -|                   -|Effectiveness of ...|
|  8|     

In [6]:
# Load metadata
df_metadata = sqlContext.read.format('csv').options(header='true') \
                .load('/home/rikz/Documents/Master/Semester2/SDDM/data/metadata.csv') \
                .select(col('sha').alias('paper_id'), 'publish_time', 'title', 'doi', 'journal')

df_metadata.show()

+--------------------+------------+--------------------+--------------------+--------------------+
|            paper_id|publish_time|               title|                 doi|             journal|
+--------------------+------------+--------------------+--------------------+--------------------+
|d1aafb70c066a2068...|  2001-07-04|Clinical features...|10.1186/1471-2334...|      BMC Infect Dis|
|6b0567729c2143a66...|  2000-08-15|Nitric oxide: a p...|        10.1186/rr14|          Respir Res|
|06ced00a5fc042159...|  2000-08-25|Surfactant protei...|        10.1186/rr19|          Respir Res|
|348055649b6b8cf2b...|  2001-02-22|Role of endotheli...|        10.1186/rr44|          Respir Res|
|5f48792a5fa08bed9...|  2001-05-11|Gene expression i...|        10.1186/rr61|          Respir Res|
|b2897e1277f566411...|  2001-12-17|Sequence requirem...|10.1093/emboj/20....|    The EMBO Journal|
|3bb07ea10432f7738...|  2001-03-08|Debate: Transfusi...|       10.1186/cc987|           Crit Care|
|5806726a2

### Initialize Annotators

In [7]:
# Pipeline for text
document_assembler = DocumentAssembler() \
                        .setInputCol('full_text') \
                        .setOutputCol('document')

# Tokenizer divides the text into tokens
tokenizer = Tokenizer() \
                .setInputCols(['document']) \
                .setOutputCol('tokens')

# Finisher converts tokens to human-readable output (we need the tokens for determining the text lengths)
finisher_tokens = Finisher() \
                        .setInputCols(['tokens']) \
                        .setCleanAnnotations(False)

# Normalizer removes punctuation, numbers etc.
normalizer = Normalizer() \
                .setInputCols(['tokens']) \
                .setOutputCol('normalized') \
                .setLowercase(True)

# Lemmatizer changes each word to its lemma
lemmatizer = LemmatizerModel.pretrained() \
                .setInputCols(['normalized']) \
                .setOutputCol('lemma')

# StopWordsCleaner removes stop words    
stopwords_cleaner = StopWordsCleaner() \
                        .setInputCols(['lemma']) \
                        .setOutputCol('clean_lemma') \
                        .setCaseSensitive(False).setStopWords(all_stopwords)

# Finisher converts clean tokens to human-readable output
finisher = Finisher() \
            .setInputCols(['clean_lemma']) \
            .setCleanAnnotations(False)

lemma_antbnc download started this may take some time.
Approximate size to download 907.6 KB
[OK!]


### Create Pipeline

In [8]:
# Pipeline for fully preprocessing the text
pipeline = Pipeline() \
            .setStages([
                document_assembler,
                tokenizer,
                normalizer,
                lemmatizer,
                stopwords_cleaner,
                finisher_tokens,
                finisher
             ])

### Preprocess questions

In [9]:
# questions = sqlContext.read.format('csv').options(header='true').load('/data/s1847503/SDDM/newdata/questions.csv')
questions = sqlContext.read.format('csv').options(header='true').load('/home/rikz/Documents/Master/Semester2/SDDM/data/questions.csv')
questions_clean = pipeline.fit(questions).transform(questions)
questions_clean = questions_clean.select('question_id', 'full_text', col('finished_clean_lemma').alias('preprocessed'))
questions_clean.show()

+-----------+--------------------+--------------------+
|question_id|           full_text|        preprocessed|
+-----------+--------------------+--------------------+
|          0|How does temperat...|[temperature, hum...|
|          1|Seasonality of tr...|[seasonality, tra...|
|          2|Effectiveness of ...|[effectiveness, i...|
|          3|Effectiveness of ...|[effectiveness, p...|
|          4|Effectiveness of ...|[effectiveness, s...|
|          5|Effectiveness of ...|[effectiveness, c...|
|          6|Effectiveness of ...|[effectiveness, m...|
|          7|Effectiveness of ...|[effectiveness, c...|
|          8|Significant chang...|[change, transmis...|
|          9|Effectiveness of ...|[effectiveness, w...|
+-----------+--------------------+--------------------+



In [10]:
# Select a question (from 0 to 9)
question_num = 2

questions_clean = questions_clean.filter(questions_clean.question_id == question_num)
q = questions_clean.first().full_text
q

'Effectiveness of inter inner travel restriction'

### Preprocess text

In [11]:
time_before = time.time()

In [12]:
# Peprocess the data
df = pipeline.fit(df).transform(df)
df = df.select('*', size('finished_tokens').alias('text_length'))

df = df.dropna(subset=['paper_id', 'full_text'])
print("Removed empty papers")
df = df.dropDuplicates(subset=['paper_id', 'full_text'])
print("Removed duplicates")
print()

df = df.select(
                col('_c0').alias('id'),
                'paper_id',
                'title',
                'full_text',
                'text_length',
                col('finished_clean_lemma').alias('preprocessed')
            )

df.show()

Removed empty papers
Removed duplicates

+---+--------------------+--------------------+--------------------+-----------+--------------------+
| id|            paper_id|               title|           full_text|text_length|        preprocessed|
+---+--------------------+--------------------+--------------------+-----------+--------------------+
|263|6cb2eced687ea9da4...|Medical recommend...|"In early 2020, t...|       1224|[early, face, glo...|
|468|9482e5881613ae262...|Oral vaccination ...|Helicobacter pylo...|       3705|[helicobacter, py...|
|491|c3ba4e042c5173d4a...|Thomas Grünewald,...|"Die Versorgung v...|        205|[versorgung, pati...|
|173|e639fc8b330785fb2...|            Vaccines|"Since vaccinatio...|        537|[vaccination, doc...|
|201|8645437ad8a6f8538...|Accessibility and...|The terminology o...|       8079|[terminology, acc...|
|224|121638b718d18f7bb...|JOURNAL OF MEDICA...|"Adenoviruses are...|        793|[adenoviruses, do...|
|200|4bc77a5504262d2f8...|A Fuzzy Model f

In [13]:
time_after = time.time()

In [14]:
print('Preprocessing time: {} sec'.format(time_after-time_before) )
# Small dataset: ~15 sec

Preprocessing time: 14.461075067520142 sec


### Word2Vec

In [15]:
# Create the model
word2Vec = Word2Vec(inputCol='preprocessed', outputCol='word_vector')

In [16]:
# Train the model for the papers and get the vectors for all papers
model = word2Vec.fit(df)
df = model.transform(df)

In [17]:
# Get the vector for the question
questions_clean = model.transform(questions_clean)
ques_vec = questions_clean.first().word_vector

In [18]:
# Calculate cosine similarity between a document vector and a question vector
def cossim(doc_vec): 
    global ques_vec
    sim = np.dot(doc_vec, ques_vec) / np.sqrt(np.dot(doc_vec, ques_vec)) / np.sqrt(np.dot(doc_vec, ques_vec)) 
    return float(sim)

cossim_udf = udf(cossim)

In [19]:
# Calculate similarity between all papers and the selected question
df_relevant = df.select('id', 'paper_id', cossim_udf('word_vector').alias('similarity'))

# Remove questions from the paper list
# Remove papers with a similarity of 'NaN'
# Sort on cosine similarity
# Take the top 10 relevant documents
df_relevant = df_relevant.filter(df_relevant.id > 9) \
                            .filter(df_relevant.similarity != 'NaN') \
                            .sort(col('similarity').desc()) \
                            .limit(10)

# Get the data of the 10 most relevant papers in order of relevance
print("Query: {}".format(q))
print()
print("Relevant Papers:")
print()
df_relevant.show()

Query: Effectiveness of inter inner travel restriction

Relevant Papers:

+---+--------------------+------------------+
| id|            paper_id|        similarity|
+---+--------------------+------------------+
|412|0fa6e26b053037098...|1.0000000000000002|
|382|a3f69a45be4bf642b...|1.0000000000000002|
| 87|872a34dc8f89a091d...|1.0000000000000002|
|185|72ab1b77e1ea96069...|1.0000000000000002|
|101|f2b8d478a21a63e3c...|1.0000000000000002|
|365|49d58cfebcb62abd1...|1.0000000000000002|
|358|e161b910c1411b6d4...|1.0000000000000002|
|144|c4213fb7b0fdd8926...|1.0000000000000002|
| 62|249562b091482dd3e...|1.0000000000000002|
|457|ce707aeb4fe129ed6...|1.0000000000000002|
+---+--------------------+------------------+



In [20]:
# Create the summary table with the relevant paper from the metadata
df_relevant = df_relevant.join(df_metadata, on=['paper_id'], how='left_outer') \
                            .select('paper_id', 'publish_time', 'title', 'doi', 'journal', 'similarity') \
                            .toPandas() \
                            .sort_values(by='similarity', ascending=False)
df_relevant.head(10)

Unnamed: 0,paper_id,publish_time,title,doi,journal,similarity
0,f2b8d478a21a63e3c986ccbd0bfafc71578252d0,2020-05-30,"COVID-19 panic, solidarity and equity—the Malt...",10.1007/s10389-020-01308-w,Z Gesundh Wiss,1.0000000000000002
1,e161b910c1411b6d44fb63dbb5534dda132d44cd,2020-06-14,Optimal size of sample pooling for RNA pool te...,10.1101/2020.06.11.20128793,,1.0000000000000002
2,ce707aeb4fe129ed6e3b011c2776f48b9cb200d6,,,,,1.0000000000000002
3,d96113a2d8691d3b1aee5fd1b5d30241f2b2a633,2020-06-08,Quantify the role of superspreaders -opinion l...,10.1371/journal.pone.0234023,PLoS One,1.0000000000000002
4,c4213fb7b0fdd8926e9e4108726524bc87483677,2020-05-08,Forecasting COVID-19 new cases in Algeria usin...,10.1101/2020.05.03.20089615,,1.0000000000000002
5,dc57cacedbbad71644464158ec9a2d61afbaeb70,2020-05-21,Surgical treatment of thoracolumbar fracture w...,10.1016/j.cjtee.2020.05.005,Chin J Traumatol,1.0000000000000002
6,a3f69a45be4bf642b2375f5df6218fe9bb087f92,2020-07-31,Managing business relationships during a pande...,10.1016/j.indmarman.2020.05.025,Industrial Marketing Management,1.0000000000000002
7,249562b091482dd3e0ca8bdc45ea1d7c52ba1616,2014-04-13,Viral Respiratory Infections Diagnosed by Mult...,10.1016/j.bbmt.2014.04.004,Biol Blood Marrow Transplant,1.0000000000000002
8,49d58cfebcb62abd1286784f0bc3142ac91e4026,2010-12-31,Chapter 45 Gender Differences in Emerging Infe...,10.1016/b978-0-12-374271-1.00045-9,Principles of Gender-Specific Medicine,1.0000000000000002
9,872a34dc8f89a091dbf8f6280c6947fca7e14e2c,2020-05-29,Older age is associated with sustained detecti...,10.1101/2020.05.28.20115378,,1.0000000000000002


In [21]:
# Send the summary table to a csv file
df_relevant.to_csv('/home/rikz/Documents/Master/Semester2/SDDM/SDDM/summary_tables/word2vec/{}.csv' \
                   .format(q.lower().replace(' ', '_')), index=False)
print("Summary table extracted and sent to csv file.")

Summary table extracted and sent to csv file.


In [22]:
sc.stop()