In [None]:
import torch
import torch.nn as nn
import numpy as np
from torch.utils.data import DataLoader, Subset, Dataset
import torch.nn.functional as F
from transformers import AutoTokenizer
from tqdm.notebook import tqdm
import random
import re
import math
import os
import json
import types
import time
from sklearn.model_selection import train_test_split

# Import the standardized models
# You can choose which model to use by commenting out one of these imports
# from weather_gru_models import AdvancedWeatherGRU as WeatherTextGRU
from weather_gru_models import AttentionWeatherGRU as WeatherTextGRU

# Import the standardized datasets
from weather_datasets import SimpleWeatherDataset, validate_and_clean_data_multithreaded

def skip_only_model_special_tokens(tokens, tokenizer):
    """
    Filter out model-specific special tokens from a list of tokens.
    
    Args:
        tokens (torch.Tensor): List of token IDs
        tokenizer: The tokenizer used to encode the text
        
    Returns:
        torch.Tensor: Filtered tokens without special tokens
    """
    # IDs of tokens to skip
    tokens_to_skip = set([
        tokenizer.cls_token_id,
        tokenizer.sep_token_id,
        tokenizer.pad_token_id
    ])
    
    # Filter out only model special tokens
    filtered_tokens = [t for t in tokens if t.item() not in tokens_to_skip]
    
    # Return as tensor
    return torch.tensor(filtered_tokens)

def create_improved_dataloader(dataset, batch_size, tokenizer):
    """
    Creates a more efficient DataLoader with dynamic sequence length handling.
    
    Args:
        dataset: The dataset to load
        batch_size: Batch size for the dataloader
        tokenizer: Tokenizer for text encoding
        
    Returns:
        DataLoader: Configured data loader with collate function
    """
    def collate_fn(batch_list):
        # Extract features and texts
        features = torch.stack([item['features'] for item in batch_list])
        texts = [item['text'] for item in batch_list]
        
        # Normalize features within batch for better training stability
        features = (features - features.mean(dim=(0, 1), keepdim=True)) / (
            features.std(dim=(0, 1), keepdim=True) + 1e-8)
        
        # Get token IDs with padding
        encoded = tokenizer(
            texts,
            padding=True,
            truncation=True,
            return_tensors='pt'
        )
        
        # Get sequence lengths for potential packed sequences
        seq_lengths = (encoded['attention_mask'] == 1).sum(dim=1)
        
        return {
            'features': features,
            'text': encoded['input_ids'],
            'attention_mask': encoded['attention_mask'],
            'seq_lengths': seq_lengths
        }
    
    return DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=True,
        collate_fn=collate_fn
    )

def count_model_parameters(model):
    """Count the number of trainable parameters in the model"""
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def set_seed(seed):
    """Set random seed for reproducibility"""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def prepare_for_improved_training(dataset):
    """
    Prepare the dataset for improved training by splitting into train/validation sets.
    
    Args:
        dataset: The dataset to split
        
    Returns:
        tuple: (train_dataset, val_dataset)
    """
    # Split into train/validation
    indices = list(range(len(dataset)))
    train_indices, val_indices = train_test_split(
        indices, test_size=0.1, random_state=42
    )
    
    train_dataset = Subset(dataset, train_indices)
    val_dataset = Subset(dataset, val_indices)
    
    return train_dataset, val_dataset

def load_and_validate_data_multithreaded(data_path='data/files_for_chatGPT/2024-12-12/', num_workers=None):
    """
    Load and validate weather data from JSON files using multiple threads.
    
    Args:
        data_path (str): Path to directory containing JSON files
        num_workers (int, optional): Number of worker threads to use
        
    Returns:
        list: List of loaded data samples
    """
    import os
    import json
    import concurrent.futures
    import multiprocessing
    
    # If num_workers is not specified, use CPU count
    if num_workers is None:
        num_workers = max(1, multiprocessing.cpu_count() - 1)  # Leave one CPU free
    
    # Check if we're in the right directory, navigate if needed
    if not os.path.exists(data_path):
        base_paths = ['.', '..', '../..']
        for base in base_paths:
            test_path = os.path.join(base, data_path)
            if os.path.exists(test_path):
                data_path = test_path
                break
    
    # List JSON files
    files = [f for f in os.listdir(data_path) if f.endswith('.json')]
    print(f"Found {len(files)} JSON files")
    
    # Define function to process a single file
    def process_file(file):
        file_path = os.path.join(data_path, file)
        try:
            with open(file_path, 'r', encoding='utf-8') as f:
                data = json.load(f)
                
            # Validate file - we're using SimpleWeatherDataset, so check for required fields
            if not {'gpt_rewritten_apokalyptisch_v2', 'city'}.issubset(data.keys()):
                return None
                
            if not isinstance(data['city'], str) or not data['city'].strip():
                return None
                
            if not isinstance(data['gpt_rewritten_apokalyptisch_v2'], str) or not data['gpt_rewritten_apokalyptisch_v2'].strip():
                return None
            
            # File is valid, return with a key
            key = (file.split('-')[-1]).split('_')[0]
            return (key, data)
        except json.JSONDecodeError:
            return None
        except Exception as e:
            print(f"Error processing {file}: {e}")
            return None
    
    data_dict = {}
    
    # Process files in parallel
    print(f"Loading files using {num_workers} workers...")
    with tqdm(total=len(files), desc="Loading files") as pbar:
        with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor:
            # Submit all tasks
            futures = {executor.submit(process_file, file): file for file in files}
            
            # Process results as they complete
            for future in concurrent.futures.as_completed(futures):
                result = future.result()
                if result is not None:
                    key, data = result
                    data_dict[key] = data
                pbar.update(1)
    
    print(f"Loaded {len(data_dict)} weather data points")
    
    # Convert to list
    return list(data_dict.values())

# Combined function to load, create, and clean dataset in parallel
def prepare_dataset_multithreaded(num_workers=None):
    """
    Full pipeline to load, create, and clean the dataset using multiple threads.
    
    Args:
        num_workers (int, optional): Number of worker threads to use
        
    Returns:
        tuple: (clean_dataset, train_dataset, val_dataset)
    """
    print("Starting multithreaded dataset preparation...")
    start_time = time.time()
    
    # Load data
    weather_data = load_and_validate_data_multithreaded(num_workers=num_workers)
    
    # Create dataset using SimpleWeatherDataset instead of the custom WeatherDataset
    dataset = SimpleWeatherDataset(weather_data)
    
    # Clean dataset using the imported function
    clean_dataset = validate_and_clean_data_multithreaded(dataset, num_workers=num_workers)[0]
    
    # Split into train/validation
    train_dataset, val_dataset = prepare_for_improved_training(clean_dataset)
    print(f"Training dataset size: {len(train_dataset)}")
    print(f"Validation dataset size: {len(val_dataset)}")
    
    end_time = time.time()
    print(f"Dataset preparation completed in {end_time - start_time:.2f} seconds")
    
    return clean_dataset, train_dataset, val_dataset

# Set random seed
set_seed(42)
dataset_start = time.time()

# Use multithreaded dataset preparation
clean_dataset, train_dataset, val_dataset = prepare_dataset_multithreaded()

dataset_end = time.time()
print(f"Dataset preparation completed in {dataset_end - dataset_start:.2f} seconds")

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoTokenizer, BertTokenizer
from torch.utils.data import DataLoader, Subset
from tqdm import tqdm
import random
import numpy as np
import time
import math
import os
import json
from collections import Counter

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def weather_collate_fn(batch_list, tokenizer):
    """
    Collate function for weather data batches.
    
    Args:
        batch_list (list): List of batch items
        tokenizer: Tokenizer for text encoding
        
    Returns:
        dict: Batch with features, tokens, and attention mask
    """
    # Extract features and texts
    features = torch.stack([item['features'] for item in batch_list])
    texts = [item['text'] for item in batch_list]
    
    # Normalize features within batch for better training stability
    features = (features - features.mean(dim=(0, 1), keepdim=True)) / (
        features.std(dim=(0, 1), keepdim=True) + 1e-8)
    
    # Tokenize with padding
    encoded = tokenizer(
        texts,
        padding=True,
        truncation=True,
        max_length=128,  # Limit sequence length for efficiency
        return_tensors='pt'
    )
    
    return {
        'features': features,
        'text': encoded['input_ids'],
        'attention_mask': encoded['attention_mask']
    }

def reduce_vocabulary(tokenizer, full_dataset, batch_size=64):
    """
    Identify used tokens and create a reduced vocabulary mapping.
    
    Args:
        tokenizer: Tokenizer with full vocabulary
        full_dataset: Dataset containing texts to analyze
        batch_size: Batch size for processing
        
    Returns:
        tuple: (token_mappings, reduced_vocab_size)
    """
    print("Analyzing vocabulary usage to reduce model size...")
    
    # Count tokens
    token_counter = Counter()
    
    print("Analyzing vocabulary usage to reduce model size...")
    token_counter = Counter()
    
    # Process the entire dataset directly without DataLoader
    for idx in tqdm(range(len(full_dataset)), desc="Scanning token usage"):
        # Get the raw text directly (handles Subset objects correctly)
        if isinstance(full_dataset, Subset):
            sample = full_dataset.dataset[full_dataset.indices[idx]]
        else:
            sample = full_dataset[idx]
        
        text = sample['text']
        
        # Tokenize directly
        tokens = tokenizer.encode(text, add_special_tokens=True)
        token_counter.update(tokens)
    
    # Always keep special tokens
    for special_token in tokenizer.special_tokens_map.values():
        if isinstance(special_token, str):
            token_id = tokenizer.convert_tokens_to_ids(special_token)
            if token_id not in token_counter:
                token_counter[token_id] = 1
        elif isinstance(special_token, list):
            for token in special_token:
                token_id = tokenizer.convert_tokens_to_ids(token)
                if token_id not in token_counter:
                    token_counter[token_id] = 1
    
    # Sort by frequency for efficient token ID assignment
    used_token_ids = sorted(token_counter.keys())
    
    # Create token ID mapping (old ID -> new ID)
    token_id_map = {old_id: new_id for new_id, old_id in enumerate(used_token_ids)}
    
    # Create reverse mapping for inference
    reverse_token_id_map = {new_id: old_id for old_id, new_id in token_id_map.items()}
    
    # Store the mappings for later use
    token_mappings = {
        'token_id_map': token_id_map,
        'reverse_token_id_map': reverse_token_id_map,
        'used_token_ids': used_token_ids
    }
    
    # Update vocabulary size to reduced size
    reduced_vocab_size = len(used_token_ids)
    original_vocab_size = len(tokenizer.vocab)
    print(f"Reduced vocabulary from {original_vocab_size:,} to {reduced_vocab_size:,} tokens " 
          f"({reduced_vocab_size/original_vocab_size*100:.1f}%)")
    
    return token_mappings, reduced_vocab_size

# Map tokens function for data loaders
def map_tokens_fn(batch, token_id_map):
    """
    Map token IDs from original vocabulary to reduced vocabulary.
    
    Args:
        batch: Batch containing token IDs
        token_id_map: Mapping from original to reduced token IDs
        
    Returns:
        dict: Batch with mapped token IDs
    """
    # Map the token IDs to new IDs
    old_tokens = batch['text']
    new_tokens = torch.zeros_like(old_tokens)
    
    # Apply mapping
    for i in range(old_tokens.size(0)):
        for j in range(old_tokens.size(1)):
            old_id = old_tokens[i, j].item()
            new_tokens[i, j] = token_id_map.get(old_id, 0)  # Default to 0 if token not found
    
    batch['text'] = new_tokens
    return batch

def train_model(args):
    """
    Train the weather text generation model.
    
    Args:
        args: Arguments for training configuration
        
    Returns:
        tuple: (model, tokenizer, token_mappings)
    """
    global tokenizer  # Make tokenizer accessible to model
    
    print(f"Using device: {device}")
    
    # Set seed for reproducibility
    torch.manual_seed(args.seed)
    random.seed(args.seed)
    np.random.seed(args.seed)
    
    # Load tokenizer - using German BERT for German text
    print("Loading tokenizer...")
    tokenizer = BertTokenizer.from_pretrained('bert-base-german-cased')
    
    # Add special tokens that appear in our texts
    special_tokens = {'additional_special_tokens': ['<city>', '<temp>']}
    tokenizer.add_special_tokens(special_tokens)
    
    print("Preparing datasets...")
    start_time = time.time()

    print(f"Dataset preparation completed in {time.time() - start_time:.2f}s")
    
    # Reduce vocabulary
    token_mappings, reduced_vocab_size = reduce_vocabulary(tokenizer, clean_dataset, args.batch_size)
    
    # Create dataloaders with token mapping
    train_loader = DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        shuffle=True,
        collate_fn=lambda batch: map_tokens_fn(weather_collate_fn(batch, tokenizer), token_mappings['token_id_map'])
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=args.batch_size,
        shuffle=False,
        collate_fn=lambda batch: map_tokens_fn(weather_collate_fn(batch, tokenizer), token_mappings['token_id_map'])
    )
    
    # Get feature dimension
    feature_dim = clean_dataset.dataset.feature_dim
    
    print(f"Feature dimension: {feature_dim}")
    print(f"Reduced vocabulary size: {reduced_vocab_size}")
    print(f"Training samples: {len(train_dataset)}")
    print(f"Validation samples: {len(val_dataset)}")
    
    # Create model
    print("Creating model...")
    model = WeatherTextGRU(
        feature_dim=feature_dim,
        hidden_size=args.hidden_size,
        vocab_size=reduced_vocab_size,
        dropout=args.dropout
    ).to(device)
    
    # Count parameters
    num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Model has {num_params:,} trainable parameters")
    
    # Ensure we're under the 3M parameter limit
    if num_params > 30_000_000:
        print(f"WARNING: Model exceeds 3M parameters ({num_params:,}), reducing hidden size...")
        
        # Reduce hidden size until we're under 3M params
        while num_params > 30_000_000 and args.hidden_size > 128:
            args.hidden_size -= 32
            
            model = WeatherTextGRU(
                feature_dim=feature_dim,
                hidden_size=args.hidden_size,
                vocab_size=reduced_vocab_size,
                dropout=args.dropout
            ).to(device)
            
            num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
            print(f"Adjusted model: {num_params:,} parameters with hidden_size={args.hidden_size}")
    
    # Setup aggressive optimizer
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=args.lr,
        weight_decay=0.01,
        betas=(0.9, 0.999)
    )
    
    # Loss function (with 'none' reduction for masking)
    criterion = nn.CrossEntropyLoss(reduction='none')
    
    # Or try focal loss to focus on harder examples
    def focal_loss(predictions, targets, gamma=2.0, alpha=0.25):
        ce_loss = F.cross_entropy(predictions, targets, reduction='none')
        pt = torch.exp(-ce_loss)
        focal_loss = alpha * (1-pt)**gamma * ce_loss
        return focal_loss

    # Learning rate scheduler for aggressive learning
    scheduler = torch.optim.lr_scheduler.OneCycleLR(
        optimizer,
        max_lr=args.lr * 10,  # Peak at 10x the base learning rate
        total_steps=args.epochs * len(train_loader),
        pct_start=0.1,  # Aggressive warm-up
        div_factor=25.0,  # Determines initial lr
        final_div_factor=10000.0,  # For steep decay at the end
        anneal_strategy='cos'
    )
    
    # Training parameters
    best_val_loss = float('inf')
    best_train_loss = float('inf')
    patience_counter = 0
    clip_value = 1.0
    
    # Prepare save directory
    os.makedirs(args.save_dir, exist_ok=True)
    
    training_start = time.time()
    
    # Training loop
    for epoch in range(args.epochs):
        epoch_start = time.time()
        print(f"\nEpoch {epoch+1}/{args.epochs}")
        
        # Training
        model.train()
        train_loss = 0.0
        
        train_pbar = tqdm(train_loader, desc="Training")
        for batch in train_pbar:
            # Move data to device
            features = batch['features'].to(device)
            tokens = batch['text'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            
            # Forward pass with very aggressive teacher forcing
            outputs = model(features, tokens, teacher_forcing_ratio=0.8)
            
            # Calculate loss (excluding first token - CLS)
            output_flat = outputs[:, 1:].reshape(-1, outputs.shape[-1])
            target_flat = tokens[:, 1:].reshape(-1)
            mask_flat = attention_mask[:, 1:].reshape(-1)
            
            # Compute token-wise loss
            losses = criterion(output_flat, target_flat)
            
            # Apply mask to ignore padding tokens
            masked_losses = losses * mask_flat
            
            # Average over non-padding tokens
            batch_loss = masked_losses.sum() / (mask_flat.sum() + 1e-8)
            
            # Backward pass
            optimizer.zero_grad()
            batch_loss.backward()
            
            # Clip gradients
            torch.nn.utils.clip_grad_norm_(model.parameters(), clip_value)
            
            # Update weights
            optimizer.step()
            
            # Update learning rate
            scheduler.step()
            
            # Update metrics
            train_loss += batch_loss.item()
            train_pbar.set_postfix({"loss": f"{batch_loss.item():.4f}"})
        
        avg_train_loss = train_loss / len(train_loader)
        
        # Validation
        model.eval()
        val_loss = 0.0
        
        with torch.no_grad():
            val_pbar = tqdm(val_loader, desc="Validating")
            for batch in val_pbar:
                features = batch['features'].to(device)
                tokens = batch['text'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                
                # Forward pass (no teacher forcing)
                outputs = model(features, tokens, teacher_forcing_ratio=0.0)
                
                # Calculate loss
                output_flat = outputs[:, 1:].reshape(-1, outputs.shape[-1])
                target_flat = tokens[:, 1:].reshape(-1)
                mask_flat = attention_mask[:, 1:].reshape(-1)
                
                losses = criterion(output_flat, target_flat)
                masked_losses = losses * mask_flat
                batch_loss = masked_losses.sum() / (mask_flat.sum() + 1e-8)
                
                val_loss += batch_loss.item()
                val_pbar.set_postfix({"loss": f"{batch_loss.item():.4f}"})
        
        avg_val_loss = val_loss / len(val_loader)
        # After computing train_loss for each epoch
        avg_train_loss = train_loss / len(train_loader)  
        avg_val_loss = val_loss / len(val_loader)

        epoch_time = time.time() - epoch_start
        print(f"Epoch {epoch+1} completed in {epoch_time:.2f}s")
        print(f"Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}")
        
        # Save best model
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss

        # Save model and token mappings
        if avg_train_loss < best_train_loss:
            best_train_loss = avg_train_loss
            # Save model and token mappings
            model_path = os.path.join(args.save_dir, "best_weather_text_model.pt")
            torch.save({
                'model_state_dict': model.state_dict(),
                'token_mappings': token_mappings,
                'model_config': {
                    'feature_dim': feature_dim,
                    'hidden_size': args.hidden_size,
                    'vocab_size': reduced_vocab_size,
                    'dropout': args.dropout
                },
                'train_args': vars(args),
                'epoch': epoch,
                'val_loss': best_val_loss
            }, model_path)
            
            print(f"âœ“ Model saved to {model_path}!")
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter >= args.patience:
                print(f"Early stopping triggered after {epoch+1} epochs")
                break
        
        # Generate sample text periodically
        if (epoch + 1) % 3 == 0 or epoch == 0 or epoch == args.epochs - 1:
            generate_samples(model, tokenizer, val_loader, token_mappings, 1)
    
    total_time = time.time() - training_start
    print(f"\nTraining completed in {total_time:.2f}s")
    
    # Load best model for final evaluation
    checkpoint = torch.load(os.path.join(args.save_dir, "best_weather_text_model.pt"))
    model.load_state_dict(checkpoint['model_state_dict'])
    
    # Final evaluation
    model.eval()
    final_val_loss = 0.0
    
    with torch.no_grad():
        for batch in tqdm(val_loader, desc="Final Evaluation"):
            features = batch['features'].to(device)
            tokens = batch['text'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            
            outputs = model(features, tokens, teacher_forcing_ratio=0.0)
            
            output_flat = outputs[:, 1:].reshape(-1, outputs.shape[-1])
            target_flat = tokens[:, 1:].reshape(-1)
            mask_flat = attention_mask[:, 1:].reshape(-1)
            
            losses = criterion(output_flat, target_flat)
            masked_losses = losses * mask_flat
            batch_loss = masked_losses.sum() / (mask_flat.sum() + 1e-8)
            
            final_val_loss += batch_loss.item()
    
    final_val_loss /= len(val_loader)
    print(f"Final validation loss: {final_val_loss:.4f}")
    
    # Generate final samples
    print("\nFinal generated samples:")
    generate_samples(model, tokenizer, val_loader, token_mappings, args.sample_count)
    
    return model, tokenizer, token_mappings

def generate_samples(model, tokenizer, data_loader, token_mappings, num_samples=5):
    """
    Generate text samples from the model.
    
    Args:
        model: Trained model
        tokenizer: Tokenizer for decoding
        data_loader: DataLoader for getting example features
        token_mappings: Mapping between original and reduced vocabulary
        num_samples: Number of samples to generate
    """
    model.eval()
    
    # Get mapping for converting back to original token IDs
    reverse_map = token_mappings['reverse_token_id_map']
    
    # IDs of tokens to keep (exclude [PAD], [CLS], [SEP])
    tokens_to_exclude = {
        tokenizer.pad_token_id,
        tokenizer.cls_token_id, 
        tokenizer.sep_token_id
    }

    # Get samples
    samples = []
    data_iter = iter(data_loader)
    
    for _ in range(num_samples):
        try:
            sample_batch = next(data_iter)
            samples.append(sample_batch)
        except StopIteration:
            # Reset iterator if we run out of samples
            data_iter = iter(data_loader)
            sample_batch = next(data_iter)
            samples.append(sample_batch)
    
    # Generate text for each sample
    for i, batch in enumerate(samples):
        sample_idx = 0  # Just use the first item in the batch
        
        # Get features and original tokens
        sample_features = batch['features'][sample_idx].unsqueeze(0).to(device)
        sample_tokens_mapped = batch['text'][sample_idx]
        
        # Get temperature, humidity and cloudiness data for reference
        temp_values = sample_features[0, :, 0].cpu().numpy()  # Temperature (first feature)
        humidity_values = sample_features[0, :, 1].cpu().numpy()  # Humidity (second feature)
        cloud_values = sample_features[0, :, 2].cpu().numpy()  # Cloudiness (third feature)
        
        # Generate text
        with torch.no_grad():
            generated_output = model(sample_features)
        
        # Get token predictions
        generated_tokens_mapped = generated_output[0].argmax(dim=1)
        
        # Map tokens back to original vocabulary
        original_tokens = torch.tensor([
            reverse_map[token.item()] for token in sample_tokens_mapped
        ])
        
        generated_tokens = torch.tensor([
            reverse_map[token.item()] for token in generated_tokens_mapped
        ])
        
        # Filter out unwanted special tokens from generation
        filtered_generated = [token.item() for token in generated_tokens 
                           if token.item() not in tokens_to_exclude]
        
        # Decode to text
        original_text = tokenizer.decode(original_tokens, skip_special_tokens=False)
        generated_text = tokenizer.decode(filtered_generated, skip_special_tokens=False)
        
        print(f"\nSample {i+1}:")
        print(f"Temperature: {[round(float(t), 1) for t in temp_values]}")
        print(f"Humidity: {[round(float(h), 1) for h in humidity_values]}")
        print(f"Cloudiness: {[round(float(c), 2) for c in cloud_values]}")
        print(f"Original: {original_text}")
        print(f"Generated: {generated_text}")
        print("-" * 80)

def generate_weather_text(model, tokenizer, features, token_mappings):
    """
    Generate a weather text description from features.
    
    Args:
        model: Trained weather model
        tokenizer: Tokenizer for decoding
        features: Weather features tensor
        token_mappings: Mapping between original and reduced vocabulary
        
    Returns:
        str: Generated weather text
    """
    model.eval()
    
    # IDs of tokens to exclude
    tokens_to_exclude = {
        tokenizer.pad_token_id,
        tokenizer.cls_token_id, 
        tokenizer.sep_token_id
    }
    
    with torch.no_grad():
        # Normalize features
        features = (features - features.mean(dim=0, keepdim=True)) / (
            features.std(dim=0, keepdim=True) + 1e-8)
        
        # Add batch dimension and move to device
        features = features.unsqueeze(0).to(device)
        
        # Generate text
        outputs = model(features)
        
        # Using temperature sampling for more diverse outputs
        temperature = 1.0
        probs = F.softmax(outputs[0] / temperature, dim=1)
        generated_tokens_mapped = torch.multinomial(probs, num_samples=1).squeeze(-1)
        
        # Map tokens back to original vocabulary
        reverse_map = token_mappings['reverse_token_id_map']
        generated_tokens = [reverse_map[token.item()] for token in generated_tokens_mapped]
        
        # Filter out unwanted special tokens
        filtered_tokens = [token for token in generated_tokens 
                          if token not in tokens_to_exclude]
        
        # Start text with "In" if it doesn't already
        if filtered_tokens and tokenizer.decode([filtered_tokens[0]]) != "In":
            in_token_id = tokenizer.convert_tokens_to_ids("In")
            filtered_tokens = [in_token_id] + filtered_tokens
        
        # Decode to text
        generated_text = tokenizer.decode(filtered_tokens, skip_special_tokens=False)
        
        return generated_text

if __name__ == "__main__":
    import argparse
    torch.cuda.empty_cache()
    
    # Create arguments object manually for notebook environment
    from types import SimpleNamespace
    args = SimpleNamespace(
        epochs=10,
        batch_size=128,
        hidden_size=512,
        lr=1e-3,
        dropout=0.3,
        patience=6,
        save_dir='./models',
        seed=42,
        test_only=False,
        sample_count=5,
        model_path=None
    )

    # Create save directory if it doesn't exist
    os.makedirs(args.save_dir, exist_ok=True)
    
    # Train or load model
    if not args.test_only:
        print(f"Training model with {args.epochs} epochs, batch size {args.batch_size}, hidden size {args.hidden_size}")
        model, tokenizer, token_mappings = train_model(args)
    else:
        # Load model for testing
        print("Loading pre-trained model for inference...")
        
        # Determine model path
        model_path = args.model_path or os.path.join(args.save_dir, "best_weather_text_model.pt")
        
        if not os.path.exists(model_path):
            print(f"Error: Model file not found at {model_path}")
            exit(1)
        
        # Load checkpoint
        checkpoint = torch.load(model_path, map_location=device)
        
        # Load tokenizer
        tokenizer = BertTokenizer.from_pretrained('bert-base-german-cased')
        special_tokens = {'additional_special_tokens': ['<city>', '<temp>']}
        tokenizer.add_special_tokens(special_tokens)
        
        # Get token mappings
        token_mappings = checkpoint['token_mappings']
        
        # Get model config
        model_config = checkpoint['model_config']
        
        # Create model
        model = WeatherTextGRU(
            feature_dim=model_config['feature_dim'],
            hidden_size=model_config['hidden_size'],
            vocab_size=model_config['vocab_size'],
            dropout=model_config.get('dropout', 0.2)
        ).to(device)
        
        # Load weights
        model.load_state_dict(checkpoint['model_state_dict'])
        model.eval()
        
        print(f"Model loaded from {model_path}")
        print(f"Model has {sum(p.numel() for p in model.parameters() if p.requires_grad):,} parameters")
        
        # Prepare datasets for testing
        set_seed(args.seed)
        
        # Create dataloader with token mapping
        val_loader = DataLoader(
            val_dataset,
            batch_size=args.batch_size,
            shuffle=True,
            collate_fn=lambda batch: map_tokens_fn(
                weather_collate_fn(batch, tokenizer), 
                token_mappings['token_id_map']
            )
        )
        
        # Generate samples
        print(f"\nGenerating {args.sample_count} weather text samples:")
        generate_samples(model, tokenizer, val_loader, token_mappings, args.sample_count)
    
    print("Done!")