In [1]:
!apt-get install openjdk-21-jre-headless -qq > /dev/null
import os
os.environ["JAVA_HOME"] = "/usr/lib/jvm/java-21-openjdk-amd64"
!update-alternatives --set java /usr/lib/jvm/java-21-openjdk-amd64/bin/java
!java -version

openjdk version "21.0.5" 2024-10-15
OpenJDK Runtime Environment (build 21.0.5+11-Ubuntu-1ubuntu122.04)
OpenJDK 64-Bit Server VM (build 21.0.5+11-Ubuntu-1ubuntu122.04, mixed mode, sharing)


In [2]:
#!pip install pyserini
#!pip install --upgrade openai

In [66]:
import pyserini
from pyserini.search.lucene import LuceneSearcher
from pyserini.index import LuceneIndexReader
from IPython.core.display import display, HTML
from pyserini.search import get_topics
import heapq
import re

In [8]:
def run_model(query, index_path):
  query = "Are there black bear attacks there?"

  searcher, reader = initialize_model("robust04")
  print(get_recs(query, 10, searcher, reader))

In [9]:
def initialize_model(index_path):
  searcher = LuceneSearcher.from_prebuilt_index(index_path)
  reader = LuceneIndexReader.from_prebuilt_index(index_path)
  return searcher, reader

In [101]:
def generate_token_mapping(docid, doc_vec, reader):
  # this takes forever but works
  doc = reader.doc(docid).raw().lower()
  mapping = {}
  for word in re.split(r'\s+', doc):
      analyzed = reader.analyze(word)
      for t in doc_vec:
        if t in analyzed:
          word = re.sub(r'\W+', '', word)
          mapping[t] = word
  return mapping

In [102]:
MAX_RECS = 5

def get_recs(query: str, n: int, searcher: LuceneSearcher, reader: LuceneIndexReader) -> list[tuple]:
  hits = searcher.search(query, n)
  for i in hits:
    doc_vec = reader.get_document_vector(i.docid)

    # to "untokenize" the data, we need to map original terms to tokens
    m = generate_token_mapping(i.docid, doc_vec, reader)

    h = []
    for t, f in doc_vec.items():
      tf = f / len(doc_vec)
      try:
        df = reader.get_term_counts(t)[0]
      except: # i have no idea why this is throwing an error, needs a closer look
        df = 0
      if df == 0:
        idf = 0
      else:
        idf = reader.stats()['documents'] / df
      tf_idf = tf * idf
      heapq.heappush(h, (tf_idf, m[t]))
      if len(h) > MAX_RECS:
        heapq.heappop(h)
  return h

In [103]:
query = "Are there black bear attacks there?"
index_path = "robust04"

run_model(query, index_path)

[(454.80620155038764, 'duerson'), (1364.418604651163, 'bruise'), (818.6511627906976, 'offensive'), (1364.418604651163, 'defensive'), (1364.418604651163, 'geezus')]
