In [4]:
!pip install bitsandbytes -q

In [4]:
import pandas as pd
import ast

In [7]:
# %%writefile report_generation.py

import pandas as pd
import ast

def generate_missing_affiliation_report():
    df = pd.read_csv("test_filled_21.csv")
    # Generate Report & Identify Done Papers
    fully_filled_count = 0
    partially_filled_count = 0
    total_null_author_slots = 0
    incomplete_indices = []
    total_papers = len(df)

    done_titles = []

    for idx, row in df.iterrows():
        try:
            affs = ast.literal_eval(str(row['affiliations']))
            null_count = sum(1 for a in affs if a is None)
            total_null_author_slots += null_count
            if null_count == 0: 
                fully_filled_count += 1
                done_titles.append(row['title'])
            else:
                partially_filled_count += 1
                incomplete_indices.append(idx)
        except: 
            incomplete_indices.append(idx)

    with open("missing_affiliations_reports.txt", 'w', encoding='utf-8') as f:
        f.write("=== MISSING AFFILIATIONS REPORT (V12) ===\n\n")
        f.write(f"Total papers: {total_papers}\n")
        f.write(f"Fully filled papers: {fully_filled_count}\n")
        f.write(f"Partially filled papers: {partially_filled_count}\n")
        f.write(f"Total remaining null author affiliations: {total_null_author_slots}\n")
        f.write(f"\nRemaining incomplete indices ({len(incomplete_indices)}):\n")
        f.write(str(incomplete_indices) + "\n")

In [8]:
generate_missing_affiliation_report()

In [6]:
%%writefile model_and_extraction.py

import os
import re
import json
import torch

from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

MODEL_ID = "Qwen/Qwen2.5-7B-Instruct"
CSV_PATH = "/kaggle/input/arxiver-data/test_filled_21.csv"
OUTPUT_CSV_PATH = "test_filled_22.csv"
LATEX_FILES = ["/kaggle/input/arxiver-data/latex_affiliations_output.txt", "/kaggle/input/arxiver-data/latex_affiliations_output_2.txt"]
BATCH_SIZE = 8

def load_model_for_batching():
    print(f"Loading {MODEL_ID} with padding support...")
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_compute_dtype=torch.float16,
        bnb_4bit_quant_type="nf4",
    )

    tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
    tokenizer.padding_side = "left"
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    model = AutoModelForCausalLM.from_pretrained(
        MODEL_ID,
        quantization_config=bnb_config,
        device_map="auto",
        attn_implementation="sdpa"
    )
    
    return tokenizer, model


def clean_latex_input(latex_code):
    if not latex_code or len(latex_code) < 10:
        return ""
    cut_match = re.search(r'\\begin\{abstract\}|\\section\{Intro', latex_code, re.IGNORECASE)
    limit = cut_match.start() + 500 if cut_match else 4000
    return latex_code[:limit]


def batch_extract_affiliations(tokenizer, model, batch_data):
    prompts = []

    for item in batch_data:
        short_latex = clean_latex_input(item['latex'])
        if not short_latex:
            prompts.append("NO DATA")
            continue

        prompt_text = f"""Context:
{short_latex}

Task: Extract affiliations for: {item['authors']}.
Rules:
1. JSON format only. keys=authors, values=affiliations.
2. Join multiple affiliations ONLY with a semicolon (;).
3. If missing, use null.
4. Output JSON only.

JSON:"""

        messages = [
            {"role": "system", "content": "Extract metadata to JSON. Use ';' separator."},
            {"role": "user", "content": prompt_text}
        ]
        prompts.append(tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True))

    valid_indices = [i for i, p in enumerate(prompts) if p != "NO DATA"]
    valid_prompts = [prompts[i] for i in valid_indices]

    if not valid_prompts:
        return [([None] * len(x['authors'])) for x in batch_data]

    inputs = tokenizer(valid_prompts, return_tensors="pt", padding=True, truncation=True, max_length=1600).to(model.device)

    with torch.no_grad():
        generated_ids = model.generate(
            **inputs,
            max_new_tokens=350,
            temperature=0.1,
            do_sample=False,
            pad_token_id=tokenizer.pad_token_id
        )

    input_len = inputs.input_ids.shape[1]
    responses = tokenizer.batch_decode(generated_ids[:, input_len:], skip_special_tokens=True)

    results = []
    response_ptr = 0

    for i in range(len(batch_data)):
        if i not in valid_indices:
            results.append([None] * len(batch_data[i]['authors']))
            continue

        raw_text = responses[response_ptr]
        response_ptr += 1

        authors = batch_data[i]['authors']
        aff_list = []

        try:
            clean_json = raw_text.replace("```json", "").replace("```", "").strip()
            data = json.loads(clean_json)

            for auth in authors:
                val = data.get(auth, None)
                if val:
                    val = str(val)
                    val = val.replace('\n', '; ').replace('\\\\', '; ').replace(' and ', '; ')
                    parts = [p.strip() for p in val.split(';') if p.strip()]
                    seen = set()
                    final_parts = [x for x in parts if not (x in seen or seen.add(x))]
                    val = "; ".join(final_parts)

                aff_list.append(val)
        except:
            aff_list = [None] * len(authors)

        results.append(aff_list)

    return results

Overwriting model_and_extraction.py


In [7]:
%%writefile main.py

import os
import re
import ast
import glob
import torch
import pandas as pd

from tqdm.auto import tqdm
from accelerate import PartialState

from model_and_extraction import (
    load_model_for_batching,
    batch_extract_affiliations,
    CSV_PATH,
    OUTPUT_CSV_PATH,
    LATEX_FILES,
    BATCH_SIZE,
)

distributed_state = PartialState()

def main():
    rank = distributed_state.process_index
    tokenizer, model = load_model_for_batching()
    df = pd.read_csv(CSV_PATH)
    paper_latex_map = {}
    for f_path in LATEX_FILES:
        if os.path.exists(f_path):
            with open(f_path, 'r', encoding='utf-8', errors='ignore') as f:
                content = f.read()
                sections = re.split(r'(?=PAPER:)', content)
                for sec in sections:
                    m = re.search(r'PAPER:\s*(.+?)(?:\n|$)', sec)
                    if m:
                        paper_latex_map[m.group(1).strip()] = sec
    

    def needs_processing(aff_str):
        try:
            affs = ast.literal_eval(str(aff_str))
            return any(a is None for a in affs)
        except:
            return True

    missing_indices = df[df['affiliations'].apply(needs_processing)].index.tolist()
    batches = [
        missing_indices[i:i + BATCH_SIZE]
        for i in range(0, len(missing_indices), BATCH_SIZE)
    ]
    
    local_results = []
    import gc
    gc.collect()
    with distributed_state.split_between_processes(batches) as split_batches:
        for idx_list in tqdm(split_batches, disable=not distributed_state.is_local_main_process):
            batch_data = []

            for idx in idx_list:
                row = df.loc[idx]

                try:
                    authors = ast.literal_eval(str(row['authors']))
                except:
                    authors = []

                latex = paper_latex_map.get(row['title'])
                batch_data.append({
                    "idx": idx,
                    "authors": authors,
                    "latex": latex
                })

            batch_results = batch_extract_affiliations(tokenizer, model, batch_data)

            for i, item in enumerate(batch_data):
                local_results.append((item["idx"], str(batch_results[i])))

    out_path = f"partial_affiliations_gpu_{rank}.csv"
    pd.DataFrame(local_results, columns=["index", "affiliations"]).to_csv(out_path, index=False)

    distributed_state.print(f"GPU {rank} finished.")

    if distributed_state.is_main_process:
        files = glob.glob("partial_affiliations_gpu_*.csv")
        dfs = [pd.read_csv(f) for f in files]
        merged = pd.concat(dfs).set_index("index")

        for idx, row in merged.iterrows():
            df.at[idx, "affiliations"] = row["affiliations"]

        df.to_csv(OUTPUT_CSV_PATH, index=False)
        print(f"Final CSV saved → {OUTPUT_CSV_PATH}")

if __name__ == "__main__":
    main()

Writing main.py


In [None]:
!torchrun --nproc_per_node=2 main.py

W1228 00:21:00.677000 113 torch/distributed/run.py:774] 
W1228 00:21:00.677000 113 torch/distributed/run.py:774] *****************************************
W1228 00:21:00.677000 113 torch/distributed/run.py:774] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. 
W1228 00:21:00.677000 113 torch/distributed/run.py:774] *****************************************
Loading Qwen/Qwen2.5-7B-Instruct with padding support...
Loading Qwen/Qwen2.5-7B-Instruct with padding support...
tokenizer_config.json: 7.30kB [00:00, 32.2MB/s]
vocab.json: 2.78MB [00:00, 45.8MB/s]
merges.txt: 1.67MB [00:00, 146MB/s]
tokenizer.json: 7.03MB [00:00, 189MB/s]
config.json: 100%|█████████████████████████████| 663/663 [00:00<00:00, 5.66MB/s]
2025-12-28 00:21:19.807418: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: 

In [None]:
from report_generation import generate_missing_affiliation_report
generate_missing_affiliation_report()

!cat /kaggle/working/missing_affiliations_reports.txt