<a href="https://colab.research.google.com/github/wesslen/seamless_sacrebleu_evaluation/blob/main/notebooks/02_load_hf_from_s3_updated.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
%%capture
!pip install boto3 transformers torch

In [2]:
import os
from google.colab import userdata
os.environ['AWS_ACCESS_KEY_ID'] = userdata.get('AWS_ACCESS_KEY_ID')
os.environ['AWS_SECRET_ACCESS_KEY'] = userdata.get('AWS_SECRET_ACCESS_KEY')

In [3]:
import boto3
from botocore.config import Config
import json
import os
import re
from contextlib import contextmanager
from io import BytesIO
from tempfile import TemporaryDirectory
from transformers import (
    AutoModel,
    AutoTokenizer,
    PretrainedConfig,
    AutoProcessor,
    LlamaTokenizerFast,
    LlamaForCausalLM,
    SeamlessM4Tv2Model,
    SeamlessM4TProcessor
)
from typing import Tuple, Optional, Union, List, Dict, Any
from pathlib import Path
import urllib3
import warnings
from IPython.display import display, HTML
from datetime import datetime

class NotebookLogger:
    """Custom logger for Jupyter notebooks with colored output"""

    COLORS = {
        'INFO': '#0066cc',
        'DEBUG': '#666666',
        'WARNING': '#ff9900',
        'ERROR': '#cc0000',
        'SUCCESS': '#009933'
    }

    def __init__(self, enable_debug=False):
        self.enable_debug = enable_debug

    def _log(self, level: str, message: str):
        timestamp = datetime.now().strftime('%H:%M:%S')
        color = self.COLORS.get(level, '#000000')
        display(HTML(
            f'<pre style="margin:0; padding:2px 0; color: {color}">'
            f'[{timestamp}] {level}: {message}'
            '</pre>'
        ))

    def info(self, message: str): self._log('INFO', message)
    def debug_log(self, message: str):
        if self.enable_debug: self._log('DEBUG', message)
    def warning(self, message: str): self._log('WARNING', message)
    def error(self, message: str): self._log('ERROR', message)
    def success(self, message: str): self._log('SUCCESS', message)

def get_s3_client(endpoint_url: Optional[str] = None, verify_ssl: bool = True):
    """Create an S3 client with configurable SSL verification."""
    if not verify_ssl:
        warnings.filterwarnings('ignore', category=urllib3.exceptions.InsecureRequestWarning)

    config = Config(retries=dict(max_attempts=3))
    return boto3.client(
        's3',
        endpoint_url=endpoint_url,
        aws_access_key_id=os.environ.get('AWS_ACCESS_KEY_ID'),
        aws_secret_access_key=os.environ.get('AWS_SECRET_ACCESS_KEY'),
        verify=verify_ssl,
        config=config
    )

def find_model_files(files: List[str], logger: NotebookLogger) -> Tuple[List[str], bool]:
    """Find model files and determine if they're sharded."""
    # Check for single safetensors file
    single_file = [f for f in files if f.endswith('model.safetensors')]
    if single_file:
        logger.debug_log("Found single safetensors file")
        return single_file, False

    # Check for sharded files
    sharded_pattern = re.compile(r'model-\d{5}-of-\d{5}\.safetensors$')
    sharded_files = [f for f in files if sharded_pattern.search(os.path.basename(f))]

    if sharded_files:
        logger.debug_log(f"Found {len(sharded_files)} sharded safetensors files")
        return sorted(sharded_files), True

    logger.debug_log("No safetensors files found")
    return [], False

def get_model_class(model_type: str, model_class: Optional[str], logger: NotebookLogger):
    """
    Determine the appropriate model class based on model type and class name.
    """
    if model_class:
        if model_class == "LlamaForCausalLM":
            return LlamaForCausalLM
        elif model_class == "SeamlessM4Tv2Model":
            return SeamlessM4Tv2Model
        return AutoModel

    MODEL_CLASS_MAPPING = {
        "llama": LlamaForCausalLM,
        "seamless_m4t_v2": SeamlessM4Tv2Model,
    }

    model_class = MODEL_CLASS_MAPPING.get(model_type)
    if model_class:
        logger.debug_log(f"Using specific model class: {model_class.__name__}")
        return model_class

    logger.debug_log("Using default AutoModel class")
    return AutoModel

def download_file_from_s3(s3_client, bucket: str, key: str, target_path: Path, logger: NotebookLogger) -> bool:
    """Helper function to download a file from S3 and create parent directories.
    Returns True if successful, False otherwise."""
    try:
        target_path.parent.mkdir(parents=True, exist_ok=True)
        logger.debug_log(f"Downloading: {key} to {target_path}")

        # Verify the key exists in S3
        try:
            s3_client.head_object(Bucket=bucket, Key=key)
        except:
            logger.error(f"File not found in S3: {key}")
            return False

        with open(target_path, 'wb') as out:
            obj = s3_client.get_object(Bucket=bucket, Key=key)
            out.write(obj['Body'].read())

        # Verify file was downloaded
        if not target_path.exists():
            logger.error(f"File not created at: {target_path}")
            return False

        logger.debug_log(f"Successfully downloaded {key} ({target_path.stat().st_size} bytes)")
        return True
    except Exception as e:
        logger.error(f"Error downloading {key}: {str(e)}")
        return False

def verify_directory_contents(temp_path: Path, logger: NotebookLogger):
    """Debug helper to verify directory contents"""
    logger.debug_log("\nDirectory contents:")
    logger.debug_log("-" * 50)
    for path in temp_path.rglob("*"):
        if path.is_file():
            logger.debug_log(f"File: {path.relative_to(temp_path)} ({path.stat().st_size} bytes)")
    logger.debug_log("-" * 50)

def load_model_from_s3(
    bucket: str,
    path_to_model: str,
    endpoint_url: Optional[str] = None,
    verify_ssl: bool = True,
    force_bin: bool = False,
    enable_debug: bool = True,
    model_class: Optional[str] = None
) -> Tuple[Union[AutoModel, None], Union[Any, None]]:
    """
    Load a model and its tokenizer/processor from S3 storage with enhanced debugging.
    """
    logger = NotebookLogger(enable_debug=enable_debug)
    logger.info(f"Starting model load from bucket: {bucket}, path: {path_to_model}")

    s3_client = get_s3_client(endpoint_url, verify_ssl)

    # List all files in the model directory
    logger.info("Listing files in bucket...")
    try:
        files = []
        paginator = s3_client.get_paginator('list_objects_v2')
        for page in paginator.paginate(Bucket=bucket, Prefix=path_to_model):
            if 'Contents' in page:
                files.extend(obj['Key'] for obj in page['Contents'])

        if not files:
            logger.error(f"No files found in bucket {bucket} at path {path_to_model}")
            raise Exception("No files found in specified path")

        logger.debug_log("\nFound files in S3:")
        for file in files:
            logger.debug_log(f"- {file}")
    except Exception as e:
        logger.error(f"Error listing files: {str(e)}")
        raise

    with TemporaryDirectory() as temp_dir:
        temp_path = Path(temp_dir)
        logger.debug_log(f"\nCreated temporary directory: {temp_dir}")

        # First, download config.json
        config_file = next((f for f in files if f.endswith('config.json')), None)
        if not config_file:
            logger.error("No config.json found!")
            raise Exception("No config.json found in model directory")

        config_path = temp_path / 'config.json'
        if not download_file_from_s3(s3_client, bucket, config_file, config_path, logger):
            raise Exception("Failed to download config.json")

        # Read config to determine model type
        try:
            with open(config_path) as f:
                config_data = json.load(f)
                model_type = config_data.get('model_type', '').lower()
                logger.debug_log(f"\nDetected model type: {model_type}")
        except Exception as e:
            logger.error(f"Error reading config.json: {str(e)}")
            raise

        # Define required files for Seamless model
        seamless_required_files = [
            'processor_config.json',
            'tokenizer_config.json',
            'special_tokens_map.json',
            'tokenizer.model',
            'preprocessor_config.json'
        ]

        # Download all auxiliary files first
        aux_files = [f for f in files if any(f.endswith(ext) for ext in [
            'tokenizer.json',
            'tokenizer_config.json',
            'special_tokens_map.json',
            'vocab.json',
            'merges.txt',
            'tokenizer.model',
            'processor_config.json',
            'preprocessor_config.json',
            'generation_config.json',
            'config.json'
        ])]

        logger.debug_log("\nDownloading auxiliary files:")
        downloaded_files = []
        for file in aux_files:
            relative_path = Path(file).relative_to(path_to_model)
            target_path = temp_path / relative_path
            success = download_file_from_s3(s3_client, bucket, file, target_path, logger)
            if success:
                downloaded_files.append(relative_path.name)
            else:
                logger.warning(f"Failed to download auxiliary file: {file}")

        # Check for required files for Seamless model
        if model_type == "seamless_m4t_v2":
            missing_files = [f for f in seamless_required_files if f not in downloaded_files]
            if missing_files:
                logger.error(f"Missing required files for Seamless model: {missing_files}")
                raise Exception(f"Missing required files for Seamless model: {missing_files}")

        # Download model weights
        safetensors_files, is_sharded = find_model_files(files, logger)
        if safetensors_files and not force_bin:
            logger.debug_log("\nDownloading safetensors files:")
            for file in safetensors_files:
                relative_path = Path(file).relative_to(path_to_model)
                target_path = temp_path / relative_path
                if not download_file_from_s3(s3_client, bucket, file, target_path, logger):
                    raise Exception(f"Failed to download model file: {file}")

            if is_sharded:
                logger.debug_log("\nCreating index file for sharded safetensors")
                num_shards = len(safetensors_files)
                index_data = {
                    "metadata": {"total_size": 0},
                    "weight_map": {
                        param: f"model-{i+1:05d}-of-{num_shards:05d}.safetensors"
                        for i, param in enumerate([f"layer_{i}" for i in range(num_shards)])
                    }
                }
                index_path = temp_path / "model.safetensors.index.json"
                with open(index_path, "w") as f:
                    json.dump(index_data, f)
                logger.debug_log(f"Created index file at {index_path}")

        # Verify directory contents before loading
        verify_directory_contents(temp_path, logger)

        # Try loading model
        try:
            logger.info("\nLoading model...")
            model_path = str(temp_path.absolute())
            logger.debug_log(f"Loading model from path: {model_path}")

            ModelClass = get_model_class(model_type, model_class, logger)
            model = ModelClass.from_pretrained(
                model_path,
                local_files_only=True,
                use_safetensors=not force_bin,
            )
            logger.success("Model loaded successfully")

            # Load processor for Seamless model
            if model_type == "seamless_m4t_v2":
                try:
                    logger.info("Loading Seamless processor...")
                    # Use the specific SeamlessM4TProcessor class
                    processor = SeamlessM4TProcessor.from_pretrained(
                        model_path,
                        local_files_only=True,
                        use_safetensors=not force_bin
                    )
                    logger.success("Seamless processor loaded successfully")
                    return model, processor
                except Exception as e:
                    logger.error(f"Failed to load Seamless processor: {str(e)}")
                    logger.debug_log(f"Stack trace for processor error: {e.__class__.__name__}: {str(e)}")
                    raise

            # Load tokenizer for other models
            try:
                if model_type == "llama":
                    logger.info("Loading LlamaTokenizerFast...")
                    tokenizer = LlamaTokenizerFast.from_pretrained(
                        model_path,
                        local_files_only=True
                    )
                else:
                    logger.info("Loading AutoTokenizer...")
                    tokenizer = AutoTokenizer.from_pretrained(
                        model_path,
                        local_files_only=True
                    )
                logger.success("Tokenizer loaded successfully")
                return model, tokenizer
            except Exception as e:
                logger.error(f"Failed to load tokenizer: {str(e)}")
                logger.debug_log(f"Stack trace for tokenizer error: {e.__class__.__name__}: {str(e)}")
                raise

        except Exception as e:
            logger.error(f"Failed to load model: {str(e)}")
            logger.debug_log(f"Stack trace for model error: {e.__class__.__name__}: {str(e)}")
            raise

In [4]:
endpoint = "https://nyc3.digitaloceanspaces.com"

In [5]:

# import boto3
# from botocore.client import Config

# # Initialize a session using DigitalOcean Spaces.
# session = boto3.session.Session()
# client = session.client('s3',
#                         region_name='nyc3',
#                         endpoint_url='https://nyc3.digitaloceanspaces.com',
#                         aws_access_key_id=os.environ.get('AWS_ACCESS_KEY_ID'),
#                         aws_secret_access_key=os.environ.get('AWS_SECRET_ACCESS_KEY')

# # Create a new Space.
# client.create_bucket(Bucket='seamless-model')

# # List all buckets on your account.
# response = client.list_buckets()

In [6]:
# # For Llama model
# model, tokenizer = load_model_from_s3(
#     bucket="my-bucket",
#     path_to_model="models/llama-3.1-8b-instruct",
#     endpoint_url=endpoint,
#     verify_ssl=False,
#     model_class="LlamaForCausalLM",
#     enable_debug=True
# )

In [7]:
# # Load Seamless model
# model, processor = load_model_from_s3(
#     bucket="my-bucket",
#     path_to_model="models/seamless-m4t-v2-large",
#     endpoint_url=endpoint,
#     verify_ssl=False,
#     enable_debug=True
# )

In [8]:
try:
    model, tokenizer = load_model_from_s3(
        bucket="seamless-model",
        path_to_model="models/seamless-m4t-v2-large",
        endpoint_url=endpoint,
        verify_ssl=True,
        enable_debug=True
    )
except Exception as e:
    print(f"Failed to load model: {e}")

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Some weights of the model checkpoint at /tmp/tmpk78mj4vs were not used when initializing SeamlessM4Tv2Model: ['param_0', 'param_1']
- This IS expected if you are initializing SeamlessM4Tv2Model from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing SeamlessM4Tv2Model from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of SeamlessM4Tv2Model were not initialized from the model checkpoint at /tmp/tmpk78mj4vs and are newly initialized: ['lm_head.weight', 'shared.weight', 'speech_encoder.adapter.layers.0.ffn.intermediate_dense.bias', 'speech_encoder.adapter.layers.0.ffn.intermediate_dense.weight', 'speech_encoder.adapter.layers.0.ffn.output_dense.bias', 'speech_encoder.adapter.layers.0.ffn.output_den

Failed to load model: expected str, bytes or os.PathLike object, not NoneType
