# 📈 Inference Pipeline (No GROBID)

This notebook:
- Loads your trained classifier (`saved_model`)
- Reads PDFs directly via Marker-PDF
- Predicts `type`
- Extracts `dataset_id` using regex
- Produces `submission.csv`

**Note:** No GROBID needed.

In [ ]:
# 📚 Imports
import re
import os
from marker.converters.pdf import PdfConverter
from marker.models import create_model_dict
from marker.output import text_from_rendered
import pandas as pd
from glob import glob
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch

In [ ]:
# 🧠 Load model and tokenizer
model_path = "/kaggle/input/makedatacount-mixed-train/saved_model_dual_text"
token_path = model_path

tokenizer = AutoTokenizer.from_pretrained(token_path)
model = AutoModelForSequenceClassification.from_pretrained(model_path)
model.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

In [ ]:
# 🧩 Utility functions
REPO_KEYWORDS = ['zenodo', 'dryad', 'figshare', 'pangaea', 'tcia']

def standardize_doi(doi):
    doi = str(doi).strip()
    if doi.startswith("http"):
        return doi.lower()
    if doi.startswith("doi:"):
        return "https://doi.org/" + doi[4:].lower()
    if doi.startswith("10."):
        return "https://doi.org/" + doi.lower()
    return doi.lower()

def extract_dataset_dois(text):
    text = text.replace('](', '] (').replace(')](', ') (')
    url_pattern = r'https?://[^\s\)<>\]]+'
    urls = re.findall(url_pattern, text)
    doi_pattern = r'\b10\.\d{4,9}/[^\s\)<>\]]+'
    bare_dois = re.findall(doi_pattern, text)

    candidates = []
    for u in urls:
        if 'doi.org/10.' in u:
            candidates.append(u.rstrip('.,;)]>'))
    for d in bare_dois:
        candidates.append(standardize_doi(d))

    return sorted(set(candidates))

def find_accession_ids_in_text(text):
    patterns = [
        r'\b(GSE\d+)\b',
        r'\b(PRJ[ENAD]\d+)\b',
        r'\b(SRP\d+)\b',
        r'\b(E-[A-Z]+-\d+)\b',
        r'\b(pdb\s[\d\w]+)\b',
        r'\b(CHEMBL\d+)\b'
    ]
    matches = []
    for pat in patterns:
        matches.extend(re.findall(pat, text, re.IGNORECASE))
    return [m.lower().replace(" ", "") for m in matches]

def extract_article_id(filename):
    return filename[:-4].replace("_", "/")

In [ ]:
# 📂 Process PDFs
pdf_dir = "/kaggle/input/make-data-count-finding-data-references/test/PDF"
pdf_files = sorted(glob(os.path.join(pdf_dir, "*.pdf")))
converter = PdfConverter(artifact_dict=create_model_dict(device='cuda:0'))

rows = []
def extract_text_from_pdf(file_path):
    rendered = converter(file_path)
    text, _, _ = text_from_rendered(rendered)
    return text

In [ ]:
for pdf_path in tqdm(pdf_files):
    filename = os.path.basename(pdf_path)
    article_id = extract_article_id(filename)
    text = extract_text_from_pdf(pdf_path)
    rows.append({"article_id": article_id, "text": text})

In [ ]:
# 🔮 Predict type
label_map = {0: "Primary", 1: "Secondary", 2: "Missing"}
batch_size = 8
preds = []

for i in tqdm(range(0, len(rows), batch_size)):
    batch_texts = [r["text"] for r in rows[i:i+batch_size]]
    enc = tokenizer(batch_texts, truncation=True, padding=True, max_length=512, return_tensors="pt").to(device)
    with torch.no_grad():
        logits = model(**enc).logits
        p = torch.argmax(logits, dim=1).cpu().tolist()
        preds.extend(p)

In [ ]:
# 🏷️ Extract dataset IDs and build results
results = []
for i, r in enumerate(rows):
    t = r["text"]
    dois = extract_dataset_dois(t)
    accs = find_accession_ids_in_text(t)
    dataset_id = dois[0] if dois else (accs[0] if accs else "")
    results.append({
        "article_id": r["article_id"],
        "dataset_id": dataset_id,
        "type": label_map[preds[i]]
    })

In [ ]:
# 📝 Build submission DataFrame
df = pd.DataFrame(results)
df['article_id'] = df['article_id'].str.replace('/','_')
df = df[df['type'] != 'Missing'].reset_index(drop=True)
df = df.drop_duplicates(subset=['article_id', 'dataset_id'])
df.insert(0, "row_id", range(len(df)))
df.to_csv("submission.csv", index=False)