In [None]:
import json
import re
import gc
import nltk
import jsonlines
from tqdm import tqdm
import numpy as np
import pandas as pd
from multiprocessing import Pool, cpu_count
from bs4 import BeautifulSoup

from nltk.tokenize import sent_tokenize
from eyecite import get_citations

nltk.download("punkt")
nltk.download("punkt_tab")

def load_data(file_path):
    data = []
    try:
        with jsonlines.open(file_path) as reader:
            for obj in reader:
                data.append(obj)
    except jsonlines.InvalidLineError:
        with open(file_path, 'r', encoding='utf-8') as f:
            data = json.load(f)
            if not isinstance(data, list) or not all(isinstance(item, dict) for item in data):
                raise ValueError("Input file is not in a valid JSON or JSON Lines format.")

    df = pd.DataFrame(data)
    if 'plain_text' not in df.columns:
        raise ValueError("Column 'plain_text' not found in the data")
    df["plain_text"] = df["plain_text"].astype(str).str.strip()
    return df

def fast_clean_text(text):
    soup = BeautifulSoup(text, "lxml")
    return soup.get_text(separator=" ", strip=True)

def extract_citations(text):
    citations = get_citations(text)
    grouped = {}
    for cit in citations:
        cit_type = type(cit).__name__
        if cit_type == "FullLawCitation":
            key = "STATUTE"
        elif cit_type == "FullCaseCitation":
            key = "CASE_LAW"
        elif cit_type in ("IdCitation", "SupraCitation"):
            continue
        else:
            key = "OTHER"
        grouped.setdefault(key, []).append(cit)

    rules_pattern = r"\bRule(?:s)?\s+(\d+)(?:\((\w+)\))?(?:\((\w+)\))?(?:\s+of\s+(?:the\s+)?Federal\s+Rules(?:\s+of\s+(Civil|Criminal|Appellate|Evidence))?)?"
    rule_matches = re.findall(rules_pattern, text, re.IGNORECASE)
    rules_extractions = ["Rule " + " ".join(part for part in match if part) for match in rule_matches]
    if rules_extractions:
        grouped.setdefault("RULES", []).extend(rules_extractions)

    return grouped

def get_statute_codes(citation_results):
    codes = []
    for cit in citation_results:
        try:
            if hasattr(cit, 'groups'):
                title = cit.groups.get('title', '').strip()
                reporter = cit.groups.get('reporter', '').strip()
                section = cit.groups.get('section', '').strip()
                if title and reporter and section:
                    code = f"{title} {reporter} § {section}"
                    if "FullLawCitation" not in code:
                        codes.append(code)
                    continue
                else:
                    code = str(cit).split("(")[0].strip()
                    if "FullLawCitation" not in code:
                        codes.append(code)
            else:
                code = str(cit).split("(")[0].strip()
                if "FullLawCitation" not in code:
                    codes.append(code)
        except Exception:
            code = str(cit).split("(")[0].strip()
            if "FullLawCitation" not in code:
                codes.append(code)
    return codes

def get_act_names_from_text(text):
    act_pattern = r"([A-Z][\w\s]+ Act)"
    matches = re.findall(act_pattern, text)
    return list(set(match.strip() for match in matches if match.strip()))

def get_case_law_names_from_text(text):
    pattern = r'\b([A-Z][A-Za-z0-9,\-\.]+\s+v\.\s+[A-Z][A-Za-z0-9,\-\.]+)\b'
    matches = re.findall(pattern, text)
    return list(set(match.strip() for match in matches))

def get_rule_citations(citation_results):
    return [str(cit).strip() for cit in citation_results]

def calculate_citation_density(text, citations):
    sentences = sent_tokenize(text)
    total_citations = sum(len(c) for c in citations.values())
    return total_citations / len(sentences) if sentences else 0.0

def calculate_lexical_complexity(text):
    words = [w.lower() for w in text.split() if w.isalpha()]
    return len(set(words)) / len(words) if words else 0.0

def calculate_hedging_frequency(text):
    hedgers = {
        "might", "could", "may", "seem", "appear", "suggest", "possibly", "probably", "likely",
        "perhaps", "almost", "somewhat", "relatively", "tend to", "in some cases", "to some extent",
        "it is possible", "it appears", "it seems", "suggests that", "indicates that", "would", "should"
    }
    words = text.lower().split()
    return sum(1 for w in words if w in hedgers) / len(words) if words else 0.0

def count_ultra_vires(text):
    return text.lower().count("ultra vires")

def count_chevron_deference(text):
    keywords = ["chevron deference", "chevron doctrine", "chevron step"]
    text_lower = text.lower()
    return sum(text_lower.count(keyword) for keyword in keywords)

def process_row(row):
    opinion_text = row["plain_text"]
    if not opinion_text or opinion_text.isspace():
        return None
    try:
        cleaned_text = fast_clean_text(opinion_text)
    except Exception:
        return None

    citations = extract_citations(cleaned_text)
    raw_statute = "; ".join([str(cit).split("(")[0].strip() for cit in citations.get("STATUTE", [])])
    raw_case_law = "; ".join([str(cit).split("(")[0].strip() for cit in citations.get("CASE_LAW", [])])
    raw_rules = "; ".join([str(cit).strip() for cit in citations.get("RULES", [])])

    statute_codes = get_statute_codes(citations.get("STATUTE", []))
    act_names = get_act_names_from_text(cleaned_text)
    case_law_names = get_case_law_names_from_text(cleaned_text)
    rule_codes = get_rule_citations(citations.get("RULES", []))

    citation_counts = {
        "STATUTE": len(statute_codes),
        "ACT": len(act_names),
        "CASE_LAW": len(case_law_names),
        "RULES": len(rule_codes)
    }

    return {
        "id": row["id"],
        "unique_STATUTE_count": citation_counts["STATUTE"],
        "STATUTE_codes": "; ".join(statute_codes),
        "unique_ACT_count": citation_counts["ACT"],
        "ACT_names": "; ".join(act_names),
        "unique_CASE_LAW_count": citation_counts["CASE_LAW"],
        "CASE_LAW_names": "; ".join(case_law_names),
        "unique_RULES_count": citation_counts["RULES"],
        "RULES_codes": "; ".join(rule_codes),
        "raw_STATUTE": raw_statute,
        "raw_CASE_LAW": raw_case_law,
        "raw_RULES": raw_rules,
        "citation_density": calculate_citation_density(cleaned_text, citations),
        "citation_diversity": sum(1 for count in citation_counts.values() if count > 0),
        "lexical_complexity": calculate_lexical_complexity(cleaned_text),
        "hedging_frequency": calculate_hedging_frequency(cleaned_text),
        "ultra_vires_count": count_ultra_vires(cleaned_text),
        "chevron_deference_count": count_chevron_deference(cleaned_text),
    }

def process_all_rows(df):
    with Pool(processes=cpu_count() - 1) as pool:
        results = list(tqdm(pool.imap(process_row, df.to_dict('records'), chunksize=10), total=len(df)))
    return pd.DataFrame([r for r in results if r is not None])

def main(file_path="/content/drive/MyDrive/MUN/chunk_1.json", output_csv_path=None):
    df = load_data(file_path)
    df_results = process_all_rows(df)

    if output_csv_path is None:
        base_name = file_path.rsplit(".", 1)[0]
        output_csv_path = f"{base_name}_output.csv"

    df_results.to_csv(output_csv_path, index=False)
    print("Processed results (first 5 rows):")
    print(df_results.head())
    print(f"Saved processed results to {output_csv_path}")

if __name__ == "__main__":
    main()