## Combined NER

In [1]:
import os
import re
import sys
import pdb
import json
import torch
import spacy
import random
from pathlib import Path
from datetime import datetime
from temporal_taggers.evaluation import merge_tokens, insert_tags_in_raw_text
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, BertForTokenClassification

torch.cuda.empty_cache()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [2]:
import warnings
warnings.filterwarnings("ignore", message="Field .* has conflict with protected namespace")
# To keep tokenization consistent - we use spacy
nlp = spacy.load("en_core_web_sm", disable=["parser", "ner"])

#### Selecting Reports to be extracted

In [3]:
directory_path = r"..\..\..\wamex\data"
data_path = Path(directory_path)
random.seed(42)

filepath = data_path / "wamex_xml"
reports = {}
for root, dirs, files in os.walk(filepath):
    for file in files:
        try:
            with open(filepath / file, 'r') as f:
                reports[file] = json.load(f)
        except:
            pass

# Load the Raw WAMEX XML data - list of sentences
wamex_xml_path = data_path / "wamex_xml_snapshot.json"
with open(wamex_xml_path, 'r') as file:
    xml_data = json.load(file)

report_list = list(xml_data.keys())
while len(reports) < 100:
    report_chosen = random.choice(report_list)
    sentences = xml_data[report_chosen]
    # Check if the report is not already in the reports dictionary
    # Select reports with less than 1000 sentences
    if sentences and report_chosen not in reports and len(sentences) < 1000:
        reports[report_chosen] = sentences

In [4]:
print("Number of reports: ", len(reports))
# for report in reports:
#     print(reports[report])
#     print(len(reports[report]))

Number of reports:  100


### Proprocess Sentences

In [5]:
def preprocess_text(text):
    text = " ".join([token.text for token in nlp(text)])
    return text

def clean(text):
    if text[-3:] == "Mt.":
        text = text[:-3] + "Mt ."
    else:
        text = text.replace("Mt.", "Mt")
    return text

tagged_reports = {}

for report in reports:
    tagged_reports[report] = {}
    for sentence in reports[report]:
        tagged_reports[report][sentence] = {"preprocess": preprocess_text(clean(sentence))}
        

In [6]:
def format_date_string(text):
    # Regular expression to match:
    # 1. Month Day , Year (with extra space before the comma)
    # 2. Month Day,Year (without space after the comma)
    pattern = re.compile(r"(\b\w+\s\d{1,2})\s?,\s?(\d{4})")
    # Replace the pattern with "Month Day, Year" with the correct spacing
    formatted_text = pattern.sub(r"\1, \2", text)
    return formatted_text

def find_date_pattern(text):
    # Regular expression for matching the pattern dd/mm/yyyy
    date_pattern = re.compile(r"\b(0?[1-9]|[12][0-9]|3[01])/(0?[1-9]|1[0-2])/(?:[0-9]{2}|[0-9]{4})\b")
    
    # Search for the pattern in the text
    match = date_pattern.search(text)
    
    if match:
        return match.group(0)
    return None

def ordinal(n):
    return "%d%s" % (n, "th" if 11 <= n <= 13 else {1: "st", 2: "nd", 3: "rd"}.get(n % 10, "th"))

def parse_date(found_date):
    # Determine if the year is two or four digits
    if len(found_date.split('/')[-1]) == 2:
        # Assume that two-digit years belong to the 2000s
        date_obj = datetime.strptime(found_date, "%d/%m/%y")
    else:
        date_obj = datetime.strptime(found_date, "%d/%m/%Y")
    
    return date_obj

# ________________________________________________________________________  

def find_month_year_pattern(text):
    # Regular expression for matching the pattern MM/YYYY (ensures month is two digits)
    month_year_pattern = re.compile(r"\b(0[1-9]|1[0-2])/\d{4}\b")
    
    # Search for the pattern in the text
    match = month_year_pattern.search(text)
    
    if match:
        return match.group(0)
    return None

def format_month_year(month_year):
    # Parse the date string
    date_obj = datetime.strptime(month_year, "%m/%Y")
    # Format as "Month Year"
    return date_obj.strftime("%B %Y")

# ________________________________________________________________________  
# Example usage
text = "This report was created on 14/12/2001 for the project."
found_date = find_date_pattern(text)

if found_date:
    print(f"Found date: {found_date}")
else:
    print("No date pattern found.")

# ________________________________________________________________________

for report in tagged_reports:
    for text in tagged_reports[report]:
        curr = tagged_reports[report][text]["preprocess"]

        found_date = find_date_pattern(curr)
        while found_date:
            # Format the date as "14th December 2001"
            date_obj = parse_date(found_date)
            formatted_date = f"{ordinal(date_obj.day)} {date_obj.strftime('%B')} {date_obj.year}"
            tagged_reports[report][text]["preprocess"] = curr.replace(found_date, formatted_date)

            curr = tagged_reports[report][text]["preprocess"]
            found_date = find_date_pattern(curr)
            
        found_month_year = find_month_year_pattern(curr)
        while found_month_year:
            # Format the month and year as "December 2001"
            formatted_month_year = format_month_year(found_month_year)
            tagged_reports[report][text]["preprocess"] = curr.replace(found_month_year, formatted_month_year)

            curr = tagged_reports[report][text]["preprocess"]
            found_month_year = find_month_year_pattern(curr)

Found date: 14/12/2001


In [3]:
# saved the tagged reports
# with open("../Results/tagged_reports.json", 'w') as file:
#     json.dump(tagged_reports, file)

# Load the tagged reports
with open("../Results/tagged_reports.json", 'r') as file:
    tagged_reports = json.load(file)

# NER MODELS

### 6-GeoEntity NER

In [4]:
save_directory = './Models/6-GeoEntityNER'
# Load the tokenizer and model from the saved directory
tokenizer = AutoTokenizer.from_pretrained(save_directory)
model = AutoModelForSeq2SeqLM.from_pretrained(save_directory).to(device)

def tokenize_data(texts, tokenizer, max_length=256):
    return tokenizer(texts, padding="max_length", truncation=True, max_length=max_length, return_tensors="pt").to(device)

def predict_entities(texts, model, tokenizer):
    inputs = tokenize_data(texts, tokenizer)
    inputs = {key: val.to(device) for key, val in inputs.items()}  # Move inputs to GPU
    with torch.no_grad():  # Disable gradient calculation
        outputs = model.generate(**inputs, max_new_tokens=256)
    return tokenizer.batch_decode(outputs, skip_special_tokens=True)

def NER_tagging(tagged_reports, model, tokenizer):
    for report in tagged_reports:
        for sentence in tagged_reports[report]:
            text = tagged_reports[report][sentence]["preprocess"]
            tagged_reports[report][sentence]["ner"] = predict_entities(text, model, tokenizer)
    return tagged_reports


# Extract entities from the reports
tagged_reports = NER_tagging(tagged_reports, model, tokenizer)

In [82]:
for report in tagged_reports:
    for sentence in tagged_reports[report]:
        text = tagged_reports[report][sentence]["preprocess"].split()
        ner_tags = tagged_reports[report][sentence]["ner"][0].split()
        if len(text) != len(ner_tags):
            if len(text) < len(ner_tags):
                ner_tags = ner_tags[:len(text)]
            else:
                ner_tags = ner_tags + ['O'] * (len(text) - len(ner_tags))
            tagged_reports[report][sentence]["ner"] = [" ".join(ner_tags)]

### Temporal NER (Real-time)

In [35]:
def do_nothing():
    pass

pdb.set_trace = do_nothing

def clean_timex_tags(text):
    # Regular expression to find nested TIMEX3 tags
    # Regular expression patterns to match and clean up spaces
    patterns = {
        r'<\s+TIMEX3': r'<TIMEX3',             # Clean up leading spaces before <TIMEX3
        r'</TIMEX3\s+>': r'</TIMEX3>',          # Clean up trailing spaces after </TIMEX3
        r'(\w+)="([^"]*?)\s+"': r'\1="\2"'      # Clean up spaces inside attributes (from previous example)
    }

    # Apply each pattern replacement
    for pattern, replacement in patterns.items():
        text = re.sub(pattern, replacement, text)
    
    nested_timex_pattern = re.compile(r'<TIMEX3[^>]*>(<TIMEX3[^>]*>[^<]+</TIMEX3>)</TIMEX3>')
        
    # Replace the nested TIMEX3 tags with a single TIMEX3 tag
    while nested_timex_pattern.search(text):
        text = nested_timex_pattern.sub(r'\1', text)
    
    return text

def run_temporal_model(tagged_reports):
    time_model = BertForTokenClassification.from_pretrained("satyaalmasian/temporal_tagger_BERT_tokenclassifier").to(device)
    time_tokenizer = AutoTokenizer.from_pretrained("satyaalmasian/temporal_tagger_BERT_tokenclassifier", use_fast=False)

    id2label = {v: k for k, v in time_model.config.label2id.items()}

    temporal_tagged_reports = {}
    for report in tagged_reports:
        annotation_id = 1
        temporal_tagged_reports[report] = {}
        for sentence in tagged_reports[report]:
            temporal_tagged_reports[report][sentence] = {}
            try:
                text = tagged_reports[report][sentence]["preprocess"]
                processed_text = time_tokenizer(text, return_tensors="pt", truncation=True, padding=True).to(device)

                with torch.no_grad():
                    result = time_model(**processed_text)

                classification = torch.argmax(result[0], dim=2)

                # Merge the tokens
                merged_tokens = merge_tokens(processed_text["input_ids"][0], classification[0], id2label, time_tokenizer)
                annotated_text, annotation_id = insert_tags_in_raw_text(text, merged_tokens, annotation_id)
                annotated_text = clean_timex_tags(annotated_text)
                temporal_tagged_reports[report][sentence]["temporal_tagger"] = annotated_text
                print(annotated_text)
            except Exception as e:
                print(f"An error occurred while processing the text: {sentence}")
                temporal_tagged_reports[report][sentence]["temporal_tagger"] = tagged_reports[report][sentence]["preprocess"]
                continue
    return temporal_tagged_reports

temporal_tagged = run_temporal_model(tagged_reports)

Convert TIMEX3 format into BIO format

In [62]:
def extract_timex_and_spans(text):
    # Regular expression to find all TIMEX3 tags and extract their content
    timex_pattern = re.compile(r'<TIMEX3[^>]*>([^<]+)</TIMEX3>')
    
    timex_values = []
    spans = []
    types = []
    cleaned_text = text
    match = timex_pattern.search(cleaned_text)
    
    # Find all TIMEX3 tags in the text
    # for match in timex_pattern.finditer(cleaned_text):
    while match:
        timex_value = match.group(1)
        timex_values.append(timex_value)

        type_pattern = re.compile(r'type="([^"]*)"')
        type = type_pattern.search(match.group(0)).group(1)
        types.append(type)
        
        # Calculate the start and end span in the original text
        start_span = match.span()[0]
        end_span = start_span + len(timex_value)
        spans.append((start_span, end_span))
        
        # Update the cleaned text by removing the TIMEX3 tag
        cleaned_text = cleaned_text.replace(match.group(0), timex_value)
        match = timex_pattern.search(cleaned_text)
        
    return timex_values, cleaned_text, spans, types


# Extract TIMEX3 tags and their spans from the text
for report in tagged_reports:
    for sentence in tagged_reports[report]:
        time_tagged_sentence = temporal_tagged[report][sentence]["temporal_tagger"]
        contains_real_time = "TIMEX3" in time_tagged_sentence

        tag_labels = ["O"] * len(tagged_reports[report][sentence]["preprocess"].split())

        if contains_real_time:
            timex_values, cleaned_text, spans, types = extract_timex_and_spans(time_tagged_sentence)
            doc = nlp(cleaned_text.strip())
            numOfTemporals = len(timex_values)
            temporalNum = 0
            span = spans[temporalNum]

            for i, word in enumerate(doc):
                idx = word.idx
                endIdx = idx + len(word.text)
                if idx >= span[0] and endIdx <= span[1]:
                    if idx == span[0]:
                        tag_labels[i] = "B-DATE"
                    else:
                        tag_labels[i] = "I-DATE"
                if idx > span[1]:
                    temporalNum += 1
                    if temporalNum < len(spans):
                        span = spans[temporalNum]
                        
        tagged_reports[report][sentence]["time_tagged"] = " ".join(tag_labels)
        

### Geological Time NER - Rule Based

In [6]:
# Text should be already preprocessed by spacey
def geological_timescale_ner(text):
    tokens = text.split()
    
    # Define the geological timescales pattern
    timescale_pattern = r'\b~?\d+(?:[.,]?\d+)?(?:\s*(?:to|-)\s*~?\d+(?:[.,]?\d+)?)?\s*(?:Ma|ka|Ga|MYA|KYA)\b'
    
    # Extract geological timescales and their positions
    geological_timescales = re.finditer(timescale_pattern, text, re.IGNORECASE)
    
    # Create a dictionary to store geological timescale spans
    geo_timescale_spans = set()


    for match in geological_timescales:
        start, end = match.span()
        if start != 0 and text[start-1] == "~":
            start -= 1
        geo_timescale_spans.add((start, end))
    
    # Prepare to format the output
    tokens_with_labels = []

    token_start = 0
    token_end = -1

    # Process tokens and assign labels
    for token in tokens:
        token_text = token
        token_start = token_end + 1
        token_end = token_start + len(token_text)
        token_label = 'O'

        for start, end in geo_timescale_spans:
            if token_start >= start and token_end <= end:
                if (token_start == start):
                    token_label = f'B-TIMESCALE'
                else:
                    token_label = f'I-TIMESCALE'
                break
        
        tokens_with_labels.append(f"{token_label}")
    
    return " ".join(tokens_with_labels)

# Sample text from geological surveys
text = "Mapping and geochronology by the Geological Society of Australia ( Arriens , 1971 ) reveal that the granitic rocks in the western part of the Yalgoo 1:250,000 map sheet are in the order of 2,800 to 3,000 ma ."

# Format the NER output
ner = geological_timescale_ner(text)
print(ner)

O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O B-TIMESCALE I-TIMESCALE I-TIMESCALE I-TIMESCALE O


In [7]:
for report in tagged_reports:
    for sentence in tagged_reports[report]:
        text = tagged_reports[report][sentence]["preprocess"]
        ner = geological_timescale_ner(text)
        tagged_reports[report][sentence]["geotime_tagged"] = ner

In [8]:
for report in tagged_reports:
    for sentence in tagged_reports[report]:
        print(tagged_reports[report][sentence]["geotime_tagged"])
        print(tagged_reports[report][sentence]["time_tagged"])
        print(tagged_reports[report][sentence]["ner"])
        print(tagged_reports[report][sentence]["preprocess"])
        break
    break

O O O O O O O O O O O O O
O O O O O O O O O O O O O
['O O O O O O O O O B-LOCATION I-LOCATION I-LOCATION O']
Managed By : GME Resources Ltd Level 2 907 Canning Highway Mt .


In [9]:
# # save the tagged reports
# with open("../Results/tagged_reports_all.json", 'w') as file:
#     json.dump(tagged_reports, file)

# Load the tagged reports
with open("../Results/tagged_reports_all.json", 'r') as file:
    tagged_reports = json.load(file)

## Combine Results

In [10]:
entity_types = ["B-DATE", "I-DATE", "B-TIMESCALE", "I-TIMESCALE", "B-ROCK", "I-ROCK", "B-LOCATION", "I-LOCATION",
                "B-MINERAL", "I-MINERAL", "B-STRAT", "I-STRAT", "B-ORE_DEPOSIT", "I-ORE_DEPOSIT"]

tagged_reports_combined = {}

for report in tagged_reports:
    tagged_reports_combined[report] = {}
    for sentence in tagged_reports[report]:
        tagged_reports_combined[report][sentence] = {}
        combined_tags = ["O"] * len(tagged_reports[report][sentence]["preprocess"].split())
        six_geo_tags = tagged_reports[report][sentence]["ner"][0].split()
        time_tags = tagged_reports[report][sentence]["time_tagged"].split()
        geotime_tags = tagged_reports[report][sentence]["geotime_tagged"].split()
        for i in range(len(combined_tags)):
            geo_tag = six_geo_tags[i]
            if geo_tag != "O" and geo_tag in entity_types:
                combined_tags[i] = geo_tag
            time_tag = time_tags[i]
            if time_tag != "O" and time_tag in entity_types:
                combined_tags[i] = time_tag
            geo_tag = geotime_tags[i]
            if geo_tag != "O" and geo_tag in entity_types:
                combined_tags[i] = geo_tag
        tagged_reports_combined[report][sentence]["combined"] = " ".join(combined_tags)
        tagged_reports_combined[report][sentence]["preprocess"] = tagged_reports[report][sentence]["preprocess"]


In [11]:
# # save the tagged reports
# with open("../Results/tagged_reports_combined_results.json", 'w') as file:
#     json.dump(tagged_reports_combined, file)

# Load the tagged reports
with open("../Results/tagged_reports_combined_results.json", 'r') as file:
    tagged_reports_combined = json.load(file)

### Evaluation
- Bar Graph
- Num of entities per entities types + unique num

In [None]:
num_of_entities = {'ROCK': 0, 'LOCATION': 0, 'MINERAL': 0, 'STRAT': 0, 'ORE_DEPOSIT': 0, 'DATE': 0, 'TIMESCALE': 0}

for report in tagged_reports_combined:
    for sentence in tagged_reports_combined[report]:
        text = tagged_reports_combined[report][sentence]["preprocess"].split()
        ner_tags = tagged_reports_combined[report][sentence]["combined"].split()

        for i in range(len(text)):
            if ner_tags[i] != "O":
                num_of_entities[ner_tags[i]] += 1