# LLM Baseline for Data Reference Classification

This notebook implements a baseline approach using a large language model (LLM) to classify data references in academic papers.

## Overview

The approach consists of:
1. **Text extraction**: Extract text from PDF files and identify potential data references
2. **Context extraction**: Extract relevant context around identified references
3. **LLM classification**: Use LLM to classify references as Primary, Secondary, or None

## Data Processing Pipeline

In [None]:
# Install required packages
!pip install pymupdf --no-cache-dir

# Import libraries
import os
import re
import fitz
import numpy as np
import pandas as pd
from pathlib import Path
from tqdm.auto import tqdm
from logits_processor_zoo.vllm import MultipleChoiceLogitsProcessor
import pickle
import vllm
import torch

# Configure environment
os.environ["VLLM_USE_V1"] = "0"  # vLLM V1 does not accept logits processor

# Set data paths
pdf_directory = "dataset/test/PDF" \
                if os.getenv('KAGGLE_IS_COMPETITION_RERUN') \
                else "dataset/train/PDF"

# Initialize storage
chunks = []      # DOI chunks
chunks2 = []     # Other ID chunks
text_span_len = 300

# Define regex patterns for different types of identifiers
re_doi = re.compile(r"10\.\d{4,9}/[-._;()/:A-Z0-9]+", re.IGNORECASE)
re_gsr = re.compile(r"GSE\d+|SR[APRX]\d+|PRJ[NAED][A-Z]?\d+|E-[A-Z]+-\d+", re.IGNORECASE)
re_ipe = re.compile(r"IPR\d{6}|PF\d{5}|EMPIAR-\d{5}|EMD-\d{4,5}", re.IGNORECASE)
re_c = re.compile(r"CHEMBL\d+|CVCL_[A-Z0-9]{4}|CID:\d+", re.IGNORECASE)
re_e = re.compile(r"ENS[A-Z]{0,6}[GT]\d{11}|ENSG\d{11}", re.IGNORECASE)
re_r = re.compile(r"N[MC]_\d+(?:\.\d+)?|rs\d+|XM_\d+|XP_\d+", re.IGNORECASE)
re_u = re.compile(r"(?:uniprot:)?(?:[OPQ][0-9][A-Z0-9]{3}[0-9]|[A-NR-Z][0-9][A-Z][A-Z0-9]{2}[0-9])", re.IGNORECASE)
re_g = re.compile(r"EPI(?:_ISL_)?\d+|GISAID", re.IGNORECASE)
re_p = re.compile(r"PXD\d{6}|SAM[ND]\d+|ERR\d+|DRR\d+|MSV\d+", re.IGNORECASE)
re_pdb = re.compile(r"\b[0-9][A-Z0-9]{3}\b", re.IGNORECASE)
re_geo = re.compile(r"GDS\d+|GPL\d+|GSM\d+", re.IGNORECASE)
re_arrayexpress = re.compile(r"E-[A-Z]+-\d+", re.IGNORECASE)

relist = [re_gsr, re_ipe, re_c, re_e, re_r, re_g, re_p, re_geo, re_arrayexpress]
ids = []

def remove_references_section(text):
    """Remove references section from paper text to reduce noise"""
    lines = text.split('\n')
    cut_index = -1
    
    # Look backwards from end of document
    for i in range(len(lines) - 1, max(0, int(len(lines) * 0.2)), -1):
        line = lines[i].strip()
        obvious_patterns = [
            r'^REFERENCES?$', r'^\d+\.?\s+REFERENCES?$', r'^\d+\.?\s+References?$',
            r'^References?:?$', r'^BIBLIOGRAPHY$', r'^\d+\.?\s+BIBLIOGRAPHY$',
            r'^\d+\.?\s+Bibliography$', r'^Bibliography:?$', r'^Literature\s+Cited$',
            r'^Works\s+Cited$', r'^ACKNOWLEDGMENTS?$', r'^Acknowledgments?$',
            r'^FUNDING$', r'^CONFLICTS?\s+OF\s+INTEREST$'
        ]
        
        if any(re.match(pattern, line, re.IGNORECASE) for pattern in obvious_patterns):
            # Check following lines for citation patterns
            following_lines = lines[i+1:i+5]
            has_citations = False
            for follow_line in following_lines:
                if follow_line.strip():
                    if (re.search(r'\(\d{4}\)', follow_line) or
                        re.search(r'\d{4}\.', follow_line) or
                        'doi:' in follow_line.lower() or
                        ' et al' in follow_line.lower() or
                        re.search(r'^\[\d+\]', follow_line.strip()) or
                        re.search(r'^\d+\.', follow_line.strip())):
                        has_citations = True
                        break
            
            if has_citations or i >= len(lines) - 5:
                cut_index = i
                break
    
    return '\n'.join(lines[:cut_index]).strip() if cut_index != -1 else text.strip()

def extract_context_with_keywords(text, match_start, match_end, span_len=300):
    """Extract context around matches with keyword-based scoring"""
    keyword_scores = {
        "data are available": 5, "datasets are available": 5, "deposited in": 5, 
        "submitted to": 5, "accession number": 5, "accession code": 5, 
        "accession id": 5, "archived in": 4, "uploaded to": 4, "source code": 4, 
        "raw data": 4, "sequencing data": 4, "retrieved from": 3, "downloaded from": 3, 
        "obtained from": 3, "supplementary data": 3, "supporting information": 3, 
        'deposited': 3, 'submitted': 3, 'accession': 3, "available in the": 2, 
        "publicly available": 2, "freely available": 2, "supplementary material": 2, 
        'dataset': 2, 'datasets': 2, 'database': 2, 'repository': 2, 'code': 2, 
        'scripts': 2, 'available': 1, 'download': 1, 'supplementary': 1, 
        'supporting': 1, 'software': 1, 'protocol': 1, 'data': 0.5
    }
    
    contexts = {
        'standard': text[max(0, match_start - span_len):min(len(text), match_end + span_len)],
        'extended': text[max(0, match_start - span_len * 2):min(len(text), match_end + span_len * 2)]
    }
    
    def score_context(context):
        return sum(
            context.lower().count(k) * v if ' ' in k 
            else len(re.findall(r'\b' + re.escape(k) + r'\b', context.lower())) * v
            for k, v in keyword_scores.items()
        )
    
    scores = {k: score_context(v) for k, v in contexts.items()}
    return contexts['extended'] if scores['extended'] > scores['standard'] and scores['extended'] > 4 else contexts['standard']

# Process PDF files
print("Processing PDF files...")
for filename in tqdm(os.listdir(pdf_directory), desc="Processing PDFs"):
    if filename.endswith(".pdf"):
        pdf_path = os.path.join(pdf_directory, filename)
        article_id = filename.split(".pdf")[0]
        
        try:
            with fitz.open(pdf_path) as doc:
                text = "\n".join(page.get_text() for page in doc)
        except Exception as e:
            print(f"Could not process {filename}: {e}")
            continue

        text = remove_references_section(text)
        
        # Extract DOI matches
        doi_matches = list(re_doi.finditer(text))
        for match in doi_matches:
            if match.group() in article_id:  # Skip own DOI
                continue
            chunk = extract_context_with_keywords(text, match.start(), match.end(), text_span_len)
            chunks.append((article_id, chunk))
            
        # Extract other ID matches
        for rr in relist:
            matches = list(rr.finditer(text))
            for match in matches:
                ids.append(match.group())
                chunk = extract_context_with_keywords(text, match.start(), match.end(), text_span_len)
                chunks2.append((article_id, chunk))

print(f"DOI chunks: {len(chunks)}")
print(f"Other ID chunks: {len(chunks2)}")

## Model Initialization

Load the LLM model for classification tasks.

In [None]:
model_path = "/kaggle/input/qwen2.5/transformers/32b-instruct-awq/1"

# Initialize LLM with optimized settings
llm = vllm.LLM(
    model_path,
    quantization='awq',
    tensor_parallel_size=torch.cuda.device_count(),
    gpu_memory_utilization=0.9,
    trust_remote_code=True,
    dtype="half",
    enforce_eager=True,
    max_model_len=2048,
    disable_log_stats=True,
    enable_prefix_caching=True
)

tokenizer = llm.get_tokenizer()
print("Model loaded successfully!")

## System Prompts

Define system prompts for different classification tasks.

In [None]:
SYS_PROMPT_DOI = """
You are an expert at identifying research data citations in academic papers. 

Your task is to determine if a DOI citation in the given text refers specifically to research data, datasets, or data repositories.

Only respond with either a full normalized DOI URL starting with "https://doi.org/" or the word "Irrelevant" (without quotes).

Do NOT include any other text or explanation.

If there is no DOI related to research data, respond with exactly "Irrelevant".
If multiple DOIs refer to research data, return any one of them.
"""

SYS_PROMPT_ACCESSION = """
You are an expert at analyzing research data usage in academic papers.

Look for contextual clues:
- For PRIMARY data: "we deposited", "data generated in this study", "our data", "submitted to", "newly generated"
- For SECONDARY data: "downloaded from", "obtained from", "previously published", "publicly available", "existing dataset"
- For NONE: mentioned in references, methodology descriptions without actual usage, or unrelated contexts

Respond with only one letter: A, B, or C.
"""

SYS_PROMPT_CLASSIFY_DOI = """
You are an expert at analyzing research data citations in academic papers.

Classify the data as:
A) Primary: if the data was generated specifically for this study
B) Secondary: if the data was reused or derived from prior work  
C) None: if the DOI is in references, doesn't refer to research data, or is unrelated

Respond with only one letter: A, B, or C.
"""

## DOI Extraction

Extract DOI links from the text chunks using LLM.

In [None]:
# Prepare prompts for DOI extraction
prompts = []
for article_id, academic_text in chunks:
    messages = [
        {"role": "system", "content": SYS_PROMPT_DOI},
        {"role": "user", "content": academic_text}
    ]
    
    prompt = tokenizer.apply_chat_template(
        messages,
        add_generation_prompt=True,
        tokenize=False,
    )
    prompts.append(prompt)

# Generate responses
print("Extracting DOI links...")
outputs = llm.generate(
    prompts,
    vllm.SamplingParams(
        seed=0,
        skip_special_tokens=True,
        max_tokens=64,
        temperature=0
    ),
    use_tqdm=True
)

responses = [output.outputs[0].text.strip() for output in outputs]

# Process responses to extract DOI URLs
doi_pattern = re.compile(r'(10\.\d{4,9}/[-._;()/:A-Z0-9]+)', re.I)
doi_urls = []

for response in responses:
    if response.lower() == "irrelevant":
        doi_urls.append("Irrelevant")
    else:
        match = doi_pattern.search(response)
        if match:
            doi_urls.append("https://doi.org/" + match.group(1))
        else:
            doi_urls.append("Irrelevant")

print(f"Found {sum(1 for url in doi_urls if url != 'Irrelevant')} relevant DOIs")

## DOI Classification

Classify the extracted DOI links as Primary, Secondary, or None using the LLM with logits processor.

In [None]:
# Prepare prompts for DOI classification
prompts = []
valid_indices = []

for i, (chunk, url) in enumerate(zip(chunks, doi_urls)):
    if url == "Irrelevant":
        continue
    
    article_id, academic_text = chunk
    messages = [
        {"role": "system", "content": SYS_PROMPT_CLASSIFY_DOI},
        {"role": "user", "content": f"DOI: {url}\n\nAcademic text:\n{academic_text}"}
    ]
    
    prompt = tokenizer.apply_chat_template(
        messages,
        add_generation_prompt=True,
        tokenize=False,
    )
    prompts.append(prompt)
    valid_indices.append(i)

# Initialize logits processor for multiple choice
mclp = MultipleChoiceLogitsProcessor(tokenizer, choices=["A", "B", "C"])

# Generate classifications
print("Classifying DOI references...")
outputs = llm.generate(
    prompts,
    vllm.SamplingParams(
        seed=777,
        temperature=0.05,
        skip_special_tokens=True,
        max_tokens=1,
        logits_processors=[mclp],
        logprobs=len(mclp.choices)
    ),
    use_tqdm=True
)

# Process logprobs for confidence-based selection
logprobs = []
for lps in [output.outputs[0].logprobs[0].values() for output in outputs]:
    logprobs.append({lp.decoded_token: lp.logprob for lp in list(lps)})

logit_matrix = pd.DataFrame(logprobs)[["A", "B", "C"]].values
choices = ["Primary", "Secondary", None]
answers = [None] * len(chunks)

# Apply confidence-based selection
for i, (idx, logit_row) in enumerate(zip(valid_indices, logit_matrix)):
    max_logit = np.max(logit_row)
    max_idx = np.argmax(logit_row)
    
    if max_logit > -2.0:  # Confidence threshold
        answers[idx] = choices[max_idx]

print(f"Classified {sum(1 for a in answers if a is not None)} DOI references")

In [None]:
# Prepare prompts for other ID classification
prompts = []
for chunk, acc_id in zip(chunks2, ids):
    article_id, academic_text = chunk
    messages = [
        {"role": "system", "content": SYS_PROMPT_ACCESSION},
        {"role": "user", "content": f"Accession ID: {acc_id}\n\nAcademic text:\n{academic_text}"}
    ]
    
    prompt = tokenizer.apply_chat_template(
        messages,
        add_generation_prompt=True,
        tokenize=False,
    )
    prompts.append(prompt)

# Generate classifications for other IDs
print("Classifying other ID references...")
outputs = llm.generate(
    prompts,
    vllm.SamplingParams(
        seed=777,
        temperature=0.05,
        skip_special_tokens=True,
        max_tokens=1,
        logits_processors=[mclp],
        logprobs=len(mclp.choices)
    ),
    use_tqdm=True
)

# Process logprobs for other IDs
logprobs2 = []
for lps in [output.outputs[0].logprobs[0].values() for output in outputs]:
    logprobs2.append({lp.decoded_token: lp.logprob for lp in list(lps)})

logit_matrix2 = pd.DataFrame(logprobs2)[["A", "B", "C"]].values
choices2 = ["Primary", "Secondary", None]

# Apply confidence-based selection for other IDs
answers2 = []
for logit_row in logit_matrix2:
    max_logit = np.max(logit_row)
    max_idx = np.argmax(logit_row)
    
    if max_logit > -2.0:  # Confidence threshold
        answers2.append(choices2[max_idx])
    else:
        answers2.append(None)

print(f"Classified {sum(1 for a in answers2 if a is not None)} other ID references")

## Submission Preparation

Combine results and prepare the final submission file.

In [None]:
# Create submission dataframes
sub_df = pd.DataFrame({
    "article_id": [c[0] for c in chunks],
    "dataset_id": [url.lower() for url in doi_urls],
    "type": answers
})

sub_df2 = pd.DataFrame({
    "article_id": [c[0] for c in chunks2],
    "dataset_id": ids,
    "type": answers2
})

# Filter valid entries
sub_df = sub_df[sub_df["type"].notnull()].reset_index(drop=True)
sub_df2 = sub_df2[sub_df2["type"].notnull()].reset_index(drop=True)

# Combine dataframes
sub_df = pd.concat([sub_df, sub_df2], ignore_index=True)
sub_df = sub_df[sub_df["type"].isin(["Primary", "Secondary"])].reset_index(drop=True)

# Enhanced deduplication with priority to Primary data
sub_df = sub_df.sort_values(
    by=["article_id", "dataset_id", "type"], 
    key=lambda x: x.map({"Primary": 0, "Secondary": 1}) if x.name == "type" else x
).drop_duplicates(
    subset=['article_id', 'dataset_id'], 
    keep="first"
).reset_index(drop=True)

# Prepare final submission
sub_df['row_id'] = range(len(sub_df))
sub_df.to_csv("submission.csv", index=False, columns=["row_id", "article_id", "dataset_id", "type"])

print("Final submission statistics:")
print(sub_df["type"].value_counts())
print(f"Total entries: {len(sub_df)}")
print(f"Unique articles: {sub_df['article_id'].nunique()}")
print(f"Unique datasets: {sub_df['dataset_id'].nunique()}")

## Model Evaluation

Evaluate the model performance on the training set.

In [None]:
def f1_score(tp, fp, fn):
    """Calculate F1 score from true positives, false positives, and false negatives"""
    return 2 * tp / (2 * tp + fp + fn) if (2 * tp + fp + fn) != 0 else 0.0

# Evaluate on training set if not in competition mode
if not os.getenv('KAGGLE_IS_COMPETITION_RERUN'):
    print("Evaluating model performance...")
    
    # Load predictions and labels
    pred_df = pd.read_csv("submission.csv")
    label_df = pd.read_csv("dataset/train_labels.csv")
    
    # Filter out 'Missing' type from labels
    label_df = label_df[label_df['type'] != 'Missing'].reset_index(drop=True)
    
    # Calculate hits (correct predictions)
    hits_df = label_df.merge(pred_df, on=["article_id", "dataset_id", "type"])
    
    # Calculate metrics
    tp = hits_df.shape[0]
    fp = pred_df.shape[0] - tp
    fn = label_df.shape[0] - tp
    
    # Calculate precision, recall, and F1
    precision = tp / (tp + fp) if (tp + fp) > 0 else 0
    recall = tp / (tp + fn) if (tp + fn) > 0 else 0
    f1 = f1_score(tp, fp, fn)
    
    print("\n" + "="*50)
    print("VALIDATION RESULTS")
    print("="*50)
    print(f"True Positives (TP):  {tp}")
    print(f"False Positives (FP): {fp}")
    print(f"False Negatives (FN): {fn}")
    print(f"Precision: {precision:.3f}")
    print(f"Recall:    {recall:.3f}")
    print(f"F1 Score:  {f1:.3f}")
    print("="*50)
    
    # Additional analysis
    print(f"\nPredictions by type:")
    print(pred_df["type"].value_counts())
    print(f"\nLabels by type:")
    print(label_df["type"].value_counts())
else:
    print("Competition mode - skipping evaluation")