In [1]:
import pandas as pd
from transformers import AutoTokenizer, TFAutoModel
import datasets
import tensorflow as tf

In [None]:
data_fname = "../data/subset/case_info.json"
# cases_df = pd.read_json(data_fname, lines=True, orient="records")[["id", "head_matter", "opinion_text"]]
# cases_df["text"] = cases_df["head_matter"] + "\n" +  cases_df["opinion_text"]
# cases_df = cases_df[["id", "text"]]

In [None]:
model_name = "allenai/specter"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = TFAutoModel.from_pretrained(model_name, from_pt=True)


In [None]:
dataset = datasets.load_dataset("json", data_files=data_fname, split=datasets.splits.Split("train"))

In [None]:
def load_text(examples):
    return {"text": [item[0] + "\n" + item[1] for item in zip(examples['head_matter'], examples['opinion_text'])]}

exclude_columns = ["jurisdiction_id","court_id","decision_date", "head_matter","opinion_text","citation_ids"]
dataset=dataset.map(load_text, batched=True, num_proc=15, remove_columns=exclude_columns)

In [None]:
def load_embeddings(examples):
    tokenized = tokenizer(examples["text"],
                          return_tensors="tf",
                          padding=True,
                          truncation=True,
                          max_length=512
                         )
    
    return {'embeddings': model(**tokenized)[0][:,0,:].numpy()}

In [None]:
dataset=dataset.map(load_embeddings, batched=True, batch_size=256)
dataset.save_to_disk("../data/subset/dataset_specter_embeddings")

In [None]:
dataset.add_faiss_index(column='embeddings')
dataset.save_faiss_index('embeddings', 'embeddings.faiss')

In [None]:
dataset = datasets.Dataset.load_from_disk("../data/subset/dataset_specter_embeddings")

In [None]:
train_fname = "../data/subset/train_map.csv"
val_fname = "../data/subset/val_map.csv"
test_fname = "../data/subset/test_map.csv"


def fix_nulls_and_types(fname):
    df = pd.read_csv(fname)
    df = df.dropna()
    df = df.astype(int)
    df.to_csv(fname, index=False, index_label=False)
    return df
# fix_nulls_and_types(train_fname)
# fix_nulls_and_types(val_fname)
# fix_nulls_and_types(test_fname)

In [None]:
data_files={
    "train": train_fname, 
    "validation": val_fname, 
    "test": test_fname,
}
clf_dataset = datasets.load_dataset("csv", data_files=data_files)

In [None]:
def load_clf_embeddings(examples):
    return {
        "case_embedding": dataset[examples["id"]]["embeddings"],
        "citation_embedding": dataset[examples["citation"]]["embeddings"],
    }

In [None]:
clf_dataset = clf_dataset.map(load_clf_embeddings, batched=True, num_proc=15)
clf_dataset.save_to_disk("../data/subset/clf_dataset")

In [2]:
def generator_from_dataset(dataset):
    def _gen():
        for item in dataset:
            features = (item["case_embedding"], item["citation_embedding"])
            yield features, item['label']
    return _gen

def tf_dataset_from_dataset(dataset):
    dataset_generator = generator_from_dataset(dataset)
    tfdataset = tf.data.Dataset.from_generator(
        dataset_generator,
        output_signature=(
         (tf.TensorSpec(shape=(768,), dtype=tf.float32),
          tf.TensorSpec(shape=(768,), dtype=tf.float32),
         ),
         tf.TensorSpec(shape=(None), dtype=tf.int32))
    )
    tfdataset = tfdataset.apply(tf.data.experimental.assert_cardinality(len(dataset)))
    return tfdataset


In [3]:
def tf_dataset_from_datasetv2(dataset):
    features = (dataset['case_embedding'],
                dataset['citation_embedding'])
    labels = dataset['label']
    return tf.data.Dataset.from_tensor_slices((features, labels))

In [4]:
clf_dataset = datasets.DatasetDict.load_from_disk("../data/subset/clf_dataset", keep_in_memory=True)
clf_dataset.set_format(type='tensorflow', columns=['case_embedding', 'citation_embedding', 'label'])

train_dataset = clf_dataset['train']
val_dataset = clf_dataset['validation']
test_dataset = clf_dataset['test']

In [5]:
train_dataset = tf_dataset_from_dataset(train_dataset)
val_dataset = tf_dataset_from_dataset(val_dataset)
test_dataset = tf_dataset_from_dataset(test_dataset)

In [11]:
case_input = tf.keras.layers.Input(shape=(768,), dtype=tf.float32, name="case_input")
citation_input = tf.keras.layers.Input(shape=(768,), dtype=tf.float32, name="citation_input")

shared_stack = tf.keras.models.Sequential([
    tf.keras.layers.Dense(512, activation="relu"),
    tf.keras.layers.Dropout(0.3,),
    tf.keras.layers.Dense(256, activation="relu"),
    tf.keras.layers.Dense(128, activation="relu"),
    tf.keras.layers.Dropout(0.3,),
    tf.keras.layers.Dense(64, activation="relu"),
], name="shared_stack")

case_representation = shared_stack(case_input)
citation_representation = shared_stack(citation_input)
concatenated = tf.keras.layers.Concatenate()([case_representation, citation_representation])
concatenated = tf.keras.layers.Dropout(0.3,)(concatenated)
output = tf.keras.layers.Dense(1, activation="sigmoid")(concatenated)
clf_model = tf.keras.models.Model(inputs=[case_input, citation_input], outputs=[output], name="clf_model")

In [12]:
clf_model.summary()

Model: "clf_model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
case_input (InputLayer)         [(None, 768)]        0                                            
__________________________________________________________________________________________________
citation_input (InputLayer)     [(None, 768)]        0                                            
__________________________________________________________________________________________________
shared_stack (Sequential)       (None, 64)           566208      case_input[0][0]                 
                                                                 citation_input[0][0]             
__________________________________________________________________________________________________
concatenate (Concatenate)       (None, 128)          0           shared_stack[0][0]       

In [13]:
clf_model.compile(loss="bce", metrics=["binary_accuracy"],  optimizer="adam")

In [14]:
batch_size = 1024

train_dataset = train_dataset.repeat().shuffle(batch_size*4).batch(batch_size).prefetch(buffer_size=tf.data.AUTOTUNE)
val_dataset = val_dataset.repeat().batch(batch_size).prefetch(buffer_size=tf.data.AUTOTUNE)
test_dataset = test_dataset.repeat().batch(batch_size).prefetch(buffer_size=tf.data.AUTOTUNE)

In [19]:
callbacks = [tf.keras.callbacks.EarlyStopping(patience=5, restore_best_weights=True),
             tf.keras.callbacks.ModelCheckpoint("../models/specter_matching_v1", save_best_only=True,),
             tf.keras.callbacks.CSVLogger("../models/logs/specter_matching_v1.csv"),
             tf.keras.callbacks.ReduceLROnPlateau(patience=3, verbose=1)
            ]
             

In [None]:
history = clf_model.fit(
    train_dataset,
    steps_per_epoch=250,
    validation_data=val_dataset,
    validation_steps=100,
    verbose=1,
    epochs=15,
    callbacks=callbacks
)

Epoch 1/15
 15/250 [>.............................] - ETA: 21:05 - loss: 0.6827 - binary_accuracy: 0.5527