<a href="https://colab.research.google.com/github/wesslen/seamless_sacrebleu_evaluation/blob/main/notebooks/02_load_hf_from_s3.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [3]:
import boto3
from botocore.config import Config
import json
import os
import re
from contextlib import contextmanager
from io import BytesIO
from tempfile import TemporaryDirectory
from transformers import AutoModel, AutoTokenizer, PretrainedConfig
from typing import Tuple, Optional, Union, List, Dict
from pathlib import Path
import urllib3
import warnings
from IPython.display import display, HTML
from datetime import datetime

class NotebookLogger:
    """Custom logger for Jupyter notebooks with colored output"""

    COLORS = {
        'INFO': '#0066cc',
        'DEBUG': '#666666',
        'WARNING': '#ff9900',
        'ERROR': '#cc0000',
        'SUCCESS': '#009933'
    }

    def __init__(self, enable_debug=False):
        self.enable_debug = enable_debug

    def _log(self, level: str, message: str):
        timestamp = datetime.now().strftime('%H:%M:%S')
        color = self.COLORS.get(level, '#000000')
        display(HTML(
            f'<pre style="margin:0; padding:2px 0; color: {color}">'
            f'[{timestamp}] {level}: {message}'
            '</pre>'
        ))

    def info(self, message: str):
        self._log('INFO', message)

    def debug_log(self, message: str):
        if self.enable_debug:
            self._log('DEBUG', message)

    def warning(self, message: str):
        self._log('WARNING', message)

    def error(self, message: str):
        self._log('ERROR', message)

    def success(self, message: str):
        self._log('SUCCESS', message)

def get_s3_client(endpoint_url: Optional[str] = None, verify_ssl: bool = True):
    """Create an S3 client with configurable SSL verification."""
    if not verify_ssl:
        warnings.filterwarnings('ignore', category=urllib3.exceptions.InsecureRequestWarning)

    config = Config(retries=dict(max_attempts=3))

    return boto3.client(
        's3',
        endpoint_url=endpoint_url,
        aws_access_key_id=os.environ.get('AWS_ACCESS_KEY_ID'),
        aws_secret_access_key=os.environ.get('AWS_SECRET_ACCESS_KEY'),
        verify=verify_ssl,
        config=config
    )

def find_model_files(files: List[str], logger: NotebookLogger) -> Tuple[List[str], bool]:
    """
    Find model files following either pattern:
    - model.safetensors
    - model-00001-of-00002.safetensors (sharded)

    Returns:
        Tuple of (list of files, is_sharded)
    """
    # First look for single safetensors file
    single_file = [f for f in files if f.endswith('model.safetensors')]
    if single_file:
        logger.debug_log("Found single safetensors file")
        return single_file, False

    # Look for sharded files
    sharded_pattern = re.compile(r'model-\d{5}-of-\d{5}\.safetensors$')
    sharded_files = [f for f in files if sharded_pattern.search(os.path.basename(f))]

    if sharded_files:
        logger.debug_log(f"Found {len(sharded_files)} sharded safetensors files")
        # Sort to ensure consistent ordering
        return sorted(sharded_files), True

    logger.debug_log("No safetensors files found")
    return [], False

def load_model_from_s3(
    bucket: str,
    path_to_model: str,
    endpoint_url: Optional[str] = None,
    verify_ssl: bool = True,
    force_bin: bool = False,
    enable_debug: bool = False,
    model_class: Optional[type] = None
) -> Tuple[Union[AutoModel, None], Union[AutoTokenizer, None]]:
    """
    Load a model and tokenizer from S3 storage, supporting both single and sharded safetensors.

    Args:
        bucket (str): S3 bucket name
        path_to_model (str): Path to model directory in bucket
        endpoint_url (str, optional): Custom S3 endpoint URL
        verify_ssl (bool): Whether to verify SSL certificates
        force_bin (bool): Force using .bin format even if .safetensors is available
        enable_debug (bool): Enable detailed debug logging
        model_class (type, optional): Specific model class to use (e.g., AutoModelForCausalLM)
    """
    logger = NotebookLogger(enable_debug=enable_debug)
    logger.info(f"Starting model load from bucket: {bucket}, path: {path_to_model}")

    s3_client = get_s3_client(endpoint_url, verify_ssl)

    # List all files in the model directory
    logger.info("Listing files in bucket...")
    files = []
    paginator = s3_client.get_paginator('list_objects_v2')
    for page in paginator.paginate(Bucket=bucket, Prefix=path_to_model):
        if 'Contents' in page:
            files.extend(obj['Key'] for obj in page['Contents'])

    logger.debug_log(f"Found {len(files)} total files")
    for file in files:
        logger.debug_log(f"Found file: {file}")

    # Find model files
    safetensors_files, is_sharded = find_model_files(files, logger)
    bin_files = [f for f in files if f.endswith('.bin')]

    logger.info(f"Found {len(safetensors_files)} safetensor files (sharded: {is_sharded}) and {len(bin_files)} bin files")

    with TemporaryDirectory() as temp_dir:
        temp_path = Path(temp_dir)
        logger.debug_log(f"Created temporary directory: {temp_dir}")

        # Download config.json first
        config_file = next((f for f in files if f.endswith('config.json')), None)
        if config_file:
            config_path = temp_path / 'config.json'
            logger.info(f"Downloading config file: {config_file}")
            with open(config_path, 'wb') as out:
                obj = s3_client.get_object(Bucket=bucket, Key=config_file)
                out.write(obj['Body'].read())
            logger.success("Config file downloaded successfully")

            # Read config to check model and tokenizer type
            with open(config_path) as f:
                config_data = json.load(f)
                model_type = config_data.get('model_type', '').lower()
                logger.debug_log(f"Detected model type: {model_type}")
        else:
            logger.error("No config.json found!")
            raise Exception("No config.json found in model directory")

        # Handle model weights
        if safetensors_files and not force_bin:
            logger.info("Using safetensors format for model loading")
            for file in safetensors_files:
                relative_path = Path(file).relative_to(path_to_model)
                target_path = temp_path / relative_path
                target_path.parent.mkdir(parents=True, exist_ok=True)

                logger.info(f"Downloading safetensors file: {file}")
                with open(target_path, 'wb') as out:
                    obj = s3_client.get_object(Bucket=bucket, Key=file)
                    out.write(obj['Body'].read())
                logger.success(f"Downloaded: {relative_path}")

            if is_sharded:
                index_data = {
                    "metadata": {"total_size": 0},
                    "weight_map": {
                        "": "model-00001-of-00002.safetensors"
                    }
                }
                with open(temp_path / "model.safetensors.index.json", "w") as f:
                    json.dump(index_data, f)
                logger.debug_log("Created index file for sharded safetensors")

        elif bin_files:
            logger.info("Using .bin format for model loading")
            for file in bin_files:
                relative_path = Path(file).relative_to(path_to_model)
                target_path = temp_path / relative_path
                target_path.parent.mkdir(parents=True, exist_ok=True)

                logger.info(f"Downloading bin file: {file}")
                with open(target_path, 'wb') as out:
                    obj = s3_client.get_object(Bucket=bucket, Key=file)
                    out.write(obj['Body'].read())
                logger.success(f"Downloaded: {relative_path}")
        else:
            logger.error("No model weights files found!")
            raise Exception("No model weights files (safetensors or bin) found")

        # Download tokenizer files
        tokenizer_files = [f for f in files if any(f.endswith(ext) for ext in [
            'tokenizer.json',
            'tokenizer_config.json',
            'special_tokens_map.json',
            'vocab.json',
            'merges.txt',
            'tokenizer.model'
        ])]

        has_tokenizer_files = False
        if tokenizer_files:
            logger.info(f"Found {len(tokenizer_files)} tokenizer files")
            for file in tokenizer_files:
                relative_path = Path(file).relative_to(path_to_model)
                target_path = temp_path / relative_path
                target_path.parent.mkdir(parents=True, exist_ok=True)

                logger.info(f"Downloading tokenizer file: {file}")
                with open(target_path, 'wb') as out:
                    obj = s3_client.get_object(Bucket=bucket, Key=file)
                    out.write(obj['Body'].read())
                logger.success(f"Downloaded: {relative_path}")
                has_tokenizer_files = True
        else:
            logger.warning("No tokenizer files found")

        # Debug: list all files in temp directory
        logger.debug_log("Files in temporary directory:")
        for file in Path(temp_dir).rglob('*'):
            if file.is_file():
                logger.debug_log(f"  {file.relative_to(temp_dir)}")

        # Load the model and tokenizer
        try:
            logger.info("Loading model from temporary directory")

            if model_class:
                ModelClass = model_class
            elif model_type == 'llama':
                from transformers import LlamaForCausalLM
                ModelClass = LlamaForCausalLM
            else:
                from transformers import AutoModel
                ModelClass = AutoModel

            model = ModelClass.from_pretrained(
                str(temp_path),
                local_files_only=True,
                use_safetensors=not force_bin,
            )
            logger.success("Model loaded successfully")

            tokenizer = None
            if has_tokenizer_files:
                try:
                    logger.info("Loading tokenizer")

                    # For Llama models, always use LlamaTokenizerFast
                    if model_type == 'llama':
                        from transformers import LlamaTokenizerFast
                        tokenizer = LlamaTokenizerFast.from_pretrained(
                            str(temp_path),
                            local_files_only=True
                        )
                    else:
                        from transformers import AutoTokenizer
                        tokenizer = AutoTokenizer.from_pretrained(
                            str(temp_path),
                            local_files_only=True
                        )
                    logger.success("Tokenizer loaded successfully")
                except Exception as e:
                    logger.warning(f"Failed to load tokenizer: {str(e)}")

            return model, tokenizer

        except Exception as e:
            logger.error(f"Failed to load model: {str(e)}")
            raise

ModuleNotFoundError: No module named 'boto3'

In [None]:

    # ... [previous code remains the same until tokenizer loading] ...

    # Load the model and tokenizer
    try:
        logger.info("Loading model from temporary directory")
        if model_class:
            logger.debug_log(f"Using specific model class: {model_class}")
            if model_class == "LlamaForCausalLM":
                from transformers import LlamaForCausalLM
                model = LlamaForCausalLM.from_pretrained(
                    str(temp_path),
                    local_files_only=True,
                    use_safetensors=not force_bin,
                )
            # Add other model classes as needed
        else:
            model = AutoModel.from_pretrained(
                str(temp_path),
                local_files_only=True,
                use_safetensors=not force_bin,
            )
        logger.success("Model loaded successfully")

        tokenizer = None
        if has_tokenizer_files:
            try:
                logger.info("Loading tokenizer")
                # Special handling for Llama models
                if 'LlamaTokenizer' in tokenizer_type or any(f.endswith('tokenizer.model') for f in tokenizer_files):
                    logger.debug_log("Using LlamaTokenizerFast")
                    from transformers import LlamaTokenizerFast
                    tokenizer = LlamaTokenizerFast.from_pretrained(
                        str(temp_path),
                        local_files_only=True
                    )
                else:
                    tokenizer = AutoTokenizer.from_pretrained(
                        str(temp_path),
                        local_files_only=True
                    )
                logger.success("Tokenizer loaded successfully")
            except Exception as e:
                logger.warning(f"Failed to load tokenizer: {str(e)}")

        return model, tokenizer

    except Exception as e:
        logger.error(f"Failed to load model: {str(e)}")
        raise

In [None]:
# Example usage with debugging enabled
try:
    model, tokenizer = load_model_from_s3(
        bucket="my-bucket",
        path_to_model="models/my-model",
        endpoint_url="https://my-storage-endpoint",
        verify_ssl=False,
        enable_debug=True
    )
except Exception as e:
    print(f"Failed to load model: {e}")