<a href="https://colab.research.google.com/github/wesslen/seamless_sacrebleu_evaluation/blob/main/notebooks/01_sacrebleu_evaluation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!uv pip install --system transformers sacrebleu tqdm torch

[2mUsing Python 3.10.12 environment at /usr[0m
[2mAudited [1m4 packages[0m [2min 71ms[0m[0m


In [2]:
# Required imports
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, SeamlessM4Tv2Model, AutoProcessor
from sacrebleu.metrics import BLEU
from typing import List, Union, Optional, Dict, Any
from dataclasses import dataclass
from enum import Enum
import json
import logging
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from abc import ABC, abstractmethod

## Experiment Setup

In [3]:
@dataclass
class ModelConfig:
    """Base configuration for translation models"""
    model_name: str
    device: str = "cuda" if torch.cuda.is_available() else "cpu"
    batch_size: int = 8
    src_lang: str = "eng"
    tgt_lang: str = "spa"

@dataclass
class APIConfig:
    """Configuration for API-based models"""
    base_url: str
    api_key: str
    timeout: int = 30
    max_retries: int = 3

@dataclass
class DecodingConfig:
    """Configuration for model generation parameters"""
    max_new_tokens: int = 256
    num_beams: int = 4
    temperature: float = 0.7
    top_p: float = 0.95
    top_k: int = 50
    repetition_penalty: float = 1.2

class ModelType(Enum):
    """Supported model types"""
    SEAMLESS = "seamless"
    LLAMA_BASE = "llama_base"
    LLAMA_31_INSTRUCT_API = "llama_31_8b_instruct_api"
    LLAMA_31_INSTRUCT = "llama_31_8b_instruct"
    LLAMA_32_INSTRUCT = "llama_32_3b_instruct"

class DeploymentType(Enum):
    """Deployment options"""
    LOCAL = "local"
    API = "api"

## Model Initialization

In [4]:
from abc import ABC, abstractmethod


class BaseTranslator(ABC):
    """Abstract base class for all translators"""

    @abstractmethod
    def initialize_model(self):
        pass

    @abstractmethod
    def translate_batch(self, texts: List[str]) -> List[str]:
        pass

class SeamlessTranslator(BaseTranslator):
    def __init__(self, config: ModelConfig):
        self.config = config
        self.model = None
        self.processor = None

    def initialize_model(self):
        self.processor = AutoProcessor.from_pretrained(self.config.model_name)
        self.model = SeamlessM4Tv2Model.from_pretrained(
            self.config.model_name
        ).to(self.config.device)

    def translate_batch(self, texts: List[str]) -> List[str]:
        # Process input
        text_inputs = self.processor(
            text=texts,
            src_lang=self.config.src_lang,
            return_tensors="pt"
        ).to(self.config.device)

        # Generate translations
        with torch.no_grad():
            output_tokens = self.model.generate(
                **text_inputs,
                tgt_lang=self.config.tgt_lang,
                generate_speech=False
            )

        # Decode translations
        translations = [
            self.processor.decode(tokens, skip_special_tokens=True)
            for tokens in output_tokens[0].tolist()
        ]

        return translations

class PromptStrategy(ABC):
    """Abstract base class for prompt strategies"""

    @abstractmethod
    def create_prompt(self, text: str, src_lang: str, tgt_lang: str) -> str:
        pass

    @abstractmethod
    def extract_translation(self, full_output: str, prompt: str) -> str:
        pass

class DirectPromptStrategy(PromptStrategy):
    """Strategy for direct translation without system prompts"""

    def create_prompt(self, text: str, src_lang: str, tgt_lang: str) -> str:
        return f"{text}\nTranslate to {tgt_lang}:"

    def extract_translation(self, full_output: str, prompt: str) -> str:
        # For base models, simply return everything after the prompt
        return full_output[len(prompt):].strip()

class InstructPromptStrategy(PromptStrategy):
    """Strategy for instruction-tuned models using system prompts"""

    def __init__(self):
        self.system_prompt = """You are a professional translator with expertise in multiple languages.
        Your task is to translate the provided text accurately while preserving meaning, tone, and context.
        Only provide the direct translation without any explanations or notes."""

    def create_prompt(self, text: str, src_lang: str, tgt_lang: str) -> str:
        return (
            "<|begin_of_text|>"
            "<|start_header_id|>system<|end_header_id|>\n"
            f"{self.system_prompt}"
            "<|eot_id|>"
            "<|start_header_id|>user<|end_header_id|>\n"
            f"Translate from {src_lang} to {tgt_lang}:\n{text}\n"
            "<|eot_id|>"
            "<|start_header_id|>assistant<|end_header_id|>\n"
        )

    def extract_translation(self, full_output: str, prompt: str) -> str:
        # Look for assistant's response after the prompt
        assistant_marker = "<|start_header_id|>assistant<|end_header_id|>"
        eot_marker = "<|eot_id|>"

        start_idx = full_output.find(assistant_marker)
        if start_idx != -1:
            start_idx += len(assistant_marker)
            end_idx = full_output.find(eot_marker, start_idx)
            if end_idx != -1:
                return full_output[start_idx:end_idx].strip()
        return full_output[len(prompt):].strip()

class LlamaTranslator(BaseTranslator):
    """Base Llama translator that can work with different prompt strategies"""

    def __init__(self,
                 config: ModelConfig,
                 decoding_config: DecodingConfig,
                 prompt_strategy: PromptStrategy):
        self.config = config
        self.decoding_config = decoding_config
        self.prompt_strategy = prompt_strategy
        self.model = None
        self.tokenizer = None

    def initialize_model(self):
        """Initialize the model and tokenizer with proper settings"""
        self.tokenizer = AutoTokenizer.from_pretrained(self.config.model_name)

        # Handle tokenizer settings
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token
        self.tokenizer.padding_side = "left"

        self.model = AutoModelForCausalLM.from_pretrained(
            self.config.model_name,
            torch_dtype=torch.float16 if self.config.device == "cuda" else torch.float32,
            device_map="auto"
        )

    def translate_batch(self, texts: List[str]) -> List[str]:
        # Create prompts using the strategy
        prompts = [
            self.prompt_strategy.create_prompt(
                text,
                self.config.src_lang,
                self.config.tgt_lang
            ) for text in texts
        ]

        # Tokenize inputs
        inputs = self.tokenizer(
            prompts,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=2048
        ).to(self.config.device)

        # Generate translations
        outputs = self.model.generate(**inputs, **self._get_generation_config())

        # Extract translations using the strategy
        translations = []
        for output, prompt in zip(outputs, prompts):
            full_output = self.tokenizer.decode(output, skip_special_tokens=False)
            translation = self.prompt_strategy.extract_translation(full_output, prompt)
            translations.append(translation)

        return translations

    def _get_generation_config(self) -> Dict[str, Any]:
        return {
            "max_new_tokens": self.decoding_config.max_new_tokens,
            "num_beams": self.decoding_config.num_beams,
            "temperature": self.decoding_config.temperature,
            "top_p": self.decoding_config.top_p,
            "top_k": self.decoding_config.top_k,
            "repetition_penalty": self.decoding_config.repetition_penalty,
            "pad_token_id": self.tokenizer.pad_token_id,
            "eos_token_id": self.tokenizer.eos_token_id,
            "do_sample": self.decoding_config.temperature > 0,
        }

In [5]:
from dataclasses import dataclass
import requests
from tenacity import retry, stop_after_attempt, wait_exponential

@dataclass
class APIConfig:
    """Configuration for API-based models"""
    base_url: str  # Should end with /v1/
    api_key: str
    model_name: str  # Full model path
    timeout: int = 30
    max_retries: int = 3

class APIClient:
    """Client for making API calls with retry logic"""
    def __init__(self, api_config: APIConfig):
        self.api_config = api_config
        self.session = self._create_session()

    def _create_session(self) -> requests.Session:
        session = requests.Session()
        return session

    @retry(
        stop=stop_after_attempt(3),
        wait=wait_exponential(multiplier=1, min=4, max=10),
        reraise=True
    )
    def generate_translation(
        self,
        messages: List[Dict[str, str]],
        parameters: Dict[str, Any]
    ) -> str:
        """Make API call with retry logic"""
        headers = {
            "Content-Type": "application/json",
            "Authorization": f"Bearer {self.api_config.api_key}"
        }

        data = {
            "model": self.api_config.model_name,
            "messages": messages,
            **parameters
        }

        response = self.session.post(
            f"{self.api_config.base_url}chat/completions",
            headers=headers,
            json=data,
            timeout=self.api_config.timeout
        )
        response.raise_for_status()

        result = response.json()
        return result["choices"][0]["message"]["content"]

class APITranslator(BaseTranslator):
    """Translator using API endpoints"""
    def __init__(self,
                 config: ModelConfig,
                 api_config: APIConfig,
                 decoding_config: DecodingConfig,
                 prompt_strategy: PromptStrategy):
        self.config = config
        self.api_config = api_config
        self.decoding_config = decoding_config
        self.prompt_strategy = prompt_strategy
        self.client = None

    def initialize_model(self):
        """Initialize API client"""
        self.client = APIClient(self.api_config)

    def _get_api_parameters(self) -> Dict[str, Any]:
        """Convert decoding config to API parameters"""
        return {
            "max_tokens": self.decoding_config.max_new_tokens,
            "temperature": self.decoding_config.temperature,
            "top_p": self.decoding_config.top_p,
            "presence_penalty": self.decoding_config.repetition_penalty - 1.0,  # API format
        }

    def translate_batch(self, texts: List[str]) -> List[str]:
        translations = []

        for text in texts:
            # Format messages for API
            messages = [
                {"role": "system", "content": self.prompt_strategy.system_prompt},
                {"role": "user", "content": f"Translate from {self.config.src_lang} to {self.config.tgt_lang}:\n{text}"}
            ]

            # Get API parameters
            parameters = self._get_api_parameters()

            # Make API call
            translation = self.client.generate_translation(messages, parameters)
            translations.append(translation.strip())

        return translations

## Test case validation

In [6]:
class TestCase:
    """Handles test case management and validation"""
    def __init__(self, source_texts: List[str], references: List[List[str]]):
        self.source_texts = source_texts
        self.references = references

    @classmethod
    def from_jsonl(cls, file_path: str) -> 'TestCase':
        """Load test cases from JSONL file"""
        data = []
        with open(file_path, 'r', encoding='utf-8') as f:
            for line in f:
                data.append(json.loads(line.strip()))

        source_texts = [item["source_text"] for item in data]
        references = [item["references"] for item in data]
        return cls(source_texts, references)

    def validate(self) -> bool:
        """Validate test case format and content"""
        if len(self.source_texts) != len(self.references):
            return False
        return all(isinstance(ref, list) for ref in self.references)

## Batch Translation

In [7]:
class TranslationManager:
    """Manages the translation process"""
    def __init__(
        self,
        translator: BaseTranslator,
        test_case: TestCase,
        batch_size: int = 8
    ):
        self.translator = translator
        self.test_case = test_case
        self.batch_size = batch_size
        self.translations = []

    def process_batches(self) -> List[str]:
        """Process text in batches"""
        for i in range(0, len(self.test_case.source_texts), self.batch_size):
            batch = self.test_case.source_texts[i:i + self.batch_size]
            translations = self.translator.translate_batch(batch)
            self.translations.extend(translations)
        return self.translations

## Evaluation

In [8]:
class EvaluationMetrics:
    """Handles evaluation metrics computation"""
    def __init__(self):
        self.bleu = BLEU()

    def compute_bleu(
        self,
        hypotheses: List[str],
        references: List[List[str]],
    ) -> Dict[str, float]:
        """Compute BLEU score and additional metrics"""
        bleu_score = self.bleu.corpus_score(hypotheses, references)

        return {
            "bleu": bleu_score.score,
            "precisions": bleu_score.precisions,
            "bp": bleu_score.bp,
            "ratio": bleu_score.ratio,
            "sys_len": bleu_score.sys_len,
            "ref_len": bleu_score.ref_len
        }

## Results

In [9]:
class ResultsVisualizer:
    """Handles visualization of translation results"""
    def __init__(self, results: Dict[str, Any]):
        self.results = results

    def plot_bleu_comparison(self):
        """Plot BLEU scores comparison across models"""
        df = pd.DataFrame(self.results)

        plt.figure(figsize=(10, 6))
        sns.barplot(data=df)
        plt.title("BLEU Score Comparison")
        plt.xlabel("Model")
        plt.ylabel("BLEU Score")
        plt.xticks(rotation=45)
        plt.tight_layout()

    def generate_error_analysis(self, source_texts: List[str],
                              translations: List[str],
                              references: List[List[str]]):
        """Generate detailed error analysis"""
        analysis = []
        for src, hyp, ref in zip(source_texts, translations, references):
            analysis.append({
                "source": src,
                "hypothesis": hyp,
                "reference": ref[0],
                "match": hyp.strip() == ref[0].strip()
            })
        return pd.DataFrame(analysis)

## Login or Env Variables

In [10]:
from google.colab import userdata
BASE_URL = userdata.get('MODAL_BASE_URL') # should end in /v1/
API_KEY = userdata.get('DSBA_LLAMA3_KEY')
model_name = "/models/NousResearch/Meta-Llama-3.1-8B-Instruct"

In [11]:
# optional
import os
from google.colab import userdata

os.environ['HF_TOKEN'] = userdata.get('huggingface')

In [12]:
!huggingface-cli login --token $HF_TOKEN

The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.
Token is valid (permission: read).
The token `colab` has been saved to /root/.cache/huggingface/stored_tokens
Your token has been saved to /root/.cache/huggingface/token
Login successful.
Note: Environment variable`HF_TOKEN` is set and is the current active token independently from the token you've just configured.


## Experiment

In [13]:
!wget https://github.com/wesslen/seamless_sacrebleu_evaluation/raw/main/data/01-english-spanish-mapping.jsonl
!wget https://github.com/wesslen/seamless_sacrebleu_evaluation/raw/main/data/02-english-spanish-mapping.jsonl

--2024-12-02 03:17:18--  https://github.com/wesslen/seamless_sacrebleu_evaluation/raw/main/data/01-english-spanish-mapping.jsonl
Resolving github.com (github.com)... 20.205.243.166
Connecting to github.com (github.com)|20.205.243.166|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://raw.githubusercontent.com/wesslen/seamless_sacrebleu_evaluation/main/data/01-english-spanish-mapping.jsonl [following]
--2024-12-02 03:17:19--  https://raw.githubusercontent.com/wesslen/seamless_sacrebleu_evaluation/main/data/01-english-spanish-mapping.jsonl
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 2829 (2.8K) [text/plain]
Saving to: ‘01-english-spanish-mapping.jsonl.3’


2024-12-02 03:17:19 (47.9 MB/s) - ‘01-english-spanish-mapping.jso

In [16]:
def run_experiment(model_type: ModelType,
                  deployment_type: DeploymentType,
                  test_files: List[str],
                  api_config: Optional[APIConfig] = None):
    """Run complete translation experiment"""

    # Define a dictionary to map model types to model names
    model_name_mapping = {
        ModelType.LLAMA_BASE: "meta-llama/Llama-2-7b-chat-hf",  # Replace with your desired base model name
        ModelType.LLAMA_31_INSTRUCT: "meta-llama/Llama-3.1-8B-Instruct",
        ModelType.LLAMA_32_INSTRUCT: "meta-llama/Llama-3.2-3B-Instruct",
        ModelType.SEAMLESS: "facebook/seamless-m4t-v2-large",
        ModelType.LLAMA_31_INSTRUCT_API: "/models/NousResearch/Meta-Llama-3.1-8B-Instruct"
    }

    # Get the model name from the mapping
    model_name = model_name_mapping.get(model_type)

    # Create the ModelConfig instance
    model_config = ModelConfig(model_name=model_name)

    decoding_config = DecodingConfig()

    # 2. Choose appropriate prompt strategy
    prompt_strategy = (
        DirectPromptStrategy()
        if model_type == ModelType.LLAMA_BASE
        else  InstructPromptStrategy()
    )

    # 3. Initialize translator based on deployment type
    if deployment_type == DeploymentType.API:
        if not api_config:
            raise ValueError("API config required for API deployment")
        translator = APITranslator(
            model_config,
            api_config,
            decoding_config,
            prompt_strategy
        )
    else:
        translator = (
            SeamlessTranslator(model_config)
            if model_type == ModelType.SEAMLESS
            else LlamaTranslator(model_config, decoding_config, prompt_strategy)
        )

    # 2. Initialize translator
    translator.initialize_model()

    results = {}
    for test_file in test_files:
        # 3. Load and validate test cases
        test_case = TestCase.from_jsonl(test_file)
        if not test_case.validate():
            raise ValueError(f"Invalid test case format in {test_file}")

        # 4. Run translation
        translation_manager = TranslationManager(translator, test_case)
        translations = translation_manager.process_batches()

        # 5. Compute metrics
        evaluator = EvaluationMetrics()
        metrics = evaluator.compute_bleu(translations, test_case.references)
        results[test_file] = metrics

    # 6. Visualize results
    # visualizer = ResultsVisualizer(results)
    # visualizer.plot_bleu_comparison()

    return results

# Run experiments
test_files = [
    "01-english-spanish-mapping.jsonl",
    "02-english-spanish-mapping.jsonl"
]


In [17]:
results = run_experiment(
    model_type=ModelType.LLAMA_32_INSTRUCT,
    deployment_type=DeploymentType.LOCAL,
    test_files=test_files
)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [18]:
results

{'01-english-spanish-mapping.jsonl': {'bleu': 46.713797772819994,
  'precisions': [80.0, 55.55555555555556, 37.5, 28.571428571428573],
  'bp': 1.0,
  'ratio': 1.0,
  'sys_len': 10,
  'ref_len': 10},
 '02-english-spanish-mapping.jsonl': {'bleu': 12.703318703865365,
  'precisions': [40.0, 12.5, 8.333333333333334, 6.25],
  'bp': 1.0,
  'ratio': 1.0,
  'sys_len': 5,
  'ref_len': 5}}

In [21]:
# Example usage:
api_config = APIConfig(
    base_url=BASE_URL,  # should end with /v1/
    api_key=API_KEY,
    model_name=model_name  # "/models/NousResearch/Meta-Llama-3.1-8B-Instruct"
)

results = run_experiment(
    model_type=ModelType.LLAMA_31_INSTRUCT_API,
    deployment_type=DeploymentType.API,
    test_files=test_files,
    api_config=api_config
)

In [22]:
results

{'01-english-spanish-mapping.jsonl': {'bleu': 46.713797772819994,
  'precisions': [80.0, 55.55555555555556, 37.5, 28.571428571428573],
  'bp': 1.0,
  'ratio': 1.0,
  'sys_len': 10,
  'ref_len': 10},
 '02-english-spanish-mapping.jsonl': {'bleu': 10.682175159905853,
  'precisions': [50.0, 10.0, 6.25, 4.166666666666667],
  'bp': 1.0,
  'ratio': 1.0,
  'sys_len': 6,
  'ref_len': 6}}

In [23]:
results = run_experiment(
    model_type=ModelType.SEAMLESS,
    deployment_type=DeploymentType.LOCAL,
    test_files=test_files
)

preprocessor_config.json:   0%|          | 0.00/1.78k [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/19.7k [00:00<?, ?B/s]

sentencepiece.bpe.model:   0%|          | 0.00/5.17M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/2.07k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/2.34k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/2.72k [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/211k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/4.24G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/9.91M [00:00<?, ?B/s]

In [24]:
results

{'01-english-spanish-mapping.jsonl': {'bleu': 42.7287006396234,
  'precisions': [70.0, 44.44444444444444, 37.5, 28.571428571428573],
  'bp': 1.0,
  'ratio': 1.0,
  'sys_len': 10,
  'ref_len': 10},
 '02-english-spanish-mapping.jsonl': {'bleu': 10.682175159905853,
  'precisions': [50.0, 10.0, 6.25, 4.166666666666667],
  'bp': 1.0,
  'ratio': 1.0,
  'sys_len': 6,
  'ref_len': 6}}

## Unit test

In [27]:
import unittest
from typing import List, Dict
import json

# First create helper test data
TEST_DATA = {
    "base_prompt": {
        "text": "Hello, my dog is cute",
        "src_lang": "English",
        "tgt_lang": "Spanish",
        "expected_direct": "Hello, my dog is cute\nTranslate to Spanish:",
        "expected_instruct": (
            "<|begin_of_text|>"
            "<|start_header_id|>system<|end_header_id|>\n"
            # system prompt will be here
            "<|eot_id|>"
            "<|start_header_id|>user<|end_header_id|>\n"
            "Translate from English to Spanish:\nHello, my dog is cute\n"
            "<|eot_id|>"
            "<|start_header_id|>assistant<|end_header_id|>\n"
        )
    },
    "test_extraction": {
        "full_output": (
            "<|start_header_id|>assistant<|end_header_id|>\n"
            "Hola, mi perro es lindo"
            "<|eot_id|>"
        ),
        "prompt": "Original prompt here",
        "expected": "Hola, mi perro es lindo"
    }
}

# Test Prompt Strategies
class TestPromptStrategies(unittest.TestCase):
    def setUp(self):
        self.direct_strategy = DirectPromptStrategy()
        self.instruct_strategy = InstructPromptStrategy()
        self.test_data = TEST_DATA["base_prompt"]

    def test_direct_prompt_creation(self):
        """Test that DirectPromptStrategy creates correct prompt format"""
        prompt = self.direct_strategy.create_prompt(
            self.test_data["text"],
            self.test_data["src_lang"],
            self.test_data["tgt_lang"]
        )
        self.assertEqual(prompt, self.test_data["expected_direct"])

    def test_instruct_prompt_creation(self):
        """Test that InstructPromptStrategy creates correct prompt format"""
        prompt = self.instruct_strategy.create_prompt(
            self.test_data["text"],
            self.test_data["src_lang"],
            self.test_data["tgt_lang"]
        )
        self.assertTrue(prompt.startswith("<|begin_of_text|>"))
        self.assertIn("system", prompt)
        self.assertIn("user", prompt)
        self.assertIn("assistant", prompt)

    def test_instruct_extraction(self):
        """Test that InstructPromptStrategy correctly extracts translation"""
        test_data = TEST_DATA["test_extraction"]
        result = self.instruct_strategy.extract_translation(
            test_data["full_output"],
            test_data["prompt"]
        )
        self.assertEqual(result, test_data["expected"])

# Test Configurations
class TestConfigurations(unittest.TestCase):
    def test_model_config(self):
        """Test ModelConfig initialization and validation"""
        config = ModelConfig(
            model_name="test-model",
            device="cpu",
            batch_size=8,
            src_lang="eng",
            tgt_lang="spa"
        )
        self.assertEqual(config.model_name, "test-model")
        self.assertEqual(config.device, "cpu")
        self.assertEqual(config.batch_size, 8)

    def test_api_config(self):
        """Test APIConfig initialization and validation"""
        config = APIConfig(
            base_url="http://test.com/v1/",
            api_key="test-key",
            model_name="test-model"
        )
        self.assertTrue(config.base_url.endswith("/v1/"))
        self.assertGreater(config.timeout, 0)
        self.assertGreater(config.max_retries, 0)

# Test DataLoading
class TestDataLoading(unittest.TestCase):
    def setUp(self):
        # Create a temporary test JSONL file
        self.test_data = [
            {"source_text": "Hello", "references": ["Hola"]},
            {"source_text": "World", "references": ["Mundo"]}
        ]
        self.test_file = "test.jsonl"
        with open(self.test_file, 'w') as f:
            for item in self.test_data:
                f.write(json.dumps(item) + '\n')

    def test_jsonl_loading(self):
        """Test that JSONL files are loaded correctly"""
        test_case = TestCase.from_jsonl(self.test_file)
        self.assertEqual(len(test_case.source_texts), 2)
        self.assertEqual(test_case.source_texts[0], "Hello")
        self.assertEqual(test_case.references[0], ["Hola"])

    def test_test_case_validation(self):
        """Test that TestCase validation works"""
        test_case = TestCase(
            source_texts=["Hello", "World"],
            references=[["Hola"], ["Mundo"]]
        )
        self.assertTrue(test_case.validate())

    def tearDown(self):
        # Clean up test file
        import os
        if os.path.exists(self.test_file):
            os.remove(self.test_file)

# Run the tests in Jupyter
def run_tests():
    # Create test suite
    suite = unittest.TestLoader().loadTestsFromTestCase(TestPromptStrategies)
    suite.addTests(unittest.TestLoader().loadTestsFromTestCase(TestConfigurations))
    suite.addTests(unittest.TestLoader().loadTestsFromTestCase(TestDataLoading))

    # Run tests
    runner = unittest.TextTestRunner(verbosity=2)
    runner.run(suite)

# Run tests
run_tests()

test_direct_prompt_creation (__main__.TestPromptStrategies)
Test that DirectPromptStrategy creates correct prompt format ... ok
test_instruct_extraction (__main__.TestPromptStrategies)
Test that InstructPromptStrategy correctly extracts translation ... ok
test_instruct_prompt_creation (__main__.TestPromptStrategies)
Test that InstructPromptStrategy creates correct prompt format ... ok
test_api_config (__main__.TestConfigurations)
Test APIConfig initialization and validation ... ok
test_model_config (__main__.TestConfigurations)
Test ModelConfig initialization and validation ... ok
test_jsonl_loading (__main__.TestDataLoading)
Test that JSONL files are loaded correctly ... ok
test_test_case_validation (__main__.TestDataLoading)
Test that TestCase validation works ... ok

----------------------------------------------------------------------
Ran 7 tests in 0.018s

OK


## Libraries

In [26]:
# Now check versions
import pkg_resources
import sys

def get_package_details():
    """Print details of specific packages and Python version"""
    packages_to_check = [
        'torch',
        'transformers',
        'sacrebleu',
        'tqdm',
        'numpy',
        'sentencepiece'  # Often used by transformers
    ]

    print("Python version:", sys.version.split()[0])
    print("\nPackage versions:")
    print("-" * 50)

    for package in packages_to_check:
        try:
            version = pkg_resources.get_distribution(package).version
            print(f"{package:<15} {version}")
        except pkg_resources.DistributionNotFound:
            print(f"{package:<15} Not installed")

# Check CUDA availability for PyTorch
import torch
print("\nCUDA Status:")
print("-" * 50)
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA version: {torch.version.cuda}")
    print(f"Current GPU: {torch.cuda.get_device_name()}")

# Run the check
get_package_details()


CUDA Status:
--------------------------------------------------
CUDA available: True
CUDA version: 12.1
Current GPU: NVIDIA A100-SXM4-40GB
Python version: 3.10.12

Package versions:
--------------------------------------------------
torch           2.5.1+cu121
transformers    4.46.2
sacrebleu       2.4.3
tqdm            4.66.6
numpy           1.26.4
sentencepiece   0.2.0
