In [16]:
import os
import re
import csv
import shutil
import random
import json
import pandas as pd
import numpy as np
from tqdm import tqdm
from collections import defaultdict, Counter
from utils import clean_text, sanitize_text

In [17]:
# Data paths
mimic_note_fpath = '../data/mimic3/NOTEEVENTS.csv'        # MIMIC-III clinical notes
mimic_tools_dpath = '../scripts/mimic-tools/'              # Directory containing pseudonymization scripts
lexicon_fpath = '../data/lexicon/lexicon.json'             # Lexicon dictionary

# Output paths and configuration
data_root = '../data/mimic_synthetic/'
num_val_examples = 10000
num_test_examples = 10000
num_examples = num_val_examples + num_test_examples
all_output_fpath = os.path.join(data_root, 'all.tsv')
val_output_fpath = os.path.join(data_root, 'val.tsv')
test_output_fpath = os.path.join(data_root, 'test.tsv')

# Word filtering/corruption parameters
min_word_len = 1  # Accept all words with at least 1 character 
no_corruption_prob = 0.1
max_corruptions = 2
do_substitution = True
do_transposition = True
DEFAULT_MAX_CHARACTER_POSITIONS = 64

# Directories for temporary pseudonymization
pseudo_in_dpath = os.path.join(data_root, 'temp')
pseudo_out_dpath = os.path.join(data_root, 'temp_pseudonym')


In [18]:
# DATA LOADING
# ================================
# Load MIMIC-III notes CSV file
print(f"Reading {os.path.basename(mimic_note_fpath)}... ", end="")
df_notes = pd.read_csv(mimic_note_fpath, low_memory=False)
df_notes = df_notes.set_index('ROW_ID')
print(f"done! {len(df_notes)} notes")

# Load lexicon for word validation
print(f"Reading {lexicon_fpath}... ", end="")
with open(lexicon_fpath, 'r', encoding='utf-8') as fd:
    vocab = json.load(fd)
vocab_set = set(vocab)
print(f"{len(vocab)} words")

Reading NOTEEVENTS.csv... done! 2083180 notes
Reading ../data/lexicon/lexicon.json... 822919 words


In [19]:
# HELPER FUNCTIONS
# ================================

# Function: select a random valid word from a clinical note and return its context
puncs = list("[]!\"#$%&'()*+,./:;<=>?@\\^_`{|}~-")
def random_word_context(text, max_trial=100):
    words = text.split()
    trial = 0
    while trial < max_trial:
        trial += 1
        w_idx = random.randint(0, len(words)-1)
        word = words[w_idx]
        left_res = []
        right_res = []
        # First try: if the word is directly acceptable
        if len(word) >= min_word_len and word.lower() in vocab_set and len(word) < DEFAULT_MAX_CHARACTER_POSITIONS - 4:
            return word, ' '.join(words[:w_idx]), ' '.join(words[w_idx+1:])
        else:
            # Remove punctuation from beginning and end
            if word and word[0] in puncs:
                left_res = [word[0]]
                word = word[1:]
            if not word:
                continue
            if word and word[-1] in puncs:
                right_res = [word[-1]]
                word = word[:-1]
            if len(word) < min_word_len or word.lower() not in vocab_set or len(word) >= DEFAULT_MAX_CHARACTER_POSITIONS - 4:
                continue
            # Check for anonymized fields in surrounding context 
            right_snip = ' '.join(words[w_idx+1:w_idx+5])
            left_snip = ' '.join(words[max(0, w_idx-4):w_idx])
            if ('**]' in right_snip and '[**' not in right_snip) or ('[**' in left_snip and '**]' not in left_snip):
                continue
            return word, ' '.join(words[:w_idx] + left_res), ' '.join(right_res + words[w_idx+1:])
    raise ValueError("Failed to choose a valid word context.")

# Functions to perform character-level corruption
alphabet = 'abcdefghijklmnopqrstuvwxyz'
def random_alphabet():
    return random.choice(alphabet)

# Build list of possible operations based on configuration
operation_list = ['ins', 'del']
if do_substitution:
    operation_list.append('sub')
if do_transposition:
    operation_list.append('tra')

def single_corruption(word):
    while True:
        oper = random.choice(operation_list)
        if oper == "del":  
            if len(word) == 1:
                continue
            cidx = random.randint(0, len(word)-1)
            ret = word[:cidx] + word[cidx+1:]
            break
        elif oper == "ins":  
            cidx = random.randint(0, len(word))
            ret = word[:cidx] + random_alphabet() + word[cidx:]
            break
        elif oper == "sub":  
            cidx = random.randint(0, len(word)-1)
            new_char = random_alphabet()
            while new_char == word[cidx]:
                new_char = random_alphabet()
            ret = word[:cidx] + new_char + word[cidx+1:]
            break
        elif oper == "tra": 
            if len(word) == 1:
                continue
            cidx = random.randint(0, len(word)-2)
            if word[cidx] == word[cidx+1]:
                continue
            ret = word[:cidx] + word[cidx+1] + word[cidx] + word[cidx+2:]
            break
        else:
            raise ValueError(f"Unknown operation: {oper}")
    return ret

def corrupt_word(word_original, max_corruptions=max_corruptions):
    # With a certain probability, leave the word unmodified.
    if no_corruption_prob > 0.0 and random.uniform(0, 1) < no_corruption_prob:
        return word_original
    num_corruptions = random.randint(1, max_corruptions)
    # Repeat corruption until the corrupted word is different from the original.
    corrupted = word_original
    while True:
        temp_word = corrupted
        for i in range(num_corruptions):
            temp_word = single_corruption(temp_word)
        if temp_word != word_original:
            corrupted = temp_word
            break
    return corrupted

# Process note text (cleaning and sanitizing)
def process_note(note):
    note = re.sub('\n', ' ', note)
    note = re.sub('\t', ' ', note)
    return sanitize_text(clean_text(note))


In [20]:
# MAIN PROCEDURE
# ================================

# 1. Randomly select note ids that satisfy the text length requirement.
random.seed(1234)
note_ids = list(df_notes.index)
random.shuffle(note_ids)
typo_noteids = set()
selected_count = 0
for nid in note_ids:
    note = str(df_notes.loc[nid].TEXT)  # Ensure note text is a string.
    if len(note.strip()) >= 2000 and nid not in typo_noteids:
        typo_noteids.add(nid)
        selected_count += 1
    if selected_count == num_examples:
        break
typo_noteids = list(typo_noteids)
print("Selected note IDs (first 10):", typo_noteids[:10])

# 2. For each selected note, pick a random word with context.
examples = []
for nid in tqdm(typo_noteids, desc="Extracting word contexts"):
    note = str(df_notes.loc[nid].TEXT)
    try:
        word, left, right = random_word_context(note)
        examples.append([word, left, right])
    except Exception as e:
        # If no valid word is found, use empty strings.
        examples.append(["", "", ""])

# Debug: Check how many chosen words contain punctuation.
words = [ex[0] for ex in examples]
words_with_punc = [w for w in words if any(not c.isalpha() for c in w)]
print(f"{len(words_with_punc)} words have punctuation")
# print(words_with_punc)  

# 3. Write out left/right contexts to disk for pseudonymization.
if os.path.exists(pseudo_in_dpath):
    shutil.rmtree(pseudo_in_dpath)
if os.path.exists(pseudo_out_dpath):
    shutil.rmtree(pseudo_out_dpath)
os.makedirs(pseudo_in_dpath, exist_ok=True)
for noteid, example in zip(typo_noteids, examples):
    left_text = example[1]
    right_text = example[2]
    with open(os.path.join(pseudo_in_dpath, f'{noteid}_left.txt'), 'w', encoding='utf-8') as fd:
        fd.write(left_text)
    with open(os.path.join(pseudo_in_dpath, f'{noteid}_right.txt'), 'w', encoding='utf-8') as fd:
        fd.write(right_text)

# 4. Run the external pseudonymization script 
cmd = f"python {os.path.join(mimic_tools_dpath, 'main.py')} REPLACE " \
      f"--input-dir {os.path.join(os.getcwd(), pseudo_in_dpath)} " \
      f"--output-dir {os.path.join(os.getcwd(), pseudo_out_dpath)} " \
      f"--list-dir {os.path.join(mimic_tools_dpath, 'lists')}"
print("Starting pseudonymization with command:")
print(cmd)
os.system(cmd)

# 5. Re-read the pseudonymized notes and update the examples.
for nid, example in tqdm(zip(typo_noteids, examples), total=len(typo_noteids), desc="Processing pseudonymized notes"):
    left_file = os.path.join(pseudo_out_dpath, f'{nid}_left.txt')
    right_file = os.path.join(pseudo_out_dpath, f'{nid}_right.txt')
    if os.path.exists(left_file):
        with open(left_file, 'r', encoding='utf-8') as fd:
            note = fd.read()
            example[1] = process_note(note)
    if os.path.exists(right_file):
        with open(right_file, 'r', encoding='utf-8') as fd:
            note = fd.read()
            example[2] = process_note(note)
    # Convert the chosen target word to lowercase.
    example[0] = example[0].lower()

# 6. Generate corrupted versions (typos) of the chosen words.
print("Generating corrupted words...")
random.seed(1234)
correct_words = [ex[0] for ex in examples]
typo_words = [corrupt_word(w) if w != "" else "" for w in correct_words]
# Debug: Print a few examples of word corruption.
for i, (cw, tw) in enumerate(zip(correct_words, typo_words)):
    print(f"\t{cw} -> {tw}")
    if i == 5:
        break
print("Done generating corrupted words!")

Selected note IDs (first 10): [655360, 393217, 524290, 393216, 393224, 16, 393244, 1572897, 34, 786477]


Extracting word contexts: 100%|██████████| 20000/20000 [00:03<00:00, 5998.75it/s]


129 words have punctuation
Starting pseudonymization with command:
python ../scripts/mimic-tools/main.py REPLACE --input-dir C:\Users\chen5\Desktop\FinalP2\cim-misspelling-main\scripts\../data/mimic_synthetic/temp --output-dir C:\Users\chen5\Desktop\FinalP2\cim-misspelling-main\scripts\../data/mimic_synthetic/temp_pseudonym --list-dir ../scripts/mimic-tools/lists


Processing pseudonymized notes: 100%|██████████| 20000/20000 [12:31<00:00, 26.61it/s]


Generating corrupted words...
	bp -> dcbp
	tracking -> ztracking
	not -> not
	much -> uwmch
	to -> ot
	patient -> patient
Done generating corrupted words!


In [21]:
 # 7. Generate final dataset for BERT.
# Format (TSV): index, note_id, word, left, right, correct
# The context is trimmed to the last 128 tokens for left and first 128 tokens for right.
random.seed(1234)
data_indices = list(range(num_examples))
random.shuffle(data_indices)
val_idx = sorted(data_indices[:num_val_examples])
test_idx = sorted(data_indices[num_val_examples:])

with open(all_output_fpath, 'w', encoding='utf-8', newline='') as fd:
    writer = csv.writer(fd, delimiter='\t')
    writer.writerow(['index', 'note_id', 'word', 'left', 'right', 'correct'])
    for i in range(num_examples):
        nid = typo_noteids[i]
        correct, left, right = examples[i]
        typo = typo_words[i]
        left_context = ' '.join(left.split()[-128:])
        right_context = ' '.join(right.split()[:128])
        writer.writerow([i, nid, typo, left_context, right_context, correct])

print("Dataset generation complete! Output saved to:", all_output_fpath)

Dataset generation complete! Output saved to: ../data/mimic_synthetic/all.tsv
