In [1]:
import pickle

cases = pickle.load(open("processedCases.pickle", "rb"))

justiceId = "ruth_bader_ginsburg"

justiceCases = [x for x in cases if justiceId in x["votes"].keys()]

In [2]:
len(justiceCases)

1770

In [3]:
len(cases)

2938

In [4]:
justiceCases[0]

{'plantiff': 'Calderone',
 'defendant': 'Thompson',
 'question': "Did the Court of Appeal's order recalling its mandate denying Thomas M. Thompson all habeas relief violate 28 USC section 2244(b) as amended by the Antiterrorism and Effective Death Penalty Act of 1996? Was the order an abuse of the appellate court's discretion?",
 'facts': "In 1983, Thomas M. Thompson was convicted of the rape and murder of Ginger Fleischli in California state court. The special circumstance found by the jury of murder during the commission of rape made Thompson eligible for the death penalty. In 1995, a federal District Court invalidated Thompson's death sentence by granted relief on his rape conviction and the rape special circumstance. In reversing, the Court of Appeals reinstated Thompson's death sentence, noting that the State presented strong evidence of rape at trial. The Court of Appeals then issued a mandate denying all habeas relief. Two days before Thompson's execution, the Court of Appeals r

In [5]:
data = [{"question": x["question"], "facts": x["facts"], "vote": x["votes"][justiceId]} for x in justiceCases]

In [6]:
import pandas as pd

df = pd.DataFrame(data)

In [7]:
df.head()

Unnamed: 0,question,facts,vote
0,Did the Court of Appeal's order recalling its ...,"In 1983, Thomas M. Thompson was convicted of t...",defendant
1,"Should a court's ""substantial burden"" analysis...","In 2013, the Texas Legislature passed House Bi...",defendant
2,When a person is convicted of conspiracy to co...,Manoj Nijhawan was convicted of conspiracy to ...,defendant
3,ol><li>Does the Eighth Amendment require an in...,Russell Bucklew was convicted by a state court...,plantiff
4,Did Maine's tax exemption statute violate the ...,Camps Newfound/Owatonna Inc. (Camps) operates ...,defendant


In [8]:
df["text"] = df["question"] + " " + df["facts"]

In [9]:
df["label"] = [1 if x == "defendant" else 0 for x in df["vote"]]

In [10]:
df.head()

Unnamed: 0,question,facts,vote,text,label
0,Did the Court of Appeal's order recalling its ...,"In 1983, Thomas M. Thompson was convicted of t...",defendant,Did the Court of Appeal's order recalling its ...,1
1,"Should a court's ""substantial burden"" analysis...","In 2013, the Texas Legislature passed House Bi...",defendant,"Should a court's ""substantial burden"" analysis...",1
2,When a person is convicted of conspiracy to co...,Manoj Nijhawan was convicted of conspiracy to ...,defendant,When a person is convicted of conspiracy to co...,1
3,ol><li>Does the Eighth Amendment require an in...,Russell Bucklew was convicted by a state court...,plantiff,ol><li>Does the Eighth Amendment require an in...,0
4,Did Maine's tax exemption statute violate the ...,Camps Newfound/Owatonna Inc. (Camps) operates ...,defendant,Did Maine's tax exemption statute violate the ...,1


In [11]:
df.describe()

Unnamed: 0,label
count,1770.0
mean,0.458757
std,0.498437
min,0.0
25%,0.0
50%,0.0
75%,1.0
max,1.0


In [12]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")

In [13]:
from datasets import Dataset

dataset = Dataset.from_pandas(df[["text", "label"]])

In [14]:
dataset

Dataset({
    features: ['text', 'label'],
    num_rows: 1770
})

In [15]:
dataset = dataset.train_test_split(test_size=0.2)

In [16]:
def tokenize_function(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True)

tokenized_datasets = dataset.map(tokenize_function, batched=True)

100%|██████████| 2/2 [00:00<00:00,  5.74ba/s]
100%|██████████| 1/1 [00:00<00:00, 12.15ba/s]


In [17]:
from transformers import AutoModelForSequenceClassification

model = AutoModelForSequenceClassification.from_pretrained("bert-base-cased", num_labels=2)

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertForSequenceClassification: ['cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at b

In [18]:
from transformers import TrainingArguments

training_args = TrainingArguments("test_trainer")

In [19]:
from transformers import Trainer

trainer = Trainer(
    model=model, args=training_args, train_dataset=tokenized_datasets["train"], eval_dataset=tokenized_datasets["test"]
)

trainer.train()

The following columns in the training set  don't have a corresponding argument in `BertForSequenceClassification.forward` and have been ignored: text.
***** Running training *****
  Num examples = 1416
  Num Epochs = 3
  Instantaneous batch size per device = 8
  Total train batch size (w. parallel, distributed & accumulation) = 8
  Gradient Accumulation steps = 1
  Total optimization steps = 531
 37%|███▋      | 197/531 [18:07:08<35:28:38, 382.39s/it]

KeyboardInterrupt: 

In [20]:
import numpy as np
from datasets import load_metric

metric = load_metric("accuracy")

def compute_metrics(eval_pred):
    logits = eval_pred[0]
    labels = eval_pred[1]
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels)

In [21]:
preds = trainer.predict(tokenized_datasets["test"])

The following columns in the test set  don't have a corresponding argument in `BertForSequenceClassification.forward` and have been ignored: text.
***** Running Prediction *****
  Num examples = 354
  Batch size = 8


In [22]:
compute_metrics(preds)

{'accuracy': 0.5254237288135594}