In [1]:
import json
import re
import time
import boto3
import requests
import pandas as pd
import numpy as np
import unicodedata
import torch
pd.set_option('max_colwidth', None)
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity

In [2]:
torch.cuda.device_count()

1

In [3]:
def name_to_keep_ind(groups):
    """
    Function to determine if a text should be kept or not.

    Input:
    groups: list of character groups

    Output:
    0: if text should be not used
    1: if text should be used
    """
    # Groups of characters that do not perform well
    groups_to_skip = ['HIRAGANA', 'CJK', 'KATAKANA','ARABIC', 'HANGUL', 'THAI','DEVANAGARI','BENGALI',
                      'THAANA','GUJARATI','CYRILLIC']
    
    if any(x in groups_to_skip for x in groups):
        return 0
    else:
        return 1

def remove_non_latin_characters(text):
    """
    Function to remove non-latin characters.

    Input:
    text: string of characters

    Output:
    final_char: string of characters with non-latin characters removed
    """
    final_char = []
    groups_to_skip = ['HIRAGANA', 'CJK', 'KATAKANA','ARABIC', 'HANGUL', 'THAI','DEVANAGARI','BENGALI',
                      'THAANA','GUJARATI','CYRILLIC']
    for char in text:
        try:
            script = unicodedata.name(char).split(" ")[0]
            if script not in groups_to_skip:
                final_char.append(char)
        except:
            pass
    return "".join(final_char)
    
def group_non_latin_characters(text):
    """
    Function to group non-latin characters and return the number of latin characters.

    Input:
    text: string of characters

    Output:
    groups: list of character groups
    latin_chars: number of latin characters
    """
    groups = []
    latin_chars = []
    text = text.replace(".", "").replace(" ", "")
    for char in text:
        try:
            script = unicodedata.name(char).split(" ")[0]
            if script == 'LATIN':
                latin_chars.append(script)
            else:
                if script not in groups:
                    groups.append(script)
        except:
            if "UNK" not in groups:
                groups.append("UNK")
    return groups, len(latin_chars)

def check_for_non_latin_characters(text):
    """
    Function to check if non-latin characters are dominant in a text.

    Input:
    text: string of characters

    Output:
    0: if text should be not used
    1: if text should be used
    """
    groups, latin_chars = group_non_latin_characters(str(text))
    if name_to_keep_ind(groups) == 1:
        return 1
    elif latin_chars > 20:
        return 1
    else:
        return 0

In [4]:
def clean_title(old_title):
    """
    Function to check if title should be kept and then remove non-latin characters. Also
    removes some HTML tags from the title.
    
    Input:
    old_title: string of title
    
    Output:
    new_title: string of title with non-latin characters and HTML tags removed
    """
    keep_title = check_for_non_latin_characters(old_title)
    if (keep_title == 1) & isinstance(old_title, str):
        new_title = remove_non_latin_characters(old_title)
        if '<' in new_title:
            new_title = new_title.replace("<i>", "").replace("</i>","")\
                                 .replace("<sub>", "").replace("</sub>","") \
                                 .replace("<sup>", "").replace("</sup>","") \
                                 .replace("<em>", "").replace("</em>","") \
                                 .replace("<b>", "").replace("</b>","") \
                                 .replace("<I>", "").replace("</I>", "") \
                                 .replace("<SUB>", "").replace("</SUB>", "") \
                                 .replace("<scp>", "").replace("</scp>", "") \
                                 .replace("<font>", "").replace("</font>", "") \
                                 .replace("<inf>","").replace("</inf>", "") \
                                 .replace("<i /> ", "") \
                                 .replace("<p>", "").replace("</p>","") \
                                 .replace("<![CDATA[<B>", "").replace("</B>]]>", "") \
                                 .replace("<italic>", "").replace("</italic>","")\
                                 .replace("<title>", "").replace("</title>", "") \
                                 .replace("<br>", "").replace("</br>","").replace("<br/>","") \
                                 .replace("<B>", "").replace("</B>", "") \
                                 .replace("<em>", "").replace("</em>", "") \
                                 .replace("<BR>", "").replace("</BR>", "") \
                                 .replace("<title>", "").replace("</title>", "") \
                                 .replace("<strong>", "").replace("</strong>", "") \
                                 .replace("<formula>", "").replace("</formula>", "") \
                                 .replace("<roman>", "").replace("</roman>", "") \
                                 .replace("<SUP>", "").replace("</SUP>", "") \
                                 .replace("<SSUP>", "").replace("</SSUP>", "") \
                                 .replace("<sc>", "").replace("</sc>", "") \
                                 .replace("<subtitle>", "").replace("</subtitle>", "") \
                                 .replace("<emph/>", "").replace("<emph>", "").replace("</emph>", "") \
                                 .replace("""<p class="Body">""", "") \
                                 .replace("<TITLE>", "").replace("</TITLE>", "") \
                                 .replace("<sub />", "").replace("<sub/>", "") \
                                 .replace("<mi>", "").replace("</mi>", "") \
                                 .replace("<bold>", "").replace("</bold>", "") \
                                 .replace("<mtext>", "").replace("</mtext>", "") \
                                 .replace("<msub>", "").replace("</msub>", "") \
                                 .replace("<mrow>", "").replace("</mrow>", "") \
                                 .replace("</mfenced>", "").replace("</math>", "")

            if '<mml' in new_title:
                all_parts = [x for y in [i.split("mml:math>") for i in new_title.split("<mml:math")] for x in y if x]
                final_parts = []
                for part in all_parts:
                    if re.search(r"\>[$%#!^*\w.,/()+-]*\<", part):
                        pull_out = re.findall(r"\>[$%#!^*\w.,/()+-]*\<", part)
                        final_pieces = []
                        for piece in pull_out:
                            final_pieces.append(piece.replace(">", "").replace("<", ""))
                        
                        final_parts.append(" "+ "".join(final_pieces) + " ")
                    else:
                        final_parts.append(part)
                
                new_title = "".join(final_parts).strip()
            else:
                pass

            if '<xref' in new_title:
                new_title = re.sub(r"\<xref[^/]*\/xref\>", "", new_title)

            if '<inline-formula' in new_title:
                new_title = re.sub(r"\<inline-formula[^/]*\/inline-formula\>", "", new_title)

            if '<title' in new_title:
                new_title = re.sub(r"\<title[^/]*\/title\>", "", new_title)

            if '<p class=' in new_title:
                new_title = re.sub(r"\<p class=[^>]*\>", "", new_title)
            
            if '<span class=' in new_title:
                new_title = re.sub(r"\<span class=[^>]*\>", "", new_title)

            if 'mfenced open' in new_title:
                new_title = re.sub(r"\<mfenced open=[^>]*\>", "", new_title)
            
            if 'math xmlns' in new_title:
                new_title = re.sub(r"\<math xmlns=[^>]*\>", "", new_title)

        if '<' in new_title:
            new_title = new_title.replace(">i<", "").replace(">/i<", "") \
                                 .replace(">b<", "").replace(">/b<", "") \
                                 .replace("<inline-formula>", "").replace("</inline-formula>","")
        if new_title.isupper():
            new_title = new_title.title()
        
        return new_title
    else:
        return ''
    
def clean_abstract(raw_abstract, inverted=False):
    """
    Function to clean abstract and return it in a format for the model.
    
    Input:
    raw_abstract: string of abstract
    inverted: boolean to determine if abstract is inverted index or not
    
    Output:
    final_abstract: string of abstract in format for model
    """
    max_ab_len = 700
    if inverted:
        if isinstance(raw_abstract, dict) | isinstance(raw_abstract, str):
            if isinstance(raw_abstract, dict):
                invert_abstract = raw_abstract
            else:
                invert_abstract = json.loads(raw_abstract)
            
            if invert_abstract.get('IndexLength'):
                ab_len = invert_abstract['IndexLength']

                if ab_len > 20:
                    abstract = [" "]*ab_len
                    for key, value in invert_abstract['InvertedIndex'].items():
                        for i in value:
                            abstract[i] = key
                    final_abstract = " ".join(abstract)[:max_ab_len]
                    keep_abs = check_for_non_latin_characters(final_abstract)
                    if keep_abs == 1:
                        pass
                    else:
                        final_abstract = None
                else:
                    final_abstract = None
            else:
                if len(invert_abstract) > 20:
                    abstract = [" "]*1200
                    for key, value in invert_abstract.items():
                        for i in value:
                            try:
                                abstract[i] = key
                            except:
                                pass
                    final_abstract = " ".join(abstract)[:max_ab_len].strip()
                    keep_abs = check_for_non_latin_characters(final_abstract)
                    if keep_abs == 1:
                        pass
                    else:
                        final_abstract = None
                else:
                    final_abstract = None
                
        else:
            final_abstract = None
    else:
        if raw_abstract:
            ab_len = len(raw_abstract)
            if ab_len > 30:
                final_abstract = raw_abstract[:max_ab_len]
                keep_abs = check_for_non_latin_characters(final_abstract)
                if keep_abs == 1:
                    pass
                else:
                    final_abstract = None
            else:
                final_abstract = None
        else:
            final_abstract = None
            
    return final_abstract

In [5]:
def get_title_abstract_for_df(title, abstract):
    max_title_ab_len = 900
    if title.isupper():
        title = title.title()
    if abstract:
        title_and_abstract = f"{title}\n {abstract}"
    else:
        if title:
            title_and_abstract = f"{title}"
        else:
            title_and_abstract = ""
    return title_and_abstract[:max_title_ab_len]

In [8]:
def get_candidate_keywords_df(title_abs_emb, title_and_abstract, candidate_topics):
    """
    Function to get keywords based on the topics
    
    Input:
    candidate_topics: topics of paper
    
    Output:
    keywords_data_copy: filtered df of keywords and embeddings
    """
    cand_embs_df = all_keywords_data[all_keywords_data['topic_id'].isin(candidate_topics)]\
        .drop_duplicates(subset=['keywords'])[['keywords','embedding']].copy()
    
    if title_and_abstract:
        # Get scores for each candidate keyword
        cand_embs_df['cand_scores'] = cand_embs_df['embedding'].apply(lambda x: np.dot(np.array(title_abs_emb), x))
    else:
        cand_embs_df['cand_scores'] = -1
        
    if cand_embs_df[cand_embs_df['cand_scores']>=0].shape[0] > 0:
        top_k = cand_embs_df[cand_embs_df['cand_scores']>=0].sort_values('cand_scores', ascending=False).head(5).copy()
        top_k['keywords'] = top_k['keywords'].apply(lambda x: x.lower())
        top_k = top_k.drop_duplicates(subset=['keywords'])
        keywords = top_k['keywords'].tolist()
        scores = top_k['cand_scores'].tolist()

        final_keywords = []
        _ = [final_keywords.append({"keyword": keyword, "score": score}) for keyword, score in zip(keywords, scores) if score > 0.50]

        if final_keywords:
            return final_keywords
        else:
            if scores[0] > 0.40:
                return [{"keyword": keywords[0], "score": scores[0]}]
            else:
                return []
    else:
        return []

In [10]:
def get_all_keywords_df(old_df):
    """
    Function to get keywords that match title/abstract
    
    Input:
    candidate_topics: topic ids for a paper
    paper_title: title of a paper
    abstract: abstract of a paper
    invert_abstract: whether or not the abstract is being input as an inverted index (True/False)
    topk: maximum number of keywords to pull for a paper
    
    Output:
    final_keywords
    """
    # Process title and abstract
    df = old_df.copy()
    df['original_title'] = df['original_title'].apply(clean_title)
    df['abstract'] = df['abstract'].apply(clean_title)
    
    # Get candidate keywords
    df['title_abstract'] = df.apply(lambda x: get_title_abstract_for_df(x.original_title, x.abstract), axis=1)
    with torch.no_grad():
        title_abs_embs = emb_model.encode(df['title_abstract'].tolist())
    df['embs'] = title_abs_embs.tolist()
    df['keywords'] = df.apply(lambda x: get_candidate_keywords_df(x.embs, x.title_abstract, x.topics), axis=1)
    return df['keywords'].tolist()

In [14]:
torch.cuda.mem_get_info()

(16611934208, 16935682048)

In [16]:
emb_model = SentenceTransformer('baai/BGE-M3')
all_keywords_data = pd.read_parquet('s3://openalex-keywords-matcher/v1/keywords_files/')

  return self.fget.__get__(instance, owner)()


In [17]:
test_df = pd.read_parquet("data_sample_to_test")

In [19]:
s3 = boto3.client('s3')
prefix = "keywords/v2/running_data_through_model/data_to_score/"
response = s3.list_objects_v2(Bucket="bucket", Prefix=prefix)

files = []
for obj in response['Contents']:
    file_key = obj['Key']
    if file_key.endswith('parquet'):
        files.append(file_key)
        
files.sort()

In [20]:
keys_to_score = files[160:180]

In [21]:
batch_num = 350

In [19]:
%%time
final_keys = [get_all_keywords_df(pd.DataFrame(i, columns=list(test_df.columns))) for i in np.array_split(test_df, int(test_df.shape[0]/batch_num))]

  return bound(*args, **kwds)


CPU times: user 33.8 s, sys: 169 ms, total: 34 s
Wall time: 32.6 s


In [None]:
for full_file_name in keys_to_score:
    s3_file = f"s3://bucket/{full_file_name}"
    file_name = full_file_name.split("part-")[1].split('-')[0]
    
    print(file_name)
    data_to_score = pd.read_parquet(s3_file)
    print(data_to_score.shape[0])
    final_keys = [get_all_keywords_df(pd.DataFrame(i, columns=list(data_to_score.columns))) for i in 
              np.array_split(data_to_score, int(data_to_score.shape[0]/batch_num))]
    data_to_score['keywords'] = [x for y in final_keys for x in y]
    print(len([x for y in final_keys for x in y]))
    data_to_score[['paper_id','keywords']].to_parquet(f"s3://bucket/keywords/v2/running_data_through_model/data_scored_gpu/part_{file_name}.parquet")
    print("")
