<a href="https://colab.research.google.com/github/wesslen/seamless_sacrebleu_evaluation/blob/main/notebooks/02_load_hf_from_s3.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import boto3
import json
import os
from contextlib import contextmanager
from io import BytesIO
from tempfile import NamedTemporaryFile, TemporaryDirectory
from transformers import AutoModel, AutoTokenizer, PretrainedConfig
from typing import Tuple, Optional, Union, List, Dict
from pathlib import Path

@contextmanager
def s3_fileobj(bucket: str, key: str, endpoint_url: Optional[str] = None):
    """
    Yields a file object from the filename at {bucket}/{key}

    Args:
        bucket (str): Name of the S3 bucket where your model is stored
        key (str): Path to the file within the bucket
        endpoint_url (str, optional): Custom endpoint URL for S3-compatible storage
    """
    s3 = boto3.client(
        "s3",
        endpoint_url=endpoint_url,
        aws_access_key_id=os.environ.get('AWS_ACCESS_KEY_ID'),
        aws_secret_access_key=os.environ.get('AWS_SECRET_ACCESS_KEY')
    )
    try:
        obj = s3.get_object(Bucket=bucket, Key=key)
        yield BytesIO(obj["Body"].read())
    except Exception as e:
        raise Exception(f"Failed to read {key} from bucket {bucket}: {str(e)}")

def list_s3_files(bucket: str, prefix: str, endpoint_url: Optional[str] = None) -> List[str]:
    """List all files under a prefix in an S3 bucket."""
    s3 = boto3.client(
        "s3",
        endpoint_url=endpoint_url,
        aws_access_key_id=os.environ.get('AWS_ACCESS_KEY_ID'),
        aws_secret_access_key=os.environ.get('AWS_SECRET_ACCESS_KEY')
    )

    files = []
    paginator = s3.get_paginator('list_objects_v2')
    for page in paginator.paginate(Bucket=bucket, Prefix=prefix):
        if 'Contents' in page:
            files.extend(obj['Key'] for obj in page['Contents'])
    return files

def load_model_from_s3(
    bucket: str,
    path_to_model: str,
    endpoint_url: Optional[str] = None,
    force_bin: bool = False,
) -> Tuple[Union[AutoModel, None], Union[AutoTokenizer, None]]:
    """
    Load a model and tokenizer from S3 storage, preferring .safetensors format.

    Args:
        bucket (str): S3 bucket name
        path_to_model (str): Path to model directory in bucket
        endpoint_url (str, optional): Custom S3 endpoint URL
        force_bin (bool): Force using .bin format even if .safetensors is available

    Returns:
        tuple: (model, tokenizer)
    """
    # List all files in the model directory
    files = list_s3_files(bucket, path_to_model, endpoint_url)

    # Check for single consolidated safetensors file
    safetensors_files = [f for f in files if f.endswith('.safetensors')]
    bin_files = [f for f in files if f.endswith('.bin')]

    with TemporaryDirectory() as temp_dir:
        temp_path = Path(temp_dir)
        config_path = None

        # Download config.json first as it's needed for both formats
        config_file = next((f for f in files if f.endswith('config.json')), None)
        if config_file:
            config_path = temp_path / 'config.json'
            with s3_fileobj(bucket, config_file, endpoint_url) as f:
                with open(config_path, 'wb') as out:
                    out.write(f.read())

        # Determine if we have a single safetensors file or multiple files in subfolders
        if safetensors_files and not force_bin:
            print("Using safetensors format for model loading")

            # Download all safetensors files preserving directory structure
            for file in safetensors_files:
                relative_path = Path(file).relative_to(path_to_model)
                target_path = temp_path / relative_path
                target_path.parent.mkdir(parents=True, exist_ok=True)

                with s3_fileobj(bucket, file, endpoint_url) as f:
                    with open(target_path, 'wb') as out:
                        out.write(f.read())

        # Fallback to .bin format if no safetensors or force_bin=True
        elif bin_files:
            print("Using .bin format for model loading")
            for file in bin_files:
                relative_path = Path(file).relative_to(path_to_model)
                target_path = temp_path / relative_path
                target_path.parent.mkdir(parents=True, exist_ok=True)

                with s3_fileobj(bucket, file, endpoint_url) as f:
                    with open(target_path, 'wb') as out:
                        out.write(f.read())

        # Download tokenizer files if they exist
        tokenizer_files = [f for f in files if any(f.endswith(ext) for ext in
                          ['.tokenizer.json', 'tokenizer_config.json', 'special_tokens_map.json', 'vocab.json', 'merges.txt'])]
        for file in tokenizer_files:
            relative_path = Path(file).relative_to(path_to_model)
            target_path = temp_path / relative_path
            target_path.parent.mkdir(parents=True, exist_ok=True)

            with s3_fileobj(bucket, file, endpoint_url) as f:
                with open(target_path, 'wb') as out:
                    out.write(f.read())

        # Load the model and tokenizer
        try:
            model = AutoModel.from_pretrained(
                temp_path,
                local_files_only=True,
                use_safetensors=not force_bin
            )

            tokenizer = AutoTokenizer.from_pretrained(
                temp_path,
                local_files_only=True
            ) if tokenizer_files else None

            return model, tokenizer

        except Exception as e:
            raise Exception(f"Failed to load model: {str(e)}")

In [None]:
# Load a model preferring safetensors (default)
model, tokenizer = load_model_from_s3(
    bucket="my-bucket",
    path_to_model="models/stable-diffusion-v1-5",
    endpoint_url="https://your-storagegrid-endpoint"
)

# Force using .bin format if needed
model, tokenizer = load_model_from_s3(
    bucket="my-bucket",
    path_to_model="models/stable-diffusion-v1-5",
    force_bin=True
)