# Document Classification with Bedrock Models

This notebook evaluates two document classification models using the RVL-CDIP dataset:
1. Amazon Nova Lite (us.amazon.nova-lite-v1:0)
2. Provisioned Bedrock model (arn:aws:bedrock:us-east-1:195275636621:provisioned-model/qsr1bg9tbf1v)

In [None]:
# Import necessary libraries
import os
import json
import time
import uuid
import base64
import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image
import io
import boto3
from tqdm import tqdm
from datasets import load_dataset
from sklearn.metrics import accuracy_score, f1_score, recall_score, confusion_matrix

from dotenv import load_dotenv
load_dotenv()

# Set random seed for reproducibility
random.seed(42)
np.random.seed(42)

In [None]:
# Configure AWS credentials and region
region = "us-east-1"

# Initialize AWS clients
bedrock_runtime = boto3.client(service_name="bedrock-runtime", region_name=region)
s3 = boto3.client("s3", region_name=region)

# Define models with their IDs and providers
MODELS = {
    "nova_lite": {
        "id": "us.amazon.nova-lite-v1:0",
        "provider": "amazon"
    },
    "nova_pro": {
        "id": "us.amazon.nova-pro-v1:0",
        "provider": "amazon"
    },
    "nova_premier": {
        "id": "us.amazon.nova-premier-v1:0",
        "provider": "amazon"
    },
    "ft_nova_lite": {
        "id": "arn:aws:bedrock:us-east-1:195275636621:provisioned-model/qsr1bg9tbf1v",
        "provider": "amazon"
    },
    "ft_nova_lite_lr_e4_wr_10": {
        "id": "arn:aws:bedrock:us-east-1:195275636621:provisioned-model/6q5nfhuh1pnq",
        "provider": "amazon"
    },
    "claude_3_5_haiku": {
        "id": "us.anthropic.claude-3-5-haiku-20241022-v1:0",
        "provider": "anthropic"
    },
    "claude_3_5_sonnet_v2": {
        "id": "us.anthropic.claude-3-5-sonnet-20241022-v2:0",
        "provider": "anthropic"
    },
    "claude_3_7": {
        "id": "us.anthropic.claude-3-7-sonnet-20250219-v1:0",
        "provider": "anthropic"
    }
}

## 1. Load and Sample Dataset

In [None]:
# Define label mapping for RVL-CDIP dataset
label_mapping = {
    0: "advertissement",
    1: "budget",
    2: "email",
    3: "file_folder",
    4: "form",
    5: "handwritten",
    6: "invoice",
    7: "letter",
    8: "memo",
    9: "news_article",
    10: "presentation",
    11: "questionnaire",
    12: "resume",
    13: "scientific_publication",
    14: "scientific_report",
    15: "specification"
}

# Create reverse mapping for evaluation
reverse_label_mapping = {v: k for k, v in label_mapping.items()}

In [None]:
# Load the RVL-CDIP test dataset
ds = load_dataset("chainyo/rvl-cdip", split="test")
print(f"Dataset loaded: {ds}")

# Get the unique labels
unique_labels = np.unique(ds["label"])
print(f"Number of unique labels: {len(unique_labels)}")
print(f"Labels: {unique_labels}")

In [None]:
# Sample 100 random images per class
samples_per_label = 100
sampled_data = []

# Get all indices first
all_indices = list(range(len(ds)))
random.shuffle(all_indices)  # Shuffle for randomness

# Track how many samples we've collected per label
samples_count = {label: 0 for label in unique_labels}
errors_count = 0

# Process indices until we have enough samples for each label
for idx in tqdm(all_indices):
    # Stop if we have enough samples for all labels
    if all(count >= samples_per_label for count in samples_count.values()):
        break
        
    try:
        # Try to access this sample
        sample = ds[idx]
        label = sample["label"]
        
        # If we still need samples for this label, add it
        if samples_count[label] < samples_per_label:
            sampled_data.append(sample)
            samples_count[label] += 1
            # print(f"Added sample for label {label} ({label_mapping[label]}), total: {samples_count[label]}")
    except Exception as e:
        errors_count += 1
        if errors_count % 100 == 0:  # Print every 100 errors to avoid flooding output
            print(f"Error processing sample at index {idx}: {str(e)}")
        continue

# Print summary
for label in unique_labels:
    print(f"Label {label} ({label_mapping[label]}): Sampled {samples_count[label]} samples")
print(f"Total sampled data: {len(sampled_data)}")
print(f"Total errors encountered: {errors_count}")

## 2. Save Sampled Dataset to S3

In [None]:
# Create a unique bucket name
bucket_name = "idp-evaluation-data-us-east-1"
directory = "rvl-cdip-test-data"
os.makedirs("temp_images", exist_ok=True)

In [None]:
# Function to save image to S3
def save_to_s3(sample, index):
    image = sample["image"]
    label = sample["label"]
    
    # Generate filename
    label_name = label_mapping[label]
    filename = f"{label_name}_{index}.png"
    
    # Save locally first
    local_path = os.path.join("temp_images", filename)
    image.save(local_path, format="PNG")
    
    # Upload to S3
    s3_path = f"{directory}/{label_name}/{filename}"
    s3.upload_file(local_path, bucket_name, s3_path)
    
    # Remove local file
    os.remove(local_path)
    
    return f"s3://{bucket_name}/{s3_path}"

In [None]:
# Upload a small test batch
# test_batch = sampled_data
# for i, sample in enumerate(test_batch):
#     s3_uri = save_to_s3(sample, i)
#     print(f"Uploaded sample {i} to {s3_uri}")

## 3. Model Evaluation Framework

In [None]:
# Define the system prompt and task prompt for document classification
system_prompt = "You are a document classification expert who can analyze and identify document types from images."

task_prompt = """Classify this document into one of these types:
- advertissement
- budget
- email
- file_folder
- form
- handwritten
- invoice
- letter
- memo
- news_article
- presentation
- questionnaire
- resume
- scientific_publication
- scientific_report
- specification

Return your response as JSON: {"type": "document_type_name"}"""

In [None]:
def invoke_model(image, model_id, model_provider):
    # Convert image to base64
    buffered = io.BytesIO()
    image.save(buffered, format="JPEG")
    image_bytes = buffered.getvalue()
    base64_image = base64.b64encode(image_bytes).decode("utf-8")
    
    # Invoke model with retry logic
    max_retries = 5
    for attempt in range(max_retries):
        try:
            if model_provider == "amazon":
                # Nova models use the converse API with system parameter
                system = [{"text": system_prompt}]
                messages = [
                    {
                        "role": "user",
                        "content": [
                            {"image": {"format": "jpeg", "source": {"bytes": image_bytes}}},
                            {"text": task_prompt}
                        ]
                    }
                ]
                
                # Configure inference parameters
                inf_params = {"maxTokens": 1000, "topP": 0.1, "temperature": 0.0}
                
                response = bedrock_runtime.converse(
                    modelId=model_id,
                    messages=messages,
                    system=system,
                    inferenceConfig=inf_params
                )
                return response
            else:
                # Claude models use the converse API with a different format
                # Following the working example format
                message = {
                    "role": "user",
                    "content": [
                        {"text": task_prompt},
                        {
                            "image": {
                                "format": "jpeg",
                                "source": {
                                    "bytes": image_bytes
                                }
                            }
                        }
                    ]
                }
                
                messages = [message]
                
                # Send the message using converse API for Claude models
                response = bedrock_runtime.converse(
                    modelId=model_id,
                    messages=messages
                )
                return response
                
        except Exception as e:
            if attempt < max_retries - 1:
                # print(f"Retry {attempt + 1}/{max_retries}: {str(e)}")
                time.sleep(2 ** attempt)  # Exponential backoff
            else:
                return {"error": str(e)}

In [None]:
def parse_response(response, model_provider):
    if "error" in response:
        return "error", response["error"]
    
    try:
        content = response["output"]["message"]["content"][0]["text"]
    except (KeyError, IndexError):
        return "error", "Failed to extract content from Nova response"
    
    # Try to parse as JSON
    try:
        import re
        json_match = re.search(r'\{[^\{\}]*"type"\s*:\s*"[^"]+"[^\{\}]*\}', content)
        if json_match:
            json_content = json.loads(json_match.group(0))
            if "type" in json_content:
                return "success", json_content["type"].lower().strip()
    except:
        pass
        
    # If JSON parsing fails, try to extract the document type using regex
    match = re.search(r'"type"\s*:\s*"([^"]+)"', content)
    if match:
        return "success", match.group(1).lower().strip()
    
    return "unknown", content

In [None]:
# Test with a single sample
# if len(sampled_data) > 0:
#     test_sample = sampled_data[0]
#     print(f"Testing with sample of label: {label_mapping[test_sample['label']]}")
    
#     # Test Nova Lite model
#     model = MODELS["claude_3_7"]
#     response = invoke_model(test_sample["image"], model["id"], model["provider"])
#     status, prediction = parse_response(response, model["provider"])
#     print(f"Prediction: {prediction} (status: {status})")

## 4. Run Evaluation

In [None]:
# Import concurrent.futures for parallel processing
import concurrent.futures

# Function to evaluate a single sample
def evaluate_single_sample(args):
    i, sample, model_info = args
    true_label = sample["label"]
    true_label_name = label_mapping[true_label]
    
    # Invoke model
    response = invoke_model(sample["image"], model_info["id"], model_info["provider"])
    status, prediction = parse_response(response, model_info["provider"])
    
    # Return result with full response included
    return {
        "sample_idx": i,
        "true_label": true_label,
        "true_label_name": true_label_name,
        "prediction": prediction,
        "status": status,
        "correct": prediction == true_label_name,
        "full_response": response  # Add the full response
    }

# Function to evaluate a model on a subset of samples with parallel processing
def evaluate_model(model_info, samples, model_name, max_workers=4, start_idx=0):
    """
    Evaluate a model on samples with support for resuming from a specific index
    """
    results = []
    
    # Prepare arguments for processing
    args_list = [(i, sample, model_info) for i, sample in enumerate(samples) if i >= start_idx]
    
    if not args_list:
        print(f"No samples to process for {model_name}")
        return results
    
    print(f"Processing {len(args_list)} samples for {model_name}")
    
    # Use ThreadPoolExecutor for parallel processing
    with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
        # Submit all tasks and track with tqdm for progress
        for future in tqdm(
            concurrent.futures.as_completed([executor.submit(evaluate_single_sample, args) for args in args_list]),
            total=len(args_list),
            desc=f"Evaluating {model_name}"
        ):
            try:
                result = future.result()
                results.append(result)
                
                # Save incremental results after every 10 samples
                if len(results) % 10 == 0:
                    with open(f"{model_name}_partial_results.json", "w") as f:
                        json.dump(results, f)
            except Exception as e:
                print(f"Error processing sample: {str(e)}")
    
    return results

In [None]:
def save_checkpoint(model_name, results, metrics=None):
    """Save checkpoint for a model evaluation"""
    checkpoint = {
        "model_name": model_name,
        "results": results,
        "metrics": metrics,
        "timestamp": time.time()
    }
    
    # Save checkpoint to file
    with open(f"{model_name}_checkpoint.json", "w") as f:
        json.dump(checkpoint, f)
    
    print(f"Checkpoint saved for {model_name}")

def load_checkpoints():
    """Load all available checkpoints"""
    checkpoints = {}
    completed_models = []
    
    # Check for checkpoint files
    for model_name in MODELS.keys():
        checkpoint_file = f"{model_name}_checkpoint.json"
        if os.path.exists(checkpoint_file):
            try:
                with open(checkpoint_file, "r") as f:
                    checkpoint = json.load(f)
                checkpoints[model_name] = checkpoint
                
                # If metrics exist, consider this model evaluation complete
                if checkpoint.get("metrics"):
                    completed_models.append(model_name)
                    
                print(f"Found checkpoint for {model_name}")
            except Exception as e:
                print(f"Error loading checkpoint for {model_name}: {str(e)}")
    
    return checkpoints, completed_models

In [None]:
# Function to calculate metrics
def calculate_metrics(results):
    # Extract true labels and predictions
    y_true = [result["true_label_name"] for result in results]
    y_pred = [result["prediction"] for result in results]
    
    # Calculate metrics
    accuracy = accuracy_score(y_true, y_pred)
    f1 = f1_score(y_true, y_pred, average="weighted")
    recall = recall_score(y_true, y_pred, average="weighted")
    
    # Create confusion matrix
    labels = list(set(y_true))
    cm = confusion_matrix(y_true, y_pred, labels=labels)
    
    return {
        "accuracy": accuracy,
        "f1": f1,
        "recall": recall,
        "confusion_matrix": cm.tolist(),
        "labels": labels
    }

In [None]:
# Run evaluation on a small subset
test_subset = sampled_data  # Start with just 10 samples

# Dictionary to store results and metrics
all_results = {}
all_metrics = {}

# Load existing checkpoints
checkpoints, completed_models = load_checkpoints()

# Restore results and metrics from checkpoints
for model_name, checkpoint in checkpoints.items():
    if "results" in checkpoint:
        all_results[model_name] = checkpoint["results"]
    if "metrics" in checkpoint:
        all_metrics[model_name] = checkpoint["metrics"]

# Evaluate models that haven't been completed
for model_name, model_info in MODELS.items():
    # Skip if model evaluation is already complete
    if model_name in completed_models:
        print(f"Skipping {model_name} (already completed)")
        continue
        
    print(f"\nEvaluating {model_name} model...")
    
    # Check if we have partial results
    if model_name in all_results:
        print(f"Resuming evaluation for {model_name} from checkpoint")
        results = all_results[model_name]
    else:
        # Start fresh evaluation
        results = evaluate_model(model_info, test_subset, model_name)
    
    # Save results
    all_results[model_name] = results
    with open(f"{model_name}_results.json", "w") as f:
        json.dump(results, f)
    
    # Save checkpoint with results
    save_checkpoint(model_name, results)
    
    # Calculate metrics
    metrics = calculate_metrics(results)
    all_metrics[model_name] = metrics
    
    # Save metrics
    with open(f"{model_name}_metrics.json", "w") as f:
        json.dump(metrics, f)
    
    # Update checkpoint with metrics
    save_checkpoint(model_name, results, metrics)
    
    # Print metrics
    print(f"{model_name.capitalize()} Model Metrics:")
    print(f"Accuracy: {metrics['accuracy']:.4f}")
    print(f"F1 Score: {metrics['f1']:.4f}")
    print(f"Recall: {metrics['recall']:.4f}")

## 5. Visualize Results

In [None]:
# Plot confusion matrix
def plot_confusion_matrix(cm, labels, title, fixed_labels=None):
    """
    Plot confusion matrix with consistent label ordering.
    
    Parameters:
    -----------
    cm : numpy.ndarray
        Confusion matrix to plot
    labels : list
        Labels used to create the confusion matrix
    title : str
        Title for the plot
    fixed_labels : list, optional
        Fixed order of labels to use. If None, will use the provided labels.
    """
    plt.figure(figsize=(10, 8))
    
    # If fixed_labels is provided, reorder the confusion matrix
    if fixed_labels is not None:
        # Create a mapping from current labels to indices
        label_to_idx = {label: i for i, label in enumerate(labels)}
        
        # Create a new confusion matrix with the fixed order
        n = len(fixed_labels)
        reordered_cm = np.zeros((n, n), dtype=int)
        
        # Fill in the reordered confusion matrix
        for i, true_label in enumerate(fixed_labels):
            if true_label in label_to_idx:
                for j, pred_label in enumerate(fixed_labels):
                    if pred_label in label_to_idx:
                        reordered_cm[i, j] = cm[label_to_idx[true_label], label_to_idx[pred_label]]
        
        # Use the reordered matrix and fixed labels
        cm = reordered_cm
        labels = fixed_labels
    
    sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=labels, yticklabels=labels)
    plt.title(title)
    plt.ylabel("True Label")
    plt.xlabel("Predicted Label")
    plt.xticks(rotation=90)
    plt.tight_layout()
    plt.show()

In [None]:
all_metrics["ft_nova_lite"].keys()

In [None]:
fixed_label_order = list(label_mapping.values())

# Plot Nova Lite confusion matrix
plot_confusion_matrix(
    np.array(all_metrics["ft_nova_lite"]["confusion_matrix"]),
    all_metrics["ft_nova_lite"]["labels"],
    "Confusion Matrix: Finetune Nova Lite Model on RVL-CDIP Dataset",
    fixed_labels=fixed_label_order
)

In [None]:
plot_confusion_matrix(
    np.array(all_metrics["nova_lite"]["confusion_matrix"]),
    all_metrics["nova_lite"]["labels"],
    "Confusion Matrix: Nova Lite Model on RVL-CDIP Dataset",
    fixed_labels=fixed_label_order
)

In [None]:
plot_confusion_matrix(
    np.array(all_metrics["nova_pro"]["confusion_matrix"]),
    all_metrics["nova_pro"]["labels"],
    "Confusion Matrix: Nova Pro Model on RVL-CDIP Dataset",
    fixed_labels=fixed_label_order
)

In [None]:
plot_confusion_matrix(
    np.array(all_metrics["nova_premier"]["confusion_matrix"]),
    all_metrics["nova_premier"]["labels"],
    "Confusion Matrix: Nova Premier Model on RVL-CDIP Dataset",
    fixed_labels=fixed_label_order
)

In [None]:
def plot_model_comparison(all_metrics, palette_name='colorblind'):
    """
    Plot model performance comparison with publication-quality color palette.
    Displays metrics as percentages with 2 decimal places (XX.XX%).
    
    Parameters:
    -----------
    all_metrics : dict
        Dictionary containing metrics for each model
    palette_name : str
        Name of the color palette to use:
        - 'colorblind': Seaborn's colorblind palette (default)
        - 'tableau': Tableau's color palette
        - 'deep': Seaborn's deep palette
        - 'muted': Seaborn's muted palette
        - 'pastel': Seaborn's pastel palette
        - 'bright': Seaborn's bright palette
        - 'viridis', 'plasma', 'inferno', 'magma': Matplotlib's perceptually uniform colormaps
    """
    # Set up metrics and model names
    metrics = ['Accuracy', 'F1 Score', 'Recall']
    model_names = list(all_metrics.keys())
    num_models = len(model_names)
    num_metrics = len(metrics)
    
    # Set up the figure
    fig, ax = plt.subplots(figsize=(max(12, num_models * 2), 8))
    
    # Choose color palette
    if palette_name == 'colorblind':
        colors = sns.color_palette("colorblind", n_colors=num_models)
    elif palette_name == 'tableau':
        colors = sns.color_palette("tab10", n_colors=num_models)
    elif palette_name in ['deep', 'muted', 'pastel', 'bright']:
        colors = sns.color_palette(palette_name, n_colors=num_models)
    elif palette_name in ['viridis', 'plasma', 'inferno', 'magma']:
        cmap = plt.cm.get_cmap(palette_name)
        colors = [cmap(i/num_models) for i in range(num_models)]
    else:
        colors = sns.color_palette("colorblind", n_colors=num_models)  # Default to colorblind
    
    # Set width of bars based on number of models
    total_width = 0.8
    width = total_width / num_models
    x = np.arange(num_metrics)
    
    # Plot bars for each model
    for i, (model_name, metrics_data) in enumerate(all_metrics.items()):
        # Convert values to percentages (0-100 scale)
        values = [
            metrics_data['accuracy'] * 100, 
            metrics_data['f1'] * 100, 
            metrics_data['recall'] * 100
        ]
        offset = (i - num_models / 2 + 0.5) * width
        rects = ax.bar(x + offset, values, width, label=model_name, color=colors[i], 
                      edgecolor='black', linewidth=0.5)  # Add black edge for better definition
        
        # Format labels with exactly 2 decimal places (XX.XX%)
        percentage_labels = [f"{v:.2f}%" for v in values]
        ax.bar_label(rects, padding=3, labels=percentage_labels, fontsize=8)
    
    # Add labels and legend
    ax.set_ylabel('Score (%)', fontsize=12, fontweight='bold')
    ax.set_title('Model Performance Comparison', fontsize=14, fontweight='bold')
    ax.set_xticks(x)
    ax.set_xticklabels(metrics, fontsize=11)
    
    # Set y-axis to 0-100 range with some extra space for labels
    ax.set_ylim(0, 115)
    
    # Set y-axis ticks to show percentages
    ax.set_yticks([0, 20, 40, 60, 80, 100])
    ax.set_yticklabels(['0%', '20%', '40%', '60%', '80%', '100%'])
    
    ax.spines['top'].set_visible(False)  # Remove top spine for cleaner look
    ax.spines['right'].set_visible(False)  # Remove right spine for cleaner look
    ax.legend(loc='upper center', bbox_to_anchor=(0.5, -0.15), 
             ncol=min(5, num_models), frameon=True, fontsize=10)
    
    # Add grid lines for better readability
    ax.yaxis.grid(True, linestyle='--', alpha=0.7)
    
    fig.tight_layout()
    return fig

In [None]:
all_metrics['Nova Lite'] = all_metrics.pop('nova_lite')
all_metrics['Nova Pro'] = all_metrics.pop('nova_pro')
all_metrics['Nova Premier'] = all_metrics.pop('nova_premier')
all_metrics['Nova Lite Finetuned'] = all_metrics.pop('ft_nova_lite')
plot_model_comparison(all_metrics, palette_name='pastel')

# Plot losses

In [None]:
import boto3
import pandas as pd
import matplotlib.pyplot as plt
import io

from dotenv import load_dotenv
load_dotenv()

# Download files from S3 and convert to dataframes
def load_csv_from_s3(s3_uri):
    s3_uri = s3_uri.replace('s3://', '')
    bucket_name = s3_uri.split('/')[0]
    key = '/'.join(s3_uri.split('/')[1:])
    
    s3_client = boto3.client('s3')
    response = s3_client.get_object(Bucket=bucket_name, Key=key)
    content = response['Body'].read()
    
    return pd.read_csv(io.BytesIO(content))

# Extended function to plot losses vs epoch for multiple files
def plot_losses_vs_epoch(file_uris, legends=None):
    """
    Plot losses vs epoch for multiple files.
    
    Args:
        file_uris (list): List of S3 URIs for the CSV files
        legends (list, optional): List of legend labels. If None, will use default labels.
    """
    plt.figure(figsize=(12, 6))
    
    if legends is None:
        legends = [f"Data {i+1}" for i in range(len(file_uris))]
    
    # Ensure we have the same number of legends as files
    if len(legends) != len(file_uris):
        raise ValueError("Number of legends must match number of file URIs")
    
    colors = ['b', 'r', 'g', 'c', 'm', 'y', 'k', 'orange', 'purple', 'brown']
    markers = ['o', 's', '^', 'v', '<', '>', 'p', '*', 'h', 'D']
    
    # Ensure we have enough colors and markers
    if len(file_uris) > len(colors):
        colors = colors * (len(file_uris) // len(colors) + 1)
    if len(file_uris) > len(markers):
        markers = markers * (len(file_uris) // len(markers) + 1)
    
    for i, (uri, legend) in enumerate(zip(file_uris, legends)):
        df = load_csv_from_s3(uri)
        
        # Determine which loss column to use based on the file name
        if 'training' in uri.lower() or 'train' in uri.lower():
            loss_col = 'training_loss' if 'training_loss' in df.columns else 'loss'
        elif 'validation' in uri.lower() or 'val' in uri.lower():
            loss_col = 'validation_loss' if 'validation_loss' in df.columns else 'loss'
        else:
            # Default to 'loss' if it exists, otherwise just use the first column that contains 'loss'
            loss_cols = [col for col in df.columns if 'loss' in col.lower()]
            if loss_cols:
                loss_col = loss_cols[0]
            else:
                raise ValueError(f"Could not find a loss column in the file: {uri}")
        
        # Ensure epoch column exists
        if 'epoch_number' not in df.columns:
            epoch_cols = [col for col in df.columns if 'epoch' in col.lower()]
            if epoch_cols:
                epoch_col = epoch_cols[0]
            else:
                raise ValueError(f"Could not find an epoch column in the file: {uri}")
        else:
            epoch_col = 'epoch_number'
        
        # Aggregate by epoch
        df_by_epoch = df.groupby(epoch_col)[loss_col].mean().reset_index()
        
        # Plot with line and markers
        plt.plot(df_by_epoch[epoch_col], df_by_epoch[loss_col], f'{colors[i]}-', 
                 label=legend, linewidth=2)
        plt.plot(df_by_epoch[epoch_col], df_by_epoch[loss_col], f'{colors[i]}{markers[i]}', alpha=0.7)
    
    # Add labels and legend
    plt.xlabel('Epoch Number', fontsize=12)
    plt.ylabel('Loss', fontsize=12)
    plt.title('Losses vs Epoch Number', fontsize=14)
    plt.legend(fontsize=11)
    plt.grid(True, alpha=0.3)
    plt.yscale('log')
    
    plt.tight_layout()
    plt.show()


plot_losses_vs_epoch(
    ['s3://idp-model-finetune-output-us-east-1/model-customization-job-vc1pkgkffpjp/training_artifacts/step_wise_training_metrics.csv',
     's3://idp-model-finetune-output-us-east-1/model-customization-job-vc1pkgkffpjp/validation_artifacts/post_fine_tuning_validation/validation/validation_metrics.csv',
     's3://idp-model-finetune-output-us-east-1/model-customization-job-joe0bopdzf25/training_artifacts/step_wise_training_metrics.csv',
     's3://idp-model-finetune-output-us-east-1/model-customization-job-joe0bopdzf25/validation_artifacts/post_fine_tuning_validation/validation/validation_metrics.csv',
     's3://idp-model-finetune-output-us-east-1/model-customization-job-b394czjczgj4/training_artifacts/step_wise_training_metrics.csv',
     's3://idp-model-finetune-output-us-east-1/model-customization-job-b394czjczgj4/validation_artifacts/post_fine_tuning_validation/validation/validation_metrics.csv'],
    ['Training Loss (Epoch-3)', 'Validation Loss (Epoch-3)', 'Training Loss (Epoch-5, Warmup-10)', 'Validation Loss (Epoch-5, Warmup-10)', 'Training Loss (Epoch-5, Warmup-0)', 'Validation Loss (Epoch-5, Warmup-0)']
)

In [None]:
import boto3
import pandas as pd
import matplotlib.pyplot as plt
import io

# Download files from S3 and convert to dataframes
def load_csv_from_s3(s3_uri):
    # Parse the S3 URI
    s3_uri = s3_uri.replace('s3://', '')
    bucket_name = s3_uri.split('/')[0]
    key = '/'.join(s3_uri.split('/')[1:])
    
    # Create S3 client
    s3_client = boto3.client('s3')
    
    # Download file contents
    response = s3_client.get_object(Bucket=bucket_name, Key=key)
    content = response['Body'].read()
    
    # Return as DataFrame
    return pd.read_csv(io.BytesIO(content))

# Extended function to plot losses vs step for multiple files
def plot_losses_vs_step(file_uris, legends=None):
    """
    Plot losses vs step for multiple files.
    
    Args:
        file_uris (list): List of S3 URIs for the CSV files
        legends (list, optional): List of legend labels. If None, will use default labels.
    """
    plt.figure(figsize=(12, 6))
    
    if legends is None:
        legends = [f"Data {i+1}" for i in range(len(file_uris))]
    
    # Ensure we have the same number of legends as files
    if len(legends) != len(file_uris):
        raise ValueError("Number of legends must match number of file URIs")
    
    colors = ['b', 'r', 'g', 'c', 'm', 'y', 'k', 'orange', 'purple', 'brown']
    line_styles = ['-', '--', '-.', ':', '-', '--', '-.', ':', '-', '--']
    
    # Ensure we have enough colors and line styles
    if len(file_uris) > len(colors):
        colors = colors * (len(file_uris) // len(colors) + 1)
    if len(file_uris) > len(line_styles):
        line_styles = line_styles * (len(file_uris) // len(line_styles) + 1)
    
    for i, (uri, legend) in enumerate(zip(file_uris, legends)):
        df = load_csv_from_s3(uri)
        
        # Determine which step column to use
        if 'step_number' in df.columns:
            step_col = 'step_number'
        else:
            step_cols = [col for col in df.columns if 'step' in col.lower()]
            if step_cols:
                step_col = step_cols[0]
            else:
                raise ValueError(f"Could not find a step column in the file: {uri}")
        
        # Determine which loss column to use based on the file name
        if 'training' in uri.lower() or 'train' in uri.lower():
            loss_col = 'training_loss' if 'training_loss' in df.columns else 'loss'
        elif 'validation' in uri.lower() or 'val' in uri.lower():
            loss_col = 'validation_loss' if 'validation_loss' in df.columns else 'loss'
        else:
            # Default to 'loss' if it exists, otherwise just use the first column that contains 'loss'
            loss_cols = [col for col in df.columns if 'loss' in col.lower()]
            if loss_cols:
                loss_col = loss_cols[0]
            else:
                raise ValueError(f"Could not find a loss column in the file: {uri}")
        
        # Plot with distinctive style
        plt.plot(df[step_col], df[loss_col], color=colors[i], 
                 linestyle=line_styles[i], label=legend, linewidth=2)
        
        # Optionally add markers at intervals for better readability
        # Use every nth point to avoid overcrowding
        n = max(1, len(df) // 20)  # Show about 20 markers
        plt.plot(df[step_col][::n], df[loss_col][::n], 
                 marker='o', color=colors[i], linestyle='none', alpha=0.5, markersize=4)
    
    # Add labels and legend
    plt.xlabel('Step Number', fontsize=12)
    plt.ylabel('Loss', fontsize=12)
    plt.title('Losses vs Step Number', fontsize=14)
    plt.legend(fontsize=11)
    plt.grid(True, alpha=0.3)
    plt.yscale('log')
    
    plt.tight_layout()
    plt.show()

plot_losses_vs_step(
    ['s3://idp-model-finetune-output-us-east-1/model-customization-job-vc1pkgkffpjp/training_artifacts/step_wise_training_metrics.csv',
     's3://idp-model-finetune-output-us-east-1/model-customization-job-vc1pkgkffpjp/validation_artifacts/post_fine_tuning_validation/validation/validation_metrics.csv',
     's3://idp-model-finetune-output-us-east-1/model-customization-job-joe0bopdzf25/training_artifacts/step_wise_training_metrics.csv',
     's3://idp-model-finetune-output-us-east-1/model-customization-job-joe0bopdzf25/validation_artifacts/post_fine_tuning_validation/validation/validation_metrics.csv',
     's3://idp-model-finetune-output-us-east-1/model-customization-job-b394czjczgj4/training_artifacts/step_wise_training_metrics.csv',
     's3://idp-model-finetune-output-us-east-1/model-customization-job-b394czjczgj4/validation_artifacts/post_fine_tuning_validation/validation/validation_metrics.csv'],
    ['Training Loss (Epoch-3)', 'Validation Loss (Epoch-3)', 'Training Loss (Epoch-5, Warmup-10)', 'Validation Loss (Epoch-5, Warmup-10)', 'Training Loss (Epoch-5, Warmup-0)', 'Validation Loss (Epoch-5, Warmup-0)']
)