<a href="https://colab.research.google.com/github/wesslen/seamless_sacrebleu_evaluation/blob/main/notebooks/02_load_hf_from_s3.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import boto3
from botocore.config import Config
import json
import os
from contextlib import contextmanager
from io import BytesIO
from tempfile import TemporaryDirectory
from transformers import AutoModel, AutoTokenizer, PretrainedConfig
from typing import Tuple, Optional, Union, List, Dict
from pathlib import Path
import urllib3
import warnings
from IPython.display import display, HTML
from datetime import datetime
import sys

class NotebookLogger:
    """Custom logger for Jupyter notebooks with colored output"""

    COLORS = {
        'INFO': '#0066cc',
        'DEBUG': '#666666',
        'WARNING': '#ff9900',
        'ERROR': '#cc0000',
        'SUCCESS': '#009933'
    }

    def __init__(self, debug=False):
        self.debug = debug

    def _log(self, level: str, message: str):
        timestamp = datetime.now().strftime('%H:%M:%S')
        color = self.COLORS.get(level, '#000000')
        display(HTML(
            f'<pre style="margin:0; padding:2px 0; color: {color}">'
            f'[{timestamp}] {level}: {message}'
            '</pre>'
        ))

    def info(self, message: str):
        self._log('INFO', message)

    def debug(self, message: str):
        if self.debug:
            self._log('DEBUG', message)

    def warning(self, message: str):
        self._log('WARNING', message)

    def error(self, message: str):
        self._log('ERROR', message)

    def success(self, message: str):
        self._log('SUCCESS', message)

def get_s3_client(endpoint_url: Optional[str] = None, verify_ssl: bool = True, logger: NotebookLogger = None):
    """Create an S3 client with configurable SSL verification."""
    if logger:
        logger.debug(f"Creating S3 client with endpoint: {endpoint_url}, verify_ssl: {verify_ssl}")

    if not verify_ssl:
        warnings.filterwarnings('ignore', category=urllib3.exceptions.InsecureRequestWarning)

    config = Config(retries=dict(max_attempts=3))

    return boto3.client(
        's3',
        endpoint_url=endpoint_url,
        aws_access_key_id=os.environ.get('AWS_ACCESS_KEY_ID'),
        aws_secret_access_key=os.environ.get('AWS_SECRET_ACCESS_KEY'),
        verify=verify_ssl,
        config=config
    )

def list_s3_files(bucket: str, prefix: str, endpoint_url: Optional[str] = None,
                  verify_ssl: bool = True, logger: NotebookLogger = None) -> List[str]:
    """List all files under a prefix in an S3 bucket."""
    s3_client = get_s3_client(endpoint_url, verify_ssl, logger)

    if logger:
        logger.debug(f"Listing files in bucket: {bucket}, prefix: {prefix}")

    files = []
    paginator = s3_client.get_paginator('list_objects_v2')
    for page in paginator.paginate(Bucket=bucket, Prefix=prefix):
        if 'Contents' in page:
            files.extend(obj['Key'] for obj in page['Contents'])

    if logger:
        logger.debug(f"Found {len(files)} files")
        for file in files:
            logger.debug(f"Found file: {file}")

    return files

def load_model_from_s3(
    bucket: str,
    path_to_model: str,
    endpoint_url: Optional[str] = None,
    verify_ssl: bool = True,
    force_bin: bool = False,
    debug: bool = False
) -> Tuple[Union[AutoModel, None], Union[AutoTokenizer, None]]:
    """
    Load a model and tokenizer from S3 storage, preferring .safetensors format.
    """
    logger = NotebookLogger(debug=debug)
    logger.info(f"Starting model load from bucket: {bucket}, path: {path_to_model}")

    # List all files in the model directory
    files = list_s3_files(bucket, path_to_model, endpoint_url, verify_ssl, logger)

    # Check for model files
    safetensors_files = [f for f in files if f.endswith('.safetensors')]
    bin_files = [f for f in files if f.endswith('.bin')]

    logger.info(f"Found {len(safetensors_files)} safetensor files and {len(bin_files)} bin files")

    with TemporaryDirectory() as temp_dir:
        temp_path = Path(temp_dir)
        logger.debug(f"Created temporary directory: {temp_dir}")

        # Download config.json first
        config_file = next((f for f in files if f.endswith('config.json')), None)
        if config_file:
            config_path = temp_path / 'config.json'
            logger.info(f"Downloading config file: {config_file}")
            s3_client = get_s3_client(endpoint_url, verify_ssl, logger)
            with open(config_path, 'wb') as out:
                obj = s3_client.get_object(Bucket=bucket, Key=config_file)
                out.write(obj['Body'].read())
            logger.success("Config file downloaded successfully")
        else:
            logger.error("No config.json found!")
            raise Exception("No config.json found in model directory")

        # Handle model weights
        if safetensors_files and not force_bin:
            logger.info("Using safetensors format for model loading")
            for file in safetensors_files:
                relative_path = Path(file).relative_to(path_to_model)
                target_path = temp_path / relative_path
                target_path.parent.mkdir(parents=True, exist_ok=True)

                logger.info(f"Downloading safetensors file: {file}")
                with open(target_path, 'wb') as out:
                    obj = s3_client.get_object(Bucket=bucket, Key=file)
                    out.write(obj['Body'].read())
                logger.success(f"Downloaded: {relative_path}")

        elif bin_files:
            logger.info("Using .bin format for model loading")
            for file in bin_files:
                relative_path = Path(file).relative_to(path_to_model)
                target_path = temp_path / relative_path
                target_path.parent.mkdir(parents=True, exist_ok=True)

                logger.info(f"Downloading bin file: {file}")
                with open(target_path, 'wb') as out:
                    obj = s3_client.get_object(Bucket=bucket, Key=file)
                    out.write(obj['Body'].read())
                logger.success(f"Downloaded: {relative_path}")

        # Download tokenizer files
        tokenizer_files = [f for f in files if any(f.endswith(ext) for ext in
                          ['.tokenizer.json', 'tokenizer_config.json', 'special_tokens_map.json', 'vocab.json', 'merges.txt'])]

        for file in tokenizer_files:
            relative_path = Path(file).relative_to(path_to_model)
            target_path = temp_path / relative_path
            target_path.parent.mkdir(parents=True, exist_ok=True)

            logger.info(f"Downloading tokenizer file: {file}")
            with open(target_path, 'wb') as out:
                obj = s3_client.get_object(Bucket=bucket, Key=file)
                out.write(obj['Body'].read())
            logger.success(f"Downloaded: {relative_path}")

        # Debug: list all files in temp directory
        logger.debug("Files in temporary directory:")
        for file in Path(temp_dir).rglob('*'):
            if file.is_file():
                logger.debug(f"  {file.relative_to(temp_dir)}")

        # Load the model and tokenizer
        try:
            logger.info("Loading model from temporary directory")
            model = AutoModel.from_pretrained(
                temp_path,
                local_files_only=True,
                use_safetensors=not force_bin
            )
            logger.success("Model loaded successfully")

            if tokenizer_files:
                logger.info("Loading tokenizer")
                tokenizer = AutoTokenizer.from_pretrained(
                    temp_path,
                    local_files_only=True
                )
                logger.success("Tokenizer loaded successfully")
            else:
                logger.warning("No tokenizer files found")
                tokenizer = None

            return model, tokenizer

        except Exception as e:
            logger.error(f"Failed to load model: {str(e)}")
            raise

In [None]:
# Example usage with debugging enabled
try:
    model, tokenizer = load_model_from_s3(
        bucket="my-bucket",
        path_to_model="models/my-model",
        endpoint_url="https://my-storage-endpoint",
        verify_ssl=False,
        debug=True  # Enable detailed logging
    )
except Exception as e:
    print(f"Failed to load model: {e}")