<a href="https://colab.research.google.com/github/wesslen/seamless_sacrebleu_evaluation/blob/main/notebooks/02_load_hf_from_s3_updated.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
%%capture
!pip install boto3 transformers torch

In [None]:
import boto3
from botocore.config import Config
import json
import os
import re
from contextlib import contextmanager
from io import BytesIO
from tempfile import TemporaryDirectory
from transformers import (
    AutoModel,
    AutoTokenizer,
    PretrainedConfig,
    AutoProcessor,
    LlamaTokenizerFast,
    LlamaForCausalLM,
    SeamlessM4Tv2Model
)
from typing import Tuple, Optional, Union, List, Dict, Any
from pathlib import Path
import urllib3
import warnings
from IPython.display import display, HTML
from datetime import datetime

class NotebookLogger:
    """Custom logger for Jupyter notebooks with colored output"""

    COLORS = {
        'INFO': '#0066cc',
        'DEBUG': '#666666',
        'WARNING': '#ff9900',
        'ERROR': '#cc0000',
        'SUCCESS': '#009933'
    }

    def __init__(self, enable_debug=False):
        self.enable_debug = enable_debug

    def _log(self, level: str, message: str):
        timestamp = datetime.now().strftime('%H:%M:%S')
        color = self.COLORS.get(level, '#000000')
        display(HTML(
            f'<pre style="margin:0; padding:2px 0; color: {color}">'
            f'[{timestamp}] {level}: {message}'
            '</pre>'
        ))

    def info(self, message: str): self._log('INFO', message)
    def debug_log(self, message: str):
        if self.enable_debug: self._log('DEBUG', message)
    def warning(self, message: str): self._log('WARNING', message)
    def error(self, message: str): self._log('ERROR', message)
    def success(self, message: str): self._log('SUCCESS', message)

def get_s3_client(endpoint_url: Optional[str] = None, verify_ssl: bool = True):
    """Create an S3 client with configurable SSL verification."""
    if not verify_ssl:
        warnings.filterwarnings('ignore', category=urllib3.exceptions.InsecureRequestWarning)

    config = Config(retries=dict(max_attempts=3))
    return boto3.client(
        's3',
        endpoint_url=endpoint_url,
        aws_access_key_id=os.environ.get('AWS_ACCESS_KEY_ID'),
        aws_secret_access_key=os.environ.get('AWS_SECRET_ACCESS_KEY'),
        verify=verify_ssl,
        config=config
    )

def find_model_files(files: List[str], logger: NotebookLogger) -> Tuple[List[str], bool]:
    """Find model files and determine if they're sharded."""
    # Check for single safetensors file
    single_file = [f for f in files if f.endswith('model.safetensors')]
    if single_file:
        logger.debug_log("Found single safetensors file")
        return single_file, False

    # Check for sharded files
    sharded_pattern = re.compile(r'model-\d{5}-of-\d{5}\.safetensors$')
    sharded_files = [f for f in files if sharded_pattern.search(os.path.basename(f))]

    if sharded_files:
        logger.debug_log(f"Found {len(sharded_files)} sharded safetensors files")
        return sorted(sharded_files), True

    logger.debug_log("No safetensors files found")
    return [], False

def get_model_class(model_type: str, model_class: Optional[str], logger: NotebookLogger):
    """
    Determine the appropriate model class based on model type and class name.
    """
    if model_class:
        if model_class == "LlamaForCausalLM":
            return LlamaForCausalLM
        return AutoModel

    MODEL_CLASS_MAPPING = {
        "llama": LlamaForCausalLM,
        "seamless_m4t_v2": SeamlessM4Tv2Model,
    }

    model_class = MODEL_CLASS_MAPPING.get(model_type)
    if model_class:
        logger.debug_log(f"Using specific model class: {model_class.__name__}")
        return model_class

    logger.debug_log("Using default AutoModel class")
    return AutoModel

def load_model_from_s3(
    bucket: str,
    path_to_model: str,
    endpoint_url: Optional[str] = None,
    verify_ssl: bool = True,
    force_bin: bool = False,
    enable_debug: bool = False,
    model_class: Optional[str] = None
) -> Tuple[Union[AutoModel, None], Union[Any, None]]:
    """
    Load a model and its tokenizer/processor from S3 storage.

    Args:
        bucket: S3 bucket name
        path_to_model: Path to model directory in bucket
        endpoint_url: Custom S3 endpoint URL
        verify_ssl: Whether to verify SSL certificates
        force_bin: Force using .bin format even if safetensors available
        enable_debug: Enable detailed debug logging
        model_class: Specific model class to use (e.g., "LlamaForCausalLM")

    Returns:
        Tuple of (model, tokenizer_or_processor)
    """
    logger = NotebookLogger(enable_debug=enable_debug)
    logger.info(f"Starting model load from bucket: {bucket}, path: {path_to_model}")

    s3_client = get_s3_client(endpoint_url, verify_ssl)

    # List all files in the model directory
    logger.info("Listing files in bucket...")
    files = []
    paginator = s3_client.get_paginator('list_objects_v2')
    for page in paginator.paginate(Bucket=bucket, Prefix=path_to_model):
        if 'Contents' in page:
            files.extend(obj['Key'] for obj in page['Contents'])

    logger.debug_log(f"Found {len(files)} total files")

    # Find model files
    safetensors_files, is_sharded = find_model_files(files, logger)
    bin_files = [f for f in files if f.endswith('.bin')]

    with TemporaryDirectory() as temp_dir:
        temp_path = Path(temp_dir)
        logger.debug_log(f"Created temporary directory: {temp_dir}")

        # Download and read config
        config_file = next((f for f in files if f.endswith('config.json')), None)
        if not config_file:
            raise Exception("No config.json found in model directory")

        config_path = temp_path / 'config.json'
        with open(config_path, 'wb') as out:
            obj = s3_client.get_object(Bucket=bucket, Key=config_file)
            out.write(obj['Body'].read())

        with open(config_path) as f:
            config_data = json.load(f)
            model_type = config_data.get('model_type', '').lower()
            logger.debug_log(f"Detected model type: {model_type}")

        # Handle model weights
        if safetensors_files and not force_bin:
            logger.info("Using safetensors format")
            for file in safetensors_files:
                relative_path = Path(file).relative_to(path_to_model)
                target_path = temp_path / relative_path
                target_path.parent.mkdir(parents=True, exist_ok=True)

                with open(target_path, 'wb') as out:
                    obj = s3_client.get_object(Bucket=bucket, Key=file)
                    out.write(obj['Body'].read())

            if is_sharded:
                # Create index file for sharded safetensors
                num_shards = len(safetensors_files)
                index_data = {
                    "metadata": {"total_size": 0},
                    "weight_map": {
                        param: f"model-{i+1:05d}-of-{num_shards:05d}.safetensors"
                        for i, param in enumerate(range(num_shards))
                    }
                }
                with open(temp_path / "model.safetensors.index.json", "w") as f:
                    json.dump(index_data, f)

        elif bin_files:
            logger.info("Using .bin format")
            for file in bin_files:
                relative_path = Path(file).relative_to(path_to_model)
                target_path = temp_path / relative_path
                target_path.parent.mkdir(parents=True, exist_ok=True)

                with open(target_path, 'wb') as out:
                    obj = s3_client.get_object(Bucket=bucket, Key=file)
                    out.write(obj['Body'].read())
        else:
            raise Exception("No model weights files found")

        # Download all auxiliary files
        aux_files = [f for f in files if any(f.endswith(ext) for ext in [
            'tokenizer.json',
            'tokenizer_config.json',
            'special_tokens_map.json',
            'vocab.json',
            'merges.txt',
            'tokenizer.model',
            'processor_config.json',
            'preprocessor_config.json',
            'generation_config.json'
        ])]

        for file in aux_files:
            relative_path = Path(file).relative_to(path_to_model)
            target_path = temp_path / relative_path
            target_path.parent.mkdir(parents=True, exist_ok=True)

            with open(target_path, 'wb') as out:
                obj = s3_client.get_object(Bucket=bucket, Key=file)
                out.write(obj['Body'].read())

        # Load model
        try:
            logger.info("Loading model...")
            ModelClass = get_model_class(model_type, model_class, logger)

            model = ModelClass.from_pretrained(
                str(temp_path),
                local_files_only=True,
                use_safetensors=not force_bin
            )
            logger.success("Model loaded successfully")

            # Load tokenizer or processor
            tokenizer_or_processor = None
            try:
                # For Seamless models, always use AutoProcessor
                if model_type == "seamless_m4t_v2":
                    logger.info("Loading Seamless processor...")
                    tokenizer_or_processor = AutoProcessor.from_pretrained(
                        str(temp_path),
                        local_files_only=True
                    )
                    logger.success("Seamless processor loaded successfully")
                # Handle Llama models
                elif model_type == "llama":
                    logger.info("Loading LlamaTokenizerFast...")
                    tokenizer_or_processor = LlamaTokenizerFast.from_pretrained(
                        str(temp_path),
                        local_files_only=True
                    )
                    logger.success("LlamaTokenizerFast loaded successfully")
                # Default to AutoTokenizer
                else:
                    logger.info("Loading tokenizer...")
                    tokenizer_or_processor = AutoTokenizer.from_pretrained(
                        str(temp_path),
                        local_files_only=True
                    )
                    logger.success("Tokenizer loaded successfully")
            except Exception as e:
                logger.warning(f"Failed to load tokenizer/processor: {str(e)}")

            return model, tokenizer_or_processor

        except Exception as e:
            logger.error(f"Failed to load model: {str(e)}")
            raise

ModuleNotFoundError: No module named 'boto3'

In [None]:
# For Llama model
model, tokenizer = load_model_from_s3(
    bucket="my-bucket",
    path_to_model="models/llama-3.1-8b-instruct",
    endpoint_url="https://my-storage-endpoint",
    verify_ssl=False,
    model_class="LlamaForCausalLM",
    enable_debug=True
)

In [None]:
# Load Seamless model
model, processor = load_model_from_s3(
    bucket="my-bucket",
    path_to_model="models/seamless-m4t-v2-large",
    endpoint_url="https://my-storage-endpoint",
    verify_ssl=False,
    enable_debug=True
)

In [None]:
try:
    model, tokenizer = load_model_from_s3(
        bucket="my-bucket",
        path_to_model="models/my-model",
        endpoint_url="https://my-storage-endpoint",
        verify_ssl=False,
        enable_debug=True
    )
except Exception as e:
    print(f"Failed to load model: {e}")