In [None]:
pip install lexnlp

In [None]:
import re
import os
from nltk.tokenize import sent_tokenize

def clean_text(text):
    # Replace dates and years with placeholders
    text = re.sub(r'\d{1,2}[/-]\d{1,2}[/-]\d{2,4}', '[DATE]', text)
    text = re.sub(r'\b\d{4}\b', '[YEAR]', text)
    
    # Remove special characters and unwanted symbols
    text = re.sub(r'[^a-zA-Z\s]', '', text)
    text = re.sub(r'\s+', ' ', text).strip()
    
    # Replace specific legal citations
    text = re.sub(r'\bAIR\s\d{4}\sSC\s\d{3,4}\b', '[CASE CITATION]', text)
    
    # Standardize legal terminology
    legal_dict = {
        'hereinabove': 'above',
        'hereinafter': 'below',
        'plaintiff': 'claimant',
        'defendant': 'respondent',
        'learned counsel': 'lawyer',
        'aforesaid': 'previously mentioned',
        'writ petition': 'legal petition'}
    for term, replacement in legal_dict.items():
        text = text.replace(term, replacement)
    
    boilerplate_phrases = [
        'the learned counsel submitted that',
        'in light of the above discussion',
        'the facts of the case are as follows',
    ]
    for phrase in boilerplate_phrases:
        text = text.replace(phrase, '')
    
    sentences = sent_tokenize(text)
    
    return sentences
def read(filepath):
    with open(filepath, 'r', encoding='utf-8') as file:
        text = file.read()
        return text
def load_and_preprocess_data(file_path):
    judgment_text_path = file_path[0].numpy().decode('utf-8')
    judgment_text = read(judgment_text_path)
    cleaned_judgment_text_tokenized = clean_text(judgment_text)
    return load_and_preprocess_data
dataset_dir = "C:/Users/prasa/Downloads/7152317/dataset/dataset/IN-Abs"
train_judgement_dir = os.path.join(dataset_dir, 'train-data', 'judgement')
preprocessed_data=load_and_preprocess_data(train_judgement_dir)

In [None]:
import lexnlp
def text_extraction_entities_relations_lexnlp(text):
    entities = list(lexnlp.extract.en.entities.nltk_re.get_persons(text))
    statutes = list(lexnlp.extract.en.acts.get_acts(text))
    return {
        "entities": entities,
        "statutes": statutes
        }


Using LDA(Latent Dirichlet Allocation_Topic modeling)

In [None]:
pip install sklearn gensim nltk

In [None]:
import nltk
nltk.download('stopwords')

In [None]:
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.decomposition import LatentDirichletAllocation
from nltk.corpus import stopwords
import numpy as np
docs = []
for file_name in os.listdir(train_judgement_dir):
    file_path = os.path.join(train_judgement_dir, file_name)
    text = read(file_path)
    docs.append(text)
stopwords=stopwords.words('english')
vectorizer=CountVectorizer(max_df=0.9, min_df=2, stop_words=stopwords)
term_matrix=vectorizer.fit_transform(docs)
lda_model = LatentDirichletAllocation(n_components=5, random_state=42)
lda_model.fit(term_matrix)
terms = np.array(vectorizer.get_feature_names_out())
for idx, topic in enumerate(lda_model.components_):
    print(f"Topic {idx}:")
    print(" ".join(terms[i] for i in topic.argsort()[-10:]))

In [None]:
fact_keywords = ['contract','agreement','performance','plaintiff']
issue_keywords = ['issue','dispute','claim','damages']
def filter_topics(terms, topics, fact_keywords, issue_keywords):
    facts=[]
    issues=[]
    for idx, topic in enumerate(topics):
        topic_words = " ".join(terms[i] for i in topic.argsort()[-10:])
        if any(keyword in topic_words for keyword in fact_keywords):
            facts.append((idx, topic_words))
        elif any(keyword in topic_words for keyword in issue_keywords):
            issues.append((idx, topic_words))
    return facts, issues
topics = lda_model.components_
facts, issues = filter_topics(terms, topics, fact_keywords, issue_keywords)
# print("Facts:")
# for fact in facts:
#     print(fact)

# print("\nIssues:")
# for issue in issues:
#     print(issue)

In [None]:
from transformers import AutoTokenizer, TFBertModel
tokenizer = AutoTokenizer.from_pretrained('nlpaueb/legal-bert-base-uncased')
model = TFBertModel.from_pretrained('nlpaueb/legal-bert-base-uncased')
# tokenizer = BertTokenizer.from_pretrained('nlpaueb/legal-bert-base-uncased')
# model = BertForSequenceClassification.from_pretrained('nlpaueb/legal-bert-base-uncased')
def get_embedding(text):
    inputs = tokenizer(text, return_tensors='tf', truncation=True, max_length=512)
    outputs = model(inputs)
    return outputs.last_hidden_state[:, 0, :].numpy()

Graph construction

In [None]:
import networkx as nx
def graph(facts,issues,entity_data):
    G = nx.DiGraph()
    for idx, fact in facts:
        fact_embedding = get_embedding(fact[1])
        G.add_node(f'fact_{idx}', embedding=fact_embedding, type='fact')
    for idx, issue in issues:
        issue_embedding = get_embedding(issue[1])
        G.add_node(f'issue_{idx}', embedding=issue_embedding, type='issue')
    for entity in entity_data['entities']:
        entity_embedding = get_embedding(entity)
        G.add_node(entity, embedding=entity_embedding, type='entity')
    for statute in entity_data['statutes']:
        statute_embedding = get_embedding(statute)
        G.add_node(statute, embedding=statute_embedding, type='statute')
    for fact_node in G.nodes(data=True):
        if fact_node[1]['type'] == 'fact':
            for issue_node in G.nodes(data=True):
                if issue_node[1]['type'] == 'issue':
                    G.add_edge(fact_node[0], issue_node[0], relationship='related')
    for entity_node in G.nodes(data=True):
        if entity_node[1]['type'] == 'entity':
            for fact_node in G.nodes(data=True):
                if fact_node[1]['type'] == 'fact':
                    G.add_edge(entity_node[0], fact_node[0], relationship='related to fact')
            for issue_node in G.nodes(data=True):
                if issue_node[1]['type'] == 'issue':
                    G.add_edge(entity_node[0], issue_node[0], relationship='related to issue')
    for statute_node in G.nodes(data=True):
        if statute_node[1]['type'] == 'statute':
            for fact_node in G.nodes(data=True):
                if fact_node[1]['type'] == 'fact':
                    G.add_edge(statute_node[0], fact_node[0], relationship='cited in fact')
            for issue_node in G.nodes(data=True):
                if issue_node[1]['type'] == 'issue':
                    G.add_edge(statute_node[0], issue_node[0], relationship='cited in issue')

    return G
entity_data = text_extraction_entities_relations_lexnlp(docs[0])
graph = graph(facts, issues, entity_data)

In [None]:
print(nx.info(graph))

train EUGAT model fro better retrival

In [None]:
import tensorflow as tf
node_embeddings = []
for node in graph.nodes(data=True):
    node_embeddings.append(node[1]['embedding'])
node_inputs = tf.convert_to_tensor(node_embeddings, dtype=tf.float32)
edge_features = []
relationship_types = {
    'related': [1, 0, 0],            # Fact to Issue relationship
    'related to fact': [0, 1, 0],    # Entity to Fact relationship
    'related to issue': [0, 0, 1]    # Entity to Issue relationship
}
for edge in graph.edges(data=True):
    relationship = edge[2]['relationship']
    edge_features.append(relationship_types.get(relationship, [0, 0, 0])) 
edge_inputs = tf.convert_to_tensor(edge_features, dtype=tf.float32)

In [None]:
import tensorflow as tf
from tensorflow.keras import layers
class EUGATLayer(layers.Layer):
    def __init__(self, output_dim, num_heads=4, **kwargs):
        super(EUGATLayer, self).__init__(**kwargs)
        self.output_dim = output_dim
        self.num_heads = num_heads
    def build(self, input_shape):
        self.node_weights = self.add_weight(
            shape=(input_shape[-1], self.output_dim),
            initializer="glorot_uniform", trainable=True)
        
        self.edge_weights = self.add_weight(
            shape=(input_shape[-1], self.output_dim),
            initializer="glorot_uniform", trainable=True)
        
        self.attention_heads = self.add_weight(
            shape=(self.num_heads, self.output_dim, self.output_dim),
            initializer="glorot_uniform", trainable=True)
    def call(self, node_inputs, edge_inputs):
        # Node features projection
        node_features = tf.matmul(node_inputs, self.node_weights)
        edge_features = tf.matmul(edge_inputs, self.edge_weights)
        for head in range(self.num_heads):
            attn_weights = tf.matmul(node_features, self.attention_heads[head])
            attn_scores = tf.nn.softmax(attn_weights)
            node_updates = tf.matmul(attn_scores, edge_features)
        
        return node_updates
class EUGATModel(tf.keras.Model):
    def __init__(self, output_dim, num_heads=4):
        super(EUGATModel, self).__init__()
        self.gat_layer = EUGATLayer(output_dim, num_heads)
    
    def call(self, inputs):
        node_inputs, edge_inputs = inputs
        return self.gat_layer(node_inputs, edge_inputs)
output_dim = 128
model = EUGATModel(output_dim=output_dim)
model.compile(optimizer='adam')
model.fit([node_inputs, edge_inputs], epochs=10)

In [None]:
query_embedding = get_embedding(new_case_text)
import tensorflow.keras.backend as K
def cosine_similarity(a, b):
    return K.sum(a * b, axis=-1) / (K.sqrt(K.sum(a * a, axis=-1)) * K.sqrt(K.sum(b * b, axis=-1)))
similarities = cosine_similarity(query_embedding, node_inputs)
most_similar_cases = np.argsort(similarities.numpy())[::-1][:5]
for case_index in most_similar_cases:
    print(f"Case {case_index}:")
    print(docs[case_index])  # Retrieve the original text from the dataset
    print("\n---\n")