In [None]:
cd ../..

In [None]:
import ast
import pandas as pd

from tqdm import tqdm
from typing import List, Optional, Any
from xml.etree import ElementTree as ET

# Settings

In [None]:
DATA_FILE = "data/SB_publication_PMC_with_xml.parquet"
OUTPUT_FILE = "data/SB_publication_PMC_disciplines.csv"

# Read data

In [None]:
df = pd.read_parquet(DATA_FILE)
df.head()

# Extract disciplines

In [None]:
def disciplines_to_dataframe(xml_bytes: bytes,
                             article_id: Optional[Any] = None,
                             include_keywords_fallback: bool = True) -> pd.DataFrame:
    """
    Return a DataFrame with columns: article, categoria, level_categoria, parent_category.
    Tries, in order:
      1) subj-group[@subj-group-type in {"Discipline-v2","Discipline","discipline"}]
      2) any subj-group under <article-categories> (e.g., 'heading' -> 'Review')
      3) (optional) <kwd-group><kwd> as level-1 categories if nothing else is found.
    Also robust to default namespaces and 'bytes-literal' wrappers like b'...'.
    """
    # Handle files copied as Python bytes-literals: b'<?xml ...'
    if xml_bytes.startswith(b"b'") or xml_bytes.startswith(b'b"'):
        try:
            xml_bytes = ast.literal_eval(xml_bytes.decode("utf-8", errors="ignore"))
        except Exception:
            pass

    def local(tag: str) -> str:
        return tag.split('}')[-1] if '}' in tag else tag

    def subjects(el: ET.Element) -> List[str]:
        vals = []
        for s in el:
            if local(s.tag) == "subject":
                t = "".join(s.itertext()).strip()
                if t:
                    vals.append(t)
        return vals

    def child_groups(el: ET.Element) -> List[ET.Element]:
        return [c for c in el if local(c.tag) == "subj-group"]

    def add_row(path: List[str]):
        rows.append({
            "discipline": path[-1],
            "level_discipline": len(path),
            "parent_discipline": path[-2] if len(path) > 1 else None,
        })

    def walk(group: ET.Element, path: List[str]):
        subs = subjects(group)
        kids = child_groups(group)
        if subs:
            for s in subs:
                newp = path + [s]
                add_row(newp)
                for k in kids:
                    walk(k, newp)
        else:
            for k in kids:
                walk(k, path)

    root = ET.fromstring(xml_bytes)
    rows: List[dict] = []

    # 1) Strict discipline groups
    disc_types = {"Discipline-v2", "Discipline", "discipline"}
    disc_groups = [el for el in root.iter() if local(el.tag) == "subj-group" and el.get("subj-group-type") in disc_types]

    if disc_groups:
        # get top-level (exclude nested)
        disc_set = set(disc_groups)
        nested = set()
        for g in disc_groups:
            for d in g.iter():
                if d in disc_set and d is not g:
                    nested.add(d)
        tops = [g for g in disc_groups if g not in nested]
        for g in tops:
            walk(g, [])
    else:
        # 2) Any article-categories hierarchy
        ac_groups = []
        for ac in root.iter():
            if local(ac.tag) == "article-categories":
                for el in ac.iter():
                    if local(el.tag) == "subj-group":
                        ac_groups.append(el)
        if ac_groups:
            ac_set = set(ac_groups)
            nested = set()
            for g in ac_groups:
                for d in g.iter():
                    if d in ac_set and d is not g:
                        nested.add(d)
            tops = [g for g in ac_groups if g not in nested]
            for g in tops:
                walk(g, [])

        # 3) Fallback to keywords if still empty
        if not rows and include_keywords_fallback:
            for kg in root.iter():
                if local(kg.tag) == "kwd-group":
                    for kw in kg:
                        if local(kw.tag) == "kwd":
                            t = "".join(kw.itertext()).strip()
                            if t:
                                add_row([t])

    df = pd.DataFrame(rows, columns=["discipline", "level_discipline", "parent_discipline"])
    if not df.empty:
        df = df.drop_duplicates()
    return df

In [None]:
all_disciplines = []
for _, row in tqdm(df.iterrows(), total=len(df)):
    df_pmc = disciplines_to_dataframe(row["xml"])
    df_pmc["pmc"] = row["pmc"]
    df_pmc = df_pmc[["pmc", "discipline", "level_discipline", "parent_discipline"]]
    all_disciplines.append(df_pmc)
all_disciplines = pd.concat(all_disciplines).reset_index(drop=True)
all_disciplines.head()

# Save file

In [None]:
all_disciplines.to_csv(OUTPUT_FILE, index=False, encoding="utf-8", sep="|")