In [None]:
HOME = '/your/HOME/directory'

# 1 Prepare Data

Pre-processing steps for preparing the dataset of 200 papers.

## 1.1 Prepare 200 Papers

Steps to prepare the 200-paper dataset:

1. Download the arXiv dataset JSON file from <https://www.kaggle.com/datasets/Cornell-University/arxiv?resource=download> and move to `HOME` directory
2. Run the code below to generate `ARXIV_JSON` containing 200 randomly selected papers from the materials science domain
3. Download the 200 papers' PDF files
4. Manually extract bandgap data from the 200 papers and save results to `MANUAL_XLSX` (see example in `manual_pub.xlsx`)

In [None]:
import os
import json
from datetime import datetime
from collections import defaultdict
import pandas as pd
from glob import glob

KAGGLE_JSON = os.path.join(HOME, "arxiv-metadata-oai-snapshot.json")
ARXIV_DIR = os.path.join(HOME, "arXiv_mtrl-sci")
os.makedirs(ARXIV_DIR, exist_ok=True)
ARXIV_CSV = os.path.join(HOME, "arXiv_mtrl-sci.csv")
ARXIV_JSON = os.path.join(HOME, "arXiv_mtrl-sci_200.json")

# Define date range for filtering arXiv papers
START_DATE = datetime(2000, 1, 1)
END_DATE = datetime(2024, 10, 31)

In [None]:
# Filter arXiv papers in the materials science category
grouped_data = defaultdict(list)

# Read JSON file line by line
with open(KAGGLE_JSON, 'r', encoding='utf-8') as f:
    for line in f:
        entry = json.loads(line.strip())
        
        # Filter: categories must contain 'mtrl-sci'
        if 'mtrl-sci' in entry.get('categories', ''):
            # Extract required fields
            item = {
                'id': entry['id'],
                'doi': entry.get('doi', None),
                'categories': entry['categories']
            }
            # Extract and format the creation date of the first version
            first_version = entry['versions'][0]['created']
            date_v1 = datetime.strptime(first_version, '%a, %d %b %Y %H:%M:%S %Z')
            item['date-v1'] = date_v1.strftime('%Y-%m-%d')
            
            # Check if date is within specified range
            if START_DATE <= date_v1 <= END_DATE:
                # Group by year-month
                year_month = date_v1.strftime('%Y-%m')
                grouped_data[year_month].append(item)

# Write grouped data to JSON files
for year_month, items in grouped_data.items():
    # Include entry count in filename
    count = len(items)
    output_file = os.path.join(ARXIV_DIR, f"{year_month}({count}).json")
    with open(output_file, 'w', encoding='utf-8') as out_f:
        json.dump(items, out_f, indent=4, ensure_ascii=False)

In [None]:
# Randomly select 200 papers from filtered results
import matplotlib.pyplot as plt

# Read JSON files from ARXIV_DIR
json_files = sorted(glob(os.path.join(ARXIV_DIR, '*.json')))

data = []
for json_file in json_files:
    with open(json_file, 'r') as f:
        content = json.load(f)
        for paper in content:
            # Extract year from date
            year = paper['date-v1'][:4] if paper['date-v1'] else None
            data.append({
            'year': year,
            'id': paper['id'],
            'doi': paper['doi']
        })

# Create DataFrame
df = pd.DataFrame(data)

# Count unique IDs and DOIs
unique_ids = df['id'].nunique()
unique_dois = df['doi'].dropna().nunique()  # Exclude None values

print(f"Unique article IDs: {unique_ids}")
print(f"Unique DOIs: {unique_dois}")

# Save to CSV
df.to_csv(ARXIV_CSV, index=False)
print(f"\nData saved to: {ARXIV_CSV}")

# Randomly select 200 papers (only those with both ID and DOI)
df_complete = df.dropna(subset=['id', 'doi'])
df_sample = df_complete.sample(n=200, random_state=42)
df_sample.to_json(ARXIV_JSON, orient='records', indent=2)
print(f"200 paper sample saved to: {ARXIV_JSON}")

# Analyze year distribution
year_distribution = df_sample['year'].value_counts().sort_index()
print("\nYear distribution:")
print(year_distribution)

# Visualize year distribution
plt.figure(figsize=(12, 6))
year_distribution.plot(kind='bar')
plt.title('Sample Articles Year Distribution')
plt.xlabel('Year')
plt.ylabel('Count')
plt.xticks(rotation=45)
plt.tight_layout()
plt.show()

## 1.2 Parse PDF to TXT

Steps to convert PDF files to sentence-segmented text files:

1. Download the 200 papers' PDF files to `PDF_DIR`
2. Run the code below to generate `TXT_DIR_2` (text files with papers segmented into sentences)

In [None]:
from langchain_community.document_loaders import PyMuPDFLoader
import os
import glob
import spacy

PDF_DIR = os.path.join(HOME, "PDF")
os.makedirs(PDF_DIR, exist_ok=True)
TXT_DIR_1 = os.path.join(HOME, "TXT(fromPDF)")
os.makedirs(TXT_DIR_1, exist_ok=True)
TXT_DIR_2 = os.path.join(HOME, "TXT(fromPDF_processed)")
os.makedirs(TXT_DIR_2, exist_ok=True)

# ========== Step 1: Parse PDF to raw text ==========
# Extract text from each PDF and save to TXT_DIR_1
pdf_files = sorted(glob.glob(os.path.join(PDF_DIR, "*.pdf")))
for pdf_file in pdf_files:
    loader = PyMuPDFLoader(pdf_file)
    docs = loader.load()
    text = "".join(doc.page_content for doc in docs)
    output_path = os.path.join(TXT_DIR_1, os.path.basename(pdf_file).replace('.pdf', '.txt'))
    with open(output_path, 'w') as f:
        f.write(text)

# ========== Step 2: Process text into properly segmented sentences ==========
nlp = spacy.load("en_core_web_sm")
nlp.max_length = 2000000

# Constants for sentence processing
EXCLUDED_ENDINGS = ('Fig.', 'Eq.', 'Figs.', 'et al.')
SENTENCE_ENDINGS = ('.', '!', '?')
MIN_SENTENCE_LENGTH = 20

def should_merge(sentence: str) -> bool:
    """Determine if the next sentence should be merged with the current one.
    
    Args:
        sentence: Current sentence text
        
    Returns:
        True if the sentence should be merged with the next one
    """
    if not sentence:
        return False
    # Merge if sentence doesn't end with standard punctuation or ends with excluded patterns
    return (not sentence.endswith(SENTENCE_ENDINGS) 
            or any(sentence.endswith(end) for end in EXCLUDED_ENDINGS))

def merge_consecutive(sentences: list[str]) -> list[str]:
    """Merge consecutive sentences that need to be connected.
    
    Args:
        sentences: List of sentence strings
        
    Returns:
        List of merged sentences
    """
    merged = []
    i, n = 0, len(sentences)
    
    while i < n:
        current = sentences[i].strip()
        j = i + 1
        
        # Continuously merge sentences that need connection
        while j < n and should_merge(current):
            current += " " + sentences[j].strip()
            j += 1
        
        merged.append(current)
        i = j
    
    return merged

def merge_short_sentences(sentences: list[str]) -> list[str]:
    """Merge short sentences into the previous sentence.
    
    Args:
        sentences: List of sentence strings
        
    Returns:
        List with short sentences merged
    """
    merged = []
    for sentence in sentences:
        if not sentence:
            continue
        
        # Merge short sentences to previous sentence if available
        if merged and len(sentence) < MIN_SENTENCE_LENGTH:
            merged[-1] += " " + sentence
        else:
            merged.append(sentence)
    
    return merged

def process_txt(input_path: str, output_path: str) -> None:
    """Process text file to create properly segmented sentences.
    
    Args:
        input_path: Path to input text file
        output_path: Path to output processed text file
    """
    # Read and preprocess text
    with open(input_path, "r", encoding="utf-8") as file:
        raw_text = file.read().replace("\n", " ")
    
    # Initial sentence segmentation
    doc = nlp(raw_text)
    initial_sentences = [sent.text.strip() for sent in doc.sents if sent.text.strip()]
    
    # Multi-stage processing pipeline
    processed = merge_consecutive(initial_sentences)
    processed = merge_short_sentences(processed)
    processed = merge_consecutive(processed)  # Handle new cases after short sentence merging
    
    # Write results
    with open(output_path, "w", encoding="utf-8") as file:
        file.write("\n".join(processed))

# Process all txt files from TXT_DIR_1 and save to TXT_DIR_2
for txt_file in glob.glob(os.path.join(TXT_DIR_1, "*.txt")):
    process_txt(txt_file, os.path.join(TXT_DIR_2, os.path.basename(txt_file)))

# 2 Extract Bandgap Data

This section applies multiple extraction methods to extract bandgap data from the 200 papers.

## Prerequisites

Before running extraction methods, download the following projects and models:

### Projects
1. BandgapDatabase: <https://github.com/QingyangDong-qd220/BandgapDatabase1>
2. BERT-PSIE-TC: <https://github.com/StefanoSanvitoGroup/BERT-PSIE-TC>

### Language Models & Embeddings
1. MatSciBERT: <https://huggingface.co/m3rg-iitd/matscibert> (version: 24a4e4318dda9bc18bff5e6a45debdcb3e1780e3)
2. nomic-embed-text: <https://ollama.com/library/nomic-embed-text:latest> (version: 0a109f422b47)
   - Install via: `ollama pull nomic-embed-text:latest`
3. bge-m3: <https://ollama.com/library/bge-m3:latest> (version: 790764642607)
   - Install via: `ollama pull bge-m3:latest`
4. Llama2 13B: <https://ollama.com/library/llama2:13b> (version: d475bf4c50bc)
   - Install via: `ollama pull llama2:13b`
5. Llama 3.1 Nemotron 70B: <https://huggingface.co/bartowski/Llama-3.1-Nemotron-70B-Instruct-HF-GGUF> (version: dfc9cf0b496aea479874ddce703154f07d22ec3d)
   - Install via: `ollama create llama3.1:70b -f Modelfile`
6. Qwen2.5 14B: <https://ollama.com/library/qwen2.5:14b> (version: 7cdf5a0187d5)
   - Install via: `ollama pull qwen2.5:14b`

## 2.1 ChemDataExtractor (CDE)

**Running Environment:** Docker container

**Docker Setup Note:**

```bash
docker run --name cde \
  --mount type=bind,source='/your/HOME/directory',target='/home/chemdataextractor2' \
  -it -p 8888:8888 \
  --entrypoint bash obrink/chemdataextractor:2.1.2
```

If `chemdataextractor2` cannot be imported, rename `/usr/local/lib/python3.8/site-packages/chemdataextractor` to `chemdataextractor2`.

In [None]:
!pip install playsound openpyxl

In [None]:
import os
import joblib
from pprint import pprint
from tqdm import tqdm
from datetime import datetime
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

from chemdataextractor2.relex import Snowball
from chemdataextractor2.model.units.energy import EnergyModel
from chemdataextractor2.model import BaseModel, StringType, ListType, ModelType, Compound
from chemdataextractor2.parse import R, I, W, Optional, merge, join, AutoSentenceParser
from chemdataextractor2.doc import Sentence, Document

class BandGap(EnergyModel):
    """Custom BandGap model for ChemDataExtractor."""
    specifier_expression = (
        (I("band") + R("gaps?")) | I("bandgap") | I("band-gap") | I("Eg")
    ).add_action(join)
    specifier = StringType(
        parse_expression=specifier_expression, required=True, updatable=True
    )
    compound = ModelType(
        Compound, required=True, contextual=True, binding=True, updatable=False
    )
    parsers = [AutoSentenceParser()]

def run(file_path, article):
    """Extract bandgap data from a single article using Snowball and AutoSentenceParser.
    
    This function applies two extraction methods:
    1. AutoSentenceParser: Rule-based extraction requiring both value and material name
    2. Snowball: Relation extraction model with two similarity thresholds (0.85, then 0.65)
    
    Args:
        file_path: Path to the text file
        article: Article identifier (filename)
        
    Returns:
        List of extracted bandgap records with metadata
    """
    results = []
    
    # Load document
    try:
        d = Document.from_file(file_path)
    except:
        print("Unable to read document")
        return

    publisher = "arXiv"

    # Process each paragraph and sentence
    for p in d.paragraphs:
        for s in p.sentences:
            # Skip overly long sentences (likely parsing errors)
            if s.end - s.start > 300:
                continue

            results_snow = []
            results_auto = []
            snow_85 = False

            # Method 1: AutoSentenceParser extraction
            BandGap.parsers = [AutoSentenceParser()]
            s.models = [BandGap]
            auto = s.records.serialize()
            
            for i in auto:
                if "BandGap" in i.keys():
                    # Require both bandgap value and material name
                    if ("raw_value" in i["BandGap"].keys() and 
                        "compound" in i["BandGap"].keys()):
                        if "names" in i["BandGap"]["compound"]["Compound"].keys():
                            i["BandGap"]["text"] = s.text
                            i['BandGap']['doi'] = article.replace('_', '/').replace('.html', '').replace('.xml', '').replace('.txt', '')
                            results_auto.append(i)

            # Method 2: Snowball extraction (try with similarity=0.85 first)
            snowball.minimum_cluster_similarity_score = 0.85
            BandGap.parsers = [snowball]
            s.models = [BandGap]
            snow = s.records.serialize()
            
            for i in snow:
                if "BandGap" in i.keys():
                    snow_85 = True
                    i["BandGap"]["text"] = s.text
                    i['BandGap']['doi'] = article.replace('_', '/').replace('.html', '').replace('.xml', '').replace('.txt', '')
                    results_snow.append(i)

            # If Snowball with 0.85 similarity found nothing, retry with 0.65
            if snow_85 == False:
                snowball.minimum_cluster_similarity_score = 0.65
                BandGap.parsers = [snowball]
                s.models = [BandGap]
                snow = s.records.serialize()
                for i in snow:
                    if "BandGap" in i.keys():
                        i["BandGap"]["text"] = s.text
                        i['BandGap']['doi'] = article.replace('_', '/').replace('.html', '').replace('.xml', '').replace('.txt', '')
                        results_snow.append(i)

            # Merge Snowball results into AutoSentenceParser results
            for i in results_auto:
                i["BandGap"]["AutoSentenceParser"] = 1
                i["BandGap"]["Snowball"] = 0
                for j in range(len(results_snow)):
                    # Match by material name
                    if i['BandGap']['compound']['Compound']['names'] == results_snow[j]['BandGap']['compound']['Compound']['names']:
                        i["BandGap"] = results_snow[j]["BandGap"]
                        i["BandGap"]["Snowball"] = 1
                        i["BandGap"]["AutoSentenceParser"] = 1
                        results_snow[j]["BandGap"]["match"] = 1  # Mark as matched
                        continue

            # Add unmatched Snowball-only results
            for x in results_snow:
                if "match" not in x["BandGap"].keys():
                    x["BandGap"]["Snowball"] = 1
                    x["BandGap"]["AutoSentenceParser"] = 0
                    results_auto.append(x)

            # Add publisher metadata
            if results_auto:
                for i in results_auto:
                    i["BandGap"]["publisher"] = publisher
                    results.append(i)

    return results

# Configuration
dtime = datetime.now().strftime("%m%d-%H%M")
HOME_DOCKER = "/home/chemdataextractor2"
PROJECT_DIR = os.path.join(HOME_DOCKER, "project")
TXT_DIR = os.path.join(HOME_DOCKER, "TXT(fromPDF_processed)")
OUTPUT_DIR = os.path.join(HOME_DOCKER, "output", "1-ChemDataExtractor")
os.makedirs(OUTPUT_DIR, exist_ok=True)
TEMP_SAVE = os.path.join(OUTPUT_DIR, "records_general.joblib")
RUNTIME = os.path.join(OUTPUT_DIR, f'runtime_{dtime}.txt')

SNOWBALL_PATH = os.path.join(PROJECT_DIR, "BandgapDatabase1-main", "Snowball_model", "general.pkl")

# Load Snowball model
snowball = Snowball.load(SNOWBALL_PATH)
snowball.minimum_relation_confidence = 0.001
snowball.max_candidate_combinations = 100
snowball.save_file_name = "general"
snowball.set_learning_rate(0.0)

# Load existing records if available
try:
    records = joblib.load(TEMP_SAVE)
    print("Loaded existing records")
except:
    records = []
    print("No existing records found")

start_time = datetime.now()

# Process all text files
for file_name in tqdm(os.listdir(TXT_DIR), desc="Processing files"):
    if file_name.endswith('.txt'):
        file_path = os.path.join(TXT_DIR, file_name)
        temp = run(file_path, file_name)
        # Save records incrementally
        if temp:
            pprint(temp)
            for record in temp:
                records.append(record)
            joblib.dump(records, TEMP_SAVE)

end_time = datetime.now()
run_time = end_time - start_time

with open(RUNTIME, 'w') as f:
    f.write(f'Total runtime: {run_time}')

## 2.2 BERT-PSIE (PSIE)

**Kernel:** lc

In [None]:
BERT_VERSION = "/your/directory/to/model_cache/huggingface/models--m3rg-iitd--matscibert/snapshots/24a4e4318dda9bc18bff5e6a45debdcb3e1780e3"

In [None]:
import os
from datetime import datetime
import time
dtime = datetime.now().strftime("%m%d-%H%M")
import json
import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader, random_split
from transformers import BertTokenizerFast
from seqeval.metrics import classification_report
import nltk
import re
from pymatgen.core import Composition
from datasets import load_dataset

DATA_DIR = os.path.join(HOME, "TXT(fromPDF_processed)")

# Model paths
PROJECT_DIR = os.path.join(HOME, "project")
PSIE_DIR = os.path.join(PROJECT_DIR, "BERT-PSIE-TC-main", "workflow")
import sys
sys.path.insert(1, PSIE_DIR)
import psie
MODEL_DIR = os.path.join(PSIE_DIR, "models", "Gap")
CLASSIFIER_PATH = os.path.join(MODEL_DIR, "classifier.pt")
NER_PATH = os.path.join(MODEL_DIR, "ner")
RELATION_PATH = os.path.join(MODEL_DIR, "relation")
MAX_LEN = 256  # Maximum sequence length

# Output paths
OUTPUT_DIR = os.path.join(HOME, "output", "2-BERT-PSIE")
os.makedirs(OUTPUT_DIR, exist_ok=True)
SENTENCES_JSON = os.path.join(OUTPUT_DIR, "sentences.json")
OUTPUT_JSON_1 = os.path.join(OUTPUT_DIR, "1-relevant_sentences.json")
OUTPUT_JSON_2_M = os.path.join(OUTPUT_DIR, "2-test_extraction_multiple_mentions.json")
OUTPUT_CSV_2_S = os.path.join(OUTPUT_DIR, "2-test_extraction_single_mentions.csv")
OUTPUT_CSV_3 = os.path.join(OUTPUT_DIR, "3-relations_extraction.csv")
RUNTIME = os.path.join(OUTPUT_DIR, f'runtime_{dtime}.txt')

device = torch.device("mps") if torch.backends.mps.is_available() else "cpu"
print(device)

start_time_total = time.time()

# ========== Prepare sentences from text files ==========
def process_txt_files(data_dir):
    """Process text files and generate a list of sentences with their DOI sources.
    
    Args:
        data_dir: Directory containing text files
        
    Returns:
        List of dictionaries with 'sentence' and 'source' (DOI) keys
    """
    sentences_list = []
    for filename in os.listdir(data_dir):
        if filename.endswith('.txt'):
            doi = filename[:-4]  # Remove .txt extension
            with open(os.path.join(data_dir, filename), 'r', encoding='utf-8') as file:
                for line in file:
                    sentence = line.strip()
                    if sentence:
                        sentences_list.append({"sentence": sentence, "source": doi})
    return sentences_list

# Process text files and save to JSON
sentences = process_txt_files(DATA_DIR)
with open(SENTENCES_JSON, 'w', encoding='utf-8') as json_file:
    json.dump(sentences, json_file, ensure_ascii=False, indent=2)
print(f"Saved {len(sentences)} sentences to {SENTENCES_JSON}")

# ========== Stage 1/3: Sentence Classification ==========
start_time_classifier = time.time()

# Tokenize sentences
dataset = load_dataset(path="json", data_files=SENTENCES_JSON, split="train")
tokenizer = BertTokenizerFast.from_pretrained(BERT_VERSION)

def encode(paper):
    """Tokenize, encode, and pad sentences for BERT input."""
    return tokenizer(paper["sentence"], truncation=True, max_length=MAX_LEN, padding="max_length")

dataset = dataset.map(encode, batched=True)
dataset.set_format(type="torch", columns=["source", "sentence", "input_ids", "attention_mask"])
dataset_loader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=False)

# Load and run classifier
model = psie.classifier.BertClassifier()
state_dict = torch.load(CLASSIFIER_PATH, map_location=device)
state_dict.pop("bert.embeddings.position_ids", None)  # Remove unnecessary key
model.load_state_dict(state_dict)
model.to(device)

pred = model.predict(dataset_loader, device)

# Extract predictions (class with highest probability)
predictions = []
for i in range(len(pred)):
    predictions.append(np.argmax(pred[i].cpu().numpy()))

# Filter relevant sentences (prediction == 1)
filtered_sentences = {"sentence": [], "source": []}
for i in range(len(predictions)):
    if predictions[i] == 1:
        filtered_sentences["sentence"].append((dataset[i]["sentence"]))
        filtered_sentences["source"].append((dataset[i]["source"]))

os.makedirs(os.path.dirname(OUTPUT_JSON_1), exist_ok=True)
with open(OUTPUT_JSON_1, "w") as f:
    json.dump(filtered_sentences, f)

classifier_time = time.time() - start_time_classifier

# ========== Stage 2/3: Named Entity Recognition (NER) ==========
start_time_ner = time.time()

# Entity label mapping
id_to_BOI = {
    1: "B-CHEM",     # Chemical entity
    0: "O",          # No entity
    2: "B-BANDGAP"   # Bandgap value
}

with open(OUTPUT_JSON_1, "r") as f:
    data = json.load(f)

tokenizer = BertTokenizerFast.from_pretrained(NER_PATH)

sentences = psie.NerUnlabeledDataset(data["sentence"], tokenizer, max_len=MAX_LEN)
sources = data["source"]
sentences_params = {
    'batch_size': 10,
    'shuffle': False,
    'num_workers': 0
}
sentences_loader = DataLoader(sentences, **sentences_params)

model = psie.BertForNer.from_pretrained(NER_PATH, num_labels=3)
model.to(device)

# NER predictions
predictions = model.predict(sentences_loader, device, id_to_BOI)

# Extract entity labels from predictions
extr_labels = []
for n in range(len(predictions)):
    tokens = tokenizer.tokenize(
        "[CLS]" + psie.preprocess_text(sentences[n]["plain"]) + "[SEP]",
        padding="max_length",
        truncation=True,
        max_length=MAX_LEN,
    )
    extracted = {}
    i = 0
    while i < MAX_LEN:
        if predictions[n][i] != "O" and tokens[i] not in ["[CLS]", "[SEP]", "[PAD]"]:
            entity = predictions[n][i]
            entry = []
            while predictions[n][i] == entity:
                entry.append(tokens[i])
                i += 1
                if i >= MAX_LEN:
                    break
            if entity in extracted.keys():
                extracted[entity].append(" ".join(entry))
            else:
                extracted[entity] = [" ".join(entry)]
        i += 1
    extr_labels.append(extracted)

# Extract sentences with multiple mentions (for relation classification)
relational = []
for i in range(len(extr_labels)):
    n_entries = [len(extr_labels[i][key]) for key in extr_labels[i].keys()]
    if n_entries != []:
        if len(n_entries) == 2:
            if n_entries[0] > 1 and n_entries[1] > 1:
                relational.append(extr_labels[i].copy())
                relational[-1]["sentence"] = sentences[i]["plain"]
                relational[-1]["source"] = sources[i]

print("Relational/Total: ", len(relational), "/", len(predictions))
with open(OUTPUT_JSON_2_M, "w") as f:
    json.dump(relational, f)

# Extract sentences with single mentions (exactly 1 material and 1 bandgap)
relevant = []
for i in range(len(extr_labels)):
    n_entries = [len(extr_labels[i][key]) for key in extr_labels[i].keys()]
    if n_entries == [1, 1]:
        relevant.append(extr_labels[i])
        relevant[-1]["sentence"] = sentences[i]["plain"]
        relevant[-1]["source"] = sources[i]

print("Relevant/Total: ", len(relevant), "/", len(predictions))

# Clean and standardize single-mention extractions
database = {"compound": [], "Gap": [], "sentence": [], "source": []}

for n in range(len(relevant)):
    chem, trgt = None, None

    try:
        # Clean chemical entity name
        chem = (
            relevant[n]["B-CHEM"][0]
            .strip()
            .replace(" ", "")
            .replace("#", "")
            .replace("(", "\(")
            .replace(")", "\)")
            .replace("+", "\+")
            .replace("[UNK]", "")
            .replace(".", "\.")
        )

        chem = re.findall(
            "(?i)[^a-zA-Z0-9]*" + chem + "[^a-zA-Z]",
            relevant[n]["sentence"],
        )[0].strip()

        if chem.endswith(",") or chem.endswith("."):
            chem = chem[0 : len(chem) - 1]
        if chem.startswith(",") or chem.startswith("."):
            chem = chem[1 : len(chem)]

        # Convert element names to symbols
        if chem in psie.ELEMENT_NAMES:
            chem = psie.ELEMENTS[psie.ELEMENT_NAMES.index(chem)]

        # Clean bandgap value
        trgt = relevant[n][id_to_BOI[2]][0].replace("#", "").strip()
        trgt = (
            trgt.replace("[", "")
            .replace("]", "")
            .replace("{", "")
            .replace("}", "")
            .replace("=", "")
            .replace("[UNK]", "")
        )

        trgt = trgt.replace("ev", "eV")

        if trgt.endswith(",") or trgt.endswith("."):
            trgt = trgt[0 : len(trgt) - 1]
        if trgt.startswith(",") or trgt.startswith("."):
            trgt = trgt[1 : len(trgt)]

        if (chem is not None) and (trgt is not None):
            database["compound"].append(chem)
            database["Gap"].append(trgt)

        database["sentence"].append(relevant[n]["sentence"])
        database["source"].append(relevant[n]["source"])

    except:
        comp = (
            relevant[n]["B-CHEM"][0]
            .replace("#", "")
            .replace(" ", "")
            .replace("(", "\(")
            .replace(")", "\)")
            .replace("+", "\+")
            .replace("[UNK]", "")
        )
        trgt = relevant[n][id_to_BOI[2]][0].replace("#", "").strip()
        print(comp, trgt, relevant[n]["sentence"], "\n\n")

print("Database entries:", len(database["compound"]), "/", len(relevant))

# Validate chemical formulas using pymatgen
database = pd.DataFrame(database)
valid_i = []

for i, mat in enumerate(database["compound"]):
    try:
        Composition(mat).get_reduced_formula_and_factor()[0]
        valid_i.append(i)
    except:
        print(mat, "\t", database["sentence"][i], "\n\n")

print("Database entries:", len(valid_i), "/", len(relevant))
database.iloc[valid_i].to_csv(OUTPUT_CSV_2_S)

ner_time = time.time() - start_time_ner

# ========== Stage 3/3: Relation Classification ==========
start_time_relation = time.time()

# Add relation extraction tokens to BERT vocabulary
tokenizer = BertTokenizerFast.from_pretrained(RELATION_PATH)
new_tokens = ["[E1]", "[/E1]", "[E2]", "[/E2]"]
tokenizer.add_tokens(list(new_tokens))

# Load multi-mention data
with open(OUTPUT_JSON_2_M, "r") as f:
    data = json.load(f)
print(data[0])

# Add entity tags around materials and bandgap values
data = psie.fromNer(data)

print(data["sentence"][0])
print(data["sentence"][1])

# Filter sentences containing both entity tags
ner_dataset = {"sentence": [], "isrelated": [], "source": []}
for i in range(len(data["sentence"])):
    if ("[E1]" in data["sentence"][i]) and ("[E2]" in data["sentence"][i]):
        ner_dataset["sentence"].append(str(data["sentence"][i]))
        ner_dataset["isrelated"].append(None)
        ner_dataset["source"].append(data["source"][i])
print(len(ner_dataset["sentence"]), "/", len(data["sentence"]))

ner = psie.RelationDataset(ner_dataset, tokenizer, max_len=MAX_LEN)
ner_params = {"batch_size": 8, "shuffle": False, "num_workers": 0}
ner_loader = DataLoader(ner, **ner_params)

model = psie.BertForRelations(
    pretrained=RELATION_PATH, dropout=0.2, use_cls_embedding=True
)
model.bert.resize_token_embeddings(len(tokenizer))
model.to(device)

# Predict relations
pred = model.predict(ner_loader, device)
predictions = []
for i in range(len(pred)):
    predictions.append(np.argmax(pred[i].cpu().numpy()))

# Extract confirmed relations
database = {"compound": [], "Gap": [], "sentence": [], "source": []}
for i in range(len(predictions)):
    if predictions[i] == 1:
        # Extract material (E1) and bandgap (E2) from tagged sentence
        comp = re.findall(
            re.escape("[E1]") + ".*" + re.escape("[/E1]"), ner_dataset["sentence"][i]
        )
        temp = re.findall(
            re.escape("[E2]") + ".*" + re.escape("[/E2]"), ner_dataset["sentence"][i]
        )

        if (len(comp) > 0) and (len(temp) > 0):
            comp = comp[0].replace("[E1]", "").replace("[/E1]", "").replace(" ", "")
            temp = temp[0].replace("[E2]", "").replace("[/E2]", "").replace(" ", "")
            database["compound"].append(comp)
            database["Gap"].append(temp)
            database["sentence"].append(ner_dataset["sentence"][i])
            database["source"].append(ner_dataset["source"][i])

# Validate chemical formulas
database = pd.DataFrame(database)
valid_i = []
for i, comp in enumerate(database["compound"]):
    try:
        Composition(comp).get_reduced_formula_and_factor()[0]
        valid_i.append(i)
    except:
        print(comp, "\t", database["sentence"][i], "\n\n")
print("Database entries:", len(valid_i), "/", len(database["sentence"]))

database.iloc[valid_i].to_csv(OUTPUT_CSV_3)

relation_time = time.time() - start_time_relation
total_time = time.time() - start_time_total

# Save runtime statistics
with open(RUNTIME, 'w') as f:
    f.write(f"Classifier time: {classifier_time:.2f} seconds\n")
    f.write(f"NER time: {ner_time:.2f} seconds\n")
    f.write(f"Relation time: {relation_time:.2f} seconds\n")
    f.write(f"Total time: {total_time:.2f} seconds\n")

## 2.3 ChatExtract (CE)

**Kernel:** lc

In [None]:
import re
import os
import glob
import pandas as pd
from tqdm import tqdm
from langchain_core.prompts import ChatPromptTemplate
from langchain_ollama import ChatOllama
import logging
from datetime import datetime
dtime = datetime.now().strftime("%m%d-%H%M")
import time
from copy import copy


def inference(MODEL, CODE, PROPERTY, TXT_DIR, OUTPUT_DIR, logger):
    """Run two-stage extraction: (1) filter relevant sentences, (2) extract data.
    
    Args:
        MODEL: LLM model name
        CODE: Extraction method code identifier
        PROPERTY: Property to extract (e.g., "band gap")
        TXT_DIR: Directory containing text files
        OUTPUT_DIR: Output directory for results
        logger: Logger instance
    """
    
    # ========== Stage 1: Filter relevant sentences ==========
    def get_positive_sentences(model_name):
        """Identify sentences related to the target property using LLM classification.
        
        Args:
            model_name: Name of the LLM model
            
        Returns:
            Path to CSV file containing positive sentences
        """
        MODEL_NAME = model_name.replace(":", "-")
        POSITIVE_CSV = os.path.join(OUTPUT_DIR, f"1_positive_sentences_{MODEL_NAME}_{CODE}.csv")

        classif_q = f'Is the following sentence related to "{PROPERTY}"? Answer only "Yes" or "No" without any explanation:'

        # Load existing results or initialize new DataFrame
        try:
            df_positive = pd.read_csv(POSITIVE_CSV)
            processed_dois = set(df_positive["doi"].unique())
        except FileNotFoundError:
            df_positive = pd.DataFrame(
                columns=[
                    "original_index",
                    "positive_sentences",
                    "integrated_sentences",
                    "doi",
                ]
            )
            processed_dois = set()

        # Initialize LLM
        llm = ChatOllama(model=model_name, temperature=0)
        prompt = ChatPromptTemplate.from_messages(
            [
                (
                    "system",
                    "You are an expert extraction algorithm specialized in materials science.",
                ),
                ("human", "{question}\n{text}\n"),
            ]
        )
        chain = prompt | llm

        # Process text files
        txt_files = glob.glob(os.path.join(TXT_DIR, "*.txt"))
        logger.info(
            f"Found {len(txt_files)} text files for processing with {model_name}"
        )

        for txt_path in tqdm(txt_files, desc=f"Processing {model_name}"):
            doi = os.path.basename(txt_path).replace(".txt", "").replace("_", "/")
            if doi in processed_dois:
                continue

            # Read sentences from file
            with open(txt_path, "r", encoding="utf-8") as f:
                sentences = [line.strip() for line in f]

            # Classify each sentence
            results = []
            for idx, sentence in enumerate(sentences):
                try:
                    answer = chain.invoke(
                        {"question": classif_q, "text": sentence}
                    ).content
                    answer = re.sub(r"[^\w\s]", "", answer).strip().lower()
                    results.append((idx, sentence, 1 if answer == "yes" else 0))
                except Exception as e:
                    logger.error(f"Error processing sentence {idx} in {doi}: {str(e)}")
                    results.append((idx, sentence, 0))

            # Generate integrated sentences (previous sentence + current sentence)
            positive_data = []
            for idx, (sentence_idx, sentence, label) in enumerate(results):
                if label == 1:
                    integrated = (
                        f"{results[idx-1][1]} {sentence}" if idx > 0 else sentence
                    )
                    positive_data.append(
                        {
                            "original_index": f"{sentence_idx}/{len(sentences)}",
                            "positive_sentences": sentence,
                            "integrated_sentences": integrated,
                            "doi": doi,
                        }
                    )

            # Save results periodically
            if positive_data:
                df_positive = pd.concat([df_positive, pd.DataFrame(positive_data)])
                df_positive.to_csv(POSITIVE_CSV, index=False, escapechar='\\')
                logger.info(f"Saved {len(positive_data)} positive sentences from {doi}")

        return POSITIVE_CSV

    # ========== Stage 2: Extract structured data ==========
    def extract_data(model_name, csv_path):
        """Extract structured bandgap data from positive sentences.
        
        Args:
            model_name: Name of the LLM model
            csv_path: Path to CSV containing positive sentences
        """
        TEMPERATURE = 0
        CONTEXT = 4096  # Default: 2048
        MODEL_NAME = model_name.replace(":", "-")

        # Initialize output files
        EXTRACTED_CSV = os.path.join(
            OUTPUT_DIR, f"2_extracted_{MODEL_NAME}_{CODE}.csv"
        )
        BINCLAS_CSV = os.path.join(
            OUTPUT_DIR, f"2_binclas_{MODEL_NAME}_{CODE}.csv"
        )
        DIALOGUE_CSV = os.path.join(
            OUTPUT_DIR, f"2_dialogues_{MODEL_NAME}_{CODE}.csv"
        )

        # Create empty CSV with headers
        pd.DataFrame(
            columns=[
                "passage",
                "sentence",
                "doi",
                "material",
                "value",
                "unit",
                "material_valid",
                "value_valid",
                "unit_valid",
            ]
        ).to_csv(EXTRACTED_CSV, index=False)

        # Define prompts
        classif_q = f'Answer only "Yes" or "No" without any explanation. Based on the following text, is there a value of **{PROPERTY}** mentioned in it?\n\n'
        ifmulti_q = f'Answer "Yes" or "No" only. Does the following text mention more than one value of **{PROPERTY}**?\n\n'
        single_q = [
            f'Give the number only without units, do not use a full sentence. If the value is not present in the text, type "None". What is the value of the **{PROPERTY}** in the following text?\n\n',
            f'Give the unit only, do not use a full sentence. If the unit is not present in the text, type "None". What is the unit of the **{PROPERTY}** in the following text?\n\n',
            f'Give the name of the material only, do not use a full sentence. If the name of the material is not present in the text, type "None". What is the material for which the **{PROPERTY}** is given in the following text?\n\n',
        ]
        singlefollowup_q = [
            [
                'There is a possibility that the data you extracted is incorrect. Answer "Yes" or "No" only. Be very strict. Is ',
                f" the value of the **{PROPERTY}** for the material in the following text?\n\n",
            ],
            [
                'There is a possibility that the data you extracted is incorrect. Answer "Yes" or "No" only. Be very strict. Is ',
                f" the unit of the value of **{PROPERTY}** in the following text?\n\n",
            ],
            [
                'There is a possibility that the data you extracted is incorrect. Answer "Yes" or "No" only. Be very strict. Is ',
                f" the material for which the value of **{PROPERTY}** is given in the following text? Make sure it is a real material.\n\n",
            ],
        ]

        tab_q = f'Use only data present in the text. If data is not present in the text, type "None". Summarize the values of **{PROPERTY}** in the following text in a form of a table consisting of: Material, Value, Unit. Ensure that the "Value" and "Unit" are separated into different columns.\n\n'
        tabfollowup_q = [
            [
                'There is a possibility that the data you extracted is incorrect. Answer "Yes" or "No" only. Be very strict. Is ',
                " the ",
                f" material for which the value of **{PROPERTY}** is given in the following text? Make sure it is a real material.\n\n",
            ],
            [
                'There is a possibility that the data you extracted is incorrect. Answer "Yes" or "No" only. Be very strict. Is ',
                f" the value of the **{PROPERTY}** for the ",
                " material in the following text?\n\n",
            ],
            [
                'There is a possibility that the data you extracted is incorrect. Answer "Yes" or "No" only. Be very strict. Is ',
                " the unit of the ",
                f" value of **{PROPERTY}** in the following text?\n\n",
            ],
        ]

        it = [
            "first", "second", "third", "fourth", "fifth", "sixth", "seventh",
            "eighth", "ninth", "tenth", "eleventh", "twelfth", "thirteenth",
            "fourteenth", "fifteenth", "sixteenth", "seventeenth", "eighteenth",
            "nineteenth", "twentieth",
        ]
        col = ["Material", "Value", "Unit"]
        single_cols = ["value", "unit", "material"]

        # Initialize LLM
        llm = ChatOllama(model=model_name, temperature=TEMPERATURE, num_ctx=CONTEXT)
        prompt = ChatPromptTemplate.from_messages(
            [
                (
                    "system",
                    "You are an expert extraction algorithm specialized in materials science.",
                ),
                ("placeholder", "{conversation}"),
            ]
        )
        chain = prompt | llm

        try:
            df = pd.read_csv(csv_path)
            ntot = len(df)
            logger.info(f"Starting data extraction for {len(df)} entries")
        except Exception as e:
            logger.error(f"Failed to read CSV: {str(e)}")
            return

        with tqdm(total=ntot, desc=f"Extracting {MODEL_NAME}") as pbar:
            for i in range(ntot):
                try:
                    binary_classif = []
                    sss = []
                    sss.append(("human", classif_q + df["positive_sentences"][i]))
                    ans = chain.invoke({"conversation": sss}).content
                    sss.append(("ai", ans))
                    
                    if "yes" in ans.strip().lower():  # Positive classification
                        binary_classif.append(1)
                        result = {}
                        passage = df["integrated_sentences"][i]
                        sentence = df["positive_sentences"][i]
                        sss.append(("human", ifmulti_q + passage))
                        ans = chain.invoke({"conversation": sss}).content
                        sss.append(("ai", ans))
                        
                        if "no" in ans.lower():  # Single data point
                            result["passage"] = [passage]
                            result["sentence"] = [sentence]
                            result["doi"] = [df["doi"][i]]
                            result["material"] = []
                            result["value"] = []
                            result["unit"] = []
                            result["material_valid"] = []
                            result["value_valid"] = []
                            result["unit_valid"] = []
                            
                            for j in range(len(single_q)):
                                sss.append(("human", single_q[j] + passage))
                                ans = chain.invoke({"conversation": sss}).content
                                sss.append(("ai", ans))
                                result[single_cols[j]].append(ans)
                                if "none" in ans.lower():
                                    result[single_cols[j] + "_valid"].append(0)
                                else:
                                    result[single_cols[j] + "_valid"].append(1)
                                    
                        elif "yes" in ans.lower():  # Multiple data points
                            sss.append(("human", tab_q + passage))
                            tab = chain.invoke({"conversation": sss}).content
                            sss.append(("ai", tab))
                            sst = copy(sss)

                            # Parse markdown table
                            start_index = tab.find("|")
                            end_index = tab.rfind("|") + 1
                            cleaned_output = tab[start_index:end_index].strip()
                            lines = cleaned_output.split("\n")
                            
                            # Parse data rows
                            data = []
                            for line in lines[2:]:
                                stripped_line = line.strip("| ")
                                parts = [p.strip() for p in stripped_line.split("|")]
                                
                                if len(parts) == 3:
                                    data.append(parts)
                                elif len(parts) < 3:
                                    while len(parts) < 3:
                                        parts.append(None)
                                    data.append(parts)
                                elif len(parts) > 3:
                                    material = parts[0].strip()
                                    value = parts[1].strip()
                                    sentence = "|".join(parts[2:]).strip()
                                    data.append([material, value, sentence])
                                else:
                                    continue

                            tab = pd.DataFrame(data, columns=col)

                            result["passage"] = []
                            result["sentence"] = []
                            result["doi"] = []
                            result["material"] = []
                            result["value"] = []
                            result["unit"] = []
                            result["material_valid"] = []
                            result["value_valid"] = []
                            result["unit_valid"] = []

                            for k in range(len(tab)):
                                sst.append(
                                    (
                                        "tab",
                                        f"{tab[col[0]][k]},{tab[col[1]][k]},{tab[col[2]][k]}",
                                    )
                                )
                                result["passage"].append(passage)
                                result["sentence"].append(sentence)
                                result["doi"].append(df["doi"][i])
                                multi_valid = True
                                
                                for l in range(3):
                                    temp_r = str(tab[col[l]][k])
                                    ss = (
                                        tabfollowup_q[l][0]
                                        + temp_r
                                        + tabfollowup_q[l][1]
                                        + it[k]
                                        + tabfollowup_q[l][2]
                                        + passage
                                    )
                                    result[col[l].lower()].append(temp_r)
                                    
                                    if "none" in temp_r.lower():
                                        result[col[l].lower() + "_valid"].append(0)
                                        multi_valid = False
                                    elif multi_valid:
                                        sss.append(("human", ss))
                                        sst.append(("human", ss))
                                        ans = chain.invoke({"conversation": sss}).content
                                        sss.append(("ai", ans))
                                        sst.append(("ai", ans))
                                        
                                        if "no" in ans.lower():
                                            result[col[l].lower() + "_valid"].append(0)
                                            multi_valid = False
                                        else:
                                            result[col[l].lower() + "_valid"].append(1)
                                    else:
                                        result[col[l].lower() + "_valid"].append(1)
                        
                        try:
                            pd.DataFrame(result).to_csv(
                                EXTRACTED_CSV, mode="a", index=False, header=False
                            )
                        except Exception as e:
                            print("Appending extracted data error: ", i, "  ", e)
                            print("Appending extracted data error: ", result, "  ", e)
                            print("Appending extracted data error: ", tab, "  ", e)
                    else:  # Negative classification
                        binary_classif.append(0)
                    
                    pd.DataFrame(binary_classif).to_csv(
                        BINCLAS_CSV, mode="a", index=False, header=False
                    )
                    
                    try:
                        pd.DataFrame(sst).to_csv(
                            DIALOGUE_CSV, mode="a", index=False, header=False
                        )
                        del sst
                    except:
                        pd.DataFrame(sss).to_csv(
                            DIALOGUE_CSV, mode="a", index=False, header=False
                        )
                    pbar.update(1)
                    
                except Exception as e:
                    logger.error(f"Error processing row {i}: {str(e)}")
                    print(f"Ignoring {i+1}/{ntot} ({round(i/ntot*100,1)} %)")
                    continue

        logger.info(f"Completed data extraction for {model_name}")
        return EXTRACTED_CSV

    # Execute pipeline
    stage1_output = get_positive_sentences(MODEL)
    EXTRACTED_CSV = extract_data(MODEL, stage1_output)
    return EXTRACTED_CSV


def process_chat(MODELS, PROPERTY, TXT_DIR, OUTPUT_DIR):
    """Process multiple LLM models for bandgap extraction.
    
    Args:
        MODELS: Dictionary mapping model names to codes
        PROPERTY: Property to extract
        TXT_DIR: Input text directory
        OUTPUT_DIR: Output directory
    """
    # Initialize logging
    start_time = time.time()
    log_file = os.path.join(OUTPUT_DIR, f"processing_{dtime}.log")

    logger = logging.getLogger(__name__)
    logger.setLevel(logging.INFO)
    formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")

    # File handler
    file_handler = logging.FileHandler(log_file)
    file_handler.setFormatter(formatter)

    # Console handler
    console_handler = logging.StreamHandler()
    console_handler.setFormatter(formatter)

    logger.addHandler(file_handler)
    logger.addHandler(console_handler)

    logger.info("Initializing inference pipeline...")
    logger.info(f"Property: {PROPERTY}")
    logger.info(f"Input TXT directory: {TXT_DIR}")
    logger.info(f"Output directory: {OUTPUT_DIR}")

    for MODEL, CODE in MODELS.items():
        logger.info(f"\n{'='*40}")
        logger.info(f"Processing model: {MODEL}")
        inference(MODEL, CODE, PROPERTY, TXT_DIR, OUTPUT_DIR, logger)
        logger.info(f"Completed processing for {MODEL}")

    logger.info(f"\n{'='*40}")
    logger.info(f"Total processing time: {time.time()-start_time:.2f} seconds")
    logger.info(f"Output directory: {OUTPUT_DIR}")
    logger.info("Processing complete!")


# Execute extraction
TXT_DIR = os.path.join(HOME, "TXT(fromPDF_processed)")
OUTPUT_DIR = os.path.join(HOME, "output", "3-ChatExtract")
os.makedirs(OUTPUT_DIR, exist_ok=True)

MODELS = {
    "llama2:13b": "CE_1",
    "llama3.1:70b": "CE_2",
    "qwen2.5:14b": "CE_3",
}

process_chat(
    MODELS=MODELS,
    PROPERTY="band gap",
    TXT_DIR=TXT_DIR,
    OUTPUT_DIR=OUTPUT_DIR,
)

## 2.4 LangChain RAG (LC)

**Kernel:** lc

In [None]:
TEMPERATURE = 0
TOP_K = 5
CHUNK_SIZE = 1000
CHUNK_OVERLAP = 200

In [None]:
import os
import glob
import pandas as pd
from tqdm import tqdm
import gc
import time
from datetime import datetime
dtime = datetime.now().strftime("%m%d-%H%M")
from langchain_community.document_loaders import PyMuPDFLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_chroma import Chroma
from langchain_ollama import OllamaEmbeddings, ChatOllama
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import RunnablePassthrough
from chromadb import Client
from chromadb.config import Settings

def process_single_pdf(pdf_file, embeddings, llm, text_splitter, prompt, question, client):
    """Process a single PDF file using RAG (Retrieval-Augmented Generation).
    
    This function:
    1. Loads and chunks the PDF document
    2. Creates a vector database for semantic search
    3. Retrieves relevant chunks based on the question
    4. Generates structured output using LLM
    
    Args:
        pdf_file: Path to PDF file
        embeddings: Embedding model for vectorization
        llm: Language model for generation
        text_splitter: Text chunking strategy
        prompt: Prompt template
        question: Query question
        client: ChromaDB client instance
    
    Returns:
        tuple: (doi, output, status) where status indicates if output is a valid table
    """
    doi = os.path.basename(pdf_file).replace('.pdf', '').replace('_', '/')
    try:
        collection_name = f"collection_{doi.replace('/', '_')}"
        
        # Force delete old collection if exists
        try:
            client.delete_collection(collection_name)
        except:
            pass
        
        # Load and split document
        loader = PyMuPDFLoader(file_path=pdf_file)
        docs = loader.load()
        chunks = text_splitter.split_documents(docs)
        
        # Create vector database
        vector_db = Chroma.from_documents(
            documents=chunks,
            embedding=embeddings,
            collection_name=collection_name,
            client=client
        )
        
        # Build and execute RAG chain
        def format_docs(docs):
            return "\n\n".join(doc.page_content for doc in docs)
        
        retriever = vector_db.as_retriever(search_kwargs={"k": TOP_K})
        rag_chain = (
            {"context": retriever | format_docs, "question": RunnablePassthrough()}
            | prompt
            | llm
        )
        output = rag_chain.invoke(question).content
        status = 1 if output.strip().startswith("|") else 0

        # Clean up resources
        del loader, docs, chunks, vector_db, retriever, rag_chain
        gc.collect()
        
        return doi, output, status

    except Exception as e:
        return doi, f"Processing error: {str(e)}", -1

def process_pdfs_batch(codes, embed_stream, infer_stream, output_dir, pdf_dir, template, question):
    """Batch process PDF files using multiple embedding/LLM combinations.
    
    Args:
        codes: Processing code identifiers
        embed_stream: Embedding model stream (paired with inference models)
        infer_stream: Inference model stream (paired with embedding models)
        output_dir: Output directory
        pdf_dir: PDF files directory
        template: Prompt template
        question: Query question
    """
    # Create global ChromaDB client (in-memory to avoid file residue)
    client_settings = Settings(persist_directory="")
    global_client = Client(client_settings)
    
    total_start_time = time.time()
    
    for code, embed_model, infer_model in zip(codes, embed_stream, infer_stream):
        output_csv = os.path.join(output_dir, f"output_{code}_{dtime}.csv")
        
        # Initialize models and tools
        embeddings = OllamaEmbeddings(model=embed_model)
        llm = ChatOllama(model=infer_model, temperature=TEMPERATURE, num_predict=80)
        text_splitter = RecursiveCharacterTextSplitter(
            chunk_size=CHUNK_SIZE, 
            chunk_overlap=CHUNK_OVERLAP
        )
        prompt = PromptTemplate.from_template(template)

        # Process each PDF file
        pdf_files = sorted(glob.glob(os.path.join(pdf_dir, "*.pdf")))
        start_time = time.time()
        
        for pdf_file in tqdm(pdf_files, desc=f"Processing PDF files - {code}"):
            doi, output, status = process_single_pdf(
                pdf_file, embeddings, llm, text_splitter, prompt, question, global_client
            )
            
            # Save result
            df_new = pd.DataFrame([{'doi': doi, 'output': output, 'status': status}])
            if os.path.exists(output_csv):
                df_new.to_csv(output_csv, mode='a', header=False, index=False)
            else:
                df_new.to_csv(output_csv, index=False)
        
        end_time = time.time()
        batch_time = end_time - start_time
        
        # Write to log file
        with open(log_file, 'a') as lf:
            log_entry = (
                f"Code: {code} | Embed Model: {embed_model} | Infer Model: {infer_model}\n"
                f"Start: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(start_time))} | "
                f"End: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(end_time))} | "
                f"Duration: {batch_time:.2f} seconds | Status: {'Success' if status == 1 else 'Error'}\n"
                f"{'='*40}\n"
            )
            lf.write(log_entry)
    
    total_end_time = time.time()
    total_time = total_end_time - total_start_time
    
    with open(log_file, 'a') as lf:
        log_entry = (
            f"Total Start: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(total_start_time))}\n"
            f"Total End: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(total_end_time))}\n"
            f"Total Duration: {total_time:.2f} seconds\n"
            f"{'='*40}\n"
        )
        lf.write(log_entry)

# Prompt template for extraction
TEMPLATE = """
You are an expert information extraction algorithm.
Extract all the band gap values in the CONTEXT given below.
Output the band gap values in the form of a markdown table, including: Material (name of the material), Value (band gap value), Unit (unit of value).
Do not explain, only output the table in markdown format.
The output is strictly in the following format.
| Material | Value | Unit |
|----------|-------|------|
| ... | ... | eV |
| ... | ... | meV |
If no band gap values mentioned in the article, the following table is acceptable:
| Material | Value | Unit |
|----------|-------|------|
| None | None | None |
---
CONTEXT: {context}
---
QUESTION: {question}
Answer in markdown table:
"""
QUESTION = "What are the materials' name and their band gap values?"

TXT_DIR = os.path.join(HOME, "TXT(fromPDF_processed)")
PDF_DIR = os.path.join(HOME, "PDF")
OUTPUT_DIR = os.path.join(HOME, "output", "4-LangChain")
os.makedirs(OUTPUT_DIR, exist_ok=True)
log_file = os.path.join(OUTPUT_DIR, f"log_{dtime}.log")

EMBEDDING_MODELS = [
    "nomic-embed-text",
    "bge-m3",
]
INFERENCE_MODELS = [
    "llama2:13b",
    "llama3.1:70b",
    "qwen2.5:14b",
]

# Generate model combination codes
CODES = [f"LC_{i+1}{j+1}" 
        for i in range(len(EMBEDDING_MODELS)) 
        for j in range(len(INFERENCE_MODELS))]
embed_stream = [model for model in EMBEDDING_MODELS for _ in INFERENCE_MODELS]
infer_stream = INFERENCE_MODELS * len(EMBEDDING_MODELS)

# Execute batch processing
process_pdfs_batch(
    CODES,
    embed_stream,
    infer_stream,
    OUTPUT_DIR,
    PDF_DIR,
    TEMPLATE,
    QUESTION
)

## 2.5 Kimi-1.5 (Kimi)

Use the prompt below to extract bandgap data from Kimi (<https://www.kimi.com>) and save the output to `KIMI_OUT`.

### Prompt:

```
You are an expert information extraction algorithm.
Extract all the band gap values in this article and output them in the form of a markdown table, including: Material (name of the material), Value (value with unit), Sentence (the sentence from which this data record comes).
If data is not present in the article, type "None". 
Table only, no need for explanation or any other content.
The output is strictly in the following format.
```markdown
| Material | Value | Sentence |
|----------|-------|---------|
| Material1 | 0.1 eV | ... Eg of Material1 is 0.1 eV ... |
| Material1 | 200 meV | Material1 has a band gap of 200 meV, so ... |
| Material2 | None | Material2 ... |
```

If no band gap values mentioned in the article, the following table is acceptable:
```markdown
| Material | Value | Sentence |
|----------|-------|----------|
| None | None | None |
```
```

# 3 Organize Results

Post-processing steps to organize and compare extraction results.

In [None]:
import os
import re
import pandas as pd
import joblib
from datetime import datetime
import glob
import warnings
warnings.simplefilter(action="ignore", category=FutureWarning)
from my_post import clean_and_normalize, compare_with_index, parse_markdown_table, clean_illegal_chars

# Configuration
MARK = "test"
dir_mark = "" if MARK == "" else f'_{MARK}'
manual_xlsx = os.path.join(HOME, "manual_pub.xlsx")
today = datetime.now().strftime("%m%d")
comparison_dir = os.path.join(HOME, f'comparison_{today}{dir_mark}')
os.makedirs(comparison_dir, exist_ok=True)
comparison_xlsx = os.path.join(comparison_dir, "comparison_pub.xlsx")

# Read manual annotations sheet
df_manual = pd.read_excel(manual_xlsx, sheet_name='manual')

# Filter rows with non-empty 'material' column
df_filtered = df_manual[df_manual['material'].notnull()]

# Create summary sheet with selected columns
df_summary = df_filtered[['index', 'doi', 'material', 'Manual']].copy()
df_summary.rename(columns={'Manual': 'value'}, inplace=True)

# Save both sheets to comparison Excel file
with pd.ExcelWriter(comparison_xlsx) as writer:
    df_manual.to_excel(writer, sheet_name='manual', index=False)
    df_summary.to_excel(writer, sheet_name='summary', index=False)

## 3.1 Process CDE Results

In [None]:
def post_cde(temp_save, xlsx_path, comparison_xlsx, code):
    """Convert ChemDataExtractor raw data to standard format.

    Args:
        temp_save: Path to CDE raw data file
        xlsx_path: Output Excel file path: "FINAL_{CODE}_{dtime}.xlsx"
        comparison_xlsx: Comparison Excel file path
        code: Extraction method code
    """
    records = joblib.load(temp_save)
    
    # Extract required fields
    columns = [
        "Publisher",
        "DOI",
        "Name",
        "Raw_value",
        "Raw_unit",
        "Value",
        "Unit",
        "specifier",
        "Text",
        "Snowball",
        "AutoSentenceParser",
    ]
    
    flat_data = []
    for item in records:
        bandgap = item["BandGap"]
        flat_item = {
            "Publisher": bandgap["publisher"],
            "DOI": bandgap["doi"],
            "Name": bandgap["compound"]["Compound"]["names"][0],
            "Raw_value": bandgap["raw_value"],
            "Raw_unit": bandgap["raw_units"],
            "Value": bandgap["value"],
            "Unit": bandgap["units"],
            "specifier": bandgap["specifier"],
            "Text": bandgap["text"],
            "Snowball": bandgap["Snowball"],
            "AutoSentenceParser": bandgap["AutoSentenceParser"],
        }
        flat_data.append(flat_item)
    
    # Save raw data
    df = pd.DataFrame(flat_data, columns=columns)

    # Clean illegal control characters
    df = df.map(clean_illegal_chars)

    with pd.ExcelWriter(xlsx_path, mode="w", engine="openpyxl") as writer:
        df.to_excel(writer, sheet_name="0-raw", index=False)
    
    # Save simplified data
    simplified_df = pd.DataFrame(
        {
            "doi": df["DOI"],
            "material": df["Name"],
            "value": df["Raw_value"],
            "unit": df["Raw_unit"],
        }
    )
    with pd.ExcelWriter(
        xlsx_path, mode="a", if_sheet_exists="replace", engine="openpyxl"
    ) as writer:
        simplified_df.to_excel(writer, sheet_name="1-raw", index=False)
    with pd.ExcelWriter(
        comparison_xlsx, mode="a", if_sheet_exists="replace", engine="openpyxl"
    ) as writer:
        simplified_df.to_excel(writer, sheet_name=f"{code}_raw", index=False)


def postprocess_cde(temp_save, xlsx_path, comparison_xlsx, code):
    """Main function for CDE extraction post-processing.
    
    Args:
        temp_save: Path to CDE raw data file
        xlsx_path: Result Excel file
        comparison_xlsx: Comparison Excel file path
        code: Extraction method code
    """
    post_cde(temp_save, xlsx_path, comparison_xlsx, code)
    clean_and_normalize(xlsx_path, comparison_xlsx, code, sheet_name="1-raw")
    compare_with_index(xlsx_path, comparison_xlsx, code)

code = "CDE"

# Result data paths
output_dir = os.path.join(HOME, "output", "1-ChemDataExtractor")
temp_save = os.path.join(output_dir, "records_general.joblib")
xlsx_path = os.path.join(comparison_dir, f"1_{code}_{today}.xlsx")

postprocess_cde(temp_save, xlsx_path, comparison_xlsx, code)

## 3.2 Process PSIE Results

In [None]:
def post_psie(csv_2s, csv_3, xlsx_path, comparison_xlsx, code):
    """Convert BERT-PSIE raw data to standard format.
    
    Args:
        csv_2s: Path to single-mention CSV file
        csv_3: Path to multiple-mention CSV file
        xlsx_path: Output Excel file path: "FINAL_{CODE}_{dtime}.xlsx"
        comparison_xlsx: Comparison Excel file path
        code: Extraction method code
    """
    # Read and merge CSV files
    df1 = pd.read_csv(csv_2s)
    df2 = pd.read_csv(csv_3)
    df1.columns = df2.columns
    df = pd.concat([df1, df2], ignore_index=True)

    df = df.iloc[:, 1:]  # Remove first column
    df.columns = ['material', 'value', 'sentence', 'doi']
    df = df[['sentence', 'doi', 'material', 'value']]  # Reorder columns
    
    # Clean illegal control characters
    df = df.map(clean_illegal_chars)
    
    with pd.ExcelWriter(xlsx_path, mode="w", engine="openpyxl") as writer:
        df.to_excel(writer, sheet_name="1-raw", index=False)
    with pd.ExcelWriter(comparison_xlsx, mode="a", if_sheet_exists="replace", engine="openpyxl") as writer:
        df.to_excel(writer, sheet_name=f"{code}_raw", index=False)

def postprocess_psie(csv_2s, csv_3, xlsx_path, comparison_xlsx, code):
    """Main function for PSIE extraction post-processing.
    
    Args:
        csv_2s: Path to single-mention CSV file
        csv_3: Path to multiple-mention CSV file
        xlsx_path: Result Excel file
        comparison_xlsx: Comparison Excel file path
        code: Extraction method code
    """
    post_psie(csv_2s, csv_3, xlsx_path, comparison_xlsx, code)
    clean_and_normalize(xlsx_path, comparison_xlsx, code)
    compare_with_index(xlsx_path, comparison_xlsx, code)

code = "PSIE"

# Result data paths
output_dir = os.path.join(HOME, "output", "2-BERT-PSIE")
output_csv_2_s = os.path.join(output_dir, "2-test_extraction_single_mentions.csv")
output_csv_3 = os.path.join(output_dir, "3-relations_extraction.csv")
xlsx_path = os.path.join(comparison_dir, f"2_{code}_{today}.xlsx")

postprocess_psie(output_csv_2_s, output_csv_3, xlsx_path, comparison_xlsx, code)

## 3.3 Process ChatExtract Results

In [None]:
def post_ce(extracted_csv, xlsx_path, comparison_xlsx, code):
    """Convert ChatExtract raw data to standard format.
    
    Args:
        extracted_csv: Path to extracted CSV file
        xlsx_path: Output Excel file path: "FINAL_{CODE}_{dtime}.xlsx"
        comparison_xlsx: Comparison Excel file path
        code: Extraction method code
    """
    # Read CSV file
    df = pd.read_csv(extracted_csv)
    
    # Filter rows with all validity flags set to 1
    df = df[(df["material_valid"] == 1) &
            (df["value_valid"] == 1) &
            (df["unit_valid"] == 1)]
    
    # Remove "meV" or "eV" from value column
    df["value"] = df["value"].apply(
        lambda x: x.replace("meV", "").replace("eV", "").strip() if pd.notna(x) else x
    )
    
    # Select and reorder columns
    columns_to_keep = ["doi", "material", "value", "unit", "passage", "sentence"]
    df = df[columns_to_keep]
    
    # Clean illegal control characters
    df = df.map(clean_illegal_chars)

    with pd.ExcelWriter(xlsx_path, mode="w", engine="openpyxl") as writer:
        df.to_excel(writer, sheet_name="1-raw", index=False)
    with pd.ExcelWriter(comparison_xlsx, mode="a", if_sheet_exists="replace", engine="openpyxl") as writer:
        df.to_excel(writer, sheet_name=f"{code}_raw", index=False)

def postprocess_ce(extracted_csv, xlsx_path, comparison_xlsx, code):
    """Main function for ChatExtract post-processing.
    
    Args:
        extracted_csv: Path to extracted CSV file
        xlsx_path: Result Excel file
        comparison_xlsx: Comparison Excel file path
        code: Extraction method code
    """
    post_ce(extracted_csv, xlsx_path, comparison_xlsx, code)
    clean_and_normalize(xlsx_path, comparison_xlsx, code)
    compare_with_index(xlsx_path, comparison_xlsx, code)

codes = [
    "CE_1",  # llama2:13b
    "CE_2",  # llama3.1:70b
    "CE_3",  # qwen2.5:14b
]

# Result data paths
output_dir = os.path.join(HOME, "output", "3-ChatExtract")
csvs = [
    f'2_extracted_llama2-13b_CE_1.csv',
    f'2_extracted_llama3.1-70b_CE_2.csv',
    f'2_extracted_qwen2.5-14b_CE_3.csv'
]

for csv, code in zip(csvs, codes):
    extracted_csv = os.path.join(output_dir, csv)
    xlsx_path = os.path.join(comparison_dir, f"3_{code}_{today}.xlsx")
    postprocess_ce(extracted_csv, xlsx_path, comparison_xlsx, code)

## 3.4 Process LangChain Results

In [None]:
def post_lc(extracted_csv, xlsx_path, comparison_xlsx, code):
    """Convert LangChain RAG raw data to standard format.
    
    This function:
    1. Reads and deduplicates raw data
    2. Parses markdown tables from each row
    3. Standardizes table format and merges results
    
    Args:
        extracted_csv: Input CSV file path: "output_{embed}_{infer}_{dtime}.csv"
        xlsx_path: Output Excel file path: "FINAL_{embed}_{infer}_{dtime}.xlsx"
        comparison_xlsx: Comparison Excel file path
        code: Extraction method code
    """
    # Read CSV and remove duplicates
    raw_df = pd.read_csv(extracted_csv)
    raw_df = raw_df.drop_duplicates(subset=["output"], keep="first")
    
    # Keep only rows with status == 1
    raw_df = raw_df[raw_df['status'] == 1]
    
    # Initialize result DataFrame
    result_df = pd.DataFrame(columns=["doi", "material", "value", "unit"])
    
    # Process each row
    parsed_tables = []
    for _, row in raw_df.iterrows():
        table = parse_markdown_table(row["output"])
        if table is not None:
            table["doi"] = row["doi"]
            parsed_tables.append(table[["doi", "material", "value", "unit"]])
    
    # Merge all parsed results
    if parsed_tables:
        result_df = pd.concat(parsed_tables, ignore_index=True)
    
    # Basic processing: remove "meV" or "eV" from value column
    result_df["value"] = result_df["value"].apply(
        lambda x: x.replace("meV", "").replace("eV", "").strip() if pd.notna(x) else x
    )
    
    # Clean illegal control characters
    result_df = result_df.map(clean_illegal_chars)
    
    with pd.ExcelWriter(xlsx_path, mode="w", engine="openpyxl") as writer:
        result_df.to_excel(writer, sheet_name="0-raw", index=False)
    
    # Remove null/none values
    result_df = result_df[
        ~result_df["value"].str.lower().str.strip().isin(["none", "nan"]) & 
        result_df["value"].notna()
    ]
    
    with pd.ExcelWriter(xlsx_path, mode="w", engine="openpyxl") as writer:
        result_df.to_excel(writer, sheet_name="1-raw", index=False)
    with pd.ExcelWriter(comparison_xlsx, mode="a", if_sheet_exists="replace", engine="openpyxl") as writer:
        result_df.to_excel(writer, sheet_name=f"{code}_raw", index=False)

def postprocess_lc(extracted_csv, xlsx_path, comparison_xlsx, code):
    """Main function for LangChain extraction post-processing.
    
    Args:
        extracted_csv: Path to extracted CSV file
        xlsx_path: Result Excel file
        comparison_xlsx: Comparison Excel file path
        code: Extraction method code
    """
    post_lc(extracted_csv, xlsx_path, comparison_xlsx, code)
    clean_and_normalize(xlsx_path, comparison_xlsx, code)
    compare_with_index(xlsx_path, comparison_xlsx, code)


# Result data paths
output_dir = os.path.join(HOME, "output", "4-LangChain")

csv_files = glob.glob(os.path.join(output_dir, "output_LC_*.csv"))
csv_files.sort()
codes = [
    f"{os.path.basename(file).split('_')[1]}_{os.path.basename(file).split('_')[2]}" 
    for file in csv_files
]

for extracted_csv, code in zip(csv_files, codes):
    print(extracted_csv, code)
    xlsx_path = os.path.join(comparison_dir, f"4_{code}_{today}.xlsx")
    postprocess_lc(extracted_csv, xlsx_path, comparison_xlsx, code)

## 3.5 Process Kimi Results

In [None]:
def post_kimi(extracted_xlsx, xlsx_path, comparison_xlsx, code):
    """Convert Kimi extraction raw data to standard format.
    
    Args:
        extracted_xlsx: Path to Kimi output Excel file
        xlsx_path: Output Excel file path
        comparison_xlsx: Comparison Excel file path
        code: Extraction method code
    """
    df = pd.read_excel(extracted_xlsx, header=None, names=["doi", "output"])
    
    def path2doi(pdf_path):
        """Convert PDF path to DOI by removing directory and extension."""
        return os.path.splitext(os.path.basename(pdf_path))[0]
    
    df['doi'] = df['doi'].apply(path2doi)
    result_df = pd.DataFrame(columns=["doi", "material", "value", "sentence"])
    
    # Parse markdown tables from output column
    parsed_tables = []
    for _, row in df.iterrows():
        table = parse_markdown_table(row["output"], third="sentence")
        if table is not None:
            table["doi"] = row["doi"].replace('_', '/')
            parsed_tables.append(table[["doi", "material", "value", "sentence"]])
    
    # Merge all parsed results
    if parsed_tables:
        result_df = pd.concat(parsed_tables, ignore_index=True)
    
    # Clean illegal control characters
    result_df = result_df.map(clean_illegal_chars)
    
    # Save results
    with pd.ExcelWriter(xlsx_path, mode="w", engine="openpyxl") as writer:
        result_df.to_excel(writer, sheet_name="1-raw", index=False)
    with pd.ExcelWriter(comparison_xlsx, mode="a", if_sheet_exists="replace", engine="openpyxl") as writer:
        result_df.to_excel(writer, sheet_name=f"{code}_raw", index=False)

def postprocess_kimi(extracted_xlsx, xlsx_path, comparison_xlsx, code):
    """Main function for Kimi extraction post-processing.
    
    Args:
        extracted_xlsx: Path to Kimi output file
        xlsx_path: Result Excel file
        comparison_xlsx: Comparison Excel file path
        code: Extraction method code
    """
    post_kimi(extracted_xlsx, xlsx_path, comparison_xlsx, code)
    clean_and_normalize(xlsx_path, comparison_xlsx, code)
    compare_with_index(xlsx_path, comparison_xlsx, code)

# Result data paths
KIMI_OUT = os.path.join(HOME, "output", "5-Kimi", 'Kimi_pub_2025-06-20.xlsx')
code = "Kimi"
xlsx_path = os.path.join(comparison_dir, f"5_{code}_{today}.xlsx")

postprocess_kimi(KIMI_OUT, xlsx_path, comparison_xlsx, code)

## 3.6 Generate Summary

In [None]:
"""Generate summary of all extraction results."""

summary_xlsx = os.path.join(comparison_dir, f"summary_{today}{dir_mark}.xlsx")

# Read summary sheet
df = pd.read_excel(comparison_xlsx, sheet_name="summary")
original_rows = len(df)
material_filled = df["other_mat"].notna().sum()
print(f"Original data rows: {original_rows}")
print(f"'other_mat' column filled rows: {material_filled}")

# ========== Supplement index-DOI pairs from manual sheet ==========
# Check and add missing DOI column
if 'doi' not in df.columns:
    df['doi'] = ''

# Read manual sheet
try:
    manual_df = pd.read_excel(comparison_xlsx, sheet_name="manual")
    
    # Extract index and DOI, remove duplicates
    if all(col in manual_df.columns for col in ['index', 'doi']):
        manual_pairs = manual_df[['index', 'doi']].drop_duplicates()
        
        # Create existing index-DOI pair set
        existing_pairs = set(zip(
            df['index'].fillna('').astype(str),
            df['doi'].fillna('').astype(str)
        ))
        
        # Find index-DOI pairs that need to be added
        new_rows = []
        for _, row in manual_pairs.iterrows():
            idx_val = row['index']
            doi_val = row['doi']
            # Handle NaN values and convert to strings for comparison
            idx_str = str(idx_val) if pd.notna(idx_val) else ''
            doi_str = str(doi_val) if pd.notna(doi_val) else ''
            
            if (idx_str, doi_str) not in existing_pairs:
                # Create new row (all columns empty except index/doi)
                new_row = {col: '' for col in df.columns}
                new_row['index'] = idx_val
                new_row['doi'] = doi_val
                new_rows.append(new_row)
        
        # Add missing rows
        if new_rows:
            new_df = pd.DataFrame(new_rows)
            df = pd.concat([df, new_df], ignore_index=True)
            print(f"Added {len(new_rows)} rows from manual sheet")
        else:
            print("No index-DOI pairs need to be added")
    else:
        print("Manual sheet missing 'index' or 'doi' columns, skipping")
except Exception as e:
    print(f"Error reading manual sheet: {str(e)}, skipping")
# ========== End of supplementation ==========

# Process data
mask = df["other_mat"].notna() & (df["other_mat"] != "")
modified_rows = mask.sum()

print(f"Will modify {modified_rows} rows")

df.loc[mask, "material"] = df.loc[mask, "other_mat"]
df.loc[mask, "value"] = ""
df.drop(columns=["other_mat"], inplace=True)

# Sort by index and material
sort_columns = ['index', 'material']
df = df.sort_values(by=[col for col in sort_columns if col in df.columns])

with pd.ExcelWriter(summary_xlsx, engine="openpyxl", mode='w') as writer:
    df.to_excel(writer, sheet_name="summary", index=False)

print(f"Data successfully written to: {summary_xlsx}")
print(f"Final data rows: {len(df)}")
print(f"Modified rows: {modified_rows}")