This Jupyter notebook is used to process SCOTUS cases and generate embedding as well as construct a graph from it.


The following cell connects to a Neo4j database using credentials stored in environment variables.
The connection is verified by running a simple query to ensure the connection is successful.


In [1]:
import os
from dotenv import load_dotenv
from neo4j import GraphDatabase

# Load environment variables from .env file
load_dotenv()

# Read Neo4j credentials from environment variables
neo4j_username = os.getenv('NEO4J_USERNAME')
neo4j_password = os.getenv('NEO4J_PASSWORD')
neo4j_uri = os.getenv('NEO4J_URI')

# Create a Neo4j driver instance
driver = GraphDatabase.driver(neo4j_uri, auth=(neo4j_username, neo4j_password))

# Verify the connection

def verify_connection(driver):
    try:
        with driver.session() as session:
            result = session.run("RETURN 1")
            if result.single()[0] == 1:
                print("Connection to Neo4j established successfully.")
            else:
                print("Failed to establish connection to Neo4j.")
    except Exception as e:
        print(f"An error occurred: {e}")

verify_connection(driver)

Connection to Neo4j established successfully.


Next, we connect to a MongoDB database using credentials stored in environment variables.


In [2]:
from pymongo import MongoClient
MONGODB_USERNAME = os.getenv('MONGODB_USERNAME')
MONGODB_PASSWORD = os.getenv('MONGODB_PASSWORD')
MONGODB_HOST = os.getenv('MONGODB_HOST')
MONGODB_DATABASE = os.getenv('MONGODB_DATABASE')

mongo_uri = f"mongodb+srv://{MONGODB_USERNAME}:{MONGODB_PASSWORD}@{MONGODB_HOST}/{MONGODB_DATABASE}?authSource=admin&replicaSet=db-mongo-graph-explorer"
client = MongoClient(mongo_uri)
db = client[MONGODB_DATABASE]

The function `recursive_text_splitter`:
1. Takes a document and a maximum chunk size as input.
2. It uses the `textwrap` module to split the document into lines of the specified maximum chunk size.
3. If any chunk is still larger than the maximum chunk size, it recursively splits that chunk further.
4. The function returns a list of final chunks, each of which is no larger than the specified maximum chunk size.


In [3]:
import textwrap

def recursive_text_splitter(document, max_chunk_size):
    # Use textwrap to initially split the text into lines of max_chunk_size
    chunks = textwrap.wrap(document, width=max_chunk_size)
    
    final_chunks = []
    
    for chunk in chunks:
        if len(chunk) > max_chunk_size:
            # If a chunk is still larger than max_chunk_size, split it further
            final_chunks.extend(recursive_text_splitter(chunk, max_chunk_size))
        else:
            final_chunks.append(chunk)
    
    return final_chunks

The `get_embedding` function produces an embedding for a given `text`, using the `text-embedding-3-small` model from OpenAI.

In [4]:
from openai import OpenAI
client = OpenAI()

def get_embedding(text):
    response = client.embeddings.create(
        input=text,
        model="text-embedding-3-small"
    )
    return response.data[0].embedding

The following cell sets up the Pinecone client and initializes the index object that points to the Pinecone `socuts` index. 

In [5]:
from pinecone import Pinecone

# Load Pinecone API key from environment variables
pinecone_api_key = os.getenv('PINECONE_API_KEY')
print(pinecone_api_key)

# Initialize Pinecone client
pc = Pinecone(api_key=pinecone_api_key)

# Create a Pinecone index
index_name = "scotus"
index = pc.Index(index_name)


40dcba58-adbd-4509-9323-f47d8cf22799


The `save_opinion_embedding` function takes a `case_id` and `opinion_text` as input, and:
1. Splits the `opinion_text` into smaller chunks using the `recursive_text_splitter` function.
2. For each chunk, it generates an embedding using the `get_embedding` function.
3. It then creates metadata for each chunk, including the `case_id` and the chunk itself.
4. If any metadata value is `None`, it replaces it with the string "null".
5. Finally, it upserts the chunk ID, embedding, and metadata into the Pinecone index.

In [6]:
def save_opinion_embedding(case_id, opinion_text):
    chunks = recursive_text_splitter(opinion_text, 512)
    
    # Create an index if it doesn't exist  
    for chunk in chunks:
        chunk_id = str(case_id) + "_" + str(chunks.index(chunk))
        embedding = get_embedding(chunk)        
        metadata = {
            "case_id": str(case_id),
            "chunk": chunk
        }
        
        for key, value in metadata.items():
            if value is None:
                metadata[key] = "null" 

        index.upsert([(chunk_id, embedding, metadata)])

This cell below defines two classes, Node and Edge, which are used to create and manage nodes and edges in a Neo4j graph database.

The Node class:
- The constructor (__init__) initializes a node with an id, label, and properties. It escapes special characters in these values to ensure they are safe for Neo4j.
- The `create_node_query` method generates a Cypher query to merge (create or update) a node with the given `id`, `label`, and `properties`.
- The `create_node` method executes the query to create the node in the Neo4j database, but only if the `id`, `label`, and `properties` are valid (i.e., not None).

The Edge class:
- The constructor (__init__) initializes an edge with `from_node_id`, `to_node_id`, `relationship_type`, `properties`, and an `increment_property`. It escapes special characters in these values to ensure they are safe for Neo4j.
- The `create_edge_query` method generates a Cypher query to merge (create or update) an edge with the given properties between the specified nodes. It also increments the value of the `increment_property` if the edge already exists.
- The `create_edge` method executes the query to create the edge in the Neo4j database.


In [7]:
import re

def escape_neo4j_string(value):
    if isinstance(value, str):
        return re.sub(r"(['\"\\])", r"\\\1", value)
    return value

class Node:
    def __init__(self, id, label, properties):
        self.id = escape_neo4j_string(id)
        self.label = escape_neo4j_string(label)
        self.properties = {k: escape_neo4j_string(v) if isinstance(v, str) else v for k, v in properties.items()}

    def create_node_query(self):
        props = ', '.join([f"{key}: ${key}" for key in self.properties.keys()])
        return f"MERGE (n:{self.label} {{id: '{self.id}', {props}}})"

    def create_node(self):
        if not self.id or not self.label or any(value is None for value in self.properties.values()):
            return
        query = self.create_node_query()
        with driver.session() as session:
            session.run(query, **self.properties)


class Edge:
    def __init__(self, from_node_id, to_node_id, relationship_type, properties, increment_property):
        self.from_node_id = escape_neo4j_string(from_node_id)
        self.to_node_id = escape_neo4j_string(to_node_id)
        self.relationship_type = escape_neo4j_string(relationship_type)
        self.properties = {k: escape_neo4j_string(v) if isinstance(v, str) else v for k, v in properties.items()}
        self.increment_property = increment_property

    def create_edge_query(self):
        props = ', '.join([f"{key}: ${key}" for key in self.properties.keys()])
        return (f"MATCH (a {{id: '{self.from_node_id}'}}), (b {{id: '{self.to_node_id}'}}) "
                f"MERGE (a)-[r:{self.relationship_type} {{{props}}}]->(b) "
                f"ON CREATE SET r.{self.increment_property} = 1 "
                f"ON MATCH SET r.{self.increment_property} = r.{self.increment_property} + 1")

    def create_edge(self):
        query = self.create_edge_query()
        with driver.session() as session:
            session.run(query, **self.properties)



In the following cell we define the classes (or "entities" in our ontology based on the data structure found in the Oyez API).

In [8]:
from datetime import datetime

class Citation:
    def __init__(self, data):
        self.volume = data.get('volume')
        self.page = data.get('page')
        self.year = data.get('year')

    def __str__(self):
        return f"{self.volume} U.S. {self.page} ({self.year})"

class Advocate:
    def __init__(self, data):
        advocate_data = data.get('advocate', {})
        self.name = advocate_data.get('name') if advocate_data else None
        self.description = data.get('advocate_description')

    def __str__(self):
        return f"{self.name}: {self.description}"

class Decision:
    def __init__(self, data):
        self.description = data.get('description')
        self.winning_party = data.get('winning_party')
        self.decision_type = data.get('decision_type')
        self.votes = [Vote(v) for v in data.get('votes', []) if v]

    def __str__(self):
        return f"{self.description} - Winner: {self.winning_party}"

class Vote:
    def __init__(self, data):
        self.member = Justice(data.get('member', {}))
        self.vote = data.get('vote')
        self.opinion_type = data.get('opinion_type')
        self.href = data.get('href')

    def __str__(self):
        return f"{self.member.name}: {self.vote}"

class Justice:
    def __init__(self, data):
        self.id = data.get('ID')
        self.name = data.get('name')
        self.roles = [Role(r) for r in data.get('roles', []) if r]

    def __str__(self):
        return self.name

class Role:
    def __init__(self, data):
        self.type = data.get('type')
        self.date_start = datetime.fromtimestamp(data.get('date_start', 0))
        self.date_end = datetime.fromtimestamp(data.get('date_end', 0))
        self.role_title = data.get('role_title')

    def __str__(self):
        return f"{self.role_title} ({self.date_start.year}-{self.date_end.year})"

class DecidedBy:
    def __init__(self, data):
        self.name = data.get('name')
        self.members = [Justice(j) for j in data.get('members', []) if j]

class WrittenOpinion:
    def __init__(self, data):
        self.id = data.get('id')
        self.title = data.get('title')
        self.author = data.get('author')
        self.type_value = data.get('type', {}).get('value')
        self.type_label = data.get('type', {}).get('label')
        self.justia_opinion_id = data.get('justia_opinion_id')
        self.justia_opinion_url = data.get('justia_opinion_url')
        self.judge_full_name = data.get('judge_full_name')
        self.judge_last_name = data.get('judge_last_name')
        self.title_overwrite = data.get('title_overwrite')
        self.href = data.get('href')

    def __str__(self):
        return f"{self.title} ({self.type_label})"



class Case:
    def __init__(self, data):
        self.id = data.get('ID')
        self.name = data.get('name')
        self.href = data.get('href')
        self.docket_number = data.get('docket_number')
        self.first_party = data.get('first_party')
        self.first_party_label = data.get('first_party_label')
        self.second_party = data.get('second_party')
        self.second_party_label = data.get('second_party_label')
        self.decided_date = datetime.fromtimestamp(data.get('timeline', [{}])[0].get('dates', [0])[0])
        self.citation = Citation(data.get('citation', {}))
        self.advocates = [Advocate(a) for a in data.get('advocates', []) if a] if data.get('advocates') else []
        self.decisions = [Decision(d) for d in data.get('decisions', []) if d] if data.get('decisions') else []
        self.decided_by = DecidedBy(data.get('decided_by', {})) if data.get('decided_by') else None
        self.term = data.get('term')
        self.justia_url = data.get('justia_url')
        self.written_opinion = [WrittenOpinion(o) for o in data.get('written_opinion', []) if o] if data.get('written_opinion') else []


    def __str__(self):
        return f"{self.name} ({self.term})"

    def print_details(self):
        print(f"Case: {self.name}")
        print(f"Href: {self.href}")
        print(f"Docket: {self.docket_number}")
        print(f"Citation: {self.citation}")
        print(f"Decided: {self.decided_date.strftime('%B %d, %Y')}")
        print(f"Parties: {self.first_party} v. {self.second_party}")
        print("\nAdvocates:")
        for advocate in self.advocates:
            print(f"  {advocate}")
        print("\nDecisions:")
        for decision in self.decisions:
            print(f"  {decision}")
            for vote in decision.votes:
                print(f"    {vote}")

This cell imports the `Sentence` and `Classifier` classes from the `flair` library, and:
1. It loads two classifiers: one for named entity recognition (NER) and one for relation extraction. 
2. It defines two classes, Entity and Relation, to represent entities and relations extracted from text. 
3. The `extract_entities_and_relations` function takes a sentence as input, uses the NER classifier to identify entities, and the relation classifier to identify relationships between entities. 
4. Finally, it returns a list of entities and a list of relations.


In [10]:
from flair.data import Sentence
from flair.nn import Classifier

# Load the NER and relation classifiers
tagger = Classifier.load('ner')
extractor = Classifier.load('relations')

class Entity:
    def __init__(self, text, label):
        self.text = text
        self.label = label

class Relation:
    def __init__(self, label, head, tail):
        self.label = label
        self.head = head
        self.tail = tail


def extract_entities_and_relations(sentence_text):
    if not sentence_text:
        return [], []
    
    sentence = Sentence(sentence_text)
    
    entities = []

    tagger.predict(sentence)
    for entity in sentence.get_labels('ner'):
        entities.append(Entity(entity.data_point.text, entity.value))
    
    extractor.predict(sentence)
    
    relations = []

    for relation in sentence.get_labels('relation'):
        
        head_text = relation.data_point.first.text
        head_type = relation.data_point.first.get_label('ner').value
        tail_text = relation.data_point.second.text
        tail_type = relation.data_point.second.get_label('ner').value
        
        head_entity = Entity(head_text, head_type)
        tail_entity = Entity(tail_text, tail_type)
        relations.append(Relation(relation.value, head_entity, tail_entity))
            
    return entities, relations


2024-08-23 13:44:47,416 SequenceTagger predicts: Dictionary with 20 tags: <unk>, O, S-ORG, S-MISC, B-PER, E-PER, S-LOC, B-ORG, E-ORG, I-PER, S-PER, B-MISC, I-MISC, E-MISC, I-ORG, B-LOC, E-LOC, I-LOC, <START>, <STOP>


In the cell below, we have to helper functions:
1. The `entity_to_node` function takes an `Entity` object as input and converts it into a `Node` object. The `Node` object is created with the entity's text as its ID, the entity's label as its label, and a dictionary containing the entity's text as its name property.

2. The `relation_to_nodes_and_edges` function takes a `Relation` object as input and converts it into a list of two `Node` objects and an `Edge` object.
   * It first converts the head and tail entities of the relation into `Node` objects using the `entity_to_node` function.
   * Then, it creates an `Edge` object with the head node's ID, the tail node's ID, the relation's label, an empty dictionary for properties, and `count` as the edge type.
   * The function returns a list containing the two `Node` objects and the `Edge` object.


In [11]:
def entity_to_node(entity: Entity):
    return Node(entity.text, entity.label, {"name": entity.text})
    
def relation_to_nodes_and_edges(relation: Relation):
    head_node = entity_to_node(relation.head)
    tail_node = entity_to_node(relation.tail)      
    relation_edge = Edge(head_node.id, tail_node.id, relation.label, {}, "count")    
    return [[head_node, tail_node], relation_edge]


This cell defines a function `process_scotus_opinion` that processes a Supreme Court opinion.
1. It retrieves the opinion from a MongoDB collection using the opinion's ID.
2. If the opinion is not found or does not contain content, it returns empty lists.
3. The function then saves the opinion's embedding and splits the content into sentences.
4. It initializes dictionaries to store unique entities and relations.
5. For each sentence, it extracts entities and relations using the `extract_entities_and_relations` function.
6. It calculates a hash for each entity and relation to ensure uniqueness and stores them in the dictionaries.
7. Finally, it returns lists of unique entities and relations.


In [12]:
import hashlib

# Load environment variables
opinions_collection = db['opinions']

def calculate_hash(text):
    return hashlib.md5(text.encode()).hexdigest()

def process_scotus_opinion(writtenOpinion: WrittenOpinion, case_node: Node):
    opinion = opinions_collection.find_one({'id': writtenOpinion.id})
    if not opinion or "content" not in opinion:
        return [], []
    
    save_opinion_embedding(case_node.id, opinion["content"])
    sentences = opinion['content'].split('. ')
    
    all_entities = {}
    all_relations = {}
    
    for sentence in sentences:
        entities, relations = extract_entities_and_relations(sentence)
        
        for entity in entities:
            if entity and entity.text:
                entity_hash = calculate_hash(entity.text)
                if entity_hash not in all_entities:
                    all_entities[entity_hash] = entity
        
        for relation in relations:
            if relation and relation.label:
                relation_hash = calculate_hash(relation.label)
                if relation_hash not in all_relations:
                    all_relations[relation_hash] = relation
    
    return list(all_entities.values()), list(all_relations.values())

This function iterates through written opinions, processes each opinion to extract entities and relations, and then creates nodes and edges for these entities and relations in the graph database.


In [13]:
def process_scotus_opinions(written_opinions: list[WrittenOpinion], case_node: Node): 
  if written_opinions:
    for opinion in written_opinions:
      if opinion:
        
        opinion_node = Node(opinion.id, "Opinion", {"title": opinion.title, "case_id": case_node.id})
        opinion_node.create_node()
        case_opinion_edge = Edge(case_node.id, opinion_node.id, "case_opinion", {}, "count")
        case_opinion_edge.create_edge()
        
        entities, relations = process_scotus_opinion(opinion, case_node)
        if entities:
          entities_nodes = [entity_to_node(entity) for entity in entities if entity]
          for relation in relations:
            if relation:
              relation_nodes, relation_edge = relation_to_nodes_and_edges(relation)
              head_node, tail_node = relation_nodes
              if head_node and tail_node:
                head_node.create_node()
                tail_node.create_node()
                relation_edge.create_edge()
                # print(f"head_node: {head_node}, tail_node: {tail_node}, relation_edge: {relation_edge}")

          for node in entities_nodes:
            if node:
              node.create_node()
              mentioned_in_edge = Edge(case_node.id, node.id, "mentioned_in", {}, "count")
              mentioned_in_edge.create_edge()

        print(".", end="")

This function processes SCOTUS cases and their related entities (parties, advocates, justices, decisions, opinions) and creates corresponding nodes and edges in a graph database.

In [14]:
processed_cases_collection = db.processed_cases

def process_scotus_case(case: Case):
  first_party = case.first_party
  second_party = case.second_party
  advocates = case.advocates if case.advocates else []
  decisions = case.decisions if case.decisions else []
  justices = case.decided_by.members if case.decided_by and case.decided_by.members else []
  
  case_node = Node(case.id, "Case", {"name": case.name, "docket_number": case.docket_number, "term": case.term, "decided_date": case.decided_date.strftime('%Y-%m-%d')})
  case_node.create_node()
  
  if first_party:
      first_party_node = Node(first_party,"Party", {"name": first_party})
      first_party_node.create_node()
      first_party_node_name = case.first_party_label      
      case_party_edge_1 = Edge(case_node.id, first_party_node.id, first_party_node_name, {}, "count")
      case_party_edge_1.create_edge()
  
  if second_party:
      second_party_node = Node(second_party, "Party", {"name": second_party})
      second_party_node.create_node()
      second_party_node_name = case.second_party_label
      case_party_edge_2 = Edge(case_node.id, second_party_node.id, second_party_node_name, {}, "count")
      case_party_edge_2.create_edge()  
  
  for advocate in advocates:
    advocate_node = Node(advocate.name, "Advocate", {"name": advocate.name, "description": advocate.description})
    advocate_node.create_node()
    advocate_edge = Edge(case_node.id, advocate_node.id, "advocated_by", {}, "count")
    advocate_edge.create_edge()
  
  for justice in justices:
    justice_node = Node(justice.id, "Justice", {"name": justice.name})
    justice_node.create_node()
    justice_edge = Edge(case_node.id, justice_node.id, "decided_by", {}, "count")
    justice_edge.create_edge()
    
  for decision in decisions:
    if decision.winning_party:
        decision_node = Node(decision.winning_party, "Party", { "name": decision.winning_party})
        decision_node.create_node()
        decision_edge = Edge(case_node.id, decision_node.id, "won_by", {
          "decision_type": decision.decision_type
        }, "count")
        decision_edge.create_edge()
    for vote in decision.votes:
      justice_node = Node(vote.member.id, "Justice", {"name": vote.member.name})
      justice_node.create_node()
      vote_edge = Edge(case_node.id, justice_node.id, vote.vote, {
        "opinion_type": vote.opinion_type
      }, "count")
      vote_edge.create_edge()
  

  
  process_scotus_opinions(case.written_opinion, case_node)

In [None]:
!jupyter nbextension enable --py widgetsnbextension

This cell is responsible for processing a sample of unprocessed cases from the MongoDB collection.

1. The unprocessed_count variable counts the number of documents in the collection that are either not processed or do not have a `processed` field
2. The skip_size variable is a random number between 0 and the difference between unprocessed_count and sample_size, ensuring that the sample is taken from different parts of the collection
3. The sampled_cases variable retrieves a sample of unprocessed cases from the collection, skipping a random number of documents and limiting the result to the sample size.
4. The for loop iterates over each case in the sampled_cases, processes it using the process_scotus_case function, and marks it as processed in the collection.

In [None]:
from tqdm.notebook import tqdm
import random

processed_cases_collection = db.processed_cases
sample_size = 50


unprocessed_count = processed_cases_collection.count_documents({"$or": [{"processed": False}, {"processed": {"$exists": False}}]})
skip_size = random.randint(0, max(0, unprocessed_count - sample_size))

sampled_cases = processed_cases_collection.find({"$or": [{"processed": False}, {"processed": {"$exists": False}}]}).skip(skip_size).limit(sample_size)

for case_data in tqdm(sampled_cases):
    case = Case(case_data)
    process_scotus_case(case)
    
    # Mark the case as processed
    processed_cases_collection.update_one(
        {"_id": case_data["_id"]},
        {"$set": {"processed": True}}
    )

# Close the MongoDB connection
client.close()