# Preparing Data for Fine-tuning a NER Model
started Oct 17th, last edit: Oct 31st

In [None]:
import pandas as pd
import re
from bs4 import BeautifulSoup
import numpy as np
import os
from tqdm import tqdm
from nltk.tokenize import sent_tokenize, word_tokenize


## 1. **Data/Label Collection**
- Reading in htmls, extracting text and table from it
- Reading in labels extracted by NED figuring out where in the text/table they come from
- Add the location in text/table as an extra column to the label file (to be used for annotation).
- If labeled dataset is insufficient, we should consider augmenting it with more representative examples.


In [None]:
def clean_string(s):
    """Clean the string by retaining only alphabetical and numerical values and converting to lowercase."""
    return re.sub(r'[^a-zA-Z0-9]', '', s).lower()

def generate_variants(s):
    """Generate various forms of the string."""
    s = str(s).strip()
    cleaned = clean_string(s)
    with_dash = s.replace(' ', '-')
    with_mdash = s.replace(' ', '_')
    without_mdash = s.replace('-',' ')
    withou_dash = s.replace('_', ' ')
    nospace = s.replace(' ','')
    variants = [s, cleaned, with_dash, with_mdash,nospace,withou_dash,without_mdash]
    return variants

def cleantable(df):
    '''change multi-index column to single and 
    Iterate through each row and column in the DataFrame to remove non byte like characters'''
    
    df.columns = [' '.join(col).strip() for col in df.columns.values]
    for index, row in df.iterrows():
        for col in df.columns:
            df.at[index, col] = re.sub('[^a-zA-Z0-9]', '', str(row[col]))
    return df

def clean_tables_list(tables_list):
    '''Iterate over a list of tables and apply cleantable function to each one'''
    return [cleantable(df) for df in tables_list]

def get_html_tables(f):
    with open(f) as file:
        soup = BeautifulSoup(file, 'html.parser')
    try:
        tables = pd.read_html(str(soup))
        clean_tables_list(tables)
        return tables
    except:
        #print('no table in'+str(f))
        pass
    
def get_html_text(f,plength=100):
    with open(f) as file:
        soup = BeautifulSoup(file, 'html.parser')
        
    #to inspect html and identify the class label
    #print(soup.prettify()) 
    
    sections = soup.find_all('div', class_="article-text")

    # Extracting all paragraphs in the section
    paragraphs = soup.find_all('p')
    text = ''
    for i, para in enumerate(paragraphs):
        p = para.get_text()
        if (len(p)>plength) and (p[0].isalpha()):
            text+=p
            #print(f"Paragraph {i+1}:", p)
            #print('--------------')
    #text = re.sub(r'[^a-zA-Z0-9 .,]', '', text)#.lower()
    return text

def find_all(text, substring):
    return [match.start() for match in re.finditer(substring, text, re.IGNORECASE)]

def find_variant_in_text(sub, text):
    """Find the start and end index of a variant of 'sub' in 'text'."""
    ls =[]
    for variant in generate_variants(sub):
        start_index = find_all(text,variant)
        if start_index:  # This condition checks if the list is not empty
            for s in start_index:
                ls.append((s, s + len(variant)))
    return ls if ls else None

def find_variant_in_tables(sub, tables_list):
    """Find all table numbers, row indices, and column names where a variant of 'sub' exists."""
    locations = []
    
    for variant in generate_variants(str(sub)):
        for table_num, table in enumerate(tables_list, 1):
            locs = table.isin([variant]).stack()
            matched_locs = locs[locs == True].index.tolist()
            for row, col in matched_locs:
                locations.append((table_num, row, col))
                
    return locations if locations else None
    
def add_location_columns(df, text,tables_list,colname='ap_name1',text_only=True):
    """Add 'Location' column to DataFrame based on the 'ap_name1' column values' appearance in the text."""
    if colname in df.columns:
        df[colname+'_Loc_in_text'] = df[colname].apply(lambda cell: find_variant_in_text(cell, text) if pd.notna(cell) else None)
        if (tables_list) and (text_only==False):
            df[colname+'_Loc_in_table'] = df[colname].apply(lambda cell: find_variant_in_tables(cell, tables_list) if pd.notna(cell) else None)
        else:
            df[colname+'_Loc_in_table'] = None
    return df

def find_token_number_by_index(tokens, index):
    current_index = 0

    for token_num, token in enumerate(tokens):
        start_index = texts.find(token, current_index)
        end_index = start_index + len(token)
        
        if index >= start_index and index <= end_index:
            return token_num#, token
        
        current_index = end_index  # Set current index to the end of the current token

    return None
    
def find_variant_in_text_pro(sub, text):
    """Find the sentence number and token index of a variant of 'sub' in 'text', along with the total number of parts in the text."""
    
    sentences = sent_tokenize(text)
    results = set()

    for sentence_num, sentence in enumerate(sentences):
        tokens = word_tokenize(sentence)
        
        # Consider pairs of tokens for matching
        for token_index in range(len(tokens) - 1):  # -1 because we're looking at pairs
            token_pair = tokens[token_index] + ' ' + tokens[token_index + 1]
            
            for variant in generate_variants(sub):
                if token_pair == variant:
                    results.add((sentence_num, token_index, 2))

                # Also check individual tokens if needed
                elif tokens[token_index] == variant:
                    results.add((sentence_num, token_index, 1))
                    
    return list(results) if results else None

In [None]:
directory = 'data/2022ApJ_PREPFILES/'
outdir = 'data/lable_locations2/'

count = 0
for prepfilename in tqdm(os.listdir(directory)):
    if count>1000:
        break
    prepfilepath = os.path.join(directory, prepfilename)
    if os.path.isfile(prepfilepath):
        
        # Read in label file 
        labeldf = pd.read_csv(prepfilepath, delimiter="|", skipinitialspace=True, low_memory=False)
        labeldf.columns = labeldf.columns.str.strip()
        

        try:
            # Read in html sections and tables
            s = prepfilename.split('.')
            htmldir = 'data/'+s[0][0:4]+'-'+s[0][4:]+'-Vol'+s[3][0:3]+'/HTML/'
            htmlfilepath = os.path.join(htmldir, prepfilename[0:19]+'.html')
            texts = get_html_text(htmlfilepath)
            tables = get_html_tables(htmlfilepath)
            #find where in section texts or table the labeled data are mentioned
            df = add_location_columns(labeldf,texts,tables,colname='ap_name1',text_only=True)
            #df = add_location_columns(labeldf,texts,tables,colname='name1',text_only=True)
            df = add_location_columns(labeldf,texts,tables,colname='vz1',text_only=True)
            df = add_location_columns(labeldf,texts,tables,colname='coordx1',text_only=True)
            df = add_location_columns(labeldf,texts,tables,colname='coordy1',text_only=True)
            #df = add_location_columns(labeldf,texts,tables,colname='type1')
            
            df.to_csv(outdir+prepfilename[:-3]+'csv', index=False) # commented so no overwrite now
            count+=1
        except Exception as e:
            print('---in file:'+prepfilename+'---')
            print(f"An error occurred: {e}")
            print('---------------------------')
    


## 2. **Annotation**
- Mark and label entities within your text.
- Entities to start with: `Object Name`, `RA`, `DEC`, `Redshift`, `Type`. We may add more later.
#### 2.1 **Figure out annotation formats**
Data can be represented in various formats:
- **BIO (or IOB) Format** (Begining/Inside/Outside)
- **CoNLL Format**: Columns-based, used in datasets like CoNLL-2003. (BIO seems to be under this?)
- **Spacy Format**: JSON format (for Spacy users) with entities represented by start/end character positions.

Manual annotation can be time-consuming. If NED had not already done some part of this we could have considered: [Doccano](https://doccano.herokuapp.com/), [Prodigy](https://prodi.gy/) (by Spacy creators, paid), [Labelbox](https://www.labelbox.com/), or [Brat](http://brat.nlplab.org/). 

In [None]:
outdir = 'data/lable_locations/'

labelcolumns = ['ap_name1_Loc_in_text', 'vz1_Loc_in_text','coordx1_Loc_in_text','coordy1_Loc_in_text','type1_Loc_in_text']
valuecolumns = ['ap_name1','vz1','coordx1','coordy1','type1']
outpcolumns = ['name','redshift','RA','DEC','Type']

count = 0
for labelfilename in tqdm(os.listdir(outdir)):

    filename = "data/conlls2/" + labelfilename[:-4] + ".conll"
    labelfilepath = os.path.join(outdir, labelfilename)
    if os.path.isfile(labelfilepath) and not os.path.exists(filename):
        try:
            # Read in csv and html text 
            labeldf = pd.read_csv(labelfilepath, encoding=None)
            s = labelfilename.split('.')
            htmldir = 'data/'+s[0][0:4]+'-'+s[0][4:]+'-Vol'+s[3][0:3]+'/HTML/'
            htmlfilepath = os.path.join(htmldir, labelfilename[0:19]+'.html')
            texts = get_html_text(htmlfilepath)
            #tables = get_html_tables(htmlfilepath)
        
            # Step 1: Tokenize, Split text into sentences
            sentences = sent_tokenize(texts)
        
            conll_data = ""
            for sentence_num, sentence in enumerate(sentences):
                tokens = word_tokenize(sentence)
                labels = ['O'] * len(tokens)  # Initialize with 'O' tags

                # Consider pairs of tokens for matching
                for token_index in range(len(tokens) - 1):  # -1 because we're looking at pairs
                    token_pair = tokens[token_index] + ' ' + tokens[token_index + 1]
            
                    for val in valuecolumns:
                        if val in labeldf.columns:
                            for sub in labeldf[val]:
                                for variant in generate_variants(sub):
                                    if token_pair == variant:
                                        labels[token_index] = "B-" + val
                                        labels[token_index+1] = "I-" + val
                                    
                                    # Also check individual tokens if needed
                                    elif tokens[token_index] == variant:
                                        labels[token_index] = "B-" + val
                            # Append to the conll_data
                conll_data += "\n".join([f"{token} {label}" for token, label in zip(tokens, labels)])
                conll_data += "\n\n"  # Add an extra newline to separate sentences
            with open(filename, 'w', encoding='utf-8') as file:
                file.write(conll_data)
        except:
            pass

### 3. **Train/Test Split**
- Consider an 80% training, 10% validation, and 10% test split.
- Respect document boundaries to avoid overlap between sets.

In [None]:
import os

# Path to your folder containing the .txt files
folder_path = 'data/conlls2/'

# Read all the files and combine their content
all_data = []
for filename in os.listdir(folder_path):
    if filename.endswith('.conll'):
        with open(os.path.join(folder_path, filename), 'r', encoding='utf-8') as f:
            all_data.append(f.read())

In [None]:
from sklearn.model_selection import train_test_split

# Split 80% for training, 10% for validation, and 10% for testing
train_data, temp_data = train_test_split(all_data, test_size=0.2, random_state=42)
val_data, test_data = train_test_split(temp_data, test_size=0.5, random_state=42)


In [None]:
with open('train.txt', 'w', encoding='utf-8') as f:
    f.write("\n\n".join(train_data))

with open('val.txt', 'w', encoding='utf-8') as f:
    f.write("\n\n".join(val_data))

with open('test.txt', 'w', encoding='utf-8') as f:
    f.write("\n\n".join(test_data))


In [None]:
with open("data/train.txt", "r") as file:
    lines = file.readlines()

filtered_lines = []
current_sentence = []

for line in lines:
    line = line.strip()
    if line:  # If the line is not empty
        current_sentence.append(line)
        continue
    else:  # If the line is empty (end of a sentence)
        if any(tag.startswith(("B-", "I-")) for _, tag in (l.split() for l in current_sentence)):
            filtered_lines.extend(current_sentence)
            filtered_lines.append("")  # Empty line to denote end of sentence
        current_sentence = []

# Handle the last sentence if it didn't end with an empty line
if current_sentence and any(tag.startswith(("B-", "I-")) for _, tag in (l.split() for l in current_sentence)):
    filtered_lines.extend(current_sentence)

# Write the filtered lines back to the file or to a new file
with open("filtered_train.txt", "w") as file:
    for line in filtered_lines:
        file.write(line + "\n")


In [None]:
with open("data/test.txt", "r") as file:
    lines = file.readlines()

filtered_lines = []
current_sentence = []

for line in lines:
    line = line.strip()
    if line:  # If the line is not empty
        current_sentence.append(line)
        continue
    else:  # If the line is empty (end of a sentence)
        if any(tag.startswith(("B-", "I-")) for _, tag in (l.split() for l in current_sentence)):
            filtered_lines.extend(current_sentence)
            filtered_lines.append("")  # Empty line to denote end of sentence
        current_sentence = []

# Handle the last sentence if it didn't end with an empty line
if current_sentence and any(tag.startswith(("B-a", "I-a","B-v","I-v")) for _, tag in (l.split() for l in current_sentence)):
    filtered_lines.extend(current_sentence)

# Write the filtered lines back to the file or to a new file
with open("filtered_test.txt", "w") as file:
    for line in filtered_lines:
        file.write(line + "\n")


In [None]:
with open("data/val.txt", "r") as file:
    lines = file.readlines()

filtered_lines = []
current_sentence = []

for line in lines:
    line = line.strip()
    if line:  # If the line is not empty
        current_sentence.append(line)
        continue
    else:  # If the line is empty (end of a sentence)
        if any(tag.startswith(("B-", "I-")) for _, tag in (l.split() for l in current_sentence)):
            filtered_lines.extend(current_sentence)
            filtered_lines.append("")  # Empty line to denote end of sentence
        current_sentence = []

# Handle the last sentence if it didn't end with an empty line
if current_sentence and any(tag.startswith(("B-a", "I-a","B-v","I-v")) for _, tag in (l.split() for l in current_sentence)):
    filtered_lines.extend(current_sentence)

# Write the filtered lines back to the file or to a new file
with open("filtered_val.txt", "w") as file:
    for line in filtered_lines:
        file.write(line + "\n")

### 4. **Preprocessing**
- Tokenize consistently with the pre-trained model's tokenization.
- Other steps might include converting to lowercase, handling punctuation, etc.

### 5. **Model-Specific Formatting**
- Convert data to be compatible with your chosen framework.
- For HuggingFace Transformers, use their `TokenClassification` model format.


In [None]:
from datasets import load_dataset

data_files = {
    'train': 'train.txt',
    'validation': 'val.txt',
    'test': 'test.txt'
}

# Load the dataset from local files without specifying a script
dataset = load_dataset('text', data_files=data_files)

In [None]:
def process_sample(batch):
    # Process each line in the batch
    tokens_list = []
    tags_list = []
    tokens = []
    tags = []
    
    for line in batch['text']:
        if line:  # non-empty line means we have a token-tag pair
            token, tag = line.split()  # assuming space is the delimiter
            tokens.append(token)
            tags.append(tag)
        else:  # empty line means end of sentence
            tokens_list.append(tokens)
            tags_list.append(tags)
            tokens = []
            tags = []
    
    # Add remaining tokens and tags if there's any
    if tokens:
        tokens_list.append(tokens)
        tags_list.append(tags)
    
    return {'tokens': tokens_list, 'tags': tags_list}

# Apply the processing function on all splits of the dataset
pdataset = dataset.map(process_sample, batched=True, remove_columns=['text'])


In [None]:
pdataset

In [None]:
pdataset['train'][0]

### 6. **Augmentation (Optional)**
For smaller datasets, consider:
- Back translation
- Synonym replacement
- Sentence shuffling

### 7. **Data Quality Checks**
- Ensure annotation consistency.
- Address issues like overlapping annotations or mislabeled entities.

After data preparation, proceed with fine-tuning your NER model, evaluating on the validation set and tuning hyperparameters as needed.


In [None]:
import matplotlib.pyplot as plt
import os
import pandas as pd
import numpy as np
def count_rows_in_column(folder_path: str, column_name: str) -> int:
    """
    Counts non-empty rows in a specified column across all CSV files in a given folder.
    
    Args:
    - folder_path (str): Path to the folder containing CSV files.
    - column_name (str): Name of the column to count non-empty rows for.

    Returns:
    - int: Total count of non-empty rows for the specified column across all CSV files.
    """
    
    total_count = 0
    
    # Loop over all files in the folder
    for filename in os.listdir(folder_path):
        # Check if the file is a CSV
        if filename.endswith('.csv'):
            filepath = os.path.join(folder_path, filename)
            try:
                # Read the CSV file into a DataFrame
                df = pd.read_csv(filepath)
                
                # Check if the column exists
                if column_name in df.columns:
                    # Count non-empty rows (assuming NaN as empty)
                    count = df[column_name].count()
                    total_count += count
                    
            except Exception as e:
                print(f"Error reading {filename}: {e}")
                
    return total_count

labdir = 'data/lable_locations/'
# Columns you are interested in
coln = ['ap_name1', 'ap_name1_Loc_in_text','ap_name1_Loc_in_table']
colz = ['vz1', 'vz1_Loc_in_text','vz1_Loc_in_table']
colx = ['coordx1', 'coordx1_Loc_in_text','coordx1_Loc_in_table']

ncounts,zcounts,xcounts = [],[],[]
for c in range(len(coln)):
    ncounts.append(count_rows_in_column(labdir,coln[c]))
    xcounts.append(count_rows_in_column(labdir,colx[c]))
    zcounts.append(count_rows_in_column(labdir,colz[c]))
    
# Bar width and positions
barWidth = 0.25
r1 = np.arange(len(ncounts))
r2 = [x + barWidth for x in r1]
r3 = [x + barWidth for x in r2]

# Create bars
plt.bar(r1, ncounts, width=barWidth, edgecolor='grey', label='Names')
plt.bar(r2, xcounts, width=barWidth, edgecolor='grey', label='coords')
plt.bar(r3, zcounts, width=barWidth, edgecolor='grey', label='redshift')
labels = ['in prep file','found in text','found in tables']
# Title & subtitle
plt.xticks([r + barWidth for r in range(len(ncounts))], labels)
plt.legend()
plt.ylabel('number')