# Nova Lite Fine-tuning Dataset Preparation

This notebook prepares the RVL-CDIP dataset for fine-tuning a Nova Lite model. We will:
1. Load the dataset
2. Sample 500 data samples per label
3. Save the images in PNG format
4. Create a train.jsonl file in the required format
5. Upload both the images and the train.jsonl file to the S3 bucket

In [None]:
# Import necessary libraries
from datasets import load_dataset, Dataset
import numpy as np
import pandas as pd
import os
import json
import boto3
import uuid
from tqdm import tqdm  # Use standard tqdm instead of tqdm.notebook
import io
from PIL import Image
import concurrent.futures
import time
from functools import partial
import random
from dotenv import load_dotenv
load_dotenv()

# Set random seed for reproducibility
random.seed(42)
np.random.seed(42)

In [None]:
# Define S3 bucket and directory
bucket_name = "test-idp-finetuning-data-us-east-1"
directory = "rvl-cdip-sampled"
SAMPLES_PER_LABEL = 10

# Initialize S3 client
s3 = boto3.client('s3')

In [None]:
# Define label mapping according to the requirements
label_mapping = {
    0: "advertissement",
    1: "budget",
    2: "email",
    3: "file_folder",
    4: "form",
    5: "handwritten",
    6: "invoice",
    7: "letter",
    8: "memo",
    9: "news_article",
    10: "presentation",
    11: "questionnaire",
    12: "resume",
    13: "scientific_publication",
    14: "scientific_report",
    15: "specification"
}

In [None]:
# Load the dataset
ds = load_dataset("chainyo/rvl-cdip")
print(f"Dataset loaded: {ds}")

In [None]:
# Check the types of the image and label fields
print(f"Image type: {type(ds['train'][0]['image'])}, Label type: {type(ds['train'][0]['label'])}")

In [None]:
# Get the unique labels
unique_labels = np.unique(ds["train"]["label"])
print(f"Number of unique labels: {len(unique_labels)}")
print(f"Labels: {unique_labels}")

In [None]:
# Get the label names if available
if hasattr(ds['train'].features['label'], 'names'):
    label_names = ds['train'].features['label'].names
    print(f"Label names: {label_names}")
else:
    print("Label names not available")
    label_names = [f"class_{i}" for i in unique_labels]

In [None]:
# Define checkpoint file path
checkpoint_file = "sampled_data_checkpoint.pkl"

# Function to save sampled data to checkpoint file
def save_checkpoint(data):
    print(f"Saving checkpoint to {checkpoint_file}...")
    # Create a list of dictionaries that can be serialized
    serializable_data = []
    for sample in data:
        # Convert PIL image to bytes for serialization
        sample_dict = dict(sample)
        if "image" in sample_dict:
            img_bytes = io.BytesIO()
            sample["image"].save(img_bytes, format="PNG")
            sample_dict["image_bytes"] = img_bytes.getvalue()
            del sample_dict["image"]
        serializable_data.append(sample_dict)
    
    # Save to file
    with open(checkpoint_file, "wb") as f:
        import pickle
        pickle.dump(serializable_data, f)
    print(f"Checkpoint saved with {len(data)} samples")

# Function to load sampled data from checkpoint file
def load_checkpoint():
    print(f"Loading data from checkpoint file {checkpoint_file}...")
    try:
        with open(checkpoint_file, "rb") as f:
            import pickle
            serializable_data = pickle.load(f)
        
        # Convert back to samples with PIL images
        loaded_data = []
        for sample_dict in serializable_data:
            if "image_bytes" in sample_dict:
                # Convert bytes back to PIL image
                img_bytes = sample_dict["image_bytes"]
                sample_dict["image"] = Image.open(io.BytesIO(img_bytes))
                del sample_dict["image_bytes"]
            loaded_data.append(sample_dict)
        
        print(f"Successfully loaded {len(loaded_data)} samples from checkpoint")
        return loaded_data
    except Exception as e:
        print(f"Error loading checkpoint: {e}")
        return None

# Function to process a single label
def process_label(label, samples_per_label=500):
    label_name = label_names[label]
    result_samples = []
    
    # Get indices for this label
    indices = [i for i, sample in enumerate(ds["train"]) if sample["label"] == label]
    
    # If there are fewer than samples_per_label samples, use all of them
    if len(indices) <= samples_per_label:
        sampled_indices = indices
        message = f"Label {label} ({label_name}): Using all {len(indices)} samples"
    else:
        # Randomly sample samples_per_label indices
        sampled_indices = np.random.choice(indices, samples_per_label, replace=False)
        # Convert numpy.int64 to regular Python int to avoid indexing issues
        sampled_indices = [int(idx) for idx in sampled_indices]
        message = f"Label {label} ({label_name}): Sampled {samples_per_label} out of {len(indices)} samples"
    
    # Get the actual samples
    for idx in sampled_indices:
        result_samples.append(ds["train"][idx])
    
    return result_samples, message

# Check if checkpoint file exists
if os.path.exists(checkpoint_file):
    # Try to load from checkpoint
    sampled_data = load_checkpoint()
    if sampled_data is not None:
        print(f"Successfully loaded {len(sampled_data)} samples from checkpoint")
    else:
        # If loading failed, run the sampling process
        print("Failed to load from checkpoint. Running sampling process...")
        sampled_data = None
else:
    print("No checkpoint file found. Running sampling process...")
    sampled_data = None

# If we couldn't load from checkpoint, run the sampling process
if sampled_data is None:
    # Parallel sampling approach
    samples_per_label = SAMPLES_PER_LABEL
    sampled_data = []
    start_time = time.time()

    print("Starting parallel sampling of data...")

    # Set the number of workers based on CPU cores
    max_workers = min(16, os.cpu_count())  # Use up to 1 worker per CPU core, but no more than 16
    print(f"Using {max_workers} workers for parallel sampling")

    # Process labels in parallel
    with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
        # Submit sampling tasks for each label
        future_to_label = {executor.submit(process_label, label, samples_per_label): label 
                          for label in unique_labels}
        
        # Process results as they complete
        for future in tqdm(concurrent.futures.as_completed(future_to_label), total=len(future_to_label), desc="Sampling labels"):
            label = future_to_label[future]
            try:
                samples, message = future.result()
                sampled_data.extend(samples)
                print(message)
            except Exception as e:
                print(f"Error processing label {label}: {e}")

    end_time = time.time()
    print(f"Parallel sampling completed in {end_time - start_time:.2f} seconds")
    
    # Save the sampled data to checkpoint file
    save_checkpoint(sampled_data)

print(f"Total sampled data: {len(sampled_data)}")

In [None]:
# Create a local directory to temporarily store the images
os.makedirs("temp_images", exist_ok=True)

In [None]:
# Function to save image and upload to S3
def save_and_upload_image(image, label, index):
    # Generate a unique filename
    filename = f"{label_names[label]}_{index}_{uuid.uuid4()}.png"
    # Replace spaces with underscores in the filename
    filename = filename.replace(' ', '_')
    local_path = os.path.join("temp_images", filename)
    s3_path = f"{directory}/images/{filename}"
    
    # Save the image locally
    image.save(local_path, format="PNG")
    
    # Upload to S3
    s3.upload_file(local_path, bucket_name, s3_path)
    
    # Remove the local file
    os.remove(local_path)
    
    return f"s3://{bucket_name}/{s3_path}"

In [None]:
# Define the system prompt and task prompt for document classification
system_prompt = "You are a document classification expert who can analyze and identify document types from images. Your task is to determine the document type based on its visual appearance, layout, and content, using the provided document type definitions. Your output must be valid JSON according to the requested format."

task_prompt_template = """The <document-types> XML tags contain a markdown table of known document types for detection.
<document-types>
| Document Type | Description |
|---------------|-------------|
| advertissement | Marketing or promotional material with graphics, product information, and calls to action |
| budget | Financial document with numerical data, calculations, and monetary figures organized in tables or lists |
| email | Electronic correspondence with header information, sender/recipient details, and message body |
| file_folder | Document with tabs, labels, or folder-like structure used for organizing other documents |
| form | Structured document with fields to be filled in, checkboxes, or data collection sections |
| handwritten | Document containing primarily handwritten text rather than typed or printed content |
| invoice | Billing document with itemized list of goods/services, costs, payment terms, and company information |
| letter | Formal correspondence with letterhead, date, recipient address, salutation, and signature |
| memo | Internal business communication with brief, direct message and minimal formatting |
| news_article | Journalistic content with headlines, columns, and reporting on events or topics |
| presentation | Slides or visual aids with bullet points, graphics, and concise information for display |
| questionnaire | Document with series of questions designed to collect information from respondents |
| resume | Professional summary of a person's work experience, skills, and qualifications |
| scientific_publication | Academic paper with abstract, methodology, results, and references in formal structure |
| scientific_report | Technical document presenting research findings, data, and analysis in structured format |
| specification | Detailed technical document outlining requirements, standards, or procedures |
</document-types>

CRITICAL: You must ONLY use document types explicitly listed in the <document-types> section. Do not create, invent, or use any document type not found in this list. If a document doesn't clearly match any listed type, assign it to the most similar listed type.

Follow these steps when classifying the document image:
1. Examine the document image carefully, noting its layout, content, and visual characteristics.
2. Identify visual cues that indicate the document type (e.g., tables for budgets, letterhead for letters).
3. Match the document with one of the document types from the provided list ONLY.
4. Before finalizing, verify that your selected document type exactly matches one from the <document-types> list.

Return your response as valid JSON according to this format:
```json
{"type": "document_type_name"}
```
where document_type_name is one of the document types listed in the <document-types> section."""

In [None]:
# Function to process a single sample
def process_sample(sample, index, account_id=None):
    image = sample["image"]
    label = sample["label"]
    
    # Save and upload the image
    s3_uri = save_and_upload_image(image, label, index)
    
    # Store the original sample and S3 URI
    updated_sample = dict(sample)
    updated_sample["s3_uri"] = s3_uri
    
    # Get the mapped label
    mapped_label = label_mapping[label]
    
    # Create the JSONL record
    record = {
        "schemaVersion": "bedrock-conversation-2024",
        "system": [{
            "text": system_prompt
        }],
        "messages": [
            {
                "role": "user",
                "content": [
                    {
                        "text": task_prompt_template
                    },
                    {
                        "image": {
                            "format": "png",
                            "source": {
                                "s3Location": {
                                    "uri": s3_uri,
                                    "bucketOwner": account_id if account_id else boto3.client('sts').get_caller_identity().get('Account')
                                }
                            }
                        }
                    }
                ]
            },
            {
                "role": "assistant",
                "content": [{
                    "text": f"""```json
                    {{"type": {mapped_label}}}
                    ```
                    """
                }]
            }
        ]
    }
    
    return updated_sample, record

# Process the sampled data in parallel
updated_samples = []
jsonl_records = []

# Get AWS account ID once to avoid repeated calls
account_id = boto3.client('sts').get_caller_identity().get('Account')

# Set the number of workers based on CPU cores
max_workers = min(32, os.cpu_count() * 2)  # Use up to 2 workers per CPU core, but no more than 32
print(f"Using {max_workers} workers for parallel processing")

start_time = time.time()

# Create a partial function with the account_id
process_sample_with_account = partial(process_sample, account_id=account_id)

# Process samples in parallel with standard tqdm
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
    # Submit all tasks and create a dictionary to track them
    future_to_idx = {executor.submit(process_sample_with_account, sample, i): i 
                    for i, sample in enumerate(sampled_data)}
    
    # Process results as they complete
    for future in tqdm(concurrent.futures.as_completed(future_to_idx), total=len(future_to_idx)):
        try:
            updated_sample, record = future.result()
            updated_samples.append(updated_sample)
            jsonl_records.append(record)
        except Exception as e:
            idx = future_to_idx[future]
            print(f"Error processing sample {idx}: {e}")

end_time = time.time()
print(f"Parallel processing completed in {end_time - start_time:.2f} seconds")

In [None]:
# Create a new Hugging Face dataset with the updated samples
updated_ds = Dataset.from_list(updated_samples)
print(f"Updated dataset created with {len(updated_ds)} samples")
print(f"Dataset features: {updated_ds.features}")

In [None]:
# Save the updated dataset locally
updated_ds_path = "rvl_cdip_with_s3_uris"
if len(updated_ds) > 0:  # Only save if we have samples
    updated_ds.save_to_disk(updated_ds_path)
    print(f"Updated dataset saved locally to {updated_ds_path}")
else:
    print("No samples to save locally")

In [None]:
# Upload the updated dataset to S3
# First, create a JSON file with the dataset
updated_ds_json_path = "updated_dataset.json"
with open(updated_ds_json_path, "w") as f:
    json_data = []
    for sample in updated_ds:
        # Convert PIL image to base64 for JSON serialization
        sample_dict = dict(sample)
        if "image" in sample_dict:
            # Remove the image to avoid serialization issues
            del sample_dict["image"]
        json_data.append(sample_dict)
    json.dump(json_data, f)

# Upload the JSON file to S3
s3.upload_file(updated_ds_json_path, bucket_name, f"{directory}/updated_dataset.json")
print(f"Updated dataset uploaded to S3: s3://{bucket_name}/{directory}/updated_dataset.json")

In [None]:
# Split the data into training and validation sets (90% train, 10% validation)
np.random.shuffle(jsonl_records)
split_idx = int(len(jsonl_records) * 0.9)
train_records = jsonl_records[:split_idx]
validation_records = jsonl_records[split_idx:]

print(f"Training records: {len(train_records)}")
print(f"Validation records: {len(validation_records)}")

In [None]:
# Save the JSONL files locally
train_jsonl_path = "train.jsonl"
validation_jsonl_path = "validation.jsonl"

with open(train_jsonl_path, "w") as f:
    for record in train_records:
        f.write(json.dumps(record) + "\n")

with open(validation_jsonl_path, "w") as f:
    for record in validation_records:
        f.write(json.dumps(record) + "\n")

In [None]:
# Upload the JSONL files to S3
s3.upload_file(train_jsonl_path, bucket_name, f"{directory}/train.jsonl")
s3.upload_file(validation_jsonl_path, bucket_name, f"{directory}/validation.jsonl")

print(f"Train JSONL uploaded to s3://{bucket_name}/{directory}/train.jsonl")
print(f"Validation JSONL uploaded to s3://{bucket_name}/{directory}/validation.jsonl")

# Clean up local files

In [None]:
import shutil
os.remove(train_jsonl_path)
os.remove(validation_jsonl_path)
os.remove(updated_ds_json_path)
shutil.rmtree("temp_images")

print("Local files cleaned up")

# Summary

In [None]:
print(f"Dataset preparation complete!")
print(f"Total samples: {len(jsonl_records)}")
print(f"Training samples: {len(train_records)}")
print(f"Validation samples: {len(validation_records)}")
print(f"Data uploaded to s3://{bucket_name}/{directory}/")
print(f"Updated dataset with S3 URIs saved locally to {updated_ds_path}")
print(f"Updated dataset with S3 URIs uploaded to s3://{bucket_name}/{directory}/updated_dataset.json")