In [1]:
import csv
import requests
import json
import numpy as np
import pandas as pd
import re
from txtai.embeddings import Embeddings


import geopy.distance
from geopy.geocoders import Nominatim

2023-04-24 11:17:48.562685: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
import numpy as np

import os
os.environ['TRANSFORMERS_CACHE'] = 'cache'

import pandas as pd
import numpy as np 
import json
try:
    from tqdm.auto import tqdm
except ImportError:
    tqdm = lambda x: x
    
from txtai.embeddings import Embeddings
from txtai.pipeline import Similarity
from txtai.pipeline import Tabular
from txtai.workflow import Task
from txtai.workflow import Workflow
from huggingface_hub import snapshot_download

In [3]:
import json
import pandas as pd
import gradio as gr
import numpy as np
import os
os.environ['TRANSFORMERS_CACHE'] = 'cache'
try:
    from tqdm.auto import tqdm
except ImportError:
    def tqdm(x): return x


class SemanticSearch(object):
    def __init__(
        self,
        filename="ctgov_34983_20230417",
        columns=[
            "brief_title",
            "official_title",
            "brief_summaries",
            "detailed_descriptions",
            "criteria",
            "city", "state", "zip", "country"
        ],
        ckptlist=[
            "sentence-transformers/multi-qa-mpnet-base-dot-v1",
        ],
        rerun=True,
    ):
        self.filename = filename
        self.columns = columns
        self.ckptlist = ckptlist

        for ckptpath in self.ckptlist:
            snapshot_download(repo_id=ckptpath,
                              repo_type="model",
                              cache_dir="cache")
            self.embeddings = Embeddings({
                "method": "transformers",
                "path": ckptpath,
                "content": True,
                "object": True
            })
            indexfile = f'{filename}_{ckptpath.replace("/", "-")}.index'
            if os.path.exists(indexfile) and rerun is False:
                print("Indexed and Cached!")
                self.embeddings.load(indexfile)
            else:
                print("Need to rerun or Indices and Caches dont exist, run them!")

                # Read the data from CSV
                data = pd.read_csv(f'{filename}.csv')

                # Create tabular instance mapping input.csv fields
                tabular = Tabular(idcolumn="nct_id",
                                  textcolumns=columns, content=True)

                # Create workflow
                workflow = Workflow([Task(tabular)])

                # Index the data
                data = list(workflow([data]))
                self.embeddings.index(data)
                self.embeddings.save(indexfile)
                print("Indexing and Caching finished for the 1st time!")

    def search_func(self, 
                    prompttext, 
                    pretrained="sentence-transformers/multi-qa-mpnet-base-dot-v1", 
                    limit=10):
        assert pretrained in self.ckptlist
        query = f'select {", ".join(["nct_id"] + [column for column in self.columns])} from txtai where similar({prompttext})'
        results = self.embeddings.search(query, limit)
        return results
        
    def search_cond(self, 
                    results, 
                    location, 
                    distance):
        
        # Parse location into latitude and longitude using geopy
        geolocator = Nominatim(user_agent="my-app")
        location_obj = geolocator.geocode(location)
        if location_obj is None:
            raise ValueError(f"Could not find location: {location}")
        location_coords = (location_obj.latitude, location_obj.longitude)

        # Filter results based on distance from the location
        filtered_results = []
        for result in results:
            nct_id = result['nct_id']
            # Use the first location column available
            location_col = next((col for col in result.keys() if col in [
                                'city', 'state', 'zip', 'country']), None)
            if location_col is None:
                # No location column found, skip this result
                continue
            location_str = result[location_col]
            location_obj = geolocator.geocode(location_str)
            if location_obj is None:
                # Could not parse location, skip this result
                continue
            result_coords = (location_obj.latitude, location_obj.longitude)
            dist = geopy.distance.distance(location_coords, result_coords).km
            if dist <= distance:
                filtered_results.append(result)

        return filtered_results


In [4]:
trial_search = SemanticSearch(
    filename="ctgov_34983_20230417",
    columns=[
        "brief_title",
        "official_title",
        "brief_summaries",
        "detailed_descriptions",
        "criteria",
        "city",
        "state",
        "zip",
        "country",
    ],
    ckptlist=[
        "sentence-transformers/multi-qa-mpnet-base-dot-v1",
    ],
    rerun=False,
)
results = trial_search.search_func(prompttext="diabetes", limit=100)
display(len(results))

Fetching 14 files:   0%|          | 0/14 [00:00<?, ?it/s]

Indexed and Cached!


100

In [5]:
locations = ["Kendall MIT", "Seatle", "San Francisco"]
for location in locations:
    filtered_results = trial_search.search_cond(
        results, 
        location=location, 
        distance=100)
    display(filtered_results)
