# Verification of BM25 Implementation

In [20]:
from rank_bm25 import BM25Okapi
import os
import json
import logging
import re
from string import punctuation
from nltk.corpus import stopwords
from nltk.stem import WordNetLemmatizer
from tqdm import tqdm

In [9]:
pattern = rf"[{punctuation}\s]+"
stopwords = set(stopwords.words('english'))
lemmatizer = WordNetLemmatizer()

In [10]:
def create_concat_text(doc_id_list, data_path):
    """Load documents and return concatenated document.

    Parameters
    ----------
    doc_id_list : list
        List of document IDs to load and concatenate together.
    data_path : str
        Path to load documents from.

    Returns
    -------
    doc_concat : str
        Concatenated string of advocate's cases.
    doc : dict
        Dictionary containing the processed text of each document.
    """

    docs = {}
    for doc_id in doc_id_list:
        flname = doc_id
        try:
            if (os.path.exists(os.path.join(data_path, f"{flname}.txt"))):
                with open(os.path.join(data_path, f"{flname}.txt"), 'r') as f:
                    #  docs[flname] = f.read().split()
                    docs[flname] = process(f.read())
                    if docs[flname] == '':
                        raise ValueError((f"Found empty document {flname}."
                                          "Documents cannot be empty"))
            else:
                raise FileNotFoundError(f"{flname}.txt not found")
        except FileNotFoundError as f:
            logging.error(repr(f))
            sys.exit(1)
        except ValueError as e:
            logging.error(repr(e))
            sys.exit(1)

    # Concatenating into one document
    doc_concat = [token for doc in docs.values() for token in doc]

    return doc_concat, docs


In [11]:
def process(text):
    """Carry out processing of given text."""
    processed = list(filter(None, [re.sub('[^0-9a-zA-Z]+', '',
                                          token.lower())
                                   for token in re.split(pattern, text)]))

    # Removing tokens of length 1
    processed = [lemmatizer.lemmatize(token)
                 for token in processed
                 if len(token) > 1 and token not in stopwords]

    return processed


In [12]:
data_path = "/home/workboots/Datasets/DHC/common/preprocess/fact_sentences/"
split_path = "/home/workboots/Datasets/DHC/variations/new/var_1/cross_val/5_fold/fold_0/adv_split_info.json"

In [13]:
with open(split_path, 'r') as f:
    split_info = json.load(f)

In [14]:
adv_concat = {}
test_doc_ids = set()
train_texts = {}
test_texts = {}

In [15]:
for adv, cases in split_info.items():
    adv_concat[adv], docs = create_concat_text(cases["train"], data_path)
    train_texts = {**train_texts, **docs}
    test_doc_ids.update(cases["test"])

In [16]:
for idx in test_doc_ids:
    with open(os.path.join(data_path, f"{idx}.txt"), 'r') as f:
        test_text = f.read()
    test_texts[idx] = process(test_text)

In [17]:
bm25 = BM25Okapi(list(adv_concat.values()))

In [18]:
scores = {}

In [21]:
for idx, text in tqdm(test_texts.items()):
    scores[idx] = bm25.get_scores(" ".join(text))
    scores[idx] = {k: v for k, v in zip(adv_concat.keys(), scores[idx])}

100%|██████████████████████████████████████████████████████████████████████████████████████████████| 924/924 [03:28<00:00,  4.42it/s]


In [22]:
scores = {
    idx: { k: v for k, v in sorted(values.items(), key=lambda x: x[1], reverse=True)}
    for idx, values in scores.items()}

In [24]:
with open("/home/workboots/Results/advocate_recommendation/new/exp_2/cross_val/5_fold/fold_0/results/scores.json", 'w') as f:
    json.dump(scores, f, indent=4)

In [49]:
adv_concat["RebeccaJohn"]

['delhi',
 'high',
 'court',
 'abhay',
 'kumar',
 'mishra',
 'v',
 'state',
 'anr',
 'date',
 'author',
 'sunil',
 'gaur',
 'high',
 'court',
 'delhi',
 'new',
 'delhi',
 'judgment',
 'reserved',
 'date',
 'judgment',
 'pronounced',
 'date',
 'crl',
 '3952',
 '2012',
 'crl',
 'advocate',
 'person',
 'advocate',
 'respondent',
 'coram',
 'honble',
 'mr',
 'justice',
 'sunil',
 'gaur',
 'judgment',
 'preferred',
 'second',
 'respondent',
 'trial',
 'court',
 'vide',
 'impugned',
 'order',
 'date',
 'annexure',
 'directed',
 'registration',
 'fir',
 'complaint',
 'annexure',
 'second',
 'respondent',
 'pursuance',
 'impugned',
 'order',
 'fir',
 '397',
 '12',
 'section',
 '195',
 '409',
 '420',
 '467',
 '468',
 '470',
 '471',
 '477a',
 '506',
 '120',
 'indian',
 'penal',
 'code',
 '1860',
 'registered',
 'dabri',
 'delhi',
 'date',
 'pendency',
 'petition',
 'respondent',
 'state',
 'placed',
 'record',
 'status',
 'report',
 'narrates',
 'factual',
 'background',
 'case',
 'also',
 'take

# TF-IDF Vectors

In [37]:
from sklearn.feature_extraction.text import TfidfVectorizer
import numpy as np

In [42]:
vectorizer = TfidfVectorizer()

In [43]:
train_vectors = vectorizer.fit_transform([" ".join(text) for text in adv_concat.values()])

In [44]:
train_vectors.shape

(103, 51571)

In [45]:
test_vectors = vectorizer.transform([" ".join(text) for text in test_texts.values()])

In [46]:
 test_vectors.shape

(924, 51571)

In [48]:
np.array(train_vectors[0].toarray())

array([[0.01411858, 0.00289932, 0.        , ..., 0.        , 0.        ,
        0.        ]])

In [None]:
np.save(