In [1]:
%matplotlib inline
import os
import pandas as pd
from transformers import AutoTokenizer, TFAutoModel
import datasets
import tensorflow as tf
from sklearn.preprocessing import normalize
import faiss
import numpy as np
from functools import partial

from tensorflow.keras import mixed_precision

import numpy as np
import json

from matplotlib import pylab as plt
import seaborn as sns

In [2]:
mixed_precision.set_global_policy('mixed_float16')

INFO:tensorflow:Mixed precision compatibility check (mixed_float16): OK
Your GPU will likely run quickly with dtype policy mixed_float16 as it has compute capability of at least 7.0. Your GPU: NVIDIA A100-PCIE-40GB, compute capability 8.0


In [3]:
##Delcare filepaths

#dataset directory
base_data_dir = "../data/subset"

#base dataset
data_fname = f"{base_data_dir}/case_info.json"

#embeddings dataset
embeddings_model_name="allenai/specter"
embedding_type = "specter_cls"
embedding_dataset_dir = f"{base_data_dir}/{embedding_type}_embeddings_dataset"

#classification dataset
clf_dataset_dir = f"{base_data_dir}/{embedding_type}_clf_dataset"
data_files = {
    "train": f"{base_data_dir}/train_map.csv", 
    "validation": f"{base_data_dir}/val_map.csv", 
    "test": f"{base_data_dir}/test_map.csv",}

#model
models_dir = "../models"
model_name = "specter_cls_clf_model"
model_checkpoint = f"{models_dir}/{model_name}"
model_log_fname = f"{models_dir}/logs/{model_name}.csv"

In [4]:
def load_text(examples):
    """Create text features from head matter and opinion text for a batch of examples"""
    
    batch = [item[0] + "\n" + item[1] for item in zip(examples['head_matter'], examples['opinion_text'])]
    return {"text": batch}

def load_embeddings(examples, tokenizer, model, embedding_type):
    """Tokenize and load embeddings from given huggingface pretrained model"""
    
    tokenized = tokenizer(examples["text"],
                          return_tensors="tf",
                          padding=True,
                          truncation=True,
                          max_length=512
                         )
    if embedding_type == "pooled":
        embeddings = {'embeddings': normalize(model(**tokenized)[1].numpy())}
    else:
        embeddings = {'embeddings': normalize(model(**tokenized)[0][:,0, :].numpy())}
    return embeddings


def load_embeddings_dataset(
    dataset_dir, embedding_type, embedding_model="allenai/specter",
    num_proc=15, batch_size=256, faiss_device=0, keep_in_memory=False):
    """Load embeddings dataset and create faiss index on embeddings column"""
    
    if os.path.isfile(f"{dataset_dir}/state.json"):
        print("Found existing embeddings. loading from disk ...")
        dataset = datasets.Dataset.load_from_disk(dataset_dir, keep_in_memory=keep_in_memory)
        if os.path.isfile(f"{dataset_dir}/embeddings.faiss"):
            print("Found existing fiass index. loading from disk ...")
            dataset.load_faiss_index("embeddings", f"{dataset_dir}/embeddings.faiss")
        else:
            print("No fiass index found. creating and saving new index to disk ...")
            dataset.add_faiss_index(column="embeddings", device=faiss_device, metric_type=faiss.METRIC_INNER_PRODUCT)
            dataset.save_faiss_index("embeddings", f"{dataset_dir}/embeddings.faiss")
    else:
        print("No existing embeddings found. Creating and saving to disk ...")
        tokenizer = AutoTokenizer.from_pretrained(embedding_model)
        model = TFAutoModel.from_pretrained(embedding_model, from_pt=True)
        
        print("Loading dataset and text column ...")
        dataset = datasets.load_dataset("json", data_files=data_fname, split=datasets.splits.Split("train"))
        exclude_columns = ["jurisdiction_id","court_id","decision_date", "head_matter","opinion_text","citation_ids"]
        dataset = dataset.map(load_text, batched=True, num_proc=num_proc, remove_columns=exclude_columns)
        print("Loading embeddings ...")
        embedder = partial(load_embeddings, tokenizer=tokenizer, model=model, embedding_type=embedding_type.split("_")[-1])
        dataset = dataset.map(embedder, batched=True, batch_size=batch_size)
        print(f"Saving Dataset to disk at {dataset_dir}")
        dataset.save_to_disk(dataset_dir)
        print("Creating new fiass index and saving to disk ...")
        dataset.add_faiss_index(column="embeddings", device=faiss_device, metric_type=faiss.METRIC_INNER_PRODUCT)
        dataset.save_faiss_index("embeddings", f"{dataset_dir}/embeddings.faiss")
    return dataset

In [6]:
embeddings_dataset = load_embeddings_dataset(
    embedding_dataset_dir,
    embedding_type,
    embedding_model=embeddings_model_name)

Found existing embeddings. loading from disk ...
Found existing fiass index. loading from disk ...


In [5]:
def load_clf_embeddings(examples):
    """load embeddings into classification dataset."""
    
    return {
        "case_embedding": embeddings_dataset[examples["id"]]["embeddings"],
        "citation_embedding": embeddings_dataset[examples["citation"]]["embeddings"],
        }


def generator_from_dataset(dataset):
    """Create a generator from a huggignface dataset."""
    
    def _gen():
        for item in dataset:
            features = (item["case_embedding"], item["citation_embedding"])
            yield features, item['label']
    return _gen


def tf_dataset_from_dataset(dataset):
    """Create a tensorflow dataset from a huggingface dataset using a generator."""
    
    dataset_generator = generator_from_dataset(dataset)
    tfdataset = tf.data.Dataset.from_generator(
        dataset_generator,
        output_signature=(
         (tf.TensorSpec(shape=(768,), dtype=tf.float32),
          tf.TensorSpec(shape=(768,), dtype=tf.float32),
         ),
         tf.TensorSpec(shape=(None), dtype=tf.int32))
    )
    tfdataset = tfdataset.apply(tf.data.experimental.assert_cardinality(len(dataset)))
    return tfdataset

def shuffle_batch_repeat(dataset, batch_size=64, shuffle=False):
    """Batch shuffle and repeat a tensorflow dataset infinitely."""
    
    dataset = dataset.repeat()
    if shuffle:
        dataset = dataset.shuffle(batch_size*4)
    dataset = dataset.batch(batch_size)
    dataset = dataset.prefetch(buffer_size=tf.data.AUTOTUNE)
    return dataset


def load_clf_datasets(
    dataset_dir, data_files, batch_size=64,
    embeddings_dataset=None, num_proc=15):
    """Create tensorflow dataset with pairs for embeddings and a label for classification."""
    
    if os.path.isfile(f"{dataset_dir}/dataset_dict.json"):
        print("Found existing dataset dict. loading from disk ...")
        dataset = datasets.DatasetDict.load_from_disk(dataset_dir, keep_in_memory=True)
        dataset.set_format(type='tensorflow', columns=['case_embedding', 'citation_embedding', 'label'])
        
        train_dataset = dataset['train']
        val_dataset = dataset['validation']
        test_dataset = dataset['test']
        
    elif embeddings_dataset is not None:
        print("Found data files and embeddings dataset. Creating new clf dataset and saving to disk ...")
        dataset = datasets.load_dataset("csv", data_files=data_files)
        print("Mapping embeddings to clf dataset ...")        
        dataset = dataset.map(load_clf_embeddings, batched=True, num_proc=num_proc)
        print(f"Saving Dataset to disk at {dataset_dir}")
        dataset.save_to_disk(dataset_dir)
        dataset.set_format(type='tensorflow', columns=['case_embedding', 'citation_embedding', 'label'])
        
        train_dataset = dataset['train']
        val_dataset = dataset['validation']
        test_dataset = dataset['test']

    print("loading tensorflow datasets ...")
    train_dataset = tf_dataset_from_dataset(train_dataset)
    val_dataset = tf_dataset_from_dataset(val_dataset)
    test_dataset = tf_dataset_from_dataset(test_dataset)
    
    print("batching and shuffling tensorflow datasets ... ")
    train_dataset = shuffle_batch_repeat(train_dataset, batch_size=batch_size, shuffle=True)
    val_dataset = shuffle_batch_repeat(val_dataset, batch_size=batch_size,)
    test_dataset = shuffle_batch_repeat(test_dataset, batch_size=batch_size,)
    
    return train_dataset, val_dataset, test_dataset


In [7]:
def load_model(model_name, hidden_dim=256, dropout=0.2):
    """Create a classification model for paired embedding inputs"""
    
    case_input = tf.keras.layers.Input(shape=(768,), dtype=tf.float32, name="case_input")
    citation_input = tf.keras.layers.Input(shape=(768,), dtype=tf.float32, name="citation_input")

    case_representation = tf.keras.layers.BatchNormalization()(case_input)
    citation_representation = tf.keras.layers.BatchNormalization()(citation_input)

    case_representation = tf.keras.layers.Dense(hidden_dim, activation="relu")(case_representation)

    citation_representation = tf.keras.layers.Dense(hidden_dim, activation="relu")(citation_representation)

    case_representation = tf.keras.layers.Dropout(dropout,)(case_representation)
    citation_representation = tf.keras.layers.Dropout(dropout,)(citation_representation)

    case_representation = tf.keras.layers.Dense(hidden_dim, activation="relu")(case_representation)
    citation_representation = tf.keras.layers.Dense(hidden_dim, activation="relu")(citation_representation)
    
    sims = tf.keras.layers.Dot(axes=1, normalize=True)([case_representation, citation_representation])
    concatenated = tf.keras.layers.Concatenate()([case_representation, citation_representation])

    def shared_stack(prev_input):
        return tf.keras.models.Sequential(
            [tf.keras.layers.BatchNormalization(),
             tf.keras.layers.Dense(hidden_dim, activation="relu"),
             tf.keras.layers.Dropout(dropout)
            ])(prev_input) 
    concatenated = shared_stack(concatenated)
    concatenated = shared_stack(concatenated)
    concatenated = shared_stack(concatenated)
    output = tf.keras.layers.Dense(1, dtype=tf.float32, activation="sigmoid")(concatenated)
    output = tf.keras.layers.Average()([sims,output])
    clf_model = tf.keras.models.Model(inputs=[case_input, citation_input], outputs=[output], name=model_name)
    loss= tf.keras.losses.BinaryCrossentropy()
    acc = tf.keras.metrics.BinaryAccuracy()
    clf_model.compile(loss=loss, metrics=[acc],  optimizer="adam")
    clf_model.summary()
    return clf_model

In [8]:
clf_model = load_model(model_name)

Model: "specter_cls_clf_model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
case_input (InputLayer)         [(None, 768)]        0                                            
__________________________________________________________________________________________________
citation_input (InputLayer)     [(None, 768)]        0                                            
__________________________________________________________________________________________________
batch_normalization (BatchNorma (None, 768)          3072        case_input[0][0]                 
__________________________________________________________________________________________________
batch_normalization_1 (BatchNor (None, 768)          3072        citation_input[0][0]             
______________________________________________________________________________

In [9]:
callbacks = [tf.keras.callbacks.EarlyStopping(patience=5, restore_best_weights=True),
             tf.keras.callbacks.ModelCheckpoint(model_checkpoint, save_best_only=True,),
             tf.keras.callbacks.CSVLogger(model_log_fname),
             tf.keras.callbacks.ReduceLROnPlateau(patience=3, verbose=1)
            ]
             

In [10]:

train_dataset, val_dataset, test_dataset = load_clf_datasets(
    clf_dataset_dir, data_files, batch_size=128, embeddings_dataset=embeddings_dataset)

Found existing dataset dict. loading from disk ...
loading tensorflow datasets ...
batching and shuffling tensorflow datasets ... 


In [11]:
history = clf_model.fit(
    train_dataset,
    steps_per_epoch=1000,
    validation_data=val_dataset,
    validation_steps=100,
    verbose=1,
    epochs=15,
    callbacks=callbacks
)
clf_model.save(model_checkpoint)

Epoch 1/15
INFO:tensorflow:Assets written to: ../models/specter_cls_clf_model/assets
Epoch 2/15
INFO:tensorflow:Assets written to: ../models/specter_cls_clf_model/assets
Epoch 3/15
Epoch 4/15
INFO:tensorflow:Assets written to: ../models/specter_cls_clf_model/assets
Epoch 5/15
INFO:tensorflow:Assets written to: ../models/specter_cls_clf_model/assets
Epoch 6/15
INFO:tensorflow:Assets written to: ../models/specter_cls_clf_model/assets
Epoch 7/15
INFO:tensorflow:Assets written to: ../models/specter_cls_clf_model/assets
Epoch 8/15
INFO:tensorflow:Assets written to: ../models/specter_cls_clf_model/assets
Epoch 9/15
INFO:tensorflow:Assets written to: ../models/specter_cls_clf_model/assets
Epoch 10/15
Epoch 11/15
Epoch 12/15

Epoch 00012: ReduceLROnPlateau reducing learning rate to 0.00010000000474974513.
Epoch 13/15
Epoch 14/15
INFO:tensorflow:Assets written to: ../models/specter_cls_clf_model/assets


In [None]:
database = load_embeddings_dataset(
    embedding_dataset_dir,
    embedding_type,
    embedding_model=embeddings_model_name, keep_in_memory=True)

database.set_format("numpy", columns=["embeddings", "id"], dtype=np.float32)

In [None]:
train_df = pd.read_csv("../data/subset/train_map.csv",)
val_df = pd.read_csv("../data/subset/val_map.csv", )
test_df = pd.read_csv("../data/subset/test_map.csv",)

In [None]:
train_df = train_df.groupby("id").agg({"citation": list})
val_df = val_df.groupby("id").agg({"citation": list})
test_df = test_df.groupby("id").agg({"citation": list})

In [None]:
def retrieve_top_k_preds(df, k=50, key=None):
    query_embeddings = database[df.index.tolist()]['embeddings']
    scores, samples = database.get_nearest_examples_batch('embeddings', query_embeddings, k=k)
    if key:
        preds = [sample[key] for sample in samples]
    else:
        preds = samples
    return preds

In [None]:
val_preds = retrieve_top_k_preds(val_df, key="id")
test_preds = retrieve_top_k_preds(test_df, key="id")

In [None]:

def apk(actual, predicted, k=10):
    """
    #https://github.com/benhamner/Metrics/blob/master/Python/ml_metrics/average_precision.py
    Computes the average precision at k.
    """
    if len(predicted)>k:
        predicted = predicted[:k]

    score = 0.0
    num_hits = 0.0

    for i,p in enumerate(predicted):
        if p in actual and p not in predicted[:i]:
            num_hits += 1.0
            score += num_hits / (i+1.0)

    if not actual:
        return 0.0

    return score / min(len(actual), k)

def mapk(actual, predicted, k=10):
    """
    #https://github.com/benhamner/Metrics/blob/master/Python/ml_metrics/average_precision.py
    Computes the mean average precision at k.
    """
    return np.mean([apk(a,p,k) for a,p in zip(actual, predicted)])

In [None]:
import tqdm

In [None]:
idx2cite=json.load(open("../data/subset/idx2cite.json"))
idx2cite = {int(k):int(v) for k,v in idx2cite.items()}

tocite = lambda x: [idx2cite.get(y) for y in x]
val_df.loc[: , "cite_id"] = val_df["citation"].map(tocite)
test_df.loc[: , "cite_id"] = test_df["citation"].map(tocite)

def load_map(df, preds):
    scores = []
    for k in tqdm.tnrange(5, 30, 5):
        scores.append(mapk(df.cite_id.tolist(), preds, k))
    return scores


In [None]:
val_scores = load_map(val_df, val_preds)
test_scores = load_map(test_df, val_preds)


In [None]:
test_scores

In [None]:
def plot_scores(val_scores, test_scores, col_names = ["val_map", "test_map"]):
    plot_df = pd.DataFrame(list(zip(val_scores, test_scores)), columns=col_names, index=range(5, 30, 5))
    plt.figure(figsize=(20,10))
    sns.lineplot(data=plot_df, markers=True)
    plt.title("mean average precision(MAP) at various k's recommended samples")
    plt.xlabel("top-k")
    sns.despine()
    plt.show()

In [None]:
plot_scores(val_scores, test_scores)

In [None]:
cite2idx=json.load(open("../data/subset/cite2idx.json"))
cite2idx = {int(k):int(v) for k,v in cite2idx.items()}

toidx = lambda x: [cite2idx.get(y) for y in x]


def get_sorted_preds(model, df, preds):
    filtered_preds = []
    preds = list(map(toidx, preds))
    for idx, pred in tqdm.tqdm_notebook(zip(df.index, preds), total=len(df)):
        num_results = len(pred)
        result_embeddings = database[pred]['embeddings']
        query_embeddings = np.array([database[idx]['embeddings']]*num_results)
        preds = model.predict((query_embeddings, result_embeddings), batch_size=num_results)
        preds = preds.flatten()
        filtered_preds.append(np.array(pred)[np.argsort(preds)])
    return filtered_preds
        
        

In [None]:
model_checkpoint

In [None]:
clf_model = tf.keras.models.load_model(model_checkpoint)

In [None]:
sorted_preds = get_top_filtered_k(clf_model, val_df.head(100), val_preds[:100])

In [None]:
scores = []
for i in tqdm.tnrange(1, 50):
    scores.append(mapk(val_df.head(100).citation, filtered_preds, k=i))

In [None]:
scores

In [None]:
val_df.head(10)

In [None]:
def precision`_recall_k(y_true, y_pred, k=10):
    precisions =[]
    recalls = []
    for y_t, y_p in zip(y_true, y_pred):
        y_p = set(y_p[:k])
        relevant_retrieved = set(y_t).intersection(y_p)
        if relevant_retrieved:
            precision = len(relevant_retrieved)/len(y_p)
            recall = len(relevant_retrieved)/len(y_t)
            precisions.append(precision)
            recalls.append(recall)
    return np.mean(precisions), np.mean(recalls)