# 📈 Inference Pipeline (No GROBID)

This notebook:
- Loads your trained classifier (`saved_model`)
- Reads PDFs directly via PyPDF2
- Predicts `type`
- Extracts `dataset_id` using regex
- Produces `submission.csv`

**Note:** No GROBID needed.

In [1]:
# 📦 Install PyPDF2 if needed
# !pip install /kaggle/input/pypdf2-mdc/pypdf2-3.0.1-py3-none-any.whl

In [2]:
import os

In [3]:
! mkdir cache data

In [4]:
os.environ["XDG_CACHE_HOME"] = os.path.join(os.getcwd(), "cache")
os.environ["XDG_DATA_HOME"] = os.path.join(os.getcwd(), "data")

In [5]:
! cp -r /kaggle/input/datalab-marker/kaggle/working/datalab ./cache

In [6]:
! cp -r /kaggle/input/pip-datalab/kaggle/working/cache/pip ./cache

In [7]:
! mkdir -p /usr/local/lib/python3.11/dist-packages/static/fonts

In [8]:
! cp /kaggle/input/marker-font/GoNotoCurrent-Regular.ttf /usr/local/lib/python3.11/dist-packages/static/fonts/GoNotoCurrent-Regular.ttf

In [None]:
! pip install /kaggle/working/cache/pip/wheels/8d/ac/a9/4e6dd2d86235ea3da1c286279118c49e931f77cfb33e9b1af5/EbookLib-0.18-py3-none-any.whl

In [None]:
! sh /kaggle/input/mdc-marker-pdf-reqs/install_requirements.sh

In [None]:
# 📚 Imports
import re
# import PyPDF2
from marker.converters.pdf import PdfConverter
from marker.models import create_model_dict
from marker.output import text_from_rendered
import pandas as pd
from glob import glob
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch

In [None]:
# 🧠 Load model and tokenizer
model_path = "/kaggle/input/makedatacount-mixed-train/saved_model_dual_text"
token_path = "/kaggle/input/makedatacount-mixed-train/saved_model_dual_text"

tokenizer = AutoTokenizer.from_pretrained(token_path)
model = AutoModelForSequenceClassification.from_pretrained(model_path)
model.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

In [59]:
# 🧩 Utility functions
REPO_KEYWORDS = ['zenodo', 'dryad', 'figshare', 'pangaea', 'tcia']

def is_repository_doi(doi_url):
    return any(k in doi_url for k in REPO_KEYWORDS)

def clean_doi(doi):
    return doi.rstrip(')]>.,;')

def clean_pdf_text(text):
    return (
        text.replace('](', '] (')
            .replace(')](', ') (')
            .replace(')] (', ') (')
            .replace('] (', '] (')
    )

def standardize_doi(doi):
    doi = str(doi).strip()
    if doi.startswith("http"):
        return doi.lower()
    if doi.startswith("doi:"):
        return "https://doi.org/" + doi[4:].lower()
    if doi.startswith("10."):
        return "https://doi.org/" + doi.lower()
    return doi.lower()
import re

def extract_candidate_dois(text):
    pattern_url = r'https?://[^\s\)<>\]]+'
    return list(set([
        u.rstrip('.,;)]>') for u in re.findall(pattern_url, text)
    ]))

def extract_all_doi_candidates(text):
    text = text.replace('](', '] (').replace(')](', ') (')

    patterns = [
        r'https?://doi\.org/[^\s\)<>\]]+',
        r'http://dx\.doi\.org/[^\s\)<>\]]+',
        r'doi:\s*10\.\d{4,9}/[^\s\)<>\]]+',
        r'\b10\.\d{4,9}/[^\s\)<>\]]+'
    ]

    matches = []
    for pat in patterns:
        matches.extend(re.findall(pat, text, flags=re.IGNORECASE))

    # Standardize all DOIs
    return [standardize_doi(m) for m in matches]

def clean_and_deduplicate_dois(dois, min_suffix_len=8):
    cleaned = set()
    for doi in dois:
        doi = doi.strip().rstrip('.,;)]>\\')
        if doi.startswith("doi:"):
            doi = doi[4:].strip()
        if doi.startswith("http://dx.doi.org/"):
            doi = doi.replace("http://dx.doi.org/", "https://doi.org/")
        if not doi.startswith("https://doi.org/10."):
            doi = "https://doi.org/" + doi if doi.startswith("10.") else None
        if not doi:
            continue
        try:
            prefix, suffix = doi.replace("https://doi.org/", "").split("/", 1)
            if not suffix or len(suffix) < min_suffix_len:
                continue
            cleaned.add("https://doi.org/" + prefix + "/" + suffix)
        except ValueError:
            continue
    return sorted(cleaned)
    
def extract_dataset_dois(text):
    text = text.replace('](', '] (').replace(')](', ') (')

    url_pattern = r'https?://[^\s\)<>\]]+'
    urls = re.findall(url_pattern, text)

    doi_pattern = r'\b10\.\d{4,9}/[^\s\)<>\]]+'
    bare_dois = re.findall(doi_pattern, text)

    candidates = []

    for u in urls:
        clean_u = u.rstrip('.,;)]>')
        if 'doi.org/10.' in clean_u:
            candidates.append(clean_u)

    for d in bare_dois:
        candidates.append(standardize_doi(d))

    # Remove invalid or incomplete
    def is_valid_doi(doi_url):
        if not doi_url.startswith("https://doi.org/10."):
            return False
        parts = doi_url.replace("https://doi.org/", "").split("/", 1)
        if len(parts) != 2:
            return False
        suffix = parts[1]
        if suffix == "" or suffix == ".":
            return False
        return True

    candidates = [c for c in candidates if is_valid_doi(c)]
    candidates = [c for c in candidates if any(k in c for k in REPO_KEYWORDS)]

    return sorted(set(candidates))

def find_dataset_dois(text):
    # Pre-sanitize
    text = text.replace('](', '] (').replace(')](', ') (')
    
    # Extract all candidate DOIs with strict filtering
    raw_dois = extract_dataset_dois(text)
    clean_dois = clean_and_deduplicate_dois(raw_dois)

    # Lowercase for searching
    text_lc = text.lower()
    
    # Define dataset-related keywords
    keywords = [
        "data availability", "data are available", "data is available",
        "data can be found", "archived in", "deposited in", "repository",
        "zenodo", "figshare", "dryad", "pangaea", "dataset", "available at", "tcia"
    ]
    
    # Keep only DOIs near relevant context
    dataset_dois = []
    for doi in clean_dois:
        idx = text_lc.find(doi.lower())
        if idx == -1:
            continue
        window = text_lc[max(0, idx - 200): idx + 200]
        if any(kw in window for kw in keywords):
            dataset_dois.append(doi)
    
    return dataset_dois

def clean_and_deduplicate_dois(dois, min_suffix_len=8):
    cleaned = set()
    for doi in dois:
        doi = doi.strip().rstrip('.,;)]>\\')
        if doi.startswith("doi:"):
            doi = doi[4:].strip()
        if doi.startswith("http://dx.doi.org/"):
            doi = doi.replace("http://dx.doi.org/", "https://doi.org/")
        if not doi.startswith("https://doi.org/10."):
            doi = "https://doi.org/" + doi if doi.startswith("10.") else None
        if not doi:
            continue
        try:
            prefix, suffix = doi.replace("https://doi.org/", "").split("/", 1)
            if not suffix or len(suffix) < min_suffix_len:
                continue
            cleaned.add("https://doi.org/" + prefix + "/" + suffix)
        except ValueError:
            continue
    return sorted(cleaned)


def find_accession_ids_in_text(text):
    patterns = [
        r'\b(GSE\d+)\b',
        r'\b(PRJ[ENAD]\d+)\b',
        r'\b(SRP\d+)\b',
        r'\b(E-[A-Z]+-\d+)\b',
        r'\b(pdb\s[\d\w]+)\b',
        r'\b(CHEMBL\d+)\b',
    ]
    matches = []
    for pat in patterns:
        matches.extend(re.findall(pat, text, re.IGNORECASE))
    return [m.lower().replace(" ", "") for m in matches]

def extract_article_id(filename):
    return filename[:-4].replace("_", "/")

In [14]:
import os


os.environ["DETECTOR_MODEL_CHECKPOINT"] = "/kaggle/working/cache/datalab/modelstext_detection/2025_05_07"
os.environ["RECOGNITION_MODEL_CHECKPOINT"] = "/kaggle/working/cache/datalab/modelstext_recognition/2025_05_16"
os.environ["LAYOUT_MODEL_CHECKPOINT"] = "/kaggle/working/cache/datalab/modelslayout/2025_02_18"
os.environ["OCR_ERROR_MODEL_CHECKPOINT"] = "/kaggle/working/cache/datalab/modelsocr_error_detection/2025_02_18"
os.environ["TABLE_REC_MODEL_CHECKPOINT"] = "/kaggle/working/cache/datalab/modelstable_recognition/2025_02_18"

In [17]:
# 📂 Process PDFs
pdf_dir = "/kaggle/input/make-data-count-finding-data-references/test/PDF"
pdf_files = sorted(glob(os.path.join(pdf_dir, "*.pdf")))
converter = PdfConverter(
    artifact_dict=create_model_dict(device='cuda:0'),
)
rows = []
def extract_text_from_pdf(file_path):
    rendered = converter(
        file_path
    )

    text, _, _ = text_from_rendered(rendered)

    return text

In [None]:
for pdf_path in tqdm(pdf_files):
    filename = os.path.basename(pdf_path)
    article_id = extract_article_id(filename)
    text = extract_text_from_pdf(pdf_path)
    # Extract text with PyPDF2
    # reader = PyPDF2.PdfReader(pdf_path)
    # text = "\n".join(page.extract_text() or "" for page in reader.pages)
    
    rows.append({"article_id": article_id, "text": text})

In [19]:
# 🔮 Predict type
label_map = {0: "Primary", 1: "Secondary", 2: "Missing"}
batch_size = 8
preds = []

for i in tqdm(range(0, len(rows), batch_size)):
    batch_texts = [r["text"] for r in rows[i:i+batch_size]]
    enc = tokenizer(batch_texts, truncation=True, padding=True, max_length=512, return_tensors="pt").to(device)
    with torch.no_grad():
        logits = model(**enc).logits
        p = torch.argmax(logits, dim=1).cpu().tolist()
        preds.extend(p)

100%|██████████| 4/4 [00:02<00:00,  1.77it/s]


In [68]:
# 🏷️ Extract dataset IDs and build results
def infer_dataset_ids(text):
    raw = extract_all_doi_candidates(text)
    return clean_and_deduplicate_dois(raw)

results = []
for i, r in enumerate(rows):
    t = r["text"]
    dois = find_dataset_dois(t)
    # print("Possible DOIs found in text:", r["article_id"])
    # for d in dois:
        # print("-", d)
    accs = find_accession_ids_in_text(t)
    dataset_id = dois[0] if dois else (accs[0] if accs else "")
    results.append({
        "article_id": r["article_id"],
        "dataset_id": dataset_id,
        "type": label_map[preds[i]]
    })

In [62]:
# 📝 Build submission DataFrame
df = pd.DataFrame(results)

In [63]:
df['article_id'] = df['article_id'].str.replace('/','_')

In [65]:
df = df[df["type"] != "Missing"].reset_index(drop=True)
df = df.drop_duplicates(subset=['article_id', 'dataset_id'])
df.insert(0, "row_id", range(len(df)))
df.to_csv("submission.csv", index=False)