# Czech Language Adaptation of Gemma Language Model

**Author:** Jirka Helmich

**Last Updated:** 2025-01-06

**License:** MIT

## Overview

This notebook demonstrates the fine-tuning process of the Gemma language model for Czech language understanding and generation. We focus on creating a robust multilingual model capable of handling various Czech-specific NLP tasks.

### Key Objectives

1. 🎯 **Primary Goal**: Adapt Gemma for superior Czech language processing
2. 🔄 **Tasks**: Translation, sentiment analysis, text generation
3. 📊 **Evaluation**: Comprehensive benchmarking on Czech-specific metrics

### Technical Requirements

```
Python >= 3.10
polars >= 0.20.0
datasets >= 2.15.0
tqdm >= 4.66.0
```

### Dataset Sources

We utilize multiple high-quality Czech datasets:

1. **ParaCrawl v9**
   - Parallel corpus for EN-CS translation
   - ~52M sentence pairs
   - [Source](https://paracrawl.eu/v9)

2. **Czech Books Descriptions**
   - Book descriptions in Czech
   - [Source](https://huggingface.co/datasets/vojtam/czech_books_descriptions)

## Environment Setup

First, let's install required dependencies. We use specific versions to ensure reproducibility.

In [None]:
%pip install datasets polars tqdm fasttext

## Data Processing Components

### 1. ParaCrawl Dataset Loader

The `ParaCrawlDataLoader` class handles downloading and processing of the ParaCrawl translation dataset. Key features:

- Automatic download and decompression
- Progress tracking
- Data cleaning and validation

## Implementation

This section implements a robust data loader for the ParaCrawl dataset with the following features:

- ✨ Automatic download with progress tracking
- 🔍 Data validation and integrity checks
- 📊 Efficient processing using Polars
- 💾 Caching of processed data

### Dependencies

In [2]:
import polars as pl
from pathlib import Path
import urllib.request
import gzip
import logging
from tqdm import tqdm
from typing import Optional

# Configure logging
logging.basicConfig(level=logging.INFO)

### ParaCrawl Data Loader Class

The main class implementation with detailed documentation:

In [3]:
class ParaCrawlDataLoader:
    """Handles downloading and processing of ParaCrawl translation datasets."""

    def __init__(
        self,
        source_lang: str = "en",
        target_lang: str = "cs",
        data_dir: Optional[str] = None,
        cache_dir: Optional[str] = None
    ):
        """Initialize the ParaCrawl data loader."""
        self.source_lang = source_lang
        self.target_lang = target_lang
        self.base_url = "https://web-language-models.s3.amazonaws.com/paracrawl/release9"

        # Setup directories
        self.data_dir = Path(data_dir or "./data")
        self.cache_dir = Path(cache_dir or "./cache")
        self.data_dir.mkdir(parents=True, exist_ok=True)
        self.cache_dir.mkdir(parents=True, exist_ok=True)

        self.logger = logging.getLogger(__name__)

        # Construct file paths
        self.filename = f"{source_lang}-{target_lang}.txt.gz"
        self.filepath = self.data_dir / self.filename
        self.processed_path = self.cache_dir / f"{source_lang}-{target_lang}.parquet"

### Download and Validation Methods

Methods for downloading data with progress tracking and validation:

In [4]:
def _download_with_progress(self, url: str, filepath: Path) -> None:
    """Download a file with progress bar."""
    try:
        response = urllib.request.urlopen(url)
        total_size = int(response.headers['Content-Length'])
        print(f"Total size: {total_size}")
        with tqdm(total=total_size, unit='B', unit_scale=True, desc=f"Downloading {filepath.name}") as pbar:
            urllib.request.urlretrieve(
                url,
                filepath,
                reporthook=lambda count, block_size, total_size: pbar.update(block_size)
            )
    except Exception as e:
        self.logger.error(f"Error downloading file: {e}")
        raise

def _validate_file(self, filepath: Path) -> bool:
    """Validate downloaded file integrity."""
    if not filepath.exists():
        return False
        
    try:
        with gzip.open(filepath, 'rt', encoding='utf-8') as f:
            for _ in range(5):
                line = f.readline()
                if not '\t' in line:
                    return False
        return True
    except Exception:
        return False

ParaCrawlDataLoader._download_with_progress = _download_with_progress
ParaCrawlDataLoader._validate_file = _validate_file

### Data Processing Methods

Methods for processing and loading the data:

In [5]:
def _process_raw_file(self) -> None:
    """Process raw gzipped file into Parquet format."""
    if self.processed_path.exists():
        self.logger.info("Using cached processed data")
        return

    self.logger.info("Processing raw data file...")

    chunk_size = 100_000
    chunks = []

    with gzip.open(self.filepath, "rt", encoding="utf-8") as f:
        with tqdm(desc="Processing chunks") as pbar:
            while True:
                lines = [next(f, None) for _ in range(chunk_size)]
                lines = [line for line in lines if line is not None]

                if not lines:
                    break

                pairs = [line.strip().split("\t") for line in lines]
                # Filter out invalid pairs
                pairs = [p for p in pairs if len(p) == 2]

                if not pairs:
                    continue

                # Pre-filter by length before creating DataFrame
                pairs = [
                    p
                    for p in pairs
                    if (0 < len(p[0]) < 1000 and 0 < len(p[1]) < 1000)
                ]

                if not pairs:
                    continue

                chunk_df = pl.DataFrame(
                    pairs,
                    schema=[self.source_lang, self.target_lang],
                    orient="row",  # Explicitly specify orientation
                )

                if len(chunk_df) > 0:
                    chunks.append(chunk_df)
                pbar.update(1)

    if not chunks:
        raise ValueError("No valid data found in the input file")

    df = pl.concat(chunks)
    df.write_parquet(self.processed_path)
    self.logger.info(f"Processed data saved to {self.processed_path}")


ParaCrawlDataLoader._process_raw_file = _process_raw_file

### Public Interface Methods

Methods for downloading and loading the dataset:

In [6]:
def download_data(self) -> None:
    """Download ParaCrawl dataset if not already present."""
    if self.filepath.exists() and self._validate_file(self.filepath):
        self.logger.info("Using existing download")
        return
        
    url = f"{self.base_url}/{self.source_lang}-{self.target_lang}/{self.filename}"
    self.logger.info(f"Downloading from {url}")
    
    self._download_with_progress(url, self.filepath)
    
    if not self._validate_file(self.filepath):
        raise ValueError("Downloaded file appears to be corrupt")

def load_dataframe(self) -> pl.DataFrame:
    """Load the processed ParaCrawl dataset."""
    self.download_data()
    self._process_raw_file()
    
    df = pl.read_parquet(self.processed_path)
    self.logger.info(f"Loaded {len(df):,} translation pairs")
    
    return df

def get_sample(self, n: int = 5) -> pl.DataFrame:
    """Get a sample of n translation pairs."""
    df = self.load_dataframe()
    return df.sample(n)

ParaCrawlDataLoader.download_data = download_data
ParaCrawlDataLoader.load_dataframe = load_dataframe
ParaCrawlDataLoader.get_sample = get_sample

### 2. Alpaca Format Converter Implementation

This section implements a robust converter for transforming datasets into the Alpaca instruction format, which is optimized for fine-tuning language models. Key features:

- 🔄 Flexible input handling
- 📝 Customizable instruction templates
- 💾 Efficient JSONL output
- ✨ Data validation and cleaning

#### Dependencies

In [7]:
import polars as pl
from pathlib import Path
from typing import Optional, Dict, List, Union
from tqdm.auto import tqdm
import logging

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

### Alpaca Data Converter Class

Main class for converting datasets to Alpaca instruction format:

In [8]:
class AlpacaConverter:
    """Converts datasets to Alpaca instruction format for fine-tuning.
    
    This class handles the conversion of various dataset formats into the
    Alpaca instruction format, which is suitable for fine-tuning language models.
    """
    
    def __init__(
        self,
        instruction_templates: Optional[Dict[str, str]] = None,
        max_length: int = 2048,
        min_length: int = 3
    ):
        """Initialize the Alpaca converter.
        
        Args:
            instruction_templates: Dictionary of task types to instruction templates
            max_length: Maximum length of input/output text
            min_length: Minimum length of input/output text
        """
        self.iso_to_lang = {
            'en': 'Angličtiny',
            'cs': 'Češtiny',
        }
        self.instruction_templates = instruction_templates or {
            'translation': "Přelož tento text z {iso_to_lang[source_lang]} do {iso_to_lang[target_lang]}",
            'book_description': "Popiš tuto knihu",
        }
        self.max_length = max_length
        self.min_length = min_length
        self.logger = logging.getLogger(__name__)

### Data Validation Methods

Methods for validating and cleaning input data:

In [9]:
def _validate_text(self, text: str) -> bool:
    """Validate text length and content.
    
    Args:
        text: Input text to validate
        
    Returns:
        bool: True if text is valid
    """
    if not isinstance(text, str):
        return False
        
    text = text.strip()
    length = len(text)
    
    return (length >= self.min_length and 
            length <= self.max_length and
            not text.isspace())

def _clean_text(self, text: str) -> str:
    """Clean and normalize text.
    
    Args:
        text: Input text to clean
        
    Returns:
        str: Cleaned text
    """
    return " ".join(text.strip().split())

AlpacaConverter._validate_text = _validate_text
AlpacaConverter._clean_text = _clean_text

### Format Conversion Methods

Core methods for converting data to Alpaca format:

In [10]:
def _create_instruction(self, task_type: str, **kwargs) -> str:
    """Create instruction from template.
    
    Args:
        task_type: Type of task (e.g., 'translation')
        **kwargs: Format parameters for instruction template
        
    Returns:
        str: Formatted instruction
    """
    template = self.instruction_templates.get(task_type)
    if not template:
        raise ValueError(f"Unknown task type: {task_type}")
    return template.format(**kwargs)

def _create_example(self,
    instruction: str,
    output: str,
    input_text: Optional[str] = None
) -> Dict[str, str]:
    """Create a single Alpaca format example.
    
    Args:
        instruction: Task instruction
        output: Expected output text
        input_text: Optional input text
        
    Returns:
        Dict[str, str]: Alpaca format example
    """
    example = {
        "instruction": instruction,
        "output": self._clean_text(output)
    }
    
    if input_text:
        example["input"] = self._clean_text(input_text)
        
    return example

AlpacaConverter._create_instruction = _create_instruction
AlpacaConverter._create_example = _create_example

### Public Interface Methods

Methods for converting different types of datasets:

In [11]:
def convert_translations(
    self,
    df: pl.DataFrame,
    source_lang: str,
    target_lang: str, 
    output_path: Union[str, Path]
) -> None:
    """Convert translation pairs to Alpaca format and save as Parquet.
    
    This optimized version processes data in chunks and saves to Parquet format
    for better performance and compression.
    """
    instruction = self._create_instruction(
        'translation',
        source_lang=source_lang,
        target_lang=target_lang
    )
    
    output_path = Path(output_path)
    output_path.parent.mkdir(parents=True, exist_ok=True)
    
    # Convert output path to parquet
    output_path = output_path.with_suffix('.parquet')
    
    # Process in chunks for memory efficiency
    chunk_size = 50000  # Increased chunk size for Parquet
    
    # Initialize empty list to store processed chunks
    processed_chunks = []
    
    # Create a separate progress bar for translations
    trans_pbar = tqdm(total=len(df), desc="Processing translation rows", position=0)
    
    for i in range(0, len(df), chunk_size):
        # Get chunk
        chunk = df.slice(i, chunk_size)
        
        # Process chunk using vectorized operations
        source_texts = chunk.select(pl.col(source_lang)).to_series()
        target_texts = chunk.select(pl.col(target_lang)).to_series()
        
        # Create DataFrame with processed rows
        processed_df = pl.DataFrame({
            "instruction": [instruction] * len(chunk),
            "input": source_texts.map_elements(self._clean_text, return_dtype=pl.Utf8),
            "output": target_texts.map_elements(self._clean_text, return_dtype=pl.Utf8)
        })
        
        # Filter valid rows using vectorized operations
        mask = (processed_df["input"].map_elements(self._validate_text, return_dtype=pl.Boolean) & 
               processed_df["output"].map_elements(self._validate_text, return_dtype=pl.Boolean))
        processed_df = processed_df.filter(mask)
        
        # Append to list
        processed_chunks.append(processed_df)
        
        # Update progress
        trans_pbar.update(len(chunk))
        trans_pbar.set_postfix({'valid_rows': len(processed_df)})
    
    trans_pbar.close()
    
    # Combine all chunks
    final_df = pl.concat(processed_chunks)
    
    # Write to Parquet with compression
    final_df.write_parquet(
        output_path,
        compression="zstd",  # Use ZSTD compression for better ratio/speed balance
        statistics=True,     # Include statistics for better query performance
        row_group_size=100000  # Optimize row groups for typical query patterns
    )

def convert_descriptions(
    self,
    df: pl.DataFrame,
    title_col: str,
    desc_col: str,
    output_path: Union[str, Path]
) -> None:
    """Convert title-description pairs to Alpaca format and save as Parquet."""
    instruction = self._create_instruction('book_description')
    
    output_path = Path(output_path)
    output_path.parent.mkdir(parents=True, exist_ok=True)
    
    # Convert output path to parquet
    output_path = output_path.with_suffix('.parquet')
    
    # Process in chunks for memory efficiency
    chunk_size = 50000
    
    # Initialize empty list to store processed chunks
    processed_chunks = []
    
    # Create a separate progress bar for descriptions
    desc_pbar = tqdm(total=len(df), desc="Processing book description rows", position=1)
    
    for i in range(0, len(df), chunk_size):
        # Get chunk
        chunk = df.slice(i, chunk_size)
        
        # Process chunk using vectorized operations
        titles = chunk.select(pl.col(title_col)).to_series()
        descriptions = chunk.select(pl.col(desc_col)).to_series()
        
        # Create DataFrame with processed rows
        processed_df = pl.DataFrame({
            "instruction": [instruction] * len(chunk),
            "input": titles.map_elements(self._clean_text, return_dtype=pl.Utf8),
            "output": descriptions.map_elements(self._clean_text, return_dtype=pl.Utf8)
        })
        
        # Filter valid rows using vectorized operations
        mask = (processed_df["input"].map_elements(self._validate_text, return_dtype=pl.Boolean) & 
               processed_df["output"].map_elements(self._validate_text, return_dtype=pl.Boolean))
        processed_df = processed_df.filter(mask)
        
        # Append to list
        processed_chunks.append(processed_df)
        
        # Update progress
        desc_pbar.update(len(chunk))
        desc_pbar.set_postfix({'valid_rows': len(processed_df)})
    
    desc_pbar.close()
    
    # Combine all chunks
    final_df = pl.concat(processed_chunks)
    
    # Write to Parquet with compression
    final_df.write_parquet(
        output_path,
        compression="zstd",
        statistics=True,
        row_group_size=100000
    )

AlpacaConverter.convert_translations = convert_translations
AlpacaConverter.convert_descriptions = convert_descriptions

## Data Processing Pipeline 🔄

This section implements the main data processing pipeline for preparing our training data. We'll walk through each step to ensure high-quality training data.

### Pipeline Overview 📋

1. 📥 **Load ParaCrawl Corpus**
   - Download EN-CS parallel data
   - Clean and validate entries
   - Remove low-quality pairs

2. 📚 **Process Book Descriptions**
   - Load Czech book dataset
   - Extract titles and descriptions
   - Filter and clean text

3. 🔄 **Format Conversion**
   - Transform to Alpaca format
   - Add instruction templates
   - Validate final structure

4. 💾 **Save Training Data**
   - Export to JSONL format
   - Create data splits
   - Verify data integrity

### Key Features ✨

- 🧹 Robust data cleaning
- ⚡ Efficient Polars processing
- 🔍 Quality validation steps
- 📊 Progress tracking
- 💪 Scalable pipeline

### 1. Load and Process ParaCrawl Dataset 🌐

In [None]:
# Initialize loader
loader = ParaCrawlDataLoader(source_lang="en", target_lang="cs")

# Load ParaCrawl EN-CS dataset
df_paracrawl = loader.load_dataframe()
print(f"Loaded {len(df_paracrawl):,} translation pairs")
df_paracrawl.head()

### 2. Process Book Descriptions Dataset

In [None]:
from datasets import load_dataset

# Load Czech book descriptions
ds = load_dataset("vojtam/czech_books_descriptions")
books_df = ds['train'].to_polars()
print(f"Loaded {len(books_df):,} book descriptions")
books_df.head()

## Convert to Training Format

Convert our processed datasets to the Alpaca instruction format for fine-tuning.

In [26]:
import os

alpaca_converter = AlpacaConverter()

if not os.path.exists("data/translation/dataset/paracrawl.parquet"):
    alpaca_converter.convert_translations(
        df_paracrawl,
        source_lang="en",
        target_lang="cs",
        output_path="data/translation/dataset/paracrawl.parquet"
    )

if not os.path.exists("data/translation/dataset/czech_books.parquet"):
    alpaca_converter.convert_descriptions(
        books_df,
        title_col="title",
        desc_col="text",
        output_path="data/translation/dataset/czech_books.parquet"
    )

# Load the dataset

In this phase we will load the dataset and split it into train, validation and test sets.

This is the end of the data processing phase and we will be proceeding to the model fine-tuning phase.


In [None]:
import polars as pl
from datasets import load_dataset

print("Loading datasets...")
translations_ds = load_dataset(
    "parquet", data_files="data/translation/dataset/paracrawl.parquet"
)
books_ds = load_dataset(
    "parquet", data_files="data/translation/dataset/czech_books.parquet"
)

Now let's print the dataset samples to verify that the data is in the correct format.

In [None]:
print("\n=== Dataset Samples ===")
print("\nFirst translation sample:")
print(translations_ds["train"][0])
print("\nFirst book description sample:")
print(books_ds["train"][0])

Now we need to examine the dataset statistics to see if there are any potential issues with the data.

In [None]:
print("\n=== Dataset Sizes ===")
print(f"Translation dataset size: {len(translations_ds['train'])}")
print(f"Books dataset size: {len(books_ds['train'])}")


## Data Cleaning and Quality Analysis 🧹

Before splitting the datasets, we'll perform thorough cleaning and quality analysis to ensure high-quality Czech language data. We'll focus on:

1. **Text Quality Checks** 📊
   - Remove non-Czech characters
   - Fix common encoding issues
   - Validate linguistic patterns

2. **Statistical Analysis** 📈
   - Length distributions
   - Character frequency analysis
   - Quality metrics visualization

3. **Cleaning Pipeline** 🔄
   - Remove invalid entries
   - Fix common errors
   - Normalize text format

In [None]:
import re
import unicodedata
import matplotlib.pyplot as plt
import seaborn as sns
from typing import Dict, List
import logging

class CzechTextCleaner:
    """Handles cleaning and validation of Czech text data with permissive validation but strict correction."""

    FASTTEXT_MODEL_URL = (
        "https://dl.fbaipublicfiles.com/fasttext/supervised-models/lid.176.ftz"
    )

    def __init__(self):
        # Core Czech characters (strict)
        self.czech_chars = set('aábcčdďeéěfghiíjklmnňoópqrřsštťuúůvwxyýzžAÁBCČDĎEÉĚFGHIÍJKLMNŇOÓPQRŘSŠTŤUÚŮVWXYÝZŽ')
        self.czech_punctuation = set(',.!?-–—()[]{}/\\"\'»«„"‟"\'')
        self.czech_numbers = set('0123456789')
        self.stats = {}
        self.corrections_made = 0
        self.logger = logging.getLogger(__name__)

        # Critical error patterns that require dropping
        self.critical_patterns = [
            r'[^a-zA-ZáéíóúůýěščřžďťňÁÉÍÓÚŮÝĚŠČŘŽĎŤŇ\s,.!?0-9-–—()[]{}/\\"\'»«„"‟"\']',  # Non-Czech characters
            r'[áéíóúůýě]{3,}',  # Three or more consecutive diacritics
            r'\s{3,}'  # Excessive whitespace
        ]

        # Setup FastText model
        self.model_dir = Path("models")
        self.model_dir.mkdir(parents=True, exist_ok=True)
        self.model_path = self.model_dir / "lid.176.ftz"

        # Download and load model
        self._setup_fasttext()

    def _download_with_progress(self, url: str, filepath: Path) -> None:
        """Download file with progress bar."""
        try:
            response = urllib.request.urlopen(url)
            total_size = int(response.headers["Content-Length"])

            with tqdm(
                total=total_size,
                unit="B",
                unit_scale=True,
                desc=f"Downloading FastText model",
            ) as pbar:
                urllib.request.urlretrieve(
                    url,
                    filepath,
                    reporthook=lambda count, block_size, total_size: pbar.update(
                        block_size
                    ),
                )
        except Exception as e:
            self.logger.error(f"Error downloading FastText model: {e}")
            raise

    def _setup_fasttext(self) -> None:
        """Download and load FastText model if needed."""
        try:
            import fasttext
        except ImportError:
            raise ImportError(
                "FastText is not installed. Please restart the kernel and try again."
            )

        if not self.model_path.exists():
            self.logger.info(
                "Downloading FastText language detection model..."
            )
            self._download_with_progress(
                self.FASTTEXT_MODEL_URL, self.model_path
            )
            self.logger.info("Download complete!")

        self.logger.info("Loading FastText model...")
        self.fasttext_model = fasttext.load_model(str(self.model_path))
        self.logger.info("FastText model loaded successfully!")

    def is_valid_czech(self, text: str) -> bool:
        """
        Validate Czech text using language detection and basic checks.
        """
        if not text or len(text.strip()) == 0:
            return False

        # Basic length check
        if len(text) < 2 or len(text) > 10000:
            return False

        try:
            # Use fasttext for language detection
            prediction = self.fasttext_model.predict(text.replace("\n", " "))
            lang = prediction[0][0].replace("__label__", "")
            confidence = prediction[1][0]

            # Accept text if it's confidently detected as Czech
            if lang == "cs" and confidence > 0.8:
                return True
            
            # NOTE: We could possibly accept Slovak as it's very similar and might be mixed in and majority of Czech speakers are also fluent in Slovak
            
            return False
        except Exception as e:
            self.logger.warning(f"Language detection failed: {e}")
            return False

    def clean_text(self, text: str) -> str:
        """
        Very strictly clean Czech text - drop text that can't be confidently corrected.
        """
        # Check for critical issues that require dropping
        for pattern in self.critical_patterns:
            if re.search(pattern, text):
                return ""

        original = text
        # Normalize unicode characters
        text = unicodedata.normalize("NFKC", text)

        # Strict quote and dash normalization
        text = text.replace('"', '„').replace('"', '"')
        text = text.replace('-', '–')

        # Strict whitespace and punctuation cleanup
        text = re.sub(r'\s+', ' ', text)
        text = re.sub(r'[\s,.!?]+([,.!?])', r'\1', text)
        text = re.sub(r'([,.!?])([^\s])', r'\1 \2', text)

        # Remove control characters
        text = ''.join(char for char in text if unicodedata.category(char)[0] != 'C')

        # Additional strict cleanups
        text = re.sub(r'\.{2,}', '...', text)  # Normalize ellipsis
        text = re.sub(r'[\u200b\ufeff\u200e\u200f]', '', text)  # Remove zero-width chars

        cleaned = text.strip()

        # Drop if significant changes were needed
        if abs(len(cleaned) - len(original)) > len(original) * 0.2:
            return ""

        if cleaned != original:
            self.corrections_made += 1

        return cleaned

    def clean_text_batch(self, texts: List[str], pbar: tqdm) -> List[str]:
        """Process a batch of texts with progress tracking."""
        cleaned = []
        for text in texts:
            if self.is_valid_czech(text):  # Permissive validation
                cleaned_text = self.clean_text(text)  # Strict cleaning
                if cleaned_text:  # Only keep if cleaning succeeded
                    cleaned.append(cleaned_text)
            pbar.update(1)
        return cleaned

    def analyze_dataset(self, dataset, field: str) -> Dict:
        """Perform detailed linguistic analysis of dataset."""
        lengths = [len(item[field]) for item in dataset]
        valid_count = sum(1 for item in dataset if self.is_valid_czech(item[field]))
        word_counts = [len(item[field].split()) for item in dataset]

        return {
            'total': len(dataset),
            'valid': valid_count,
            'invalid': len(dataset) - valid_count,
            'corrected': self.corrections_made,
            'avg_length': sum(lengths) / len(lengths),
            'max_length': max(lengths),
            'min_length': min(lengths),
            'avg_words': sum(word_counts) / len(word_counts),
            'max_words': max(word_counts),
            'min_words': min(word_counts)
        }

    def plot_statistics(self, stats: Dict, title: str) -> None:
        """Visualize dataset quality metrics."""
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))

        # First subplot - Valid vs Invalid
        ax1.bar(['Valid Czech', 'Invalid/Non-Czech'], 
                [stats['valid'], stats['invalid']], 
                color=['green', 'red'])
        ax1.set_title(f'{title} - Czech Language Quality Analysis')
        ax1.set_ylabel('Number of Entries')

        total = stats['total']
        for i, v in enumerate([stats['valid'], stats['invalid']]):
            ax1.text(i, v, f'{v:,}\n({v/total*100:.1f}%)', 
                    ha='center', va='bottom')

        # Second subplot - Corrections Made
        ax2.bar(['Original', 'Corrected'], 
                [stats['total'] - stats['corrected'], stats['corrected']], 
                color=['blue', 'orange'])
        ax2.set_title(f'{title} - Text Corrections')
        ax2.set_ylabel('Number of Entries')

        for i, v in enumerate([stats['total'] - stats['corrected'], stats['corrected']]):
            ax2.text(i, v, f'{v:,}\n({v/total*100:.1f}%)', 
                    ha='center', va='bottom')

        plt.tight_layout()
        plt.show()


# Initialize cleaner
cleaner = CzechTextCleaner()

### Clean and Analyze the Translation dataset

### Clean and Analyze Books Dataset

In [None]:
# Analyze original books dataset
books_stats_before = cleaner.analyze_dataset(books_ds['train'], 'output')
print("\n=== Books Dataset Statistics (Before Cleaning) ===")
print(f"Total entries: {books_stats_before['total']:,}")
print(f"Valid entries: {books_stats_before['valid']:,}")
print(f"Average length: {books_stats_before['avg_length']:.1f}")

# Clean books dataset
def clean_books(example):
    example['output'] = cleaner.clean_text(example['output'])
    return example

books_ds['train'] = books_ds['train'].filter(
    lambda x: cleaner.is_valid_czech(x['output'])
).map(clean_books)

# Analyze cleaned dataset
books_stats_after = cleaner.analyze_dataset(books_ds['train'], 'output')
cleaner.plot_statistics(books_stats_after, 'Books Dataset')

### Length Distribution Analysis

In [None]:
# Plot length distributions
plt.figure(figsize=(15, 5))

# Translation dataset lengths
trans_lengths = [len(item['output']) for item in translations_ds['train']]
plt.subplot(1, 2, 1)
sns.histplot(trans_lengths, bins=50)
plt.title('Translation Text Lengths')
plt.xlabel('Length (characters)')
plt.ylabel('Count')

# Books dataset lengths
book_lengths = [len(item['output']) for item in books_ds['train']]
plt.subplot(1, 2, 2)
sns.histplot(book_lengths, bins=50)
plt.title('Book Description Lengths')
plt.xlabel('Length (characters)')
plt.ylabel('Count')

plt.tight_layout()
plt.show()

## Split the datasets

Now that the data is cleaned and ready, we can split the datasets into train, validation and test sets.

This is the end of the data processing phase and we will be proceeding to the model fine-tuning phase.


In [None]:
print("\nSplitting datasets...")

# Split translations dataset
translations_ds = translations_ds["train"].train_test_split(
    test_size=0.2, shuffle=True, seed=42
)
translations_test_val = translations_ds["test"].train_test_split(
    test_size=0.5, shuffle=True, seed=42
)

translations_ds = {
    "train": translations_ds["train"],
    "validation": translations_test_val["train"],
    "test": translations_test_val["test"],
}

# Split books dataset
books_ds = books_ds["train"].train_test_split(
    test_size=0.2, shuffle=True, seed=42
)
books_test_val = books_ds["test"].train_test_split(
    test_size=0.5, shuffle=True, seed=42
)

books_ds = {
    "train": books_ds["train"],
    "validation": books_test_val["train"],
    "test": books_test_val["test"],
}

# 5. Print split sizes
print("\n=== Split Sizes ===")
for split in ["train", "validation", "test"]:
    print(f"\nTranslations {split}:")
    print(f"Size: {len(translations_ds[split])}")
    print(f"\nBooks {split}:")
    print(f"Size: {len(books_ds[split])}")

All of the data is now prepared and we need to combine the datasets into a single dataset later used for model fine-tuning.

In [None]:
from datasets import concatenate_datasets

combined_datasets = {
    split: concatenate_datasets([translations_ds[split], books_ds[split]])
    for split in ["train", "validation", "test"]
}

print("\n=== Combined Dataset Sizes ===")
for split in ["train", "validation", "test"]:
    print(f"\n{split}:")
    print(f"Size: {len(combined_datasets[split])}")

# 🚀 Model Fine-Tuning

In this phase, we will fine-tune the model on the combined dataset using **PyTorch** and the **Hugging Face Transformers** library. The `gemma-2-2b-it` model will serve as the base model. 

The datasets are already split into train/validation/test sets (80/10/10 ratio) for proper evaluation. Let's break down the process into key phases:

---

## 🏗️ Model Architecture

We will use **Google's Gemma 2B** as the base model, fine-tuned with **PyTorch Lightning** for efficient training. The architecture includes:

- 📂 **Custom Dataset Class** for handling the specific data format
- 📊 **Lightning DataModule** for data management
- 🛠️ **Lightning Module** for training logic
- ⚡ **Mixed Precision Training** (bfloat16) for improved performance
- 🧹 **Gradient Accumulation and Clipping** for stability during training

---

## 📈 Training Process

The training will proceed through several stages:

1. **Data Batching and Tokenization** 🗃️:
   - Efficiently preprocess and batch the input data.
   
2. **Forward Pass** 🔄:
   - Pass the tokenized data through the `gemma-2b` model.
   
3. **Loss Calculation** 🎯:
   - Use **Cross Entropy** as the loss function.

4. **Backpropagation** 🔙:
   - Perform backpropagation with **Gradient Accumulation** to stabilize updates.

5. **Optimization** 🛠️:
   - Use **AdamW** optimizer with **Cosine Learning Rate Scheduling** for smooth convergence.

---

## 🧪 Evaluation

We will evaluate the model's performance using:

- 📉 **Validation Loss** during training to track progress
- ✅ **Test Set Performance** to assess generalization
- 🔍 **Practical Examples** from both tasks to verify real-world applicability

---

## 📊 Monitoring

Training progress will be tracked using **Weights & Biases (W&B)** 📈. Model checkpoints will be saved based on **validation loss improvements** 💾.

---

## 🛠️ Next Steps

After successful training, the following steps will be taken:

1. **Evaluate Model Performance** 🏆:
   - Assess results on both translation and book description tasks.
   
2. **Fine-Tune Hyperparameters** 🎛️:
   - Adjust as necessary to optimize performance.
   
3. **Test Real-World Examples** 🌍:
   - Validate the model with practical scenarios.
   
4. **Deploy the Model** 🚀:
   - Make the model available for use.

---

## 📘 Documentation

Each phase will be thoroughly documented with:

- 📊 **Results**
- 📝 **Observations**
- 💡 **Insights for Improvement**

This ensures progress is clearly tracked and potential areas for enhancement are identified.


## Next Steps

1. **Evaluation**
   - Benchmark on Czech NLP tasks
   - Compare with baseline models

## References

1. ParaCrawl (2023). ParaCrawl v9.0. https://paracrawl.eu/v9
2. Gemma (2024). Google AI. https://blog.google/technology/ai/gemma-open-models/
3. Czech Books Descriptions Dataset. https://huggingface.co/datasets/vojtam/czech_books_descriptions