# 🌐 Czech Language Adaptation of Gemma Language Model

**Author:** Jirka Helmich  
**Last Updated:** 2024-01-06  
**License:** MIT

## 📋 Overview

This notebook demonstrates the fine-tuning process of the Gemma language model for Czech language understanding and generation. We focus on creating a robust multilingual model capable of handling various Czech-specific NLP tasks.

### 🎯 Key Objectives

1. Adapt Gemma for superior Czech language processing
2. Support translation and text generation tasks
3. Comprehensive benchmarking on Czech-specific metrics

### 📊 Data Sources

1. **ParaCrawl v9**
   - EN-CS parallel corpus (~52M pairs)
   - [Source](https://paracrawl.eu/v9)

2. **Czech Books Descriptions**
   - Book descriptions in Czech
   - [Source](https://huggingface.co/datasets/vojtam/czech_books_descriptions)

### 🛠️ Technical Requirements

```python
Python >= 3.10
polars >= 0.20.0
datasets >= 2.15.0
tqdm >= 4.66.0
fasttext >= 0.9.2
torch >= 2.0.0
transformers >= 4.36.0
```

## 1️⃣ Environment Setup

First, let's set up our environment with all required dependencies.

In [1]:
# Install core dependencies
%pip install -q datasets polars tqdm fasttext torch transformers>=4.47.1 wandb seaborn matplotlib numpy peft>=0.14.0 evaluate huggingface_hub bitsandbytes>=0.45.0

# Import common libraries
import polars as pl
from pathlib import Path
import logging
from tqdm.auto import tqdm
from typing import Optional, Dict, List, Union
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np

# Configure plotting
plt.style.use('seaborn-v0_8-paper')
sns.set_palette('husl')

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.3.1[0m[39;49m -> [0m[32;49m24.3.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython3.10 -m pip install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.


## 2️⃣ Data Processing Pipeline

Our data processing pipeline is optimized for handling large-scale text data efficiently:

1. 📥 **Data Loading**: Streaming large files with chunked processing
2. 🧹 **Text Cleaning**: Efficient Czech text validation and normalization
3. 🔄 **Format Conversion**: Optimized Alpaca format transformation
4. 💾 **Storage**: Compressed Parquet format with optimal chunk sizes

### 2.1 Core Data Processing Classes

In [2]:
import gzip

class ParaCrawlDataLoader:
    """Optimized loader for ParaCrawl dataset with chunked processing."""
    
    def __init__(
        self,
        source_lang: str = "en",
        target_lang: str = "cs",
        chunk_size: int = 500_000,  # Increased for better throughput
        data_dir: Optional[str] = None,
        cache_dir: Optional[str] = None
    ):
        self.source_lang = source_lang
        self.target_lang = target_lang
        self.chunk_size = chunk_size
        self.base_url = "https://web-language-models.s3.amazonaws.com/paracrawl/release9"
        
        # Setup directories
        self.data_dir = Path(data_dir or "./data")
        self.cache_dir = Path(cache_dir or "./cache")
        self.data_dir.mkdir(parents=True, exist_ok=True)
        self.cache_dir.mkdir(parents=True, exist_ok=True)
        
        # File paths
        self.filename = f"{source_lang}-{target_lang}.txt.gz"
        self.filepath = self.data_dir / self.filename
        self.processed_path = self.cache_dir / f"{source_lang}-{target_lang}.parquet"
        
        self.logger = logging.getLogger(__name__)
    
    def _validate_file(self, filepath: Path) -> bool:
        """Validate downloaded file integrity."""
        if not filepath.exists():
            return False
            
        try:
            with gzip.open(filepath, 'rt', encoding='utf-8') as f:
                # Try to read first few lines
                for _ in range(5):
                    line = f.readline()
                    if not line or '\t' not in line:
                        return False
            return True
        except Exception:
            return False
    
    def download_data(self) -> None:
        """Download dataset with progress tracking."""
        if self.filepath.exists() and self._validate_file(self.filepath):
            self.logger.info("Using existing valid download")
            return
            
        url = f"{self.base_url}/{self.source_lang}-{self.target_lang}/{self.filename}"
        self.logger.info(f"Downloading from {url}")
        
        try:
            import urllib.request
            response = urllib.request.urlopen(url)
            total_size = int(response.headers['Content-Length'])
            
            with tqdm(total=total_size, unit='B', unit_scale=True) as pbar:
                urllib.request.urlretrieve(
                    url,
                    self.filepath,
                    reporthook=lambda count, block_size, _: pbar.update(block_size)
                )
                
            if not self._validate_file(self.filepath):
                raise ValueError("Downloaded file appears to be corrupt")
                
        except Exception as e:
            self.logger.error(f"Download failed: {e}")
            if self.filepath.exists():
                self.filepath.unlink()
            raise
    
    def process_chunk(self, chunk: List[str]) -> pl.DataFrame:
        """Process a chunk of text data efficiently."""
        if not chunk:
            return pl.DataFrame()
            
        # Split and filter in one pass
        pairs = [
            line.strip().split("\t") 
            for line in chunk 
            if "\t" in line
        ]
        
        # Filter invalid pairs
        valid_pairs = [
            p for p in pairs 
            if len(p) == 2 and all(0 < len(text) < 1000 for text in p)
        ]
        
        if not valid_pairs:
            return pl.DataFrame()
        
        # Create DataFrame efficiently
        return pl.DataFrame(
            valid_pairs,
            schema=[self.source_lang, self.target_lang],
            orient="row"
        )
    
    def load_dataframe(self) -> pl.DataFrame:
        """Load and process data in memory-efficient chunks."""
        if self.processed_path.exists():
            self.logger.info(f"Loading cached processed data from {self.processed_path}")
            return pl.read_parquet(self.processed_path)
        
        self.download_data()
        chunks = []
        total_rows = 0
        
        self.logger.info("Processing raw data file...")
        with gzip.open(self.filepath, "rt", encoding="utf-8") as f:
            with tqdm(desc="Processing chunks") as pbar:
                while True:
                    chunk = []
                    for _ in range(self.chunk_size):
                        line = f.readline()
                        if not line:
                            break
                        chunk.append(line)
                    
                    if not chunk:
                        break
                        
                    df_chunk = self.process_chunk(chunk)
                    if not df_chunk.is_empty():
                        chunks.append(df_chunk)
                        total_rows += len(df_chunk)
                    
                    pbar.update(len(chunk))
                    pbar.set_postfix({"valid_rows": total_rows})
        
        # Combine chunks and save
        self.logger.info(f"Combining {len(chunks)} chunks with {total_rows:,} total rows")
        df = pl.concat(chunks)
        
        self.logger.info(f"Saving processed data to {self.processed_path}")
        df.write_parquet(
            self.processed_path,
            compression="zstd",
            compression_level=3
        )
        
        return df

In [3]:
from concurrent.futures import ThreadPoolExecutor
import unicodedata
from typing import Optional, List
import polars as pl
from pathlib import Path
import logging
from tqdm.notebook import tqdm
import fasttext
from dataclasses import dataclass
from enum import Enum


class TextIssue(Enum):
    INVALID_CHARS = "invalid_characters"
    NON_CZECH = "non_czech_language"
    LOW_CONFIDENCE = "low_language_confidence"
    TOO_SHORT = "too_short"
    NO_ISSUES = "no_issues"


@dataclass
class TextQuality:
    original: str
    cleaned: Optional[str]
    issues: List[TextIssue]
    confidence: float

    @property
    def is_valid(self) -> bool:
        return TextIssue.NO_ISSUES in self.issues

from concurrent.futures import ThreadPoolExecutor
import unicodedata
from typing import Optional, List
import polars as pl
from pathlib import Path
import logging
from tqdm.notebook import tqdm
import fasttext


class CzechTextCleaner:
    """Efficient Czech text validation and cleaning for large datasets."""

    FASTTEXT_MODEL_URL = (
        "https://dl.fbaipublicfiles.com/fasttext/supervised-models/lid.176.ftz"
    )

    def __init__(self, model_dir: Optional[str] = None):
        # Character sets for validation
        self.czech_chars = frozenset(
            "aábcčdďeéěfghiíjklmnňoópqrřsštťuúůvwxyýzžAÁBCČDĎEÉĚFGHIÍJKLMNŇOÓPQRŘSŠTŤUÚŮVWXYÝZŽ"
        )
        self.czech_punctuation = frozenset(',.!?-–—()[]{}/\\"\'»«„"‟"\'')
        self.czech_numbers = frozenset("0123456789")
        self.valid_chars = (
            self.czech_chars
            | self.czech_punctuation
            | self.czech_numbers
            | {" "}
        )

        # Setup model and logging
        self.model_dir = Path(model_dir or "models")
        self.model_dir.mkdir(parents=True, exist_ok=True)
        self.model_path = self.model_dir / "lid.176.ftz"
        self.logger = logging.getLogger(__name__)

        # Load FastText model
        self._setup_fasttext()

    def _setup_fasttext(self) -> None:
        """Initialize FastText model."""
        if not self.model_path.exists():
            self._download_model()
        self.model = fasttext.load_model(str(self.model_path))

    def _download_model(self) -> None:
        """Download language model with progress tracking."""
        import urllib.request

        response = urllib.request.urlopen(self.FASTTEXT_MODEL_URL)
        total_size = int(response.headers["Content-Length"])

        with tqdm(
            total=total_size,
            unit="B",
            unit_scale=True,
            desc="Downloading model",
        ) as pbar:
            urllib.request.urlretrieve(
                self.FASTTEXT_MODEL_URL,
                self.model_path,
                lambda count, block_size, _: pbar.update(block_size),
            )

    def _is_valid_czech(self, text: str) -> bool:
        """Check if text is valid Czech with good confidence."""
        if not isinstance(text, str) or len(text.strip()) < 2:
            return False

        # Check characters
        if not all(c in self.valid_chars or c.isspace() for c in text):
            return False

        # Detect language
        text = " ".join(text.split())
        pred = self.model.predict(text)
        lang, conf = pred[0][0].replace("__label__", ""), pred[1][0]

        return lang == "cs" and conf >= 0.8

    def _clean_text(self, text: str) -> Optional[str]:
        """Clean text if valid, return None if invalid."""
        if not self._is_valid_czech(text):
            return None

        # Unicode normalization
        text = unicodedata.normalize("NFKC", text)

        # Quote and dash normalization
        text = text.replace('"', "„").replace('"', '"').replace("-", "–")

        # Initial whitespace normalization
        text = " ".join(text.split())

        # Thorough punctuation cleanup
        for punct in ",.!?":
            # Remove all spaces before punctuation
            text = text.replace(f" {punct}", punct)
            # Replace any runs of spaces after punctuation with a single space
            # First add space if missing
            text = text.replace(f"{punct}", f"{punct} ")
            # Then collapse multiple spaces
            while f"{punct}  " in text:
                text = text.replace(f"{punct}  ", f"{punct} ")

        # Final whitespace cleanup
        text = " ".join(text.split())
        return text.strip()

    def clean_dataframe(
        self, df: pl.DataFrame, text_columns: List[str], num_threads: int = 8
    ) -> pl.DataFrame:
        """Clean text columns and drop rows with invalid texts."""
        pl.Config.set_streaming_chunk_size(10000)

        for col in text_columns:
            self.logger.info(f"Processing column: {col}")

            # Process texts in parallel
            texts = df[col].to_list()
            with ThreadPoolExecutor(max_workers=num_threads) as executor:
                with tqdm(total=len(texts), desc=f"Cleaning {col}") as pbar:
                    futures = []
                    for text in texts:
                        future = executor.submit(self._clean_text, text)
                        future.add_done_callback(lambda p: pbar.update(1))
                        futures.append(future)

                    cleaned_texts = [future.result() for future in futures]

            # Update column with cleaned texts
            df = df.with_columns([pl.Series(col, cleaned_texts)])

            # Drop rows where cleaning failed (null values)
            initial_rows = len(df)
            df = df.filter(~pl.col(col).is_null())
            kept_rows = len(df)

            self.logger.info(
                f"Kept {kept_rows:,} valid rows out of {initial_rows:,} "
                f"({kept_rows/initial_rows:.1%})"
            )

        return df

    def analyze_parallel_stats(self, df: pl.DataFrame) -> dict:
        """Analyze parallel corpus statistics"""
        return {
            "total_pairs": len(df),
            "unique_cs": df["cs"].n_unique(),
            "unique_en": df["en"].n_unique(),
            "avg_cs_len": df["cs"].str.len_chars().mean(),
            "avg_en_len": df["en"].str.len_chars().mean(),
            "cs_vocab_size": df["cs"].str.split(" ").explode().n_unique(),
            "en_vocab_size": df["en"].str.split(" ").explode().n_unique(),
        }

In [4]:
class AlpacaConverter:
    """Memory-efficient converter to Alpaca instruction format."""
    
    def __init__(
        self,
        instruction_templates: Optional[Dict[str, str]] = None,
        chunk_size: int = 100_000
    ):
        self.chunk_size = chunk_size
        self.iso_to_lang = {
            "en": "angličtiny",
            "cs": "češtiny",
        }
        self.instruction_templates = instruction_templates or {
            'translation': "Přelož tento text z {source_lang} do {target_lang}",
        }
        self.logger = logging.getLogger(__name__)
    
    def create_instruction(self, task_type: str, **kwargs) -> str:
        """Create instruction from template."""
        template = self.instruction_templates.get(task_type)
        if not template:
            raise ValueError(f"Unknown task type: {task_type}")
        return template.format(**kwargs)
    
    def create_translation_examples(
        self,
        df: pl.DataFrame,
        source_lang: str,
        target_lang: str,
        output_path: Union[str, Path]
    ) -> None:
        """Convert translation pairs to Alpaca format."""
        output_path = Path(output_path)
        
        # Check if file already exists
        if output_path.exists():
            self.logger.info(f"Using existing processed file: {output_path}")
            return
            
        output_path.parent.mkdir(parents=True, exist_ok=True)
        
        iso_to_lang = {
            "en": "angličtiny",
            "cs": "češtiny",
        }
        instruction = self.create_instruction(
            'translation',
            source_lang=iso_to_lang[source_lang],
            target_lang=iso_to_lang[target_lang]
        )
        
        # Process in chunks and collect all chunks
        chunks = []
        total_rows = 0
        with tqdm(total=len(df), desc="Converting translations") as pbar:
            for i in range(0, len(df), self.chunk_size):
                chunk = df.slice(i, min(self.chunk_size, len(df) - i))
                
                # Create Alpaca format efficiently
                alpaca_df = pl.DataFrame({
                    "instruction": [instruction] * len(chunk),
                    "input": chunk[source_lang],
                    "output": chunk[target_lang]
                })
                chunks.append(alpaca_df)
                total_rows += len(chunk)
                pbar.update(len(chunk))
                pbar.set_postfix({"total_rows": total_rows})
        
        # Combine all chunks and save at once
        self.logger.info(f"Combining {len(chunks)} chunks with {total_rows:,} total rows")
        final_df = pl.concat(chunks)
        
        self.logger.info(f"Saving {total_rows:,} examples to {output_path}")
        final_df.write_parquet(
            output_path,
            compression="zstd",
            compression_level=3
        )

### 2.2 Data Processing Pipeline

Now let's use our optimized classes to process the datasets:

In [5]:
from datasets import load_dataset

# Initialize components with optimized settings
loader = ParaCrawlDataLoader(chunk_size=500_000)
cleaner = CzechTextCleaner()
converter = AlpacaConverter(chunk_size=100_000)

# Setup paths
data_dir = Path("data")
processed_dir = data_dir / "processed"
processed_dir.mkdir(parents=True, exist_ok=True)

paracrawl_path = processed_dir / "paracrawl_alpaca.parquet"
books_path = processed_dir / "books_alpaca.parquet"

# Load ParaCrawl
print("Loading and analyzing ParaCrawl dataset...")
df_paracrawl = loader.load_dataframe()
df_paracrawl = df_paracrawl.slice(0, 50000)

print("Raw ParaCrawl dataset:")
print(df_paracrawl.head())

# Store raw data for comparison
raw_paracrawl = df_paracrawl.clone()

# Clean Czech texts
print("\nCleaning Czech texts...")
df_paracrawl = cleaner.clean_dataframe(
    df_paracrawl,
    text_columns=["cs"],
    num_threads=24
)

# Analyze stats
stats = cleaner.analyze_parallel_stats(df_paracrawl)
print(stats)

print("Cleaned ParaCrawl dataset:")
print(df_paracrawl.head())

# Convert to Alpaca format with progress tracking
print("\nConverting to Alpaca format...")

# Process translations
print("Processing translations...")
converter.create_translation_examples(
    df_paracrawl,
    source_lang="en",
    target_lang="cs",
    output_path=paracrawl_path,
)


INFO:datasets:PyTorch version 2.1.0+cu118 available.
INFO:datasets:Polars version 1.19.0 available.
INFO:__main__:Loading cached processed data from cache/en-cs.parquet


Loading and analyzing ParaCrawl dataset...


INFO:__main__:Processing column: cs


Raw ParaCrawl dataset:
shape: (5, 2)
┌─────────────────────────────────┬─────────────────────────────────┐
│ en                              ┆ cs                              │
│ ---                             ┆ ---                             │
│ str                             ┆ str                             │
╞═════════════════════════════════╪═════════════════════════════════╡
│ Offering various dining option… ┆ Hosté se mohou najíst v restau… │
│ As families grow in size, so t… ┆ Čím větší rodina, tím více pož… │
│ Weather in Barueri: no precipi… ┆ Počasí v Barueri: přeháňky - 0… │
│ Local Time: Sīdī Sālim, Egypt   ┆ Místní čas: Al Husayniyah, Egy… │
│ Then let him patiently wait an… ┆ Pak nechť trpělivě čeká a pečl… │
└─────────────────────────────────┴─────────────────────────────────┘

Cleaning Czech texts...


Cleaning cs:   0%|          | 0/50000 [00:00<?, ?it/s]

INFO:__main__:Kept 28,674 valid rows out of 50,000 (57.3%)
INFO:__main__:Using existing processed file: data/processed/paracrawl_alpaca.parquet


{'total_pairs': 28674, 'unique_cs': 28565, 'unique_en': 28657, 'avg_cs_len': 89.32743949222292, 'avg_en_len': 91.80592174094998, 'cs_vocab_size': 94016, 'en_vocab_size': 66755}
Cleaned ParaCrawl dataset:
shape: (5, 2)
┌─────────────────────────────────┬─────────────────────────────────┐
│ en                              ┆ cs                              │
│ ---                             ┆ ---                             │
│ str                             ┆ str                             │
╞═════════════════════════════════╪═════════════════════════════════╡
│ Offering various dining option… ┆ Hosté se mohou najíst v restau… │
│ As families grow in size, so t… ┆ Čím větší rodina, tím více pož… │
│ Brojenje - Total count of mess… ┆ Počet – Celkový počet poselstv… │
│ The entire route is about 151 … ┆ Celá trasa měří cca 151 km.     │
│ Sort by Most Subscribed         ┆ Třídit podle Nejlépe hodnocené  │
└─────────────────────────────────┴─────────────────────────────────┘

Converting 

Now let's examine the processed dataset.

In [6]:
# Load processed datasets
print("Loading processed datasets...")
translations_ds = load_dataset(
    "parquet",
    data_files=str(paracrawl_path)
)

print("Translations dataset:")
print(translations_ds["train"].to_polars().head())

Loading processed datasets...


Translations dataset:
shape: (5, 3)
┌─────────────────────┬─────────────────────────────────┬─────────────────────────────────┐
│ instruction         ┆ input                           ┆ output                          │
│ ---                 ┆ ---                             ┆ ---                             │
│ str                 ┆ str                             ┆ str                             │
╞═════════════════════╪═════════════════════════════════╪═════════════════════════════════╡
│ Přelož tento text z ┆ Offering various dining option… ┆ Hosté se mohou najíst v restau… │
│ angličtiny…         ┆                                 ┆                                 │
│ Přelož tento text z ┆ As families grow in size, so t… ┆ Čím větší rodina, tím více pož… │
│ angličtiny…         ┆                                 ┆                                 │
│ Přelož tento text z ┆ Brojenje - Total count of mess… ┆ Počet – Celkový počet poselstv… │
│ angličtiny…         ┆                     

## 3️⃣ Model Training Setup

Now that our data is prepared, we'll set up the model training pipeline.

The training pipeline will include:

1. 🤖 **Model Configuration**
   - Gemma 2B base model
   - Mixed precision (bfloat16)
   - Gradient accumulation

2. 📊 **Training Loop**
   - Custom data collation
   - Efficient batching
   - Progress tracking

3. 📈 **Evaluation**
   - Translation metrics
   - Text quality assessment
   - Error analysis

First let's configure the hardware settings.

In [7]:
import torch

# Configure hardware settings
DEVICE = "cuda"
DTYPE = torch.bfloat16

# For H100 we can enable TF32 for matrix multiplications
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

In [8]:
@dataclass
class TrainingConfig:
    """Configuration for Gemma fine-tuning"""

    # Model settings
    model_name: str = "google/gemma-2-2b-it"
    max_length: int = 512

    # Early stopping settings
    early_stopping_patience: int = 3
    early_stopping_threshold: float = 0.01
   
    # Reduced batch sizes to fit in memory
    train_batch_size: int = 32     # Reduced from 64
    eval_batch_size: int = 64       # Reduced from 32
    gradient_accumulation_steps: int = 8  # Increased to maintain effective batch size

    # LoRA settings (unchanged)
    lora_r: int = 16
    lora_alpha: int = 32
    lora_dropout: float = 0.05

    # Training settings
    learning_rate: float = 5e-4   # Adjusted for new effective batch size
    weight_decay: float = 0.01
    warmup_ratio: float = 0.03
    max_grad_norm: float = 1.0    # Reduced to prevent memory spikes

    # Evaluation settings (adjusted for memory)
    num_epochs: int = 2
    eval_steps: int = 100         # More frequent evaluation with smaller batches
    save_steps: int = 200

    # Paths
    output_dir: str = "models/gemma-cs-translator"


config = TrainingConfig()

Let's prepare the dataset with specific format for the model:

In [9]:
from torch.utils.data import Dataset

class GemmaChatDataset(Dataset):
    """Dataset for Gemma chat format"""

    def __init__(self, data, tokenizer, max_length=512):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]

        # Create chat format
        chat = [
            {
                "role": "user",
                "content": f"{item['instruction']}\n\n{item['input']}",
            }
        ]

        # Apply chat template
        input_text = self.tokenizer.apply_chat_template(
            chat, tokenize=False, add_generation_prompt=True
        )

        # Add expected output
        full_text = f"{input_text}{item['output']}<end_of_turn>"

        # Tokenize
        encodings = self.tokenizer(
            full_text,
            max_length=self.max_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt",
        )

        # Create attention mask and labels
        input_ids = encodings["input_ids"][0]
        attention_mask = encodings["attention_mask"][0]

        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels": input_ids.clone(),
        }

Now we need to prepare the model and tokenizer for training.

In [10]:
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import get_peft_model, prepare_model_for_kbit_training, LoraConfig, TaskType

def setup_model_and_tokenizer(config: TrainingConfig):
    """Initialize model with LoRA and quantization"""

    # Initialize tokenizer
    tokenizer = AutoTokenizer.from_pretrained(config.model_name)

    # Quantization config
    quant_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16,
    )

    # Load base model
    model = AutoModelForCausalLM.from_pretrained(
        config.model_name,
        quantization_config=quant_config,
        device_map="auto",
        torch_dtype=DTYPE,
        trust_remote_code=True,
    )

    # Prepare model for k-bit training
    model = prepare_model_for_kbit_training(model)

    # LoRA configuration
    lora_config = LoraConfig(
        r=config.lora_r,
        lora_alpha=config.lora_alpha,
        target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"],
        lora_dropout=config.lora_dropout,
        bias="none",
        task_type=TaskType.CAUSAL_LM,
    )

    # Apply LoRA
    model = get_peft_model(model, lora_config)

    return model, tokenizer

Now we need to prepare the dataset for training.

In [11]:
def prepare_datasets(data_path: str, tokenizer, config: TrainingConfig):
    """Load and split datasets for training"""

    # Load the processed dataset
    dataset = load_dataset("parquet", data_files=data_path)["train"]

    # Split into train/val/test
    splits = dataset.train_test_split(test_size=0.2, seed=42)
    train_data = splits["train"]

    # Further split test into val/test
    test_splits = splits["test"].train_test_split(test_size=0.5, seed=42)
    val_data = test_splits["train"]
    test_data = test_splits["test"]

    print(f"Train size: {len(train_data)}")
    print(f"Val size: {len(val_data)}")
    print(f"Test size: {len(test_data)}")

    # Create custom datasets
    train_dataset = GemmaChatDataset(train_data, tokenizer, config.max_length)
    val_dataset = GemmaChatDataset(val_data, tokenizer, config.max_length)
    test_dataset = GemmaChatDataset(test_data, tokenizer, config.max_length)

    return train_dataset, val_dataset, test_dataset

Let's prepare the training loop with custom data collation and efficient batching.

In [12]:
import evaluate
import numpy as np
from transformers import TrainingArguments


def get_compute_metrics(tokenizer):
    """Create compute_metrics function with access to tokenizer"""

    def compute_metrics(eval_preds):
        """Compute BLEU and other metrics for translation evaluation"""
        bleu_metric = evaluate.load("bleu")

        predictions, labels = eval_preds

        # Decode predictions
        predictions = np.where(
            predictions != -100, predictions, tokenizer.pad_token_id
        )
        decoded_preds = tokenizer.batch_decode(
            predictions, skip_special_tokens=True
        )

        # Clean up predictions (remove template parts)
        cleaned_preds = []
        for pred in decoded_preds:
            # Extract only the translation part after the template
            try:
                translation = (
                    pred.split("<start_of_turn>model\n")[1]
                    .split("<end_of_turn>")[0]
                    .strip()
                )
            except IndexError:
                translation = pred  # Fallback if splitting fails
            cleaned_preds.append(translation)

        # Decode labels
        labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
        decoded_labels = tokenizer.batch_decode(
            labels,
            skip_special_tokens=True,
        )

        # Clean up references (remove template parts)
        cleaned_refs = []
        for ref in decoded_labels:
            try:
                translation = (
                    ref.split("<start_of_turn>model\n")[1]
                    .split("<end_of_turn>")[0]
                    .strip()
                )
            except IndexError:
                translation = ref  # Fallback if splitting fails
            cleaned_refs.append([translation])

        # Compute BLEU
        bleu_score = bleu_metric.compute(
            predictions=cleaned_preds, references=cleaned_refs
        )

        return {
            "bleu": bleu_score["bleu"],
        }

    return compute_metrics


def get_training_args(config: TrainingConfig):
    return TrainingArguments(
        output_dir=config.output_dir,
        num_train_epochs=config.num_epochs,
        per_device_train_batch_size=config.train_batch_size,
        per_device_eval_batch_size=config.eval_batch_size,
        gradient_accumulation_steps=config.gradient_accumulation_steps,
        
        # H100-specific optimizations
        bf16=True,  # H100 has excellent bfloat16 support
        tf32=True,  # Enable tensor float 32 for faster matrix multiplications
        gradient_checkpointing=True,
        
        # Optimizer settings
        learning_rate=config.learning_rate,
        weight_decay=config.weight_decay,
        warmup_ratio=config.warmup_ratio,
        max_grad_norm=1.0,
        optim="adamw_torch_fused",  # Use fused optimizer for better performance
        
        # Evaluation and saving
        evaluation_strategy="steps",
        eval_steps=config.eval_steps,
        save_strategy="steps",
        save_steps=config.save_steps,
        
        # Logging
        logging_strategy="steps",
        logging_steps=100,
        logging_first_step=True,
        report_to=["wandb"],
        
        # Model selection
        load_best_model_at_end=True,
        metric_for_best_model="bleu",
        greater_is_better=True,
    )

Now before proceeding with the training, let's validate the pipeline with small sample of data.

In [13]:
def validate_data_pipeline():
    """Validate the entire data pipeline"""
    # Load a small subset
    dataset = load_dataset(
        "parquet", data_files="data/processed/paracrawl_alpaca.parquet"
    )["train"].select(range(5))

    # Initialize components
    _, tokenizer = setup_model_and_tokenizer(config)
    train_dataset = GemmaChatDataset(dataset, tokenizer)

    # Check a sample
    sample = train_dataset[0]

    print("=== Data Pipeline Validation ===")
    print("\n1. Input IDs shape:", sample["input_ids"].shape)
    print("\n2. Decoded input:")
    print(tokenizer.decode(sample["input_ids"]))

    # Test compute_metrics
    dummy_preds = (
        sample["input_ids"].unsqueeze(0),
        sample["labels"].unsqueeze(0),
    )
    metrics = get_compute_metrics(tokenizer)(dummy_preds)

    print("\n3. Metrics computation test:")
    print(metrics)

    return True


validate_data_pipeline()

INFO:accelerate.utils.modeling:We will use 90% of the memory on device 0 for storing the model, and 10% for the buffer to avoid OOM. You can set `max_memory` in to a higher value to use more memory (at your own risk).


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

=== Data Pipeline Validation ===

1. Input IDs shape: torch.Size([512])

2. Decoded input:
<pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad

True

Let's define the training loop.

In [14]:
import wandb
from transformers import Trainer
from transformers import EarlyStoppingCallback

def train_translation_model():
    """Main training function with full pipeline"""
    # Initialize wandb
    wandb.init(
        project="gemma-cs-translator",
        config={
            "model": config.model_name,
            "lora_r": config.lora_r,
            "num_epochs": config.num_epochs,
            "train_batch_size": config.train_batch_size,
            "eval_batch_size": config.eval_batch_size,
            "gradient_accumulation_steps": config.gradient_accumulation_steps,
            "learning_rate": config.learning_rate,
            "early_stopping_patience": config.early_stopping_patience,
            "early_stopping_threshold": config.early_stopping_threshold,
        },
    )

    # Setup model and tokenizer
    model, tokenizer = setup_model_and_tokenizer(config)

    # Prepare datasets
    train_dataset, val_dataset, _ = prepare_datasets(
        "data/processed/paracrawl_alpaca.parquet", tokenizer, config
    )

    # Setup training arguments
    training_args = get_training_args(config)

    # Create early stopping callback
    early_stopping_callback = EarlyStoppingCallback(
        early_stopping_patience=config.early_stopping_patience,
        early_stopping_threshold=config.early_stopping_threshold,
    )

    # Initialize trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        compute_metrics=get_compute_metrics(tokenizer),
        callbacks=[early_stopping_callback],
    )

    # Train
    train_result = trainer.train()

    # Save final model
    trainer.save_model(f"{config.output_dir}/final")

    # Log metrics
    metrics = train_result.metrics
    trainer.log_metrics("train", metrics)

    return trainer, tokenizer


# Run training
trainer, tokenizer = train_translation_model()

[34m[1mwandb[0m: Currently logged in as: [33mjirkax[0m ([33mjirkax-individual[0m). Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


INFO:accelerate.utils.modeling:We will use 90% of the memory on device 0 for storing the model, and 10% for the buffer to avoid OOM. You can set `max_memory` in to a higher value to use more memory (at your own risk).


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Train size: 22939
Val size: 2867
Test size: 2868


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.


Step,Training Loss,Validation Loss


KeyboardInterrupt: 

Now post training we can evaluate the model on the test dataset.

In [15]:
class TranslationInference:
    def __init__(self, model_path: str):
        self.tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it")
        self.model = AutoModelForCausalLM.from_pretrained(
            model_path,
            device_map="auto",
            torch_dtype=torch.bfloat16,
            trust_remote_code=True,
        ).eval()
        
        # Cache the prompt template
        self.prompt = "Přelož tento text z angličtiny do češtiny:\n\n"
       
    def translate(self, text: str) -> str:
        chat = [
            {
                "role": "user",
                "content": f"Přelož tento text z angličtiny do češtiny:\n\n{text}"
            }
        ]
        
        prompt = self.tokenizer.apply_chat_template(
            chat,
            tokenize=False,
            add_generation_prompt=True
        )
        
        inputs = self.tokenizer(
            prompt,
            return_tensors="pt",
            truncation=True,
            max_length=256
        ).to(self.model.device)
        
        with torch.no_grad():
            outputs = self.model.generate(
                **inputs,
                max_new_tokens=256,
                do_sample=False,
                num_beams=1
            )
        
        response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
        # Extract only the model's translation
        translation = response.split("model\n")[-1].split("<end_of_turn>")[0].strip()
        return translation
       
    def translate_batch(self, texts: List[str], max_length: int = 256) -> List[str]:
        # Create chat format for each text
        chats = [
            [{
                "role": "user",
                "content": f"Přelož tento text z angličtiny do češtiny:\n\n{text}"
            }] for text in texts
        ]
        
        # Apply chat template to each
        prompts = [
            self.tokenizer.apply_chat_template(
                chat,
                tokenize=False,
                add_generation_prompt=True
            ) for chat in chats
        ]
        
        # Tokenize batch
        inputs = self.tokenizer(
            prompts,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=max_length,
        ).to(self.model.device)
        
        # Generate translations
        with torch.no_grad():
            outputs = self.model.generate(
                **inputs,
                max_new_tokens=max_length,
                do_sample=False,
                num_beams=1,
                early_stopping=True,
            )
        
        # Decode outputs
        responses = self.tokenizer.batch_decode(
            outputs, 
            skip_special_tokens=True,
        )
        
        # Extract only the model's translations
        translations = [
            response.split("model\n")[-1].split("<end_of_turn>")[0].strip()
            for response in responses
        ]
        
        return translations

def test_model(test_dataset, model_path: str):
    """Evaluate model on test dataset"""
    inference = TranslationInference(model_path)
    bleu_metric = evaluate.load("bleu")
    
    # Process in batches
    batch_size = 8
    predictions = []
    references = []
    
    print("Running inference...")
    for i in tqdm(range(0, min(100, len(test_dataset)), batch_size)):
        batch = test_dataset[i:i + batch_size]
        
        # Decode input_ids to get source texts
        source_texts = []
        batch_refs = []
        
        for j in range(len(batch['input_ids'])):
            # Get non-padding tokens
            mask = batch['attention_mask'][j] == 1
            input_ids = batch['input_ids'][j][mask]
            label_ids = batch['labels'][j][mask]
            
            # Decode source text
            input_text = tokenizer.decode(input_ids, skip_special_tokens=True)
            # Extract English text after the instruction
            if "Přelož tento text z angličtiny do češtiny:" in input_text:
                source = input_text.split("Přelož tento text z angličtiny do češtiny:")[-1].strip()
            else:
                source = input_text.strip()
            source_texts.append(source)
            
            # Decode reference translation
            reference = tokenizer.decode(label_ids, skip_special_tokens=True)
            if "<end_of_turn>" in reference:
                reference = reference.split("<end_of_turn>")[0].strip()
            batch_refs.append([reference])
        
        # Get translations for batch
        translations = inference.translate_batch(source_texts)
        
        predictions.extend(translations)
        references.extend(batch_refs)
    
    # Calculate BLEU score
    bleu_score = bleu_metric.compute(
        predictions=predictions,
        references=references
    )
    
    print(f"\nTest BLEU Score: {bleu_score['bleu']:.2f}")
    
    # Show examples
    print("\nExample Translations:")
    for i in range(min(3, len(predictions))):
        print(f"\nSource: {source_texts[i]}")
        print(f"Reference: {references[i][0]}")
        print(f"Prediction: {predictions[i]}")

Let's test the model on some examples.

In [17]:
# Setup model and tokenizer
_, tokenizer = setup_model_and_tokenizer(config)

# Prepare datasets
_, _, test_dataset = prepare_datasets(
    "data/processed/paracrawl_alpaca.parquet", tokenizer, config
)

# Test the trained model
model_path = f"{config.output_dir}/final"
# test_model(test_dataset, model_path)

# Interactive translation example
inference = TranslationInference(model_path)

examples = [
    "Hello, how are you today?",
    "This is a test of the translation system.",
    "Machine learning is transforming the world.",
]

print("\nInteractive Translation Examples:")
for text in examples:
    translation = inference.translate(text)
    print(f"\n🇬🇧 English: {text}")
    print(f"🇨🇿 Czech: {translation}")

INFO:accelerate.utils.modeling:We will use 90% of the memory on device 0 for storing the model, and 10% for the buffer to avoid OOM. You can set `max_memory` in to a higher value to use more memory (at your own risk).


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Train size: 22939
Val size: 2867
Test size: 2868


INFO:accelerate.utils.modeling:We will use 90% of the memory on device 0 for storing the model, and 10% for the buffer to avoid OOM. You can set `max_memory` in to a higher value to use more memory (at your own risk).


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

The 'batch_size' attribute of HybridCache is deprecated and will be removed in v4.49. Use the more precisely named 'self.max_batch_size' attribute instead.



Interactive Translation Examples:

🇬🇧 English: Hello, how are you today?
🇨🇿 Czech: Ahoj, jak se dnes máte?

🇬🇧 English: This is a test of the translation system.
🇨🇿 Czech: Toto je test systému překládání.

🇬🇧 English: Machine learning is transforming the world.
🇨🇿 Czech: Vývoj machine learningu mění svět.


In [None]:
def compare_models(test_dataset, num_examples: int = 5):
    print("Loading models...")
    original = TranslationInference("google/gemma-2-2b-it")
    finetuned = TranslationInference(f"{config.output_dir}/final")
    bleu_metric = evaluate.load("bleu")
    
    print("\n=== Model Comparison (5 Examples) ===")
    
    results = {
        "Original": {"predictions": [], "bleu": 0.0},
        "Finetuned": {"predictions": [], "bleu": 0.0}
    }
    references = []
    
    for i in range(num_examples):
        example = test_dataset[i]
        mask = example['attention_mask'] == 1
        conversation = tokenizer.decode(example['input_ids'][mask])
        
        try:
            # Extract English and reference Czech
            parts = conversation.split("Přelož tento text z angličtiny do češtiny\n\n")
            english = parts[1].split("<end_of_turn>")[0].strip()
            czech = parts[1].split("<start_of_turn>model\n")[1].split("<end_of_turn>")[0].strip()
            
            # Get translations from both models
            orig_translation = original.translate(english)
            fine_translation = finetuned.translate(english)
            
            # Store results
            results["Original"]["predictions"].append(orig_translation)
            results["Finetuned"]["predictions"].append(fine_translation)
            references.append([czech])
            
            # Show each example with word-by-word comparison
            print(f"\n{'='*80}")
            print(f"Example {i+1}:")
            print(f"🇬🇧 Source:     {english}")
            print(f"🇨🇿 Reference:  {czech}")
            print(f"🇨🇿 Original:   {orig_translation}")
            print(f"🇨🇿 Finetuned:  {fine_translation}")
            
            # Analysis of translations
            print("\nAnalysis:")
            print(f"- Original follows instruction: {'Přelož' in orig_translation or 'czech' in orig_translation.lower()}")
            print(f"- Original length: {len(orig_translation.split())}")
            print(f"- Finetuned length: {len(fine_translation.split())}")
            print(f"- Reference length: {len(czech.split())}")
                
        except Exception as e:
            print(f"Skipping example {i}, error: {str(e)}")
            continue
    
    # Calculate BLEU scores
    if references:
        print(f"\n{'='*80}")
        print("\nOverall Statistics:")
        for model_name in results:
            predictions = results[model_name]["predictions"]
            bleu_score = bleu_metric.compute(predictions=predictions, references=references)
            results[model_name]["bleu"] = bleu_score["bleu"]
            
            # Calculate average lengths
            avg_len = sum(len(p.split()) for p in predictions) / len(predictions)
            print(f"\n{model_name} model:")
            print(f"- BLEU Score: {results[model_name]['bleu']:.2f}")
            print(f"- Average output length: {avg_len:.1f} words")
            
        # Reference statistics
        avg_ref_len = sum(len(r[0].split()) for r in references) / len(references)
        print(f"\nReference statistics:")
        print(f"- Average length: {avg_ref_len:.1f} words")
    else:
        print("\nNo valid examples found for evaluation")

# Run comparison
compare_models(test_dataset)