In [None]:
import pymupdf
from pathlib import Path
from glob import glob
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from transformers import BitsAndBytesConfig
from json import loads, dump, dumps, JSONDecodeError
from pydantic import BaseModel, Field, ConfigDict, ValidationError
from typing import List
import psutil, os, time

In [None]:
pdf_dir = "../../data/metadata_extraction_data/demo/"

model_id = "microsoft/Phi-4-mini-instruct"
output_dir = Path("../../data/metadata_extraction_data/phi4mini_demo_metadata")
response_start_token = "<|assistant|>"
response_end_token = "<|end|>"

# model_id = "Qwen/Qwen2.5-3B-Instruct"
# output_dir = Path("../../data/metadata_extraction_data/qwen3b_demo_metadata")
# response_start_token = "<|im_start|>assistant\n"
# response_end_token = "<|im_end|>"

# model_id = "Qwen/Qwen3-4B-Base"
# output_dir = Path("../../data/metadata_extraction_data/qwen4b_demo_metadata")
# response_start_token = "<|im_start|>assistant\n"
# response_end_token = "<|endoftext|>"

# model_id = "meta-llama/Llama-3.2-3B-Instruct"
# output_dir = Path("../../data/metadata_extraction_data/llama3b_demo_metadata")
# response_start_token = "<|start_header_id|>assistant<|end_header_id|>\n\n"
# response_end_token = "<|eot_id|>"

output_dir.mkdir(parents=True, exist_ok=True)

In [None]:
class ExtractMetadata(BaseModel):
    """
    Structured metadata for an academic publication.
    """
    model_config = ConfigDict(extra="forbid")
    title: str = Field(
        ...,
        description="The full name identifying the academic publication.",
    )
    authors: List[str] = Field(
        ...,
        description="The names of individuals who wrote the publication.",
    )
    affiliations: List[str] = Field(
        ...,
        description="Institutions or organizations associated with the authors.",
    )
    email_ids: List[str] = Field(
        ...,
        description="Contact email IDs of the authors.",
    )
    publication_date: str = Field(
        ...,
        description="The date when the publication was officially published in DD-MM-YYYY or MM-YYYY or YYYY format.",
    )
    publisher: str = Field(
        ...,
        description="The organization responsible for publishing the document.",
    )
    doi: str = Field(
        ...,
        description="A unique digital object identifier linking directly to the publication online.",
    )
    keywords: List[str] = Field(
        ...,
        description="Specific terms highlighting the main topics of the publication.",
    )
    abstract: str = Field(
        ...,
        description="A brief summary outlining the publication’s content, methods, and findings.",
    )

In [None]:
def get_pdf_text(file, no_pgs=1):
    doc = pymupdf.open(file)
    pdf_text = ""
    for page in doc.pages(0, no_pgs, 1):
        pdf_text += page.get_text()+"\n"

    return pdf_text

In [None]:
response_schema = dumps(ExtractMetadata.model_json_schema(), indent=4)
system_prompt = (
    "You are given the text from an academic publication. Your task is to extract metadata from the given text."
    "The metadata fields you should extract are: Title, Authors, Affiliations, Email IDs, DOI, Publisher, Publication Date, Keywords, and Abstract."
    "If any metadata information is missing, leave it blank."
    f"Return a JSON with this schema:\n{response_schema}\n"
    "Do not add any preamble or explanations."
    )
user_prompt_template = (
    "## Text:\n{text}"
)

In [None]:
pdf_files = glob(pdf_dir+"*.pdf")

In [None]:
tok  = AutoTokenizer.from_pretrained(model_id, cache_dir="../../.hf_cache")
bnb_conf = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16) # NOTE: if qwen models dont give proper structured response, try not using "bnb_conf"
model = AutoModelForCausalLM.from_pretrained(model_id, dtype=torch.float16, cache_dir="../../.hf_cache", trust_remote_code=False, device_map="auto", quantization_config=bnb_conf)

process = psutil.Process(os.getpid())

torch.cuda.reset_peak_memory_stats()
t0 = time.perf_counter()
tt=[]

for file in pdf_files:
    count = 0
    text_file_path = Path(f"{output_dir / Path(file).stem}.json")
    if not text_file_path.is_file():
        user_prompt = user_prompt_template.format(text=get_pdf_text(file, 1))
        messages = [{"role": "system", "content": system_prompt},
                {"role": "user", "content": user_prompt}]
        tok.pad_token = tok.eos_token
        inputs = tok.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_tensors="pt", padding=True).to(model.device)
        attention_mask = torch.ones_like(inputs)
        start = torch.cuda.Event(enable_timing=True)
        end   = torch.cuda.Event(enable_timing=True)
        start.record()
        with torch.inference_mode():
            generated = model.generate(
                inputs,
                attention_mask=attention_mask,
                max_new_tokens=1024,
                return_dict_in_generate=True,
                output_scores=True,
                pad_token_id=tok.eos_token_id,
            )
        end.record();  torch.cuda.synchronize()
        seq = generated.sequences[0]
        #text = tok.batch_decode(generated)[0]
        text = tok.decode(seq, skip_special_tokens=False)
        clean_text = text.split(response_start_token)[1].split(response_end_token)[0].replace("\n", "").replace("```json","").replace("```","")
        try:
            op = loads(clean_text)
        except JSONDecodeError as e:
            op = {
                "title": "",
                "authors": [],
                "affiliations": [],
                "email_ids": [],
                "doi": "",
                "publisher": "",
                "publication_date": "",
                "keywords": [],
                "abstract": ""
            }
            print(file, clean_text[-50:], e)
        tt.append(start.elapsed_time(end))

        try:
            metadata = ExtractMetadata(**op)
        except ValidationError as e:
            print("ERROR: ", file, e)

        with open(text_file_path, 'w', encoding='utf-8') as f:
            dump(op, f, ensure_ascii=False, indent=4)
        print("Wrote file to: ", text_file_path)
    else:
        print("Metadata exists")
torch.cuda.synchronize()          # wait for the GPU
elapsed_s   = time.perf_counter() - t0
gpu_alloc   = torch.cuda.max_memory_allocated()  / 1024**2   # MiB actually used
gpu_reserved= torch.cuda.max_memory_reserved()   / 1024**2   # MiB reserved by the allocator
cpu_mem     = process.memory_info().rss          / 1024**2   # MiB

print(f"elapsed {elapsed_s:,.2f}s | GPU used {gpu_alloc:.0f} MiB | "
      f"GPU reserved {gpu_reserved:.0f} MiB | CPU {cpu_mem:.0f} MiB")

In [None]:
# gpt-oss using ollama

from ollama import chat
from ollama import ChatResponse

model_id = 'gpt-oss:20b'
output_dir = Path("../../data/metadata_extraction_data/gpt_oss_demo_metadata")

output_dir.mkdir(parents=True, exist_ok=True)

tt = []
failed=False
for file in pdf_files:
    text_file_path = Path(f"{output_dir / Path(file).stem}.json")

    if not text_file_path.is_file():
        user_prompt = user_prompt_template.format(text=get_pdf_text(file, 2))
        start = time.time()
        messages=[
                {
                    'role': 'system',
                    'content': system_prompt,
                },
                {
                    'role': 'user',
                    'content': user_prompt,
                }
            ]
        response: ChatResponse = chat(model=model_id, messages=messages)
        time_taken = time.time() - start
        tt.append(round(time_taken,2))
        text = response['message']['content']

        try:
            op = loads(text)
        except JSONDecodeError as e:
            print(file, text[-50:])
            op = {
                "title": "",
                "authors": [],
                "affiliations": [],
                "email_ids": [],
                "doi": "",
                "publisher": "",
                "publication_date": "",
                "keywords": [],
                "abstract": ""
            }
            failed=True
        try:
            metadata = ExtractMetadata(**op)
        except ValidationError as e:
            print("ERROR: ", file, e)
        if not failed:
            with open(text_file_path, 'w', encoding='utf-8') as f:
                dump(op, f, ensure_ascii=False, indent=4)
            print("Wrote file to: ", text_file_path)
    else:
        print("Metadata exists")