In [None]:
cd ../..

In [None]:
import ast
import html
import pandas as pd
import re

from tqdm import tqdm
from typing import List, Optional
from xml.etree import ElementTree as ET

tqdm.pandas()

# Settings

In [None]:
DATA_FILE = "data/SB_publication_PMC_with_xml.parquet"
OUTPUT_FILE = "data/SB_publication_PMC_texts.parquet"

# Read data

In [None]:
df = pd.read_parquet(DATA_FILE)
df.head()

# Extract abstract

In [None]:
def get_abstract_text(
    xml_bytes: bytes,
    include_translated: bool = True,
    join_with: str = "\n\n"
) -> str | None:
    """
    Extracts abstract(s) from a JATS/PMC XML file as plain text.

    - Includes <abstract> and optionally <trans-abstract>.
    - Keeps real <title> tags (UPPERCASE), removes synthetic "ABSTRACT" headers.
    - Strips all leading/trailing newlines and spaces.
    - Ignores style tags (<italic>, <bold>, etc.) but keeps their text content.
    - Normalizes whitespace and decodes HTML entities.

    Args:
        xml_bytes (bytes): The XML content.
        include_translated (bool): Whether to include translated abstracts.
        join_with (str): Separator between multiple abstracts.

    Returns:
        str | None: Clean plain-text abstract(s), or None if not found.
    """
    # Handle weird serialized byte literals like b'<?xml ...'
    if isinstance(xml_bytes, (bytes, bytearray)):
        s = xml_bytes.decode("utf-8", errors="ignore")
        if (s.startswith("b'") or s.startswith('b"')) and s.endswith(("'", '"')):
            try:
                xml_bytes = ast.literal_eval(s)
            except Exception:
                pass
        else:
            xml_bytes = s
    if isinstance(xml_bytes, str):
        xml_str = xml_bytes
    else:
        xml_str = xml_bytes.decode("utf-8", errors="ignore")

    root = ET.fromstring(xml_str)

    def local(tag: str) -> str:
        return tag.split("}")[-1] if "}" in tag else tag

    def norm_text(t: str) -> str:
        """Normalize whitespace and decode HTML entities."""
        t = html.unescape(t or "")
        return " ".join(t.split())

    def get_all_text(el: Optional[ET.Element]) -> str:
        """Flatten element text (ignore tags, keep text)."""
        if el is None:
            return ""
        return norm_text("".join(el.itertext()))

    # Collect <abstract> and optionally <trans-abstract>
    candidates: List[ET.Element] = []
    for el in root.iter():
        lt = local(el.tag)
        if lt == "abstract" or (include_translated and lt == "trans-abstract"):
            candidates.append(el)

    if not candidates:
        return None

    out_chunks: List[str] = []

    for abs_el in candidates:
        # Real title of the abstract (if any)
        title_el = abs_el.find("./title")
        if title_el is not None and get_all_text(title_el):
            out_chunks.append(get_all_text(title_el).upper())

        # Handle structured abstracts (<sec>)
        secs = [s for s in abs_el.findall("./sec")]
        if secs:
            for sec in secs:
                sec_title = get_all_text(sec.find("./title"))
                if sec_title:
                    out_chunks.append(sec_title.upper())
                for p in sec.findall(".//p"):
                    txt = get_all_text(p)
                    if txt:
                        out_chunks.append(txt)
        else:
            # Simple abstract (just <p> or text)
            ps = abs_el.findall("./p")
            if ps:
                for p in ps:
                    txt = get_all_text(p)
                    if txt:
                        out_chunks.append(txt)
            else:
                txt = get_all_text(abs_el)
                if title_el is not None:
                    title_txt = get_all_text(title_el)
                    if txt.startswith(title_txt):
                        txt = txt[len(title_txt):].lstrip()
                if txt:
                    out_chunks.append(txt)

    # Clean up and join
    cleaned = [chunk.strip() for chunk in out_chunks if chunk.strip()]
    text = join_with.join(cleaned)

    # Remove any leading "ABSTRACT" (or "ABSTRACT 1") and stray newlines/spaces
    text = re.sub(r"^(ABSTRACT\s*\d*\s*)", "", text, flags=re.IGNORECASE).lstrip("\n").strip()

    return text if text else None

In [None]:
df["abstract"] = df["xml"].apply(get_abstract_text)
df["abstract"] = df["abstract"].fillna("")

In [None]:
df["len"] = df["abstract"].apply(len)
df["len"].hist()

# Extract plain text

In [None]:
from xml.etree import ElementTree as ET
from typing import List, Optional, Iterable
import ast
import html

def get_plain_body_text(xml_bytes: bytes, fallback_to_abstract: bool = True) -> str:
    """
    Extreu el cos de l'article com a text pla.
    - Títols de secció en MAJÚSCULES.
    - Es descarten etiquetes d'estil (italic, bold, etc.) però es conserva el text.
    - Normalitza espais en paràgrafs.
    - Salta figures/taules/ref-list.
    - Si no hi ha <body>, opcionalment fa fallback a <abstract> / <trans-abstract>.

    Args:
        xml_bytes: XML de l'article (bytes). Admet bytes amb prefix b'...' serialitzat.
        fallback_to_abstract: Si True i no hi ha <body>, retorna el(s) abstract(s).

    Returns:
        str: Text pla amb encapçalaments i paràgrafs. Pot retornar "" si no troba res.
    """
    # Alguns dumps arriben com a str dins d'uns bytes: b'<?xml ...</pmc-articleset>'
    # Intentem "desempaquetar-ho" de forma segura.
    if isinstance(xml_bytes, (bytes, bytearray)):
        try:
            s = xml_bytes.decode("utf-8", errors="ignore")
            if (s.startswith("b'") or s.startswith('b"')) and s.endswith(("'", '"')):
                xml_bytes = ast.literal_eval(s)
        except Exception:
            pass

    # Parse
    if isinstance(xml_bytes, str):
        xml_bytes = xml_bytes.encode("utf-8")
    root = ET.fromstring(xml_bytes)

    def local(tag: str) -> str:
        return tag.split("}")[-1] if "}" in tag else tag

    # Utils
    def _text(el: Optional[ET.Element]) -> str:
        if el is None:
            return ""
        raw = "".join(el.itertext())
        raw = html.unescape(raw)
        return " ".join(raw.split())

    def add_blank_line(lines: List[str]):
        if lines and lines[-1] != "":
            lines.append("")

    # ----------- BODY PATH -----------
    body = None
    # A vegades el body és directe; altres cops hi ha namespaces
    for cand in root.iter():
        if local(cand.tag) == "body":
            body = cand
            break

    lines: List[str] = []

    SKIP_TAGS = {
        "fig", "fig-group", "table", "table-wrap", "table-wrap-foot",
        "supplementary-material", "ref-list", "media", "graphic"
    }

    def handle_block(block: ET.Element):
        btag = local(block.tag)
        if btag == "list":
            for li in block.findall("./*"):
                if local(li.tag) != "list-item":
                    continue
                label = li.find("./label")
                label_text = _text(label)
                p_children = [c for c in li.findall("./*") if local(c.tag) == "p"]
                if p_children:
                    for p in p_children:
                        t = _text(p)
                        if t:
                            lines.append(f"{label_text} {t}".strip() if label_text else t)
                    lines.append("")
                else:
                    t = _text(li)
                    if t:
                        lines.append(f"{label_text} {t}".strip() if label_text else t)
                        lines.append("")
        elif btag == "def-list":
            for item in block.findall("./def-item"):
                term = _text(item.find("./term"))
                defs = item.findall("./def")
                if defs:
                    for d in defs:
                        t = _text(d)
                        if t:
                            lines.append(f"{term + ': ' if term else ''}{t}")
                    lines.append("")
                else:
                    t = _text(item)
                    if t:
                        lines.append(t)
                        lines.append("")
        else:
            t = _text(block)
            if t:
                lines.append(t)
                lines.append("")

    def walk(node: ET.Element):
        tag = local(node.tag)
        if tag in SKIP_TAGS:
            return

        if tag == "sec":
            title_el = node.find("./title")
            title = _text(title_el)
            if title:
                add_blank_line(lines)
                lines.append(title.upper())
                lines.append("")

            for child in node:
                ctag = local(child.tag)
                if ctag in SKIP_TAGS:
                    continue
                if ctag == "p":
                    t = _text(child)
                    if t:
                        lines.append(t)
                        lines.append("")
                elif ctag in {"list", "def-list", "boxed-text", "disp-quote", "speech", "statement"}:
                    handle_block(child)
                elif ctag == "sec":
                    walk(child)
            return

        # Fora d'una sec, processa blocs comuns
        if tag == "p":
            t = _text(node)
            if t:
                lines.append(t)
                lines.append("")
        elif tag in {"list", "def-list", "boxed-text", "disp-quote", "speech", "statement"}:
            handle_block(node)
        else:
            for child in node:
                walk(child)

    if body is not None:
        walk(body)
        # Neteja finals en blanc
        while lines and lines[-1] == "":
            lines.pop()
        return "\n".join(lines)

    # ----------- FALLBACK ABSTRACT -----------
    if fallback_to_abstract:
        abstracts: List[ET.Element] = []
        for el in root.iter():
            if local(el.tag) in {"abstract", "trans-abstract"}:
                abstracts.append(el)

        abs_lines: List[str] = []
        for idx, abs_el in enumerate(abstracts, start=1):
            # Títol de l'abstract si existeix
            title_el = abs_el.find("./title")
            title = _text(title_el)
            if title:
                add_blank_line(abs_lines)
                abs_lines.append(title.upper())
                abs_lines.append("")
            elif len(abstracts) > 1:
                add_blank_line(abs_lines)
                abs_lines.append(f"ABSTRACT {idx}")
                abs_lines.append("")
            else:
                add_blank_line(abs_lines)
                abs_lines.append("ABSTRACT")
                abs_lines.append("")

            # Paràgrafs dins de l'abstract (p, sec, llistes)
            def walk_abs(n: ET.Element):
                t = local(n.tag)
                if t in SKIP_TAGS:
                    return
                if t == "p":
                    txt = _text(n)
                    if txt:
                        abs_lines.append(txt)
                        abs_lines.append("")
                elif t in {"sec", "list", "def-list", "boxed-text", "disp-quote", "speech", "statement"}:
                    # reutilitzem la lògica general
                    if t == "sec":
                        st = _text(n.find("./title"))
                        if st:
                            abs_lines.append(st.upper())
                            abs_lines.append("")
                        for c in n:
                            walk_abs(c)
                    elif t == "list" or t == "def-list":
                        # petita crida auxiliar
                        before_len = len(lines)
                        handle_block(n)  # escriu a 'lines', no a 'abs_lines', així que adaptem:
                        # Movem el que s'ha escrit a 'lines' cap a 'abs_lines'
                        pass
                    else:
                        txt = _text(n)
                        if txt:
                            abs_lines.append(txt)
                            abs_lines.append("")
                else:
                    for c in n:
                        walk_abs(c)

            # Implementació simple: iterem paràgrafs i sub-seccions principals
            for child in list(abs_el):
                if local(child.tag) == "p":
                    txt = _text(child)
                    if txt:
                        abs_lines.append(txt)
                        abs_lines.append("")
                elif local(child.tag) == "sec":
                    st = _text(child.find("./title"))
                    if st:
                        abs_lines.append(st.upper())
                        abs_lines.append("")
                    for p in child.findall(".//p"):
                        txt = _text(p)
                        if txt:
                            abs_lines.append(txt)
                            abs_lines.append("")
                elif local(child.tag) in {"list", "def-list"}:
                    # Tractem llista com paràgrafs plans
                    tmp_before = len(lines)
                    handle_block(child)  # escriu a 'lines'; capturem la sortida i la passem
                    new_chunk = lines[tmp_before:]
                    if new_chunk:
                        abs_lines.extend(new_chunk)
                        # netegem el buffer global 'lines' afegit accidentalment
                        del lines[tmp_before:]

        # Neteja
        while abs_lines and abs_lines[-1] == "":
            abs_lines.pop()
        return "\n".join(abs_lines)

    # Si no hi ha body ni abstract (o fallback desactivat)
    return ""

In [None]:
df["text"] = df["xml"].progress_apply(get_plain_body_text)

In [None]:
df.head()

In [None]:
df["len"] = df["text"].apply(len)
df["len"].hist()

# Save file

In [None]:
df = df.drop(columns=["xml", "len"])

In [None]:
df.to_parquet(OUTPUT_FILE)