# Day 23 – Transcript Repair for TextGrid Imports

Convert the problematic story transcripts into standard Praat TextGrid files so downstream
notebooks stop skipping them. This workflow normalizes headers, resolves floating-point
overlaps, and re-serializes the legacy "Praat chronological" format into the short
TextGrid representation used elsewhere.

In [None]:

import os
import sys
import warnings
from pathlib import Path
from typing import Dict, Iterable, List, Optional, Sequence, Tuple

import numpy as np
import pandas as pd

project_root = Path('/flash/PaoU/seann/fmri-edm-ccm')
project_root.mkdir(parents=True, exist_ok=True)
os.chdir(project_root)

sys.path.append(str(project_root))
sys.path.append('/flash/PaoU/seann/pyEDM/src')
sys.path.append('/flash/PaoU/seann/MDE-main/src')

try:
    from IPython.display import display
except Exception:
    def display(obj):  # type: ignore
        print(obj)

try:
    import textgrid  # type: ignore
except Exception as exc:
    raise ImportError('textgrid package is required to repair transcripts.') from exc

from src.utils import load_yaml

np.random.seed(42)
pd.options.display.max_columns = 80

cfg = load_yaml('configs/demo.yaml')
paths = dict(cfg.get('paths', {}) or {})
data_root = Path(paths.get('data_root', '/bucket/PaoU/seann/openneuro/ds003020'))

DEFAULT_TARGET_STORIES: Sequence[str] = [
    'exorcism',
    'food',
    'haveyoumethimyet',
    'legacy',
    'theshower',
]

candidate_roots = [
    data_root / 'derivative' / 'TextGrids',
    data_root / 'derivatives' / 'TextGrids',
    data_root / 'stimuli',
    data_root / 'annotations',
]

candidate_roots = [path for path in candidate_roots if path.exists()]
print(f'Data root: {data_root}')
print('TextGrid roots to inspect:')
for root in candidate_roots:
    print('  -', root)


In [None]:

# --- Controls --------------------------------------------------------------------
TARGET_STORIES: Sequence[str] = DEFAULT_TARGET_STORIES  # set to [] to process everything found
OUTPUT_ORIGINAL_DIR = project_root / 'misc' / 'textgrids_original'
OUTPUT_DIR = project_root / 'misc' / 'textgrids'
DRY_RUN = False  # True -> skip writing sanitized files
EPSILON = 1e-8  # tolerance for clamping neighbouring interval boundaries

print(f'Saving sanitized transcripts under: {OUTPUT_DIR}')
print(f'Backing up original transcripts under: {OUTPUT_ORIGINAL_DIR}')
if TARGET_STORIES:
    print('Restricting to stories:', ', '.join(sorted(TARGET_STORIES)))
else:
    print('Processing every TextGrid discovered in the search roots.')



In [None]:

import shutil

if not candidate_roots:
    raise FileNotFoundError('No TextGrid directories found. Check data_root or candidate paths.')

SOURCE_TEXTGRID_ROOT = candidate_roots[0]
print(f'Using source TextGrid directory: {SOURCE_TEXTGRID_ROOT}')

OUTPUT_ORIGINAL_DIR.mkdir(parents=True, exist_ok=True)
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

if not any(OUTPUT_ORIGINAL_DIR.iterdir()):
    print('Creating pristine backup at textgrids_original ...')
    shutil.copytree(SOURCE_TEXTGRID_ROOT, OUTPUT_ORIGINAL_DIR, symlinks=False, dirs_exist_ok=True)
else:
    print('Backup directory already populated; leaving as-is.')

print('Refreshing working copy at textgrids ...')
shutil.copytree(SOURCE_TEXTGRID_ROOT, OUTPUT_DIR, symlinks=False, dirs_exist_ok=True)



In [None]:

import math
import re
import shlex
import tempfile
from dataclasses import dataclass
from typing import Any

HEADER_LINES = [
    'File type = "ooTextFile"\n',
    'Object class = "TextGrid"\n',
    '\n',
]

FLOAT_LINE = re.compile(r'^(?P<prefix>\s*(?:xmin|xmax)\s*=\s*)(?P<value>[-+]?\d*\.?\d+(?:[eE][-+]?\d+)?)?(?P<suffix>\s*)$')

def _format_number(original: str, value: float) -> str:
    if any(ch in original for ch in 'eE'):
        return f'{value:.12g}'
    if '.' in original:
        decimals = len(original.split('.')[-1])
        return f'{value:.{decimals}f}'
    return str(int(round(value)))

def _update_float_line(line: str, new_value: float) -> Tuple[str, bool]:
    base = line[:-1] if line.endswith('\n') else line
    match = FLOAT_LINE.match(base)
    if not match:
        return line, False
    original_value = match.group('value')
    if original_value is None:
        return line, False
    formatted = _format_number(original_value, new_value)
    updated = f"{match.group('prefix')}{formatted}{match.group('suffix')}"
    if line.endswith('\n'):
        updated += '\n'
    return updated, formatted != original_value

def ensure_header(lines: List[str]) -> Tuple[List[str], bool]:
    for ln in lines[:5]:
        if ln.strip().startswith('File type ='):
            return lines, False
    return HEADER_LINES + lines, True

def fix_overlaps(lines: List[str], tolerance: float = EPSILON) -> Tuple[List[str], int]:
    last_xmax_by_tier: Dict[str, float] = {}
    current_tier: Optional[str] = None
    current_xmin: Optional[float] = None
    inside_intervals = False
    fixes = 0

    for idx, line in enumerate(lines):
        stripped = line.strip()
        if stripped.startswith('name ='):
            parts = stripped.split('"')
            current_tier = parts[1] if len(parts) > 1 else None
            if current_tier and current_tier not in last_xmax_by_tier:
                last_xmax_by_tier[current_tier] = float('-inf')
            current_xmin = None
            inside_intervals = False
        elif stripped.startswith('intervals ['):
            inside_intervals = True
            current_xmin = None
        elif stripped.startswith('points ['):
            inside_intervals = False
            current_xmin = None
        elif inside_intervals and stripped.startswith('xmin ='):
            parts = stripped.split('=')
            if len(parts) < 2:
                continue
            try:
                value = float(parts[1].split()[0])
            except Exception:
                continue
            last_xmax = None
            if current_tier is not None:
                last = last_xmax_by_tier.get(current_tier)
                if last is not None and last != float('-inf'):
                    last_xmax = last
            if last_xmax is not None and value < last_xmax - tolerance:
                new_line, changed = _update_float_line(line, last_xmax)
                if changed:
                    lines[idx] = new_line
                    value = last_xmax
                    fixes += 1
            current_xmin = value
        elif inside_intervals and stripped.startswith('xmax ='):
            parts = stripped.split('=')
            if len(parts) < 2:
                continue
            try:
                value = float(parts[1].split()[0])
            except Exception:
                continue
            if current_xmin is not None and value < current_xmin - tolerance:
                new_line, changed = _update_float_line(line, current_xmin)
                if changed:
                    lines[idx] = new_line
                    value = current_xmin
                    fixes += 1
            if current_tier is not None:
                last_xmax_by_tier[current_tier] = value
            current_xmin = None
        elif stripped.startswith('item'):
            inside_intervals = False
            current_xmin = None
    return lines, fixes

def convert_chronological(lines: List[str]) -> Tuple[List[str], bool]:
    if not lines or not lines[0].strip().startswith('"Praat chronological TextGrid text file"'):
        return lines, False
    domain_tokens = shlex.split(lines[1])
    if len(domain_tokens) < 2:
        return lines, False
    xmin_token, xmax_token = domain_tokens[0], domain_tokens[1]
    tier_line = shlex.split(lines[2])
    if not tier_line:
        return lines, False
    tier_count = int(tier_line[0])
    cursor = 3
    tier_defs: List[Dict[str, str]] = []
    for _ in range(tier_count):
        if cursor >= len(lines):
            break
        tokens = shlex.split(lines[cursor])
        cursor += 1
        if len(tokens) < 4:
            continue
        tier_defs.append({
            'class': tokens[0],
            'name': tokens[1],
            'xmin': tokens[2],
            'xmax': tokens[3],
        })
    tier_defs = tier_defs[:tier_count]
    intervals_by_tier: List[List[Tuple[str, str, str]]] = [[] for _ in range(len(tier_defs))]
    while cursor < len(lines):
        row = lines[cursor].strip()
        cursor += 1
        if not row:
            continue
        tokens = shlex.split(row)
        if len(tokens) < 3:
            continue
        tier_idx = int(tokens[0])
        start_token, end_token = tokens[1], tokens[2]
        label = ''
        if cursor < len(lines):
            label_line = lines[cursor].strip()
            cursor += 1
            if label_line.startswith('"') and label_line.endswith('"'):
                label = label_line[1:-1]
            else:
                label = label_line
        if 1 <= tier_idx <= len(intervals_by_tier):
            intervals_by_tier[tier_idx - 1].append((start_token, end_token, label))
    out: List[str] = []
    out.extend(HEADER_LINES)
    out.append(f'xmin = {xmin_token} \n')
    out.append(f'xmax = {xmax_token} \n')
    out.append('tiers? <exists> \n')
    out.append(f'size = {len(tier_defs)} \n')
    out.append('item []: \n')
    for idx, tier in enumerate(tier_defs, start=1):
        intervals = intervals_by_tier[idx - 1]
        out.append(f'    item [{idx}]:\n')
        out.append(f'        class = "{tier["class"]}" \n')
        out.append(f'        name = "{tier["name"]}" \n')
        out.append(f'        xmin = {tier["xmin"]} \n')
        out.append(f'        xmax = {tier["xmax"]} \n')
        out.append(f'        intervals: size = {len(intervals)} \n')
        for jdx, (start, end, label) in enumerate(intervals, start=1):
            safe_label = label.replace('"', '""')
            out.append(f'        intervals [{jdx}]:\n')
            out.append(f'            xmin = {start} \n')
            out.append(f'            xmax = {end} \n')
            out.append(f'            text = "{safe_label}" \n')
    return out, True

def sanitize_textgrid_text(text: str, tolerance: float = EPSILON) -> Tuple[str, Dict[str, int]]:
    lines = text.splitlines(keepends=True)
    lines, converted = convert_chronological(lines)
    lines, header_added = ensure_header(lines)
    lines, overlap_fixes = fix_overlaps(lines, tolerance=tolerance)
    content = ''.join(lines)
    metrics = {
        'converted': int(converted),
        'header_added': int(header_added),
        'overlap_fixes': overlap_fixes,
    }
    return content, metrics

def validate_textgrid(content: str) -> Tuple[bool, Optional[str]]:
    with tempfile.NamedTemporaryFile('w', suffix='.TextGrid', delete=False) as tmp:
        tmp.write(content)
        tmp_path = Path(tmp.name)
    try:
        textgrid.TextGrid.fromFile(str(tmp_path))
        return True, None
    except Exception as exc:
        return False, str(exc)
    finally:
        tmp_path.unlink(missing_ok=True)


In [None]:

from collections import defaultdict

OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

targets = {story.lower() for story in TARGET_STORIES} if TARGET_STORIES else None
processed: Dict[str, Dict[str, Any]] = {}
records: List[Dict[str, Any]] = []

source_root = SOURCE_TEXTGRID_ROOT
for path in sorted(source_root.glob('*.TextGrid')):
    story = path.stem.lower()
    if targets and story not in targets:
        continue
    try:
        original_text = path.read_text(encoding='utf-8')
    except UnicodeDecodeError:
        original_text = path.read_text(encoding='utf-8', errors='ignore')
    sanitized_text, metrics = sanitize_textgrid_text(original_text, tolerance=EPSILON)
    parse_ok, parse_error = validate_textgrid(sanitized_text)
    dest_path = OUTPUT_DIR / path.name
    changed = sanitized_text != original_text
    if not DRY_RUN:
        dest_path.parent.mkdir(parents=True, exist_ok=True)
        if dest_path.exists():
            try:
                dest_path.unlink()
            except PermissionError:
                dest_path.chmod(0o666)
                dest_path.unlink()
        dest_path.write_text(sanitized_text)
    record = {
        'story': path.stem,
        'source': str(path),
        'destination': str(dest_path),
        'converted': bool(metrics['converted']),
        'header_added': bool(metrics['header_added']),
        'overlap_fixes': metrics['overlap_fixes'],
        'changed': changed,
        'parse_ok': parse_ok,
        'parse_error': parse_error,
    }
    records.append(record)
    processed[story] = record

results_df = pd.DataFrame(records).sort_values(['story', 'source']).reset_index(drop=True)
print(f'Repaired {len(results_df)} transcript(s).')
display(results_df)



In [None]:

# Check original TextGrids in-place to reproduce prior failures
from src.decoding import load_transcript_words

original_failures: List[Tuple[str, str, str]] = []
original_successes: List[str] = []
subjects_to_test = [f'UTS{idx:02d}' for idx in range(1, 10)]
stories_to_test = [story for story in (TARGET_STORIES or processed.keys())]
for sub in subjects_to_test:
    for story in stories_to_test:
        try:
            events = load_transcript_words(paths, sub, story)
        except Exception as exc:
            original_failures.append((sub, story, str(exc)))
        else:
            original_successes.append(f'{sub}/{story} -> {len(events)} words')

print('Original transcript load check (raw data root)')
if original_successes:
    for line in original_successes:
        print('  OK:', line)
if original_failures:
    print('Failures:')
    for sub, story, err in original_failures:
        print(f'  {sub}/{story}: {err}')
else:
    print('No failures encountered using original files.')



In [None]:

# Verify a few transcripts round-trip with the sanitized copies in place
from src.decoding import load_transcript_words

paths_override = dict(paths)
paths_override['transcripts'] = str(OUTPUT_DIR)
pairs_to_check = [(sub, story) for sub in ['UTS01'] for story in (TARGET_STORIES or processed.keys())]
successes: List[str] = []
failures: List[Tuple[str, str, str]] = []
for sub, story in pairs_to_check:
    try:
        events = load_transcript_words(paths_override, sub, story)
    except Exception as exc:
        failures.append((sub, story, str(exc)))
    else:
        successes.append(f'{sub}/{story} -> {len(events)} words')

print("Verification using paths['transcripts'] = OUTPUT_DIR")
for line in successes:
    print('  OK:', line)
if failures:
    print('Failures:')
    for sub, story, err in failures:
        print(f'  {sub}/{story}: {err}')
else:
    print('All requested transcripts parsed successfully.')
