# W2 Form OCR with Amazon Nova Lite - Data Preparation for Fine-tuning

## Introduction

This notebook demonstrates how to prepare data for fine-tuning Amazon Nova Lite models for OCR tasks, specifically focusing on extracting structured data from scanned W2 tax forms.

Scanned tax document OCR presents unique challenges due to the critical importance of accurately extracting numerical values, identifying form fields, and maintaining the semantic structure of the information. Fine-tuning allows our model to specialize in this high-precision extraction task.

In this notebook, we'll:

- Process a dataset of scanned W2 tax form images
- Upload the images to Amazon S3 for processing
- Create prompts optimized for OCR extraction of structured tax data
- Format the dataset for Amazon Nova Lite model fine-tuning job
- Prepare the training, validation, and test datasets for fine-tuning

## Prerequisites

Before starting, ensure you have:

- An AWS account with access to Amazon Bedrock for Amazon Nova Lite model
- Appropriate IAM permissions for Bedrock and S3
- A working SageMaker environment with the necessary libraries

You'll need to create an IAM role with the following permissions:

```
{
    "Version": "2012-10-17",
    "Statement": [
        {
            "Effect": "Allow",
            "Action": [
                "s3:GetObject",
                "s3:PutObject",
                "s3:ListBucket"
            ],
            "Resource": [
                "arn:aws:s3:::YOUR_BUCKET_NAME",
                "arn:aws:s3:::YOUR_BUCKET_NAME/*"
            ]
        },
        {
            "Effect": "Allow",
            "Action": [
                "bedrock:CreateModelCustomizationJob",
                "bedrock:GetModelCustomizationJob",
                "bedrock:ListModelCustomizationJobs",
                "bedrock:StopModelCustomizationJob"
            ],
            "Resource": "arn:aws:bedrock:us-west-2:YOUR_ACCOUNT_ID:model-customization-job/*"
        }
    ]
}
```

## Environment Setup

First, let's install and import the necessary libraries for working with OCR data and AWS services:

In [None]:
# Install required libraries
%pip install --upgrade pip
%pip install boto3 datasets pillow tqdm ipywidgets deepdiff --upgrade --quiet

In [None]:
# Restart kernel to ensure updated packages take effect
from IPython.core.display import HTML
HTML("<script>Jupyter.notebook.kernel.restart()</script>")

In [None]:
import boto3
import os
import json
import time
import shutil
from tqdm import tqdm
from datasets import load_dataset
from PIL import Image
import io
import uuid
import warnings
warnings.filterwarnings('ignore')

In [None]:
# Set AWS region
region = "us-east-1"

# Create AWS clients
session = boto3.session.Session(region_name=region)
s3_client = session.client('s3')
sts_client = session.client('sts')
bedrock = session.client(service_name="bedrock")

# Get account ID
account_id = sts_client.get_caller_identity()["Account"]

# Generate bucket name with account ID for uniqueness
bucket_name = f"nova-vision-ft-{account_id}-{region}"

print(f"Account ID: {account_id}")
print(f"Bucket name: {bucket_name}")

## Create S3 Storage for W2 Form Images

Let's create an S3 bucket to store our scanned W2 form images and processed OCR data:

In [None]:
try:
    if region == 'us-east-1':
        s3_client.create_bucket(
            Bucket=bucket_name
        )
    else:
        # For all other regions, specify the LocationConstraint
        s3_client.create_bucket(
            Bucket=bucket_name,
            CreateBucketConfiguration={'LocationConstraint': region}
        )
    print(f"Bucket {bucket_name} created successfully")
except s3_client.exceptions.BucketAlreadyExists:
    print(f"Bucket {bucket_name} already exists")
except s3_client.exceptions.BucketAlreadyOwnedByYou:
    print(f"Bucket {bucket_name} already owned by you")
except Exception as e:
    print(f"Error creating bucket: {e}")

## Download and Prepare the W2 Form Dataset

For this OCR fine-tuning task, we'll use a dataset of scanned W2 tax forms from Hugging Face. These forms contain structured tax information that we want our model to accurately extract. We'll use 1800 samples for training, 100 for validation, and 100 for testing to ensure the model learns to recognize diverse form layouts and handwriting styles.

<div style="background-color: #FFFFCC; color: #856404; padding: 15px; border-left: 6px solid #FFD700; margin-bottom: 15px;">
<h3 style="margin-top: 0; color: #856404;">⚠️ W2 Form Dataset Processing</h3>
<p>This cell downloads the synthetic W2 tax form dataset which:</p>
<ul>
  <li>Contains <b>2,000 scanned W2 form images</b> with synthetic but realistic tax data</li>
  <li>May take <b>5-10 minutes</b> to download and process depending on your connection</li>
  <li>Requires sufficient disk space for storing high-resolution document scans</li>
  <li>Contains sensitive (though synthetic) information resembling real tax forms</li>
</ul>
<p>This dataset is specifically designed for training OCR models on structured tax document extraction tasks.</p>
</div>

In [None]:
import requests
import zipfile
from tqdm import tqdm

# Create directories to store images and metadata
os.makedirs('ocr_images/train', exist_ok=True)
os.makedirs('ocr_images/val', exist_ok=True)
os.makedirs('ocr_images/test', exist_ok=True)


# Select a subset for our fine-tuning task
# We want 1200 examples total (1000 train, 100 val, 100 test)
train_data = load_dataset("singhsays/fake-w2-us-tax-form-dataset", split="train")
val_data = load_dataset("singhsays/fake-w2-us-tax-form-dataset", split="validation")
test_data = load_dataset("singhsays/fake-w2-us-tax-form-dataset", split="test")


print(f"\nNumber of training examples: {len(train_data)}")
print(f"Number of validation examples: {len(val_data)}")
print(f"Number of test examples: {len(test_data)}")

## Upload W2 Form Images to S3

Now, let's upload the scanned W2 form images to S3 for processing by our OCR model:

In [None]:
def upload_images_to_s3(dataset, subset):
    """Upload images to S3 and return paths"""
    print(f"Uploading {subset} images to S3...")
    
    s3_paths = []
    
    for i, item in enumerate(tqdm(dataset)):
        try:
            # Get image from dataset
            image = item['image']
            image_format = image.format if hasattr(image, 'format') else 'jpeg'
            
            # Convert image to bytes
            with io.BytesIO() as buffer:
                image.save(buffer, format=image_format)
                image_bytes = buffer.getvalue()

            # Define S3 path for this image
            s3_key = f"ocr_images/{subset}/img_{i}.{image_format.lower()}"
            
            # Upload image to S3
            s3_client.put_object(Bucket=bucket_name, Key=s3_key, Body=image_bytes)
            
            # Store S3 URI for later use
            s3_uri = f"s3://{bucket_name}/{s3_key}"
            
            # Save metadata about this image
            s3_paths.append({
                'index': i,
                's3_uri': s3_uri,
                'gt': item["ground_truth"]
            })
            
        except Exception as e:
            print(f"Error uploading image {i}: {e}")
    
    return s3_paths

# Upload images to S3
train_s3_paths = upload_images_to_s3(train_data, 'train')
val_s3_paths = upload_images_to_s3(val_data, 'val')
test_s3_paths = upload_images_to_s3(test_data, 'test')

print(f"Uploaded {len(train_s3_paths)} training images")
print(f"Uploaded {len(val_s3_paths)} validation images")
print(f"Uploaded {len(test_s3_paths)} test images")

## Format W2 Data for OCR Model Fine-tuning

Let's prepare the tax form data in the required format for Bedrock Amazon Nova Lite OCR fine-tuning. The goal is to teach the model to extract structured information from tax forms into consistent JSON format:

### Convert the ground truth data into the schema we need

In [None]:
def transform_schema(original):
    # Create employee section
    employee = {
          "name": original["box_e_employee_name"],
          "address": f"{original['box_e_employee_street_address']}, {original['box_e_employee_city_state_zip']}",
          "socialSecurityNumber": original["box_a_employee_ssn"]
      }

    # Create employer section
    employer = {
          "name": original["box_c_employer_name"],
          "ein": original["box_b_employer_identification_number"],
          "address": f"{original['box_c_employer_street_address']}, {original['box_c_employer_city_state_zip']}"
      }

    # Create earnings section
    earnings = {
          "wages": original["box_1_wages"],
          "socialSecurityWages": original["box_3_social_security_wages"],
          "medicareWagesAndTips": original["box_5_medicare_wages"],
          "federalIncomeTaxWithheld": original["box_2_federal_tax_withheld"],
          "stateIncomeTax": original["box_17_1_state_income_tax"] + original["box_17_2_state_income_tax"],
          "localWagesTips": original["box_18_1_local_wages"],
          "localIncomeTax": original["box_19_1_local_income_tax"]
      }

    # Create benefits section
    benefits = {
          "dependentCareBenefits": original["box_10_dependent_care_benefits"],
          "nonqualifiedPlans": original["box_11_nonqualified_plans"]
      }

    # Create multiStateEmployment section
    multiStateEmployment = {
          original["box_15_1_state"]: {
              "localWagesTips": original["box_18_1_local_wages"],
              "localIncomeTax": original["box_19_1_local_income_tax"],
              "localityName": original["box_20_1_locality"]
          },
          original["box_15_2_state"]: {
              "localWagesTips": original["box_18_2_local_wages"],
              "localIncomeTax": original["box_19_2_local_income_tax"],
              "localityName": original["box_20_2_locality"]
          }
      }

    # Combine all sections
    return {
          "employee": employee,
          "employer": employer,
          "earnings": earnings,
          "benefits": benefits,
          "multiStateEmployment": multiStateEmployment
      }

In [None]:
sample_gt_data = transform_schema(json.loads(test_s3_paths[0].get("gt"))["gt_parse"])
sample_gt_data

## Base model inference

Before fine-tuning, let's try inference with the base model and get familiar with the API syntax.

In [None]:
text_prompt="""
Analyze above W2 form, extracting all fields and bounding boxes, and return the data as a JSON object. 
Focus on capturing each field as labeled on the form, and be especially precise with multi-state information. 
For each field, ensure the following:

1. **Employee Information**: Extract 'Employee Name,' 'Employee Address,' 'Social Security Number,' etc.
2. **Employer Information**: Include 'Employer Name,' 'Employer EIN,' 'Employer Address,' and 'Zip Code.'
3. **Earnings and Tax Information**: Extract 'Wages,' 'Social Security Wages,' 'Medicare Wages and Tips,' 'Federal Income Tax Withheld,' 'State Income Tax,' 'Local Wages / Tips,' 'Local Income Tax,' etc.
4. **Benefits and Other Deductions**: Include fields like 'Dependent Care Benefits' and 'Nonqualified Plans.'
5. **Multi-state Employment Information**: Identify all states listed on the W2, capturing information for each:
    - Ensure each state's data is complete and correct, including 'Local Wages / Tips,' 'Local Income Tax,' and 'Locality Name.'
    - Each state's information should be grouped under its abbreviation (e.g., "NC", "UT").

The JSON output should precisely reflect all information, especially multiple states, with each state’s information grouped under its corresponding abbreviation. Here is a one-shot example for structure:
```json
                {
                    "employee": {
                        "name": "Ann Hill",
                        "address": "39572 Jack Trail Apt. 308, New Sarahside, MN 56848-7193",
                        "socialSecurityNumber": "192-67-3262"
                    },
                    "employer": {
                        "name": "Bryant Ltd Group",
                        "ein": "06-6105986",
                        "address": "82582 William Cape Suite 370, Scottside, ND 93090-3134"
                    },
                    "earnings": {
                        "wages": 238111.55,
                        "socialSecurityWages": 309486.28,
                        "medicareWagesAndTips": 205695.97,
                        "federalIncomeTaxWithheld": 71007.86,
                        "stateIncomeTax": 399.0,
                        "localWagesTips": 5965.18,
                        "localIncomeTax": 399.0
                    },
                    "benefits": {
                        "dependentCareBenefits": 198,
                        "nonqualifiedPlans": 7053
                    },
                    "multiStateEmployment": {
                        "NC": {
                            "localWagesTips": 287711.19,
                            "localIncomeTax": 46607.9,
                            "localityName": "Millier Oval"
                        },
                        "UT": {
                            "localWagesTips": 301013.17,
                            "localIncomeTax": 24688.05,
                            "localityName": "Gomez Covas"
                        }
                    }
                }
```
"""

In [None]:
def process_w2(s3_uri):

    messages = [
        {
            "role": "user",
            "content": [
                {
                    "image": {
                    "format": "png",
                    "source": {
                        "s3Location": {
                            "uri": s3_uri,
                            "bucketOwner" : account_id
                        }
                    }
                    }
                },
                {
                    "text": text_prompt

                }
            ]
        },
        {
            "role": "assistant",
            "content": [
                {"text": "```json"}
            ]
        }
    ]
    return messages

In [None]:
bedrock_client = boto3.client("bedrock-runtime", region_name="us-east-1")
nova_lite_id = "us.amazon.nova-lite-v1:0"

response = bedrock_client.converse(
            modelId=nova_lite_id,
            messages=process_w2(test_s3_paths[0].get("s3_uri")),
            inferenceConfig={"maxTokens": 2048, "temperature": 0.0, "topP": 0.1, "stopSequences": ["```"]},
        )

In [None]:
prediction = json.loads(response["output"]["message"]["content"][0]["text"].replace("```", ""))
prediction

In [None]:
from deepdiff import DeepDiff
diff = DeepDiff(sample_gt_data, prediction, ignore_order=True)
diff

## Prepare dataset for Fine-tuning job

In [None]:
def create_jsonl_entry(item, s3_uri):
    """Create a JSONL entry in the Bedrock conversation schema format"""
    
    # Extract conversation components
    gt = transform_schema(item["gt_parse"])
    
    # Create entry in the required format
    return {
        "schemaVersion": "bedrock-conversation-2024",
        "messages": [
            {
                "role": "user",
                "content": [
                    {
                        "text": text_prompt
                    },
                    {
                        "image": {
                            "format": "png",
                            "source": {
                                "s3Location": {
                                    "uri": s3_uri,
                                    "bucketOwner": account_id
                                }
                            }
                        }
                    }
                ]
            },
            {
                "role": "assistant",
                "content": [
                    {
                        "text": f"```json\n{json.dumps(gt)}\n```"
                    }
                ]
            }
        ]
    }

def prepare_dataset_jsonl(s3_paths, output_file):
    """Prepare dataset in JSONL format for fine-tuning"""
    
    with open(output_file, 'w') as f:
        for item in s3_paths:
            # Create JSONL entry
            entry = create_jsonl_entry(json.loads(item['gt']), item['s3_uri'])
            
            # Write to file
            f.write(json.dumps(entry) + '\n')
    
    print(f"Created {output_file} with {len(s3_paths)} samples")

# Prepare JSONL files
prepare_dataset_jsonl(train_s3_paths, 'train.jsonl')
prepare_dataset_jsonl(val_s3_paths, 'validation.jsonl')
prepare_dataset_jsonl(test_s3_paths, 'test.jsonl')

## Upload OCR Training Data to S3

Let's upload our prepared JSONL files containing W2 form images and structured extraction targets to S3:

In [None]:
# Upload JSONL files to S3
s3_client.upload_file('train.jsonl', bucket_name, 'data/train.jsonl')
s3_client.upload_file('validation.jsonl', bucket_name, 'data/validation.jsonl')
s3_client.upload_file('test.jsonl', bucket_name, 'data/test.jsonl')

# Store S3 URIs for later use
train_data_uri = f"s3://{bucket_name}/data/train.jsonl"
validation_data_uri = f"s3://{bucket_name}/data/validation.jsonl"
test_data_uri = f"s3://{bucket_name}/data/test.jsonl"

print(f"Training data URI: {train_data_uri}")
print(f"Validation data URI: {validation_data_uri}")
print(f"Test data URI: {test_data_uri}")

## Create IAM Role for OCR Model Fine-tuning
Let's create an IAM role with the necessary permissions for fine-tuning our W2 form OCR model:

In [None]:
# Generate policy documents
trust_policy_doc = {
    "Version": "2012-10-17",
    "Statement": [
        {
            "Effect": "Allow",
            "Principal": {
                "Service": "bedrock.amazonaws.com"
            },
            "Action": "sts:AssumeRole",
            "Condition": {
                "StringEquals": {
                    "aws:SourceAccount": account_id
                },
                "ArnLike": {
                    "aws:SourceArn": f"arn:aws:bedrock:{region}:{account_id}:model-customization-job/*"
                }
            }
        }
    ]
}

access_policy_doc = {
    "Version": "2012-10-17",
    "Statement": [
        {
            "Effect": "Allow",
            "Action": [
                "s3:GetObject",
                "s3:PutObject",
                "s3:ListBucket",
                "s3:GetBucketLocation"
            ],
            "Resource": [
                f"arn:aws:s3:::{bucket_name}",
                f"arn:aws:s3:::{bucket_name}/*"
            ]
        }
    ]
}


# Create IAM client
iam = session.client('iam')

# Role name for fine-tuning
role_name = f"NovaVisionFineTuningRole-{int(time.time())}"
policy_name = f"NovaVisionFineTuningPolicy-{int(time.time())}"

# Create role
try:
    response = iam.create_role(
        RoleName=role_name,
        AssumeRolePolicyDocument=json.dumps(trust_policy_doc),
        Description="Role for fine-tuning Nova vision model with Amazon Bedrock"
    )
    
    role_arn = response["Role"]["Arn"]
    print(f"Created role: {role_arn}")
    
    # Create policy
    response = iam.create_policy(
        PolicyName=policy_name,
        PolicyDocument=json.dumps(access_policy_doc)
    )
    
    policy_arn = response["Policy"]["Arn"]
    print(f"Created policy: {policy_arn}")
    
    # Attach policy to role
    iam.attach_role_policy(
        RoleName=role_name,
        PolicyArn=policy_arn
    )
    
    print(f"Attached policy to role")
    
except Exception as e:
    print(f"Error creating IAM resources: {e}")

# Allow time for IAM role propagation
print("Waiting for IAM role to propagate...")
time.sleep(10)


## Save Variables for Fine-tuning
Let's save the important variables we'll need in the next notebook

In [None]:
# Store variables for the next notebook
%store bucket_name
%store train_data_uri
%store validation_data_uri
%store test_data_uri
%store role_arn
%store role_name
%store policy_arn
%store text_prompt
%store test_s3_paths
%store account_id

print("Variables saved for use in the next notebook")

## Conclusion

In this notebook, we prepared the data needed for fine-tuning a specialized OCR model for W2 tax form extraction. We:

- Processed a dataset of scanned W2 tax forms with structured information
- Uploaded the form images to S3 for model training
- Developed prompts specifically designed for tax document OCR extraction
- Created training data for fine-tuning Nova Lite
- Set up the necessary IAM roles and permissions for the fine-tuning process