In [None]:
import sys
sys.path.append('../code')
import os
from pathlib import Path
import json
import random
import numpy as np
import spacy


from tqdm import tqdm
import spacy
import json
import random
import re
import pandas as pd
import numpy as np
from copy import deepcopy
from sklearn import model_selection
from spacy.tokenizer import Tokenizer
from spacy.lang.en import English
from spacy.symbols import ORTH
from sklearn.feature_extraction import DictVectorizer
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.svm import SVC
from sklearn.ensemble import RandomForestClassifier
from sklearn.naive_bayes import GaussianNB
from sklearn.metrics import classification_report, confusion_matrix
from sklearn import tree
import matplotlib.pyplot as plt
from spacy.language import Language
from luima_sbd import sbd_utils as luima

# Phase 2 - Decisidng on a Sentence Segmenter

### Defining some functions

In [None]:
single_word_headers = ["REPRESENTATION",
                       "____________________________________________",
                       "ORDER",
                       "INTRODUCTION",
                      ]


other_headers = ["THE ISSUE",
                 "WITNESS AT HEARING ON APPEAL",
                 "ATTORNEY FOR THE BOARD",
                 "FINDINGS OF FACT",
                 "CONCLUSION OF LAW",
                 "REASONS AND BASES FOR FINDING AND CONCLUSION",
                ]

In [None]:
# get all sentences assuming every annotation is a sentence
def make_span_data(documents_by_id, types_by_id, annotations):
    span_data = []
    for a in annotations:
        start = a['start']
        end = a['end']
        document_txt = documents_by_id[a['document']]['plainText']
        atype = a['type']
        sd = {'txt': document_txt[start:end],
              'document': a['document'],
              'type': types_by_id[atype]['name'],
              'start': a['start'],
              'start_normalized': a['start'] / len(document_txt),
              'end': a['end']}
        span_data.append(sd)
    return span_data

In [None]:
def prepare_ann_span_by_doc(spans, doc_id):
    start = []
    end = []
    for span in spans:
        if span['document'] == doc_id:
            start.append(span['start'])
            end.append(span['end'])
    return start, end

In [None]:
def generate_ann_span_by_doc_with_spacy(train_doc_ids, nlp):
    gen_ann_span_by_doc = {}
    for train_id in tqdm(train_doc_ids, disable=True):
        text = documents_by_id[train_id]['plainText']
        doc = nlp(text)
        sentences = list(doc.sents)
        sentence_starts = [sent.start_char for sent in sentences]
        sentence_ends = [sent.end_char for sent in sentences]
        gen_ann_span_by_doc[train_id] = {'start': sentence_starts, 'end': sentence_ends}
    return gen_ann_span_by_doc

In [None]:
def generate_ann_span_by_doc_with_luima(train_doc_ids, nlp=None):
    gen_ann_span_by_doc = {}
    for train_id in tqdm(train_doc_ids, disable=True):
        text = documents_by_id[train_id]['plainText'].strip()
        doc = luima.text2sentences(text, offsets=False)
        indices = luima.text2sentences(text, offsets=True)
        
        sentence_starts = [ind[0] for ind in indices]
        sentence_ends = [ind[1] for ind in indices]
        gen_ann_span_by_doc[train_id] = {'start': sentence_starts, 'end': sentence_ends}
    return gen_ann_span_by_doc

In [None]:
def find_closest_start_point(train_doc_ids, true_ann_span_by_doc, gen_ann_span_by_doc):
    closest_by_id = {}
    for train_id in train_doc_ids:
        true_spans = []
        closest_neighbors = []
        for true_start, true_end in zip(true_ann_span_by_doc[train_id]['start'], true_ann_span_by_doc[train_id]['end']):
            dist = 7000000
            for gen_start, gen_end in zip(gen_ann_span_by_doc[train_id]['start'], gen_ann_span_by_doc[train_id]['end']):
                cal_dist = abs(true_start - gen_start)
                if cal_dist < dist:
                    dist = cal_dist
                    closest_neighbor = {'start': gen_start, 'end': gen_end}
            
            true_spans.append({'start': true_start, 'end': true_end})
            closest_neighbors.append(closest_neighbor)
        closest_by_id[train_id] = {'true': true_spans, 'pred': closest_neighbors}
    return closest_by_id

In [None]:
def calculate_error_metrics(train_doc_ids, true_ann_span_by_doc, gen_ann_span_by_doc, closest_by_id):
    TP = 0
    FP = 0
    FN = 0
    
    tot_true_splits = 0
    tot_gen_splits = 0
    for train_id in train_doc_ids:
        true_split_len = len(true_ann_span_by_doc[train_id]['start'])
        gen_split_len = len(gen_ann_span_by_doc[train_id]['start'])
        
        tot_true_splits += true_split_len
        tot_gen_splits += gen_split_len
        
        
        true_starts = np.array([true_span['start'] for true_span in closest_by_id[train_id]['true']])
        pred_closest_starts = np.array([pred_span['start'] for pred_span in closest_by_id[train_id]['pred']])
        
        tp_doc = ((abs(true_starts - pred_closest_starts))<=3).sum()
        fn_doc = true_split_len - tp_doc
        fp_doc = gen_split_len - tp_doc
        
        TP += tp_doc
        FP += fp_doc
        FN += fn_doc
        
    precision = TP / (TP + FP)
    recall = TP / (TP + FN)
    f1_score = 2 * (precision * recall) / (precision + recall)
    
#     print(f"True split length: {tot_true_splits}")
#     print(f"Generated split length: {tot_gen_splits}")
#     print(f"TP: {TP}, FP: {FP}, FN: {FN}\n")
    return tot_true_splits, tot_gen_splits, TP, FP, FN, precision, recall, f1_score

In [None]:
def analyze_segmenter(train_doc_ids, segmenter, nlp=None):
    if segmenter == 'luima':
        ann_span_generator = generate_ann_span_by_doc_with_luima
    elif segmenter == 'spacy':
        ann_span_generator = generate_ann_span_by_doc_with_spacy
    error_metrics = []
    for doc_id in tqdm(train_doc_ids):
        gen_ann_span = ann_span_generator([doc_id], nlp)
        closest_by_id = find_closest_start_point(
            [doc_id], 
            true_ann_span_by_doc, 
            gen_ann_span
        )


        total_true_splits, total_gen_splits, tp, fp, fn, precision, recall, f1_score = calculate_error_metrics(
            [doc_id], 
            true_ann_span_by_doc, 
            gen_ann_span, 
            closest_by_id
        )

        em_doc = {
            'doc_id': doc_id,
            'true_split_count': total_true_splits,
            'gen_split_count': total_gen_splits,
            'tp': tp,
            'fp': fp,
            'fn': fn,
            'precision': round(precision, 2),
            'recall': round(recall, 2),
            'f1_score': round(f1_score, 2)
        }
        error_metrics.append(em_doc)
    return error_metrics

In [None]:
def print_error_metrics(error_metrics, segmenter):
    TP = sum([em['tp'] for em in error_metrics])
    FP = sum([em['fp'] for em in error_metrics])
    FN = sum([em['fn'] for em in error_metrics])

    precision = TP / (TP + FP)
    recall = TP / (TP + FN)
    f1_score = 2 * (precision * recall) / (precision + recall)

    # print(f"True split length: {tot_true_splits}")
    # print(f"Generated split length: {tot_gen_splits}")
    # print(f"TP: {TP}, FP: {FP}, FN: {FN}\n")

    print(f"Error metrics using the {segmenter} segmenter:")
    print(f'Precision: {precision:.2f}\nRecall: {recall:.2f}\nF1_score: {f1_score:.2f}')

##### EXTEND THE SPACY'S STANDARD SEGMENTER

In [None]:
# EXTEND THE SPACY'S STANDARD SEGMENTER
@Language.component("set_custom_boundaries")
def set_custom_boundaries(doc):
    for i in range(len(doc)):
        if doc[i].text in ("’s", "'s"):
            doc[i].is_sent_start = False
        elif doc[i].text in ("\n", "\t", "\r", "DC.","Archive", "NO."):
            doc[i].is_sent_start = False    
        elif doc[i].text =="DOCKET" and doc[i+1:i+3].text =="NO.":
            doc[i].is_sent_start = False
            end = i+3
            while i+1<=end:
                doc[i+1].is_sent_start = False
                i += 1
        elif doc[i].text in single_word_headers:
            doc[i].is_sent_start = True
            i += 1
            while (doc[i].text.isspace()):
                i += 1
            doc[i].is_sent_start = True
        # Fixed    
        elif doc[i].text == "THE" and doc[i+1].text == "ISSUE":
            doc[i].is_sent_start = True
            i += 2
            while (doc[i].text.isspace()):
                i += 1
            doc[i].is_sent_start = True
        # Fixed
        elif doc[i].text == "WITNESS" and doc[i+1: i+5].text == "AT HEARING ON APPEAL":
            doc[i].is_sent_start = True
            end = i + 5
            while i+1<=end:
                doc[i+1].is_sent_start = False
                i += 1
            while (doc[i].text.isspace()):
                i += 1
            doc[i].is_sent_start = True
        # Fixed    
        elif doc[i].text == "ATTORNEY" and doc[i+1: i+4].text == "FOR THE BOARD":
            doc[i].is_sent_start = True
            end = i + 4
            while i+1 <= end:
                doc[i+1].is_sent_start = False
                i += 1
            while (doc[i].text.isspace()):
                i += 1
            doc[i].is_sent_start = True            
        # Fixed
        elif (doc[i].text == "FINDINGS" or doc[i].text == "FINDING") and doc[i+1: i+3].text == "OF FACT":
            doc[i].is_sent_start = True
            end = i + 3
            while i+1 <= end:
                doc[i+1].is_sent_start = False
                i += 1
            while (doc[i].text.isspace()):
                i += 1
            doc[i].is_sent_start = True      
        # Fixed    
        elif doc[i].text == "CONCLUSION" and doc[i+1: i+3].text == "OF LAW":
            doc[i].is_sent_start = True
            end = i + 3
            while i+1 <= end:
                doc[i+1].is_sent_start = False
                i += 1
            while (doc[i].text.isspace()):
                i += 1
            doc[i].is_sent_start = True  
        # Fixed    
        elif doc[i].text == "REASONS" and (doc[i+1: i+7].text == "AND BASES FOR FINDING AND CONCLUSION" or doc[i+1: i+7].text == "AND BASES FOR FINDINGS AND CONCLUSION"):
            doc[i].is_sent_start = True
            end = i + 7
            while i+1 <= end:
                doc[i+1].is_sent_start = False
                i += 1
            while (doc[i].text.isspace()):
                i += 1
            doc[i].is_sent_start = True      
        elif doc[i].text.lower() == "on" and doc[i+1:i+4].text == "appeal from the":
            doc[i].is_sent_start = True
            end = i + 4
            while i+1 <= end:
                doc[i+1].is_sent_start = False
                i += 1
            

    return doc

### Load Data 

In [None]:
CURATED_ANN_PATH = "../Data/ldsi_w21_curated_annotations_v2.json"
with open(CURATED_ANN_PATH, 'r') as j:
     data = json.loads(j.read())
        
annotations = data['annotations']
documents_by_id = {d['_id']: d for d in data['documents']}
types_by_id = {t['_id']: t for t in data['types']}
type_ids_by_name = {t['name']: t['_id'] for t in data['types']}
type_names_by_id = {t['_id']: t['name'] for t in data['types']}
doc_id_by_name = {d['name']: d['_id'] for d in data['documents']}
doc_name_by_id = {d['_id']: d['name'] for d in data['documents']}


granted_doc_ids = set([doc['_id'] for doc in data['documents'] if doc['outcome'] == 'granted'])
denied_doc_ids = set([doc['_id'] for doc in data['documents'] if doc['outcome'] == 'denied'])
print(len(granted_doc_ids), len(denied_doc_ids))

# Filter out the IDs of the 141 documents from a total of 540
ids_annotated_docs = set([ann['document'] for ann in data['annotations']])
print(len(ids_annotated_docs))


granted_ids = sorted(list(granted_doc_ids.intersection(ids_annotated_docs)))
denied_ids = sorted(list(denied_doc_ids.intersection(ids_annotated_docs)))
print(len(granted_ids), len(denied_ids))


os.environ['PYTHONHASHSEED'] = str(42)
random.seed(42)
np.random.seed(42)


random.shuffle(granted_ids)
random.shuffle(denied_ids)
granted_train, granted_val, granted_test = np.split(granted_ids, [int(len(granted_ids)*0.8), int(len(granted_ids)*0.9)])
denied_train, denied_val, denied_test = np.split(denied_ids, [57, 64])


train_set, dev_set, test_set = np.concatenate((granted_train, denied_train), axis=0), \
                                np.concatenate((granted_val, denied_val), axis=0), \
                                    np.concatenate((granted_test, denied_test), axis=0), \

print(train_set.shape, dev_set.shape, test_set.shape)

In [None]:
# # loading the training doc ids
# train_doc_ids = np.load('../Data/train.npy')
# train_doc_ids.shape

In [None]:
spans = make_span_data(documents_by_id, types_by_id, annotations)
span_labels = [s['type'] for s in spans]

In [None]:
# Creating the annotation span for the train data with original spans
train_doc_ids = train_set
true_ann_span_by_doc = {}
for train_id in train_doc_ids:
    ann_span_starts, ann_span_ends = prepare_ann_span_by_doc(spans, train_id)
    true_ann_span_by_doc[train_id] = {'start': ann_span_starts, 'end': ann_span_ends}

len(true_ann_span_by_doc)

## Step 2.1: Standard Spacy segmenter

In [None]:
# Calculating the error metrics for all training docs with standard saegmenter

nlp = spacy.load("en_core_web_sm")
error_maetrics_spacy_std = analyze_segmenter(train_doc_ids, 'spacy', nlp)
print_error_metrics(error_maetrics_spacy_std, 'standard spacy')

In [None]:
# Three doc ids with lowest precision
df_std = pd.DataFrame(error_maetrics_spacy_std)
low_f1_doc_ids_df = df_std.sort_values('precision').head(3)

low_scoring_docs = []
for row in low_f1_doc_ids_df.iterrows():
    dic = {
        'doc_id': row[1].doc_id,
        'precision': row[1].precision,
        'recall': row[1].recall,
        'f1_score': row[1].f1_score
        
    }
    low_scoring_docs.append(dic)
low_scoring_docs

In [None]:
low_f1_doc_ids_df

## Step 2.2: Improved segmenter

In [None]:
# Calculating the error metrics for all training docs with extended saegmenter
nlp = spacy.load("en_core_web_sm")
nlp.add_pipe("set_custom_boundaries", before="parser")
nlp.tokenizer.add_special_case('Vet. App.', [{ORTH: 'Vet. App.'}])
nlp.tokenizer.add_special_case('Fed. Cir.', [{ORTH: 'Fed. Cir.'}])

error_maetrics_spacy_ext = analyze_segmenter(train_doc_ids, 'spacy', nlp)
print_error_metrics(error_maetrics_spacy_ext, 'extended spacy')

In [None]:
# Creating a dataframe with the error metrics
df_ext = pd.DataFrame(error_maetrics_spacy_ext)
df_ext.sort_values('f1_score').head()

In [None]:
# See whether the scores for the three docs with lowest precision have improved
# with the extended spacy segmenter

for doc in low_scoring_docs:
    improved_precision = df_ext[df_ext['doc_id']==doc['doc_id']]['precision'].values[0]
    improved_recall = df_ext[df_ext['doc_id']==doc['doc_id']]['recall'].values[0]
    improved_f1 = df_ext[df_ext['doc_id']==doc['doc_id']]['f1_score'].values[0]
    print(f"Comparison between standard and improved segmenter for document {doc['doc_id']}:")
    
    print(f"Old Precision: {doc['precision']}, improved precision: {improved_precision}")
    print(f"Recall: {doc['recall']}, improved_recall: {improved_recall}")
    print(f"F1 score: {doc['f1_score']}, improved_f1: {improved_f1}\n\n")

## Step 2.3: Law-specific sentence segmenter

In [None]:
# Calculating the error metrics for all training docs with LUIMA saegmenter
error_metrics_luima = analyze_segmenter(train_doc_ids, 'luima')
print_error_metrics(error_metrics_luima, 'luima')

In [None]:
# Creating a dataframe with the error metrics
df_luima = pd.DataFrame(error_metrics_luima)
df_luima.sort_values('f1_score').head()

In [None]:
# See whether the scores for the three docs with lowest F1 scores have improved
# with the law specific segmenter

for doc in low_scoring_docs:
    improved_precision = df_luima[df_luima['doc_id']==doc['doc_id']]['precision'].values[0]
    improved_recall = df_luima[df_luima['doc_id']==doc['doc_id']]['recall'].values[0]
    improved_f1 = df_luima[df_luima['doc_id']==doc['doc_id']]['f1_score'].values[0]
    print(f"Comparison between standard and improved segmenter for document {doc['doc_id']}:")
    
    print(f"Old Precision: {doc['precision']}, improved precision: {improved_precision}")
    print(f"Recall: {doc['recall']}, improved_recall: {improved_recall}")
    print(f"F1 score: {doc['f1_score']}, improved_f1: {improved_f1}\n\n")

## Error analysis on an individual level

In [None]:
low_scoring_docs

In [None]:
# Choose a low scoring doc for analysis
doc_id = low_scoring_docs[2]['doc_id']
doc_id

### Error analysis with standard Spacy

In [None]:
# Test for a single document
nlp = spacy.load("en_core_web_sm")
gen_ann_span_standard = generate_ann_span_by_doc_with_spacy([doc_id], nlp)
# print(len(gen_ann_span_standard[doc_id]['start']))
# gen_ann_span_standard

# For the standard spacy
closest_by_id_std = find_closest_start_point(
    [doc_id], 
    true_ann_span_by_doc, 
    gen_ann_span_standard
)

# print(closest_by_id_std)
# For the standard spacy
print_str = f"Error analysis for the document {doc_id} with standard segmenter"
print_pattern = "-"*len(print_str)
print(print_str)
print(print_pattern)
_, _, _, _, _, precision, recall, f1_score = calculate_error_metrics(
    [doc_id], 
    true_ann_span_by_doc, 
    gen_ann_span_standard, 
    closest_by_id_std
)

print(f'Precision: {precision:.2f}\nRecall: {recall:.2f}\nF1_score: {f1_score:.2f}')

In [None]:
# Now Compare the true and generated splits of the chosen document

i = 1
print(f"Comparing true and generated segments for document {doc_id} with standard segmenter")
for true_span, pred_span in zip(closest_by_id_std[doc_id]['true'], closest_by_id_std[doc_id]['pred']):
    test_doc = documents_by_id[doc_id]['plainText']
    GT = test_doc[true_span['start']: true_span['end']]
    pred = test_doc[pred_span['start']: pred_span['end']]
    true_start = true_span['start']
    pred_start = pred_span['start']
    
    
    dist = abs(true_start - pred_start)                      
    if  3 < dist:
        print(f"true start: {true_start}, true end: {true_span['end']} pred start: {pred_start}, distance: {abs(true_start - pred_start)}")
        print_str = GT
        print_pattern_out = "="*80
        print_pattern_in = "-"*80

        print(print_pattern_out)
        print(f"Segment {i}")
        print(print_pattern_in)
        print("GT".center(50))
        print(GT)
        print(print_pattern_in)

        print("PRED".center(50))
        print(pred, "\n")
        i += 1

### Error analysis with Extended Spacy

In [None]:
# Test for a single document
nlp = spacy.load("en_core_web_sm")
# nlp.add_pipe("set_custom_boundaries_original", before="parser")
nlp.add_pipe("set_custom_boundaries", before="parser")
nlp.tokenizer.add_special_case('Vet. App.', [{ORTH: 'Vet. App.'}])
nlp.tokenizer.add_special_case('Fed. Cir.', [{ORTH: 'Fed. Cir.'}])

gen_ann_span_ext = generate_ann_span_by_doc_with_spacy([doc_id], nlp)
len(gen_ann_span_ext[doc_id]['start'])

# For the extended spacy
closest_by_id_ext = find_closest_start_point(
    [doc_id], 
    true_ann_span_by_doc, 
    gen_ann_span_ext
)


# For the extended spacy
print(f"Error analysis for the document {doc_id} with extended segmenter")
print("----------------------------------------------------------------------------------")
_, _, _, _, _, precision, recall, f1_score = calculate_error_metrics(
    [doc_id], 
    true_ann_span_by_doc, 
    gen_ann_span_ext, 
    closest_by_id_ext
)

print(f'\nPrecision: {precision:.2f}\nRecall: {recall:.2f}\nF1_score: {f1_score:.2f}')

In [None]:
# Now Compare the true and generated splits of the chosen document
print(f"Comparing true and generated segments for document {doc_id} with extended segmenter")
i = 1
for true, pred in zip(closest_by_id_ext[doc_id]['true'], closest_by_id_ext[doc_id]['pred']):
    test_doc = documents_by_id[doc_id]['plainText']
    GT = test_doc[true['start']: true['end']]
    pred = test_doc[pred['start']: pred['end']]
    
    print_str = GT
    print_pattern_out = "="*80
    print_pattern_in = "-"*80
    
    print(print_pattern_out)
    print(f"Segment {i}")
    print(print_pattern_in)
    print("GT".center(50))
    print(GT)
    print(print_pattern_in)
    
    print("PRED".center(50))
    print(pred, "\n")
    i += 1

### Error analysis with LUIMA

In [None]:
gen_ann_span_luima = generate_ann_span_by_doc_with_luima([doc_id])
len(gen_ann_span_luima[doc_id]['start'])
# gen_ann_span_ext

# For the LUIMA
closest_by_id_luima = find_closest_start_point(
    [doc_id], 
    true_ann_span_by_doc, 
    gen_ann_span_luima
)


# For the LUIMA

print(f"Error analysis for the document {doc_id} with LUIMA segmenter")
print("----------------------------------------------------------------------------------")

_, _, _, _, _, precision, recall, f1_score = calculate_error_metrics(
    [doc_id], 
    true_ann_span_by_doc, 
    gen_ann_span_luima, 
    closest_by_id_luima
)

print(f'\nPrecision: {precision:.2f}\nRecall: {recall:.2f}\nF1_score: {f1_score:.2f}')

In [None]:
# Now Compare the true and generated splits of the chosen document
print(f"Comparing true and generated segments for document {doc_id} with LUIMA segmenter")
i = 1

for start, end in zip(gen_ann_span_luima['61aea55c97ad59b4cfc41299']['start'], gen_ann_span_luima['61aea55c97ad59b4cfc41299']['end']):
    test_doc = documents_by_id[doc_id]['plainText'].strip()
    GT = test_doc[start: end]
    
#     print(f"true start: {true_start}, true end: {true_span['end']} pred start: {pred_start}, distance: {abs(true_start - pred_start)}")
    print_str = GT
    print_pattern_out = "="*80
    print_pattern_in = "-"*80

    print(print_pattern_out)
    print(f"Segment {i}")
    print(print_pattern_in)
    print("GT".center(50))
    print(GT)
    print(print_pattern_in)

#     print("PRED".center(50))
#     print(pred, "\n")
    i += 1