In [4]:
import json
from pathlib import Path
from urllib.error import URLError
from urllib.request import urlopen

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from taln.taln_aln import norm_text, tokenize


def nd(arr):
    return np.asarray(arr).reshape(-1)


def yex(ax):
    lims = [
        np.min([ax.get_xlim(), ax.get_ylim()]),
        np.max([ax.get_xlim(), ax.get_ylim()]),
    ]

    ax.plot(lims, lims, c="k", alpha=0.75, zorder=0)
    ax.set(**{"aspect": "equal", "xlim": lims, "ylim": lims})
    return ax


fsize = 15
plt.rcParams.update({"font.size": fsize})
%config InlineBackend.figure_format = "retina"

In [5]:
# SQuAD v2.0
# - explorer: https://rajpurkar.github.io/SQuAD-explorer/
# - train: https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v2.0.json
# - dev: https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v2.0.json

DATA_DIR = (Path("..").resolve() / "data")
SQUAD_URLS = {
    "train": "https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v2.0.json",
    "dev": "https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v2.0.json",
}


def load_squad_json(split, data_dir=DATA_DIR):
    split = split.lower()
    if split not in SQUAD_URLS:
        raise ValueError(f"Unknown split: {split}")

    local_path = Path(data_dir) / f"{split}-v2.0.json"
    if local_path.exists():
        with open(local_path, "r", encoding="utf-8") as f:
            return json.load(f)

    try:
        with urlopen(SQUAD_URLS[split]) as resp:
            return json.loads(resp.read().decode("utf-8"))
    except URLError as e:
        raise URLError(
            f"Could not download SQuAD split '{split}'. "
            f"Either place it at {local_path} or ensure network access."
        ) from e


def squad_to_records(squad_json):
    records = []

    for article in squad_json["data"]:
        title = article.get("title", "")
        for paragraph in article["paragraphs"]:
            source = paragraph["context"]
            for qa in paragraph["qas"]:
                if qa.get("is_impossible", False):
                    continue

                question = qa.get("question", "")
                question_id = qa.get("id", "")

                for ans in qa.get("answers", []):
                    records.append(
                        {
                            "title": title,
                            "source": source,
                            "target": ans["text"],
                            "idx_start": ans["answer_start"],
                            "question": question,
                            "question_id": question_id,
                            "is_impossible": False,
                            "answer_type": "answer",
                        }
                    )

    return records


def keep_naive_matches(records):
    kept = []

    for r in records:
        source = r["source"]
        target = r["target"]
        idx_start = r["idx_start"]

        if not isinstance(source, str) or not isinstance(target, str):
            continue
        if not isinstance(idx_start, int) or idx_start < 0:
            continue

        if source.find(target) != idx_start:
            continue

        kept.append(r)

    return kept


def dedupe_records(records):
    seen = set()
    out = []

    for r in records:
        key = (r["source"], r["target"], r["idx_start"])
        if key in seen:
            continue
        seen.add(key)
        out.append(r)

    return out


def add_row_id(records):
    out = []

    for i, r in enumerate(records):
        rr = dict(r)
        rr["row_id"] = i
        out.append(rr)

    return out


def shift_records(records):
    out = []

    for r in records:
        rr = dict(r)
        idx_start = rr["idx_start"]
        source = rr["source"]

        rr["idx_start_orig"] = idx_start
        rr["target_orig"] = rr["target"]

        if idx_start > 0 and source[idx_start - 1] == " ":
            rr["target"] = " " + rr["target"]
            rr["idx_start"] = idx_start - 1

        out.append(rr)

    return out


def build_boat_frames(clean_records, shifted_records, min_target_ws_tokens=2):
    cols = ["title", "question", "question_id", "is_impossible", "answer_type"]
    clean_df = pd.DataFrame(clean_records)
    shifted_df = pd.DataFrame(shifted_records)

    shifted_df = shifted_df[shifted_df["target"].str.split().str.len() > min_target_ws_tokens]
    keep_ids = set(shifted_df["row_id"].tolist())
    clean_df = clean_df[clean_df["row_id"].isin(keep_ids)]

    tdf = (
        shifted_df.drop(columns=cols, errors="ignore")
        .groupby(["source", "target"])["idx_start"]
        .apply(set)
        .reset_index()
    )

    wdf = (
        clean_df.drop(columns=cols, errors="ignore")
        .groupby(["source", "target"])["idx_start"]
        .apply(set)
        .reset_index()
    )

    return wdf, tdf, clean_df.reset_index(drop=True), shifted_df.reset_index(drop=True)

In [6]:
def build_split(split):
    squad_json = load_squad_json(split)
    raw_records = squad_to_records(squad_json)

    n_raw = len(raw_records)

    clean_records = keep_naive_matches(raw_records)
    n_naive = len(clean_records)

    clean_records = dedupe_records(clean_records)
    n_dedup = len(clean_records)

    # token offsets are computed on normalized text; if normalization changes length,
    # character indices may no longer be comparable.
    def len_preserved(r):
        return len(norm_text(r["source"])) == len(r["source"]) and len(norm_text(r["target"])) == len(r["target"])

    clean_records = [r for r in clean_records if len_preserved(r)]
    n_len_ok = len(clean_records)

    if n_len_ok < n_dedup:
        print(f"{split}: dropped {n_dedup - n_len_ok} records where norm_text changes string length")

    clean_records = add_row_id(clean_records)
    shifted_records = shift_records(clean_records)

    n_shifted = sum(r["idx_start"] != r["idx_start_orig"] for r in shifted_records)

    wdf, tdf, clean_df, shifted_df = build_boat_frames(
        clean_records,
        shifted_records,
        min_target_ws_tokens=2,
    )

    return {
        "n_raw": n_raw,
        "n_naive": n_naive,
        "n_dedup": n_dedup,
        "n_len_ok": n_len_ok,
        "n_shifted": n_shifted,
        "clean_df": clean_df,
        "shifted_df": shifted_df,
        "wdf": wdf,
        "tdf": tdf,
    }


train = build_split("train")
dev = build_split("dev")

wdf_train, tdf_train = train["wdf"], train["tdf"]
wdf_dev, tdf_dev = dev["wdf"], dev["tdf"]

wdf_train.shape, tdf_train.shape, wdf_dev.shape, tdf_dev.shape

train: dropped 7705 records where norm_text changes string length
dev: dropped 794 records where norm_text changes string length


((30767, 3), (30767, 3), (4313, 3), (4313, 3))

In [7]:
tdf_train.head()

Unnamed: 0,source,target,idx_start
0,\nA gene is a locus (or region) of DNA that en...,The transmission of genes to an organism's of...,{135}
1,\nA gene is a locus (or region) of DNA that en...,a locus (or region) of DNA that encodes a fun...,{10}
2,\nA gene is a locus (or region) of DNA that en...,"blood type, risk for specific diseases, or th...",{479}
3,\nA gene is a locus (or region) of DNA that en...,eye colour or number of limbs,{422}
4,\nA gene is a locus (or region) of DNA that en...,polygenes (many different genes),{292}


In [8]:
def find_span_mismatches(df, source_col="source", target_col="target", idx_col="idx_start"):
    bad = []

    for i, r in df.iterrows():
        source = r[source_col]
        target = r[target_col]
        idx_start = r[idx_col]

        if source[idx_start : idx_start + len(target)] != target:
            bad.append(i)

    return bad


def check_shift_rule(shifted_df):
    shifted = shifted_df[shifted_df["idx_start"] != shifted_df["idx_start_orig"]]
    if shifted.shape[0] == 0:
        return True

    def row_ok(r):
        idx0 = r["idx_start_orig"]
        if idx0 <= 0:
            return False
        if r["source"][idx0 - 1] != " ":
            return False
        if r["idx_start"] != idx0 - 1:
            return False
        return r["target"] == " " + r["target_orig"]

    return shifted.apply(row_ok, axis=1).all()


def summarize_split(split_obj):
    clean_df = split_obj["clean_df"]
    shifted_df = split_obj["shifted_df"]

    clean_bad = find_span_mismatches(clean_df)
    shifted_bad = find_span_mismatches(shifted_df)

    return {
        "n_raw": split_obj["n_raw"],
        "n_naive": split_obj["n_naive"],
        "n_dedup": split_obj["n_dedup"],
        "n_len_ok": split_obj["n_len_ok"],
        "n_shifted": split_obj["n_shifted"],
        "n_clean_rows": clean_df.shape[0],
        "n_shift_rows": shifted_df.shape[0],
        "n_clean_span_bad": len(clean_bad),
        "n_shift_span_bad": len(shifted_bad),
        "shift_rule_ok": bool(check_shift_rule(shifted_df)),
    }


checks = pd.DataFrame({"train": summarize_split(train), "dev": summarize_split(dev)}).T
checks

Unnamed: 0,n_raw,n_naive,n_dedup,n_len_ok,n_shifted,n_clean_rows,n_shift_rows,n_clean_span_bad,n_shift_span_bad,shift_rule_ok
train,86821,83916,81260,73555,68528,30767,30767,0,0,True
dev,20302,19291,9451,8657,8172,4313,4313,0,0,True


In [9]:
pd.Series(
    {
        "train_sources": tdf_train["source"].nunique(),
        "train_targets": tdf_train["target"].nunique(),
        "train_pairs": tdf_train.shape[0],
        "dev_sources": tdf_dev["source"].nunique(),
        "dev_targets": tdf_dev["target"].nunique(),
        "dev_pairs": tdf_dev.shape[0],
    }
)

train_sources    13680
train_targets    29751
train_pairs      30767
dev_sources       1021
dev_targets       4259
dev_pairs         4313
dtype: int64

In [10]:
def summarize_texts(texts):
    chars = texts.str.len()
    ws = texts.apply(lambda s: len(s.split()))
    toks = texts.apply(lambda s: len(tokenize(s)[0]))

    return {
        "Char/Sample": chars.mean(),
        "WS/Sample": ws.mean(),
        "Tok/Sample": toks.mean(),
        "N Samples": texts.shape[0],
    }


def boat_summary_table(tdf):
    sources = pd.Series(tdf["source"].unique())
    targets = pd.Series(tdf["target"].unique())

    table = pd.DataFrame(
        {
            "source": summarize_texts(sources),
            "target": summarize_texts(targets),
        }
    ).T

    return table[["N Samples", "Char/Sample", "WS/Sample", "Tok/Sample"]]

In [12]:
summary = boat_summary_table(pd.concat([tdf_train, tdf_dev]))

In [13]:
print(summary.to_latex(float_format="%.0f"))

\begin{tabular}{lrrrr}
\toprule
 & N Samples & Char/Sample & WS/Sample & Tok/Sample \\
\midrule
source & 14701 & 744 & 118 & 154 \\
target & 33897 & 39 & 6 & 8 \\
\bottomrule
\end{tabular}



In [39]:
tbl = pd.concat({"train": tbl_train, "dev": tbl_dev}, names=["split", "text"]).round(3)
tbl

Unnamed: 0_level_0,Unnamed: 1_level_0,N Samples,Char/Sample,WS/Sample,Tok/Sample
split,text,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
train,source,13680.0,739.211,117.053,153.64
train,target,29751.0,38.337,5.917,7.694
dev,source,1021.0,809.574,127.489,164.746
dev,target,4259.0,39.706,6.069,7.557


In [34]:
print(tbl_train.to_latex(float_format="%.0f"))

\begin{tabular}{lrrrr}
\toprule
 & N Samples & Char/Sample & WS/Sample & Tok/Sample \\
\midrule
source & 13680 & 739 & 117 & 154 \\
target & 29751 & 38 & 6 & 8 \\
\bottomrule
\end{tabular}



In [35]:
def to_jsonable_starts(df):
    out = df.copy()
    out["idx_start"] = out["idx_start"].apply(lambda s: sorted(list(s)))
    return out


def write_boat_files(data_dir=DATA_DIR):
    data_dir = Path(data_dir)
    data_dir.mkdir(parents=True, exist_ok=True)

    exports = {
        "boat_train.tdf.json": to_jsonable_starts(tdf_train),
        "boat_train.wdf.json": to_jsonable_starts(wdf_train),
        "boat_dev.tdf.json": to_jsonable_starts(tdf_dev),
        "boat_dev.wdf.json": to_jsonable_starts(wdf_dev),
    }

    for name, df in exports.items():
        path = data_dir / name
        with open(path, "w", encoding="utf-8") as f:
            json.dump(df.to_dict(orient="records"), f)

    return {k: str(Path(data_dir) / k) for k in exports}


written = write_boat_files()
written


{'boat_train.tdf.json': '/Users/sinabooeshaghi/projects/taln/data/boat_train.tdf.json',
 'boat_train.wdf.json': '/Users/sinabooeshaghi/projects/taln/data/boat_train.wdf.json',
 'boat_dev.tdf.json': '/Users/sinabooeshaghi/projects/taln/data/boat_dev.tdf.json',
 'boat_dev.wdf.json': '/Users/sinabooeshaghi/projects/taln/data/boat_dev.wdf.json'}

In [16]:
trn =  keep_naive_matches(squad_to_records(load_squad_json("train")))
dev =  keep_naive_matches(squad_to_records(load_squad_json("dev")))



In [None]:
df = pd.concat([pd.DataFrame(trn), pd.DataFrame(dev)])

Unnamed: 0,title,source,target,idx_start,question,question_id,is_impossible,answer_type
49370,Alsace,"""Alsatia"", the Latin form of Alsace's name, ha...","a lawless place"" or ""a place under no jurisdic...",119,What is the meaning of the name Aslatia in Eng...,5727acde3acd2414000de95d,False,answer


In [23]:
tgt = df[df.target.str.contains("a place under no jurisdiction")].iloc[0].target
src = df[df.target.str.contains("a place under no jurisdiction")].iloc[0].source

In [24]:
tgt in src

True

In [25]:
tgt

'a lawless place" or "a place under no jurisdiction'

In [26]:
src

'"Alsatia", the Latin form of Alsace\'s name, has long ago entered the English language with the specialized meaning of "a lawless place" or "a place under no jurisdiction" - since Alsace was conceived by English people to be such. It was used into the 20th century as a term for a ramshackle marketplace, "protected by ancient custom and the independence of their patrons". As of 2007, the word is still in use among the English and Australian judiciaries with the meaning of a place where the law cannot reach: "In setting up the Serious Organised Crime Agency, the state has set out to create an Alsatia - a region of executive action free of judicial oversight," Lord Justice Sedley in UMBS v SOCA 2007.'

In [27]:
df[df.target.str.contains("a place under no jurisdiction")].iloc[0]

title                                                       Alsace
source           "Alsatia", the Latin form of Alsace's name, ha...
target           a lawless place" or "a place under no jurisdic...
idx_start                                                      119
question         What is the meaning of the name Aslatia in Eng...
question_id                               5727acde3acd2414000de95d
is_impossible                                                False
answer_type                                                 answer
Name: 49370, dtype: object