In [None]:
import os
import random
import shutil
from pathlib import Path

import arxiv
import pandas as pd
import pymupdf
from tqdm.contrib.concurrent import thread_map

In [None]:
PAPERS_OUTPUT_PATH = Path("papers")

PAPERS_PER_CATEGORY = 20

IMAGES_OUTPUT_PATH = Path("images")

DATASET_OUTPUT_PATH = Path("dataset")

In [None]:
with open("categories.txt", "r") as f:
    categories = f.read().strip().split("\n")

print(f"Found {len(categories)} categories: {categories}")

In [None]:
def download_paper(args):
    paper, category = args
    try:
        return paper.download_pdf(dirpath=PAPERS_OUTPUT_PATH / category)
    except Exception as e:
        print(f"Error downloading paper {paper.title}: {e}")
        return None


def process_category(category):
    print("Downloading papers for", category)

    # Create category directory
    (PAPERS_OUTPUT_PATH / category).mkdir(parents=True, exist_ok=True)

    # Count the number of papers already downloaded
    existing_files = len(list(PAPERS_OUTPUT_PATH.glob(f"{category}/*.pdf")))

    # Calculate number of papers to download
    papers_needed = PAPERS_PER_CATEGORY - existing_files

    # Skip if we already have enough papers
    if papers_needed <= 0:
        print(f"Skipping {category} as we already have {existing_files} papers")
        return

    # Download papers
    arxiv_client = arxiv.Client()

    search = arxiv.Search(
        query=f"cat:{category}",
        max_results=max(100, papers_needed * 2),
        sort_by=arxiv.SortCriterion.Relevance,
    )

    results = list(arxiv_client.results(search))

    # Shuffle the results randomly and take only the first papers_needed
    random.shuffle(results)
    results = results[:papers_needed]

    # Create list of (paper, category) tuples for download_paper function
    download_args = [(paper, category) for paper in results]

    download_results = thread_map(
        download_paper,
        download_args,
        max_workers=1,
        chunksize=1,
        desc=f"Downloading papers for {category}",
    )

    print(f"Downloaded {len(download_results)} papers for {category}")


for category in categories:
    process_category(category)

In [None]:
assert len(set(PAPERS_OUTPUT_PATH.glob("**/*.pdf"))) == PAPERS_PER_CATEGORY * len(
    categories
)

In [None]:
def process_pdf(pdf):
    image_output_dir = IMAGES_OUTPUT_PATH / pdf.parent.name
    image_output_dir.mkdir(parents=True, exist_ok=True)
    image_output_path = image_output_dir / f"{pdf.stem}.png"

    if image_output_path.exists():
        return

    try:
        doc = pymupdf.open(pdf)
    except pymupdf.FileDataError:
        # Delete the pdf if it's corrupted
        print(f"Deleting corrupted PDF: {pdf}")
        os.remove(pdf)
        return

    page = doc.load_page(0)
    pixmap = page.get_pixmap(dpi=300, colorspace=pymupdf.csRGB)

    pixmap.save(image_output_path)


def process_pdfs():
    pdfs = sorted(PAPERS_OUTPUT_PATH.glob("**/*.pdf"))

    print(f"Found {len(pdfs)} PDFs")

    thread_map(process_pdf, pdfs, max_workers=8, chunksize=1, desc="Processing PDFs")


process_pdfs()

In [None]:
assert len(set(IMAGES_OUTPUT_PATH.glob("**/*.png"))) == PAPERS_PER_CATEGORY * len(
    categories
)

In [None]:
def generate_dataset():
    # Create dataset directories
    dataset_dir = Path("dataset")
    train_dir = dataset_dir / "train"
    test_dir = dataset_dir / "test"

    train_dir.mkdir(parents=True, exist_ok=True)
    test_dir.mkdir(parents=True, exist_ok=True)

    # Initialize lists to store CSV data
    csv_data = []

    # Process each category
    for category_dir in IMAGES_OUTPUT_PATH.glob("*"):
        if not category_dir.is_dir():
            continue

        category = category_dir.name
        image_files = sorted(category_dir.glob("*.png"))

        # Take 5 images for train and 5 for test
        train_images = image_files[:5]
        test_images = image_files[5:10]

        # Copy images and collect CSV data
        for img_path in train_images:
            dest = train_dir / img_path.name
            shutil.copy2(img_path, dest)
            csv_data.append({"document": str(dest), "label": category, "is_train": 1})

        for img_path in test_images:
            dest = test_dir / img_path.name
            shutil.copy2(img_path, dest)
            csv_data.append({"document": str(dest), "label": category, "is_train": 0})

        # Create CSV file
        df = pd.DataFrame(csv_data)
        df.to_csv("labels.csv", index=False)


generate_dataset()

## Optional: Compress Images

- Install [`pngquant`](https://pngquant.org/).
- Run the following commands to compress the images:
    ```bash
    cd dataset
    pngquant --quality=65-80 --force --ext=.png **/*.png
    ```

## Optional: Build a tarball

```bash
tar -czf multimodal-vision-finetuning.tar.gz data/dataset/ data/labels.csv
```