Retrive and rank:
    1. Retrieve based on tfidf baseline model replaced by bert and faiss or nmslib
    2. Rank and filter the results based on siamese network using legal BERT embeddings.
    3. 

Features:
    1. Head matter vs Opinion texts
    2. Global vs Local Context
    3. By type of Opinion - But not available during test time
    4. Date, Jurisdiction, Court ID

Models:
    1. SPECTRE
    2. LEGAL BERT

In [1]:
import pandas as pd
import torch
import tensorflow as tf
from transformers import AutoTokenizer, TFAutoModel
import datasets
from functools import partial

In [2]:
# train_df = pd.read_json("../data/subset/train_data.json", lines=True, orient="records")
# val_df = pd.read_json("./data/subset/val_data.json", lines=True, orient="records")
# test_df = pd.read_json("sample_data.json", lines=True, orient="records")

In [3]:
class DatasetEmbeddingIndex:
    def __init__(self, model_name="allenai/specter", exclude_columns=[]):
        self._tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = TFAutoModel.from_pretrained(model_name, from_pt=True)
        self.tokenizer = partial(self._tokenizer,return_tensors="tf",padding=True, truncation=True, max_length=512)
        self.exclude_columns = exclude_columns
        self.dataset = None

    def load_embeddings(self, batch, inference=False):
        if inference:
            return self.model(**self.tokenizer(batch))[0][:,0,:].numpy()
        return {'embeddings': self.model(**self.tokenizer(batch["text"]))[0][:,0,:].numpy()}

    def load_text(self, batch):
        return {"text": [item[0] + "\n" + item[1] for item in zip(batch['head_matter'], batch['opinion_text'])]}
    
    def load_dataset(self, data_files):
        dataset = datasets.load_dataset("json", data_files=data_files)
        dataset = dataset.map(self.load_text, remove_columns=self.exclude_columns, batched=True)
        return dataset

    
    def fit(self, data_files, batch_size=64):
        self.dataset = self.load_dataset(data_files)
        self.dataset = self.dataset.map(self.load_embeddings, batched=True, batch_size=batch_size)
        for k in data_files:
            self.dataset[k].add_faiss_index(column='embeddings')
        return self
    
    def save(self, dataset_fname, index_fnames):
        for k, v in index_fnames.items():
            self.dataset[k].save_faiss_index('embeddings', v)
            self.dataset[k].drop_index('embeddings')
        self.dataset.save_to_disk(dataset_fname)
        return self
        
    def load(self, dataset_fname, index_fnames):
        self.dataset = datasets.load_from_disk(dataset_fname)
        for k,v in index_fnames.items():
            self.dataset[k].load_faiss_index('embeddings', v)
        return self
    
    def predict(self, queries, dataset_name, top_k=10):
        if self.dataset is not None:
            query_embeddings = self.load_embeddings(queries, inference=True)
            scores, examples = self.dataset[dataset_name].get_nearest_examples_batch('embeddings', query_embeddings, k=top_k)
            return scores, examples

In [4]:
embedding_index = DatasetEmbeddingIndex(
    exclude_columns = ["jurisdiction_id","court_id","decision_date",
                       "head_matter","opinion_text","citation_ids"]
)

Some weights of the PyTorch model were not used when initializing the TF 2.0 model TFBertModel: ['embeddings.position_ids']
- This IS expected if you are initializing TFBertModel from a PyTorch model trained on another task or with another architecture (e.g. initializing a TFBertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing TFBertModel from a PyTorch model that you expect to be exactly identical (e.g. initializing a TFBertForSequenceClassification model from a BertForSequenceClassification model).
All the weights of TFBertModel were initialized from the PyTorch model.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFBertModel for predictions without further training.


In [None]:
data_files = {
    "train": "data/subset/train_data.json",
    "validation": "data/subset/val_data.json",
    "test": "data/subset/test_data.json"}


index_fnames = {
    "train": "./indices/train_index.fiass",
    "validation": "./indices/val_index.fiass",
    "test": "./indices/test_index.fiass"
}

dataset_fname = "caselaw_dataset"
embedding_index = embedding_index.fit(data_files)

Using custom data configuration default-ead940c088e772ef


Downloading and preparing dataset json/default (download: Unknown size, generated: Unknown size, post-processed: Unknown size, total: Unknown size) to /home/61460004/.cache/huggingface/datasets/json/default-ead940c088e772ef/0.0.0/83d5b3a2f62630efc6b5315f00f20209b4ad91a00ac586597caee3a4da0bef02...


HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

Dataset json downloaded and prepared to /home/61460004/.cache/huggingface/datasets/json/default-ead940c088e772ef/0.0.0/83d5b3a2f62630efc6b5315f00f20209b4ad91a00ac586597caee3a4da0bef02. Subsequent calls will reuse this data.




HBox(children=(FloatProgress(value=0.0, max=213.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=27.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=27.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=3323.0), HTML(value='')))

In [None]:
embedding_index.save("sample_embeddings", index_fnames)

In [None]:
embedding_index.load('sample_embeddings', index_fnames)

In [None]:
embedding_index.dataset['train']