In [21]:
"""
Table Classifier Training Dataset Builder
Creates a dataset for SCT vs non-SCT table classification.
"""
import json, random
from pathlib import Path
from collections import Counter
from tqdm.auto import tqdm
from datasets import Dataset, Image, ClassLabel

# Config
SEED = 42424242
SAMPLE_SIZE_NEGATIVE = 3000
BASE_PATH = Path("/home/pdipasquale/MIIA/stuff")
OUTPUT_PATH = BASE_PATH / "output"
ALL_TABLES_PATH = BASE_PATH / "all_tables.json"
HF_REPO = "pierjoe/sec-table-classifier"

random.seed(SEED)

In [22]:
# Step 1: Categorize documents
sct_docs, non_sct_docs, funds, multi_sct_docs = [], [], [], []

for doc_dir in tqdm(list(OUTPUT_PATH.iterdir()), desc="Scanning docs"):
    if not doc_dir.is_dir(): continue
    metadata_path = doc_dir / "metadata.json"
    if not metadata_path.exists(): continue
    
    with open(metadata_path) as f:
        meta = json.load(f)
    
    if meta.get("sic") in ("NULL", None):
        funds.append(doc_dir.name)
        continue
    
    classification_path = doc_dir / "classification_results.json"
    if classification_path.exists():
        with open(classification_path) as f:
            classification = json.load(f)
        num_sct = classification.get("total_tables_found", 0)
        if num_sct == 1:
            sct_docs.append({"doc_id": doc_dir.name, "meta": meta, "classification": classification})
        else:
            multi_sct_docs.append(doc_dir.name)
    elif (doc_dir / "no_sct_found.json").exists():
        non_sct_docs.append(doc_dir.name)

print(f"Funds: {len(funds)} | SCT (1 table): {len(sct_docs)} | SCT (multi): {len(multi_sct_docs)} | No SCT: {len(non_sct_docs)}")

Scanning docs: 100%|██████████| 2500/2500 [00:01<00:00, 1750.45it/s]

Funds: 463 | SCT (1 table): 1549 | SCT (multi): 363 | No SCT: 125





In [23]:
# Step 2: Build positive samples (SCT tables)
positive_samples = []
for doc in tqdm(sct_docs, desc="Positive samples"):
    doc_id = doc["doc_id"]
    for table_entry in doc["classification"].get("tables", []):
        table_data = table_entry["table"]
        img_path = OUTPUT_PATH / doc_id / doc_id / "vlm" / table_data.get("img_path", "")
        table_body = table_data.get("table_body", "")
        if img_path.exists() and table_body:
            positive_samples.append({
                "doc_id": doc_id, "image_path": str(img_path), "table_html": table_body,
                "label": 1, "year": doc["meta"].get("year"), "company": doc["meta"].get("company")
            })
print(f"Positive samples: {len(positive_samples)}")

Positive samples: 100%|██████████| 1549/1549 [00:01<00:00, 907.13it/s]

Positive samples: 1549





In [24]:
# Step 3: Load all tables and build negative samples
with open(ALL_TABLES_PATH) as f:
    all_tables = json.load(f)

# Build SCT table keys to exclude
sct_table_keys = {(s["doc_id"], "/".join(s["image_path"].split("/")[-2:])) for s in positive_samples}
sct_doc_ids = {doc["doc_id"] for doc in sct_docs}

# Get negative samples - tables from SCT docs that are NOT the SCT table
all_negative_candidates = []
for table in tqdm(all_tables, desc="Negative candidates"):
    doc_id = table.get("source_doc")
    if doc_id not in sct_doc_ids: continue
    
    img_rel_path = table.get("img_path", "")
    if (doc_id, img_rel_path) in sct_table_keys: continue
    
    img_path = OUTPUT_PATH / doc_id / doc_id / "vlm" / img_rel_path
    table_body = table.get("table_body", "")
    if not img_path.exists() or not table_body: continue
    
    meta_path = OUTPUT_PATH / doc_id / "metadata.json"
    meta = json.load(open(meta_path)) if meta_path.exists() else {}
    
    all_negative_candidates.append({
        "doc_id": doc_id, "image_path": str(img_path), "table_html": table_body,
        "label": 0, "year": meta.get("year"), "company": meta.get("company")
    })

negative_samples = random.sample(all_negative_candidates, min(SAMPLE_SIZE_NEGATIVE, len(all_negative_candidates)))
print(f"Negative samples: {len(negative_samples)} (from {len(all_negative_candidates)} candidates)")

Negative candidates: 100%|██████████| 52564/52564 [00:15<00:00, 3474.10it/s]

Negative samples: 3000 (from 25414 candidates)





In [25]:
# Step 4: Create HuggingFace dataset
all_samples = positive_samples + negative_samples
random.shuffle(all_samples)
print(f"Total: {len(all_samples)} | Positive: {len(positive_samples)} | Negative: {len(negative_samples)}")

dataset = Dataset.from_dict({
    "image": [s["image_path"] for s in all_samples],
    "text": [s["table_html"] for s in all_samples],
    "label": [s["label"] for s in all_samples],
    "doc_id": [s["doc_id"] for s in all_samples],
    "year": [s["year"] for s in all_samples],
    "company": [s["company"] for s in all_samples],
})
dataset = dataset.cast_column("image", Image())
dataset = dataset.cast_column("label", ClassLabel(names=["non_sct", "sct"]))
print(dataset)

Total: 4549 | Positive: 1549 | Negative: 3000


Casting the dataset: 100%|██████████| 4549/4549 [00:00<00:00, 958066.23 examples/s]

Dataset({
    features: ['image', 'text', 'label', 'doc_id', 'year', 'company'],
    num_rows: 4549
})





In [26]:
# Step 5: Split and push to HuggingFace
dataset_split = dataset.train_test_split(test_size=0.2, seed=SEED, stratify_by_column="label")
print(f"Train: {len(dataset_split['train'])} | Test: {len(dataset_split['test'])}")

dataset_split.push_to_hub(HF_REPO, private=False)
print(f"✓ Pushed to: https://huggingface.co/datasets/{HF_REPO}")

Train: 3639 | Test: 910


Map: 100%|██████████| 3639/3639 [00:02<00:00, 1675.62 examples/s]ards/s]
Creating parquet from Arrow format: 100%|██████████| 4/4 [00:00<00:00,  9.06ba/s]
Processing Files (1 / 1): 100%|██████████|  303MB /  303MB, 50.4MB/s  
New Data Upload: 100%|██████████|  298MB /  298MB, 49.7MB/s  
Uploading the dataset shards: 100%|██████████| 1/1 [00:12<00:00, 12.31s/ shards]
Map: 100%|██████████| 910/910 [00:00<00:00, 1933.73 examples/s]shards/s]
Creating parquet from Arrow format: 100%|██████████| 1/1 [00:00<00:00,  7.86ba/s]
Processing Files (1 / 1): 100%|██████████| 77.6MB / 77.6MB, 19.1MB/s  
New Data Upload: 100%|██████████| 76.3MB / 76.3MB, 19.1MB/s  
Uploading the dataset shards: 100%|██████████| 1/1 [00:06<00:00,  6.15s/ shards]


✓ Pushed to: https://huggingface.co/datasets/pierjoe/sec-table-classifier


# Training loop

In [28]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
from PIL import Image as PILImage
from tqdm.auto import tqdm

MODEL_NAME = "Qwen/Qwen2.5-VL-3B-Instruct"
BATCH_SIZE = 2
EPOCHS = 3
LR = 1e-5
DEVICE = "cuda"

print(f"Using GPU: {torch.cuda.get_device_name(0)}")

Using GPU: NVIDIA H100 80GB HBM3


In [30]:
# Load base model and processor
processor = AutoProcessor.from_pretrained(MODEL_NAME)
base_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
    MODEL_NAME, 
    torch_dtype=torch.bfloat16,
    device_map="cuda:0"
)
print(f"Model loaded: {MODEL_NAME}")

`torch_dtype` is deprecated! Use `dtype` instead!
Fetching 2 files: 100%|██████████| 2/2 [00:08<00:00,  4.27s/it]
Loading checkpoint shards: 100%|██████████| 2/2 [00:12<00:00,  6.30s/it]


Model loaded: Qwen/Qwen2.5-VL-3B-Instruct


In [31]:
# Classifier model with classification head on top of VLM
class VLMClassifier(nn.Module):
    def __init__(self, base_model, hidden_size=None, num_labels=2):
        super().__init__()
        self.base_model = base_model
        # Get hidden size from config
        hidden_size = hidden_size or base_model.config.hidden_size
        self.classifier = nn.Sequential(
            nn.Linear(hidden_size, 512),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(512, num_labels)
        )
        
    def forward(self, input_ids, attention_mask, pixel_values, image_grid_thw, labels=None):
        # Get model outputs (last hidden state)
        outputs = self.base_model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            pixel_values=pixel_values,
            image_grid_thw=image_grid_thw,
            output_hidden_states=True,
            return_dict=True
        )
        # Use last hidden state, take mean over sequence
        hidden_states = outputs.hidden_states[-1]  # (batch, seq, hidden)
        pooled = hidden_states.mean(dim=1)  # (batch, hidden)
        logits = self.classifier(pooled.float())  # (batch, num_labels)
        
        loss = None
        if labels is not None:
            loss = nn.CrossEntropyLoss()(logits, labels)
        
        return {"loss": loss, "logits": logits}

# Freeze base model, only train classifier head
for param in base_model.parameters():
    param.requires_grad = False

model = VLMClassifier(base_model, num_labels=2).to(DEVICE)
print(f"Trainable params: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")

Trainable params: 1,050,114


In [32]:
# Custom collate function for VLM
def collate_fn(batch):
    images = []
    labels = []
    
    for item in batch:
        img = item["image"]
        if isinstance(img, dict) and "path" in img:
            img = PILImage.open(img["path"]).convert("RGB")
        images.append(img)
        labels.append(item["label"])
    
    # Create simple prompt for classification
    messages_batch = []
    for img in images:
        messages_batch.append([{
            "role": "user",
            "content": [
                {"type": "image", "image": img},
                {"type": "text", "text": "Classify this table."}
            ]
        }])
    
    # Process with Qwen processor
    texts = [processor.apply_chat_template(m, tokenize=False, add_generation_prompt=True) for m in messages_batch]
    inputs = processor(
        text=texts,
        images=images,
        padding=True,
        return_tensors="pt"
    )
    
    inputs["labels"] = torch.tensor(labels)
    return inputs

# Create dataloaders
train_loader = DataLoader(dataset_split["train"], batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)
test_loader = DataLoader(dataset_split["test"], batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn)
print(f"Train batches: {len(train_loader)} | Test batches: {len(test_loader)}")

Train batches: 1820 | Test batches: 455


In [None]:
# Training loop
optimizer = torch.optim.AdamW(model.classifier.parameters(), lr=LR)

for epoch in range(EPOCHS):
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    
    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}")
    for batch in pbar:
        # Move to device
        input_ids = batch["input_ids"].to(DEVICE)
        attention_mask = batch["attention_mask"].to(DEVICE)
        pixel_values = batch["pixel_values"].to(DEVICE, dtype=torch.bfloat16)
        image_grid_thw = batch["image_grid_thw"].to(DEVICE)
        labels = batch["labels"].to(DEVICE)
        
        optimizer.zero_grad()
        outputs = model(input_ids, attention_mask, pixel_values, image_grid_thw, labels)
        loss = outputs["loss"]
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        preds = outputs["logits"].argmax(dim=-1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)
        
        pbar.set_postfix(loss=f"{loss.item():.4f}", acc=f"{correct/total:.2%}")
    
    print(f"Epoch {epoch+1} - Loss: {total_loss/len(train_loader):.4f} | Acc: {correct/total:.2%}")

Epoch 1/3:   3%|▎         | 55/1820 [00:12<06:56,  4.24it/s, acc=64.55%, loss=0.4473]

In [None]:
# Evaluation
model.eval()
correct = 0
total = 0
all_preds = []
all_labels = []

with torch.no_grad():
    for batch in tqdm(test_loader, desc="Evaluating"):
        input_ids = batch["input_ids"].to(DEVICE)
        attention_mask = batch["attention_mask"].to(DEVICE)
        pixel_values = batch["pixel_values"].to(DEVICE, dtype=torch.bfloat16)
        image_grid_thw = batch["image_grid_thw"].to(DEVICE)
        labels = batch["labels"].to(DEVICE)
        
        outputs = model(input_ids, attention_mask, pixel_values, image_grid_thw)
        preds = outputs["logits"].argmax(dim=-1)
        
        correct += (preds == labels).sum().item()
        total += labels.size(0)
        all_preds.extend(preds.cpu().tolist())
        all_labels.extend(labels.cpu().tolist())

print(f"\nTest Accuracy: {correct/total:.2%}")

# Confusion matrix
from collections import Counter
print(f"Predictions: {Counter(all_preds)}")
print(f"Labels: {Counter(all_labels)}")

In [None]:
# Save the classifier head
SAVE_PATH = BASE_PATH / "models" / "sct_classifier"
SAVE_PATH.mkdir(parents=True, exist_ok=True)

torch.save(model.classifier.state_dict(), SAVE_PATH / "classifier_head.pt")
print(f"✓ Classifier saved to: {SAVE_PATH}")