# BERT for Supreme Court Predictions

This notebook is used to predict the votes of given justices on a case as well as the overall outcome of a case with BERT.

Source:

- [Fine-tune a pretrained model](https://huggingface.co/docs/transformers/training) (Hugging Face)

First, we need to install required packages and clear the colab memory of stored data from previous iterations.

In [1]:
# Install packages
!pip install datasets
!pip install transformers==4.28.0

# Clear Colab memory
import os
import shutil
from google.colab import drive
drive.mount('/content/gdrive', force_remount=True)
try:
    shutil.rmtree('../content/bert-finetuned-sem_eval-english')
except FileNotFoundError:
    print('No previous models to remove')

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Mounted at /content/gdrive


In [2]:
# Import libraries
import torch
import random
import numpy as np
import pandas as pd
from datasets import *
from sklearn.metrics import f1_score, roc_auc_score, accuracy_score
from transformers import TrainingArguments, Trainer, AutoModelForSequenceClassification, EvalPrediction, AutoTokenizer

Some helper functions useful later on

In [3]:
def preprocess_data(examples):
    """
    Takes a batch of texts and encodes them as input_ids such that BERT
    can function
    
    Input:
        - examples: Raw data in text form
    
    Output:
        - encoding: Encoded dataset
    """
    # Take a batch of texts and encode them
    text = examples["new_text"]
    encoding = tokenizer(text, padding="max_length", truncation=True, max_length=128)
    
    # Add labels
    labels_batch = {k: examples[k] for k in examples.keys() if k in labels}
    
    # Fill numpy array with the bach of labels
    labels_matrix = np.zeros((len(text), len(labels)))
    for idx, label in enumerate(labels):
        labels_matrix[:, idx] = labels_batch[label]
    encoding["labels"] = labels_matrix.tolist()
    
    return encoding



def multi_label_metrics(predictions, labels, threshold=0.5):
    """
    Computes the f1, roc, and accuracy while training.

    Inputs:
        - predictions: Predictions made by BERT (to be transformed)
        - labels: Labels against which to check the predictions
        - threshold (float): theshold to use for the prediction
    
    Output:
        - metrics (dict): f1, roc_auc, and accuracy
    """
    # Compute predictions from the model
    sigmoid = torch.nn.Sigmoid()
    probs = sigmoid(torch.Tensor(predictions))
    y_pred = np.zeros(probs.shape)
    y_pred[np.where(probs >= threshold)] = 1

    # Compute f1, roc_auc, and accuracy
    y_true = labels
    f1_micro_average = f1_score(y_true=y_true, y_pred=y_pred, average='micro')
    roc_auc = roc_auc_score(y_true, y_pred, average = 'micro')
    accuracy = accuracy_score(y_true, y_pred)

    # Build dictionary
    metrics = {'f1': f1_micro_average,
               'roc_auc': roc_auc,
               'accuracy': accuracy}
    return metrics



def compute_metrics(p: EvalPrediction):
    """
    Wrapper to compute the metrics

    Input:
        - p (EvalPrediction): model output
    
    Output:
        - result (dict): f1, roc_auc, and accuracy
    """
    preds = p.predictions[0] if isinstance(p.predictions, 
            tuple) else p.predictions
    result = multi_label_metrics(
        predictions=preds, 
        labels=p.label_ids)
    return result



def select_random_chars(string, max_length=3000):
    """
    BERT cannot process more than 512 tokens so we narrow down
    the text per case to a random window of 3000 characters

    Inputs:
        - string (str): the text of an utterance
        - max_length (int): the maximal length of the new string
    
    Output:
        - selected_chars (str): new, truncated text of an utterance
    """
    string_length = len(string)

    # Check if the string is long enough to select 3000 characters
    if string_length <= max_length:
        return string

    # Generate a random starting index
    start_index = random.randint(0, string_length - max_length)

    # Select the 3000 characters from the string
    selected_chars = string[start_index:start_index + max_length]
    
    return selected_chars

Before getting to the model, let's load the data and wrangle it into the right format.

In [4]:
# Need the CPUs for what comes next
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"torch.device: {device}")

# Get list of files in folder
folder_path = "gdrive/MyDrive/data/"
file_list = os.listdir(folder_path)


# Concatenate the data
df_list = []
for file in file_list:
    # Check if file is a CSV and contains the data we want
    if file.endswith('.csv') and file.startswith('utterances_clean'):
        df = pd.read_csv(os.path.join(folder_path, file))
        df_list.append(df)
df = pd.concat(df_list, axis=0, ignore_index=True)

# For each utterance, add columns specifying who is addressed
df = df.merge(
    df[['speaker', 'side']].drop_duplicates(ignore_index=True),
    how='left',
    left_on='speaker_addressed',
    right_on='speaker',
    suffixes=('', '_addressed')).drop('speaker_addressed', axis=1)

# Translate into natural text (e.g., 'J' --> 'Justice', 'j__clarence_thomas' --> 'Clarence Thomas')
for name_col, type_col in {'speaker': 'speaker_type', 'speaker_replied_to': 'speaker_type_replied_to'}.items():
    df[f'{name_col}_natural'] = df[name_col].apply(lambda name: ' '.join(name.lstrip('j__').split('_')).title())
    speaker_type_translation = {
        'J': 'Justice',
        'A': 'Attorney',
        '<Inaudile>': None
    }

    df[type_col].fillna('na', inplace=True)
    df[f'{type_col}_natural'] = df[type_col].apply(lambda s_type: speaker_type_translation[s_type] if not s_type == 'na' else None)

    df[f'{name_col}_natural'] = df[f'{type_col}_natural'] + ' ' + df[f'{name_col}_natural']
    df[f'{name_col}_natural'].fillna('Unknown', inplace=True)
    df.drop(f'{type_col}_natural', axis=1, inplace=True)

# Also translate the side into natural text (e.g., '1' --> 'Petitioning' (attorney))
side_translation = {
    0: 'Responding',
    1: 'Petitioning',
    2: '',
    3: ''
}

df['side'].fillna(3, inplace=True)
df[f'side_natural'] = df['side'].apply(lambda side: side_translation[side])

df['side_addressed'].fillna(3, inplace=True)
df[f'side_addressed_natural'] = df['side_addressed'].apply(lambda side: side_translation[side])

# Enrich the text of an utterance with some context information
df["new_text"] = "<UTTERANCE_START>" + df['side_natural'] + " " + df["speaker_natural"] + " says: '" + df["text"] + "' to " + df['side_addressed_natural'] + " " + df["speaker_replied_to_natural"] + " <UTTERANCE_END>"

torch.device: cuda


  df = pd.read_csv(os.path.join(folder_path, file))
  df = pd.read_csv(os.path.join(folder_path, file))
  df = pd.read_csv(os.path.join(folder_path, file))
  df = pd.read_csv(os.path.join(folder_path, file))


We want to only process the justices with the most cases in our data (2000-2019). Thus, we sort the justices by number of cases and select the top 15.

In [5]:
NB_JUSTICIES = 15

j_columns = [col for col in df.columns if col.startswith('votes_side_j_')]

# Count the number of cases by justice
nb_cases = {}
for col in j_columns:
    nb_cases[col] = df.drop_duplicates(
        subset='case_id',
        keep='first'
    )[col].count()

# Sort the justices and keep the top 15
frequent_justices = list(
    map(
        lambda pair: pair[0],
        sorted(
            nb_cases.items(),
            key = lambda pair: pair[1], reverse=True
        )
    )
)[:2] # [:NB_JUSTICIES] usually, but smaller for demonstration

# Also predict the overall case outcome
targets = frequent_justices + ['win_side']

targets

['votes_side_j__ruth_bader_ginsburg',
 'votes_side_j__clarence_thomas',
 'win_side']

In [6]:
EPOCHS = 5 # 15 usually, but smaller for demonstration

results = pd.DataFrame()

# Run the model for each justice (or win_side) in targets
for justice in targets:
    # Final data wrangling for the model (truncating the utterance text)
    df_j = df[['case_id', 'new_text'] + [justice]]
    grouped_df = df_j.groupby('case_id')['new_text'].apply(lambda x: ','.join(x)).reset_index()
    justices = df[["case_id"] + [justice]].drop_duplicates(keep='first')

    df1 = pd.merge(grouped_df, justices, left_on='case_id', right_on='case_id', how='left').dropna(axis='rows', how='any')
    df1 = df1.drop(df1[~df1[justice].isin([0, 1])].index)

    df1['new_text'] = df1['new_text'].apply(select_random_chars)

    # if there's no data for this judge go to next judge in the loop
    if len(df1) < 10:
        str = justice + " judge does not have enough cases for the current dataset"
        print(str)
        continue

    dataset = Dataset.from_pandas(df1.drop('case_id', axis=1), preserve_index = False)

    # 70, 15, 15 split for train, validate, test
    dataset = dataset.train_test_split(test_size=0.3, shuffle=True)
    dataset_test_valid = dataset['test'].train_test_split(test_size=0.5, shuffle=True)
    dataset = DatasetDict({
        'train': dataset['train'],
        'test': dataset_test_valid['test'],
        'validation': dataset_test_valid['train']})

    labels = [label for label in dataset['train'].features.keys() if label not in ['case_id', 'new_text']]
    id2label = {idx:label for idx, label in enumerate(labels)}
    label2id = {label:idx for idx, label in enumerate(labels)}

    # Tokenize the text
    tokenizer = AutoTokenizer.from_pretrained(
        "bert-base-uncased",
        return_overflowing_tokens=True
    )
    encoded_dataset = dataset.map(
        preprocess_data,
        batched=True,
        remove_columns=dataset['train'].column_names
    )

    # Define the model
    model = AutoModelForSequenceClassification.from_pretrained("bert-base-uncased", 
                                                                problem_type="multi_label_classification", 
                                                                num_labels=len(labels),
                                                                id2label=id2label,
                                                                label2id=label2id)

    BATCH_SIZE = 8
    METRIC_NAME = "f1"

    # Set arguments
    args = TrainingArguments(
        f"bert-finetuned-sem_eval-english",
        evaluation_strategy = "epoch",
        save_strategy = "epoch",
        learning_rate=2e-5,
        per_device_train_batch_size=BATCH_SIZE,
        per_device_eval_batch_size=BATCH_SIZE,
        num_train_epochs=EPOCHS,
        weight_decay=0.01,
        load_best_model_at_end=True,
        metric_for_best_model=METRIC_NAME
    )

    os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'
    trainer = Trainer(
        model,
        args,
        train_dataset=encoded_dataset["train"],
        eval_dataset=encoded_dataset["validation"],
        tokenizer=tokenizer,
        compute_metrics=compute_metrics
    )

    # Train
    trainer.train()

    # Validation error
    val_metrics = trainer.evaluate()
    val_metrics = dict((key.replace('eval', 'val'), [value]) for (key, value) in val_metrics.items())

    # Initialize a new trainer instance with the trained model and test data
    trainer_test = Trainer(
        model=model,  
        args=args,
        eval_dataset=encoded_dataset["test"], 
        tokenizer=tokenizer,
        compute_metrics=compute_metrics
        )

    # Test error
    test_metrics = trainer_test.evaluate()
    test_metrics = dict((key.replace('eval', 'test'), [value]) for (key, value) in test_metrics.items())

    metrics = val_metrics
    metrics.update(test_metrics)
    metrics['justice'] = [justice]

    # Save results
    results = pd.concat([results, pd.DataFrame(metrics)], ignore_index=True)

results = results[
    ['justice', 'epoch'] +
    [col for col in results.columns if col.startswith('val_')] +
    [col for col in results.columns if col.startswith('test_')]
]

Map:   0%|          | 0/973 [00:00<?, ? examples/s]

Map:   0%|          | 0/209 [00:00<?, ? examples/s]

Map:   0%|          | 0/208 [00:00<?, ? examples/s]

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at

Epoch,Training Loss,Validation Loss,F1,Roc Auc,Accuracy
1,No log,0.709605,0.533654,0.5,0.533654
2,No log,0.696107,0.514423,0.537243,0.514423
3,No log,0.715839,0.543269,0.51161,0.543269
4,No log,0.71694,0.533654,0.513653,0.533654
5,0.643900,0.734031,0.552885,0.538822,0.552885


Map:   0%|          | 0/970 [00:00<?, ? examples/s]

Map:   0%|          | 0/209 [00:00<?, ? examples/s]

Map:   0%|          | 0/208 [00:00<?, ? examples/s]

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at

Epoch,Training Loss,Validation Loss,F1,Roc Auc,Accuracy
1,No log,0.682091,0.576923,0.5,0.576923
2,No log,0.685085,0.576923,0.5,0.576923
3,No log,0.698767,0.533654,0.489773,0.533654
4,No log,0.757383,0.490385,0.468939,0.490385
5,0.660800,0.837426,0.451923,0.446212,0.451923


Map:   0%|          | 0/980 [00:00<?, ? examples/s]

Map:   0%|          | 0/210 [00:00<?, ? examples/s]

Map:   0%|          | 0/210 [00:00<?, ? examples/s]

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at

Epoch,Training Loss,Validation Loss,F1,Roc Auc,Accuracy
1,No log,0.616417,0.704762,0.5,0.704762
2,No log,0.622736,0.704762,0.5,0.704762
3,No log,0.659019,0.652381,0.509699,0.652381
4,No log,0.701166,0.633333,0.519616,0.633333
5,0.620800,0.74786,0.595238,0.506648,0.595238


Done with training the model, just saving and analyzing the results is left to do.

In [7]:
results.to_csv("gdrive/MyDrive/data/BERT_results.csv", index=False)
results

Unnamed: 0,justice,epoch,val_loss,val_f1,val_roc_auc,val_accuracy,val_runtime,val_samples_per_second,val_steps_per_second,test_loss,test_f1,test_roc_auc,test_accuracy,test_runtime,test_samples_per_second,test_steps_per_second
0,votes_side_j__ruth_bader_ginsburg,5.0,0.734031,0.552885,0.538822,0.552885,1.6305,127.566,15.946,0.729654,0.593301,0.576366,0.593301,1.5879,131.618,17.003
1,votes_side_j__clarence_thomas,5.0,0.682091,0.576923,0.5,0.576923,1.642,126.678,15.835,0.672068,0.617225,0.5,0.617225,1.5907,131.389,16.974
2,win_side,5.0,0.616417,0.704762,0.5,0.704762,1.6706,125.701,16.162,0.627484,0.685714,0.5,0.685714,1.5989,131.339,16.887
