In [None]:
############################################################################
#            COPYRIGHT (C) YASSIN KORTAM - ALL RIGHTS RESERVED             #
# UNAUTHORIZED COPYING OF THIS FILE, VIA ANY MEDIUM IS STRICTLY PROHIBITED #
#                       PROPRIETARY AND CONFIDENTIAL                       #
#    WRITTEN BY YASSIN KORTAM <YASSINKORTAM@G.UCLA.EDU>, MARCH 2023        #
############################################################################

import pandas as pd
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity

# Filter by Keywords
- Given a dataframe, a column of interest, and keywords, filter the dataframe such that only rows with the keywords are saved.
- It is determined if a row contains a keyword by checking the cosine similarity between the text in the given column and the keyword delete all other rows.

In [None]:
def filter_by_keywords(df, column, keywords):
    '''
    Filter a dataframe by keywords.

    Args:
        - DataFrame
        - str
        - list

    Returns:
        - DataFrame
    '''
    #Get the text in the given column
    text = df[column].tolist()

    #Get the embeddings of the keywords
    model = SentenceTransformer('paraphrase-MiniLM-L6-v2')
    keywords_embeddings = model.encode(keywords)

    indices = []

    #Get the embeddings of the text
    for sentence in text:
        if type(sentence) != str:
            indices.append(False)
            continue
        sentence_embeddings = model.encode(sentence.split())
        cosine_similarities = cosine_similarity(sentence_embeddings, keywords_embeddings).flatten()
        # check if any of the words have big similarity
        if cosine_similarities.max() > 0.8:
            indices.append(True)
        else:
            indices.append(False)
            
    #Filter the dataframe
    removed = df[~pd.Series(indices)]
    df = df.iloc[indices]
    return df, removed

# Filtering a dataset

In [None]:
data = pd.read_excel('grouped_data.xlsx')
column = 'History'
keywords = ['HCC', 'HCC/cholangioCA', 'Hepatocellular', 'liver', 'hepatic']
data, removed = filter_by_keywords(data, column, keywords)
print("%s rows removed:" % removed.shape[0])
removed

In [None]:
# Save the filtered data
data.to_excel('filtered_data.xlsx')