In [None]:
import torch
import re
import math
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer
from GRU_collection import BasicWeatherGRU as WeatherGRU
from tqdm.notebook import tqdm
import numpy as np
import json
import os
import random

In [4]:
# Set random seed for reproducibility
torch.manual_seed(42)

# Initialize tokenizer
tokenizer = AutoTokenizer.from_pretrained('bert-base-german-cased')

# Special tokens for text substitution
special_tokens = {
    'additional_special_tokens': ['<city>','<temp>','<date>','<velocity>','<percentile>','<rainfall>', '<ne>']
}

# Add special tokens into tokenizer
tokenizer.add_special_tokens(special_tokens)

7

In [5]:
class WeatherDataset(Dataset):
    def __init__(self, weather_data, max_length=100):
        self.data = [weather_data] if isinstance(weather_data, dict) else weather_data
        self.max_length = max_length
        
        # Initialize wind directions from all data points
        self.wind_directions = sorted(list(set([d for data in self.data for d in data['windrichtung']])))
        self.wind_dir_to_idx = {d: i for i, d in enumerate(self.wind_directions)}
        
        # Calculate expected feature dimension
        self.feature_dim = (
            1 +  # temperature
            1 +  # rain risk
            1 +  # rain amount
            1 +  # wind speed
            1 +  # pressure
            1 +  # humidity
            1 +  # cloudiness
            len(self.wind_directions) +  # one-hot wind directions
            2 +  # time encoding (sin, cos)
            3 +  # sun features
            1    # sun hours
        )
        
        # Initialize spaCy NER model for German text
        try:
            import spacy
            print("Loading German NER model...")
            self.nlp = spacy.load("de_core_news_lg")
            self.has_ner = True
            print("NER model loaded successfully")
        except:
            print("Warning: Could not load spaCy model. Make sure to install it with:")
            print("pip install spacy")
            print("python -m spacy download de_core_news_lg")
            self.has_ner = False
    
    def replace_dates(self, text: str) -> str:
        text = re.sub(r"\b\d{1,2}\.\d{1,2}\.\d{4}\b", "<date>", text)
        return text
    
    def apply_ner_replacement(self, text: str) -> str:
        """
        Apply Named Entity Recognition to replace geographic entities with <ne> tag
        
        Args:
            text (str): Input text
            
        Returns:
            str: Text with geographic entities replaced by <ne> tag
        """
        if not self.has_ner:
            return text
            
        doc = self.nlp(text)
        
        # Collect entities and their positions
        entities = []
        for ent in doc.ents:
            if ent.label_ in ["LOC", "GPE"]:  # Locations and Geopolitical entities
                entities.append((ent.start_char, ent.end_char, ent.text))
        
        # Sort entities by start position in reverse order to avoid offset issues
        entities.sort(reverse=True, key=lambda x: x[0])
        
        # Replace each entity with <ne> tag
        for start, end, entity in entities:
            text = text[:start] + "<ne>" + text[end:]
        
        return text
        
    def replace_city_and_units(self, text: str, city: str) -> str:
        # Replace the city name with <city> tag
        # Make sure city isn't empty before replacing
        if city and len(city) > 0:
            text = re.sub(r'\b' + re.escape(city) + r'\b', '<city>', text)

        unit_patterns = [
        # TEMPERATURE
        (r'(°[ ]*C|Grad)', ' <temp>'),
        # VELOCITY
        (r'[ ]*km/h', ' <velocity>'),
        # PERCENTILE
        (r'[ ]*%', ' <percentile>'),
        # RAINFALL DELTA
        # (r'\d+\.\d+ bis \d+\.\d+[ ]*l\/m²', '<rainfall> bis <rainfall>'),
        (r'[ ]*l\/m²', ' <rainfall>')
        ]
        for pattern, replacement in unit_patterns:
            try:
                text = re.sub(pattern, replacement, text)
            except Exception:
                continue

        # REMOVE MARKUP
        text = re.sub(r'\**', '', text)

        # REMOVE WEIRD PUNCUATION
        text = re.sub(r' \.', '.', text)

        # REMOVE UNNECESSARY NEWLINES
        text = re.sub(r'\n\n', '\n', text)

        # REMOVE SPACE AFTER NEWLINE
        text = re.sub(r'\n ', '\n', text)

        # REPLACE MULTIPLE WHITESPACES WITH ONE
        text = re.sub(r' +', ' ', text)
        
        return text

    def _parse_time(self, time_str):
        """Parse time string and handle missing data"""
        if time_str == '-' or not time_str:
            return None
        try:
            # Handle "HH:MM Uhr" format
            if ':' in time_str:
                hour, minute = map(int, time_str.split(' ')[0].split(':'))
                return hour + minute/60
            return None
        except (ValueError, IndexError):
            return None

    def _encode_time(self, time_str):
        # Convert "HH - HH Uhr" to cyclic features
        try:
            start_hour = int(time_str.split(' - ')[0])
            hour_sin = torch.sin(torch.tensor(2 * math.pi * start_hour / 24))
            hour_cos = torch.cos(torch.tensor(2 * math.pi * start_hour / 24))
            return torch.tensor([hour_sin, hour_cos])
        except (ValueError, IndexError):
            # Return neutral values for invalid time
            return torch.tensor([0.0, 1.0])
        
    def __len__(self):
        return len(self.data)
    
    def _encode_sun_info(self, sunrise, sunset, current_time):
        # Parse times, handling missing data
        sunrise_hour = self._parse_time(sunrise)
        sunset_hour = self._parse_time(sunset)
        
        try:
            current_hour = float(current_time.split(' - ')[0])
        except (ValueError, IndexError):
            # Return default values if current time is invalid
            return torch.tensor([0.0, 0.0, 0.0])
        
        # If sunrise or sunset is missing, use approximate values based on season
        if sunrise_hour is None or sunset_hour is None:
            # Return default encoding indicating uncertainty
            return torch.tensor([
                0.5,  # Unknown daylight status
                0.0,  # Neutral time since sunrise
                0.0   # Neutral time until sunset
            ])
        
        # Calculate daylight features
        is_daylight = (current_hour >= sunrise_hour) and (current_hour <= sunset_hour)
        
        if is_daylight:
            time_since_sunrise = (current_hour - sunrise_hour) / (sunset_hour - sunrise_hour)
            time_until_sunset = (sunset_hour - current_hour) / (sunset_hour - sunrise_hour)
        else:
            if current_hour < sunrise_hour:
                time_since_sunrise = -1 * (sunrise_hour - current_hour) / (24 - sunset_hour + sunrise_hour)
                time_until_sunset = -1
            else:
                time_since_sunrise = -1
                time_until_sunset = -1 * (current_hour - sunset_hour) / (24 - sunset_hour + sunrise_hour)
        
        return torch.tensor([float(is_daylight), time_since_sunrise, time_until_sunset])

    def one_hot_wind(self, wind_dir):
        encoding = torch.zeros(len(self.wind_directions))
        encoding[self.wind_dir_to_idx[wind_dir]] = 1
        return encoding
    
    def process_text(self, text, city):
        """Process text with all replacements in the correct order"""
        # First replace city and units 
        processed_text = self.replace_city_and_units(text, city)
        
        # Then replace dates
        processed_text = self.replace_dates(processed_text)
        
        # CRITICAL: Handle NER replacement properly
        # We need to ensure that the model can properly recognize <ne> tags
        # First, check if we have NER capability
        if self.has_ner:
            # Get all current special token positions
            special_tokens = ['<city>', '<temp>', '<date>', '<velocity>', '<percentile>', '<rainfall>']
            protected_regions = []
            
            for token in special_tokens:
                start_idx = 0
                while start_idx < len(processed_text):
                    pos = processed_text.find(token, start_idx)
                    if pos == -1:
                        break
                    protected_regions.append((pos, pos + len(token)))
                    start_idx = pos + len(token)
            
            # Sort protected regions
            protected_regions.sort()
            
            # Create a list of unprotected text regions that can be processed with NER
            unprotected_regions = []
            last_end = 0
            
            for start, end in protected_regions:
                if start > last_end:
                    unprotected_regions.append((last_end, start))
                last_end = end
            
            # Add the final region
            if last_end < len(processed_text):
                unprotected_regions.append((last_end, len(processed_text)))
            
            # Process each unprotected region with NER
            result_text = ""
            last_pos = 0
            
            for start, end in unprotected_regions:
                # Add any special tokens before this region
                result_text += processed_text[last_pos:start]
                
                # Process this region with NER
                region_text = processed_text[start:end]
                doc = self.nlp(region_text)
                
                # Collect entities in this region
                entities = []
                for ent in doc.ents:
                    if ent.label_ in ["LOC", "GPE"]:  # Locations and Geopolitical entities
                        entities.append((ent.start_char, ent.end_char, ent.text))
                
                # Sort entities by start position in reverse order
                entities.sort(reverse=True, key=lambda x: x[0])
                
                # Replace entities with <ne> tag
                region_result = region_text
                for ent_start, ent_end, _ in entities:
                    region_result = region_result[:ent_start] + "<ne>" + region_result[ent_end:]
                
                result_text += region_result
                last_pos = end
            
            # Add any remaining text
            result_text += processed_text[last_pos:]
            processed_text = result_text
        
        return processed_text
    
    def __getitem__(self, idx):
        item = self.data[idx]
        
        # Get sequence length from the data
        seq_len = len(item['temperatur_in_deg_C'])
        
        # Initialize features tensor with correct shape
        features = torch.zeros((seq_len, self.feature_dim))
        
        # Fill features one by one, maintaining consistent shapes
        current_idx = 0
        
        # Numerical features - all should be shape [seq_len, 1]
        features[:, current_idx] = torch.tensor([float(t) for t in item['temperatur_in_deg_C']])
        current_idx += 1
        
        features[:, current_idx] = torch.tensor([float(r) for r in item['niederschlagsrisiko_in_perc']])
        current_idx += 1
        
        # Handle rain amount with forward filling for NaN values
        rain_values = []
        last_valid = 0.0
        for r in item['niederschlagsmenge_in_l_per_sqm']:
            try:
                val = float(r)
                if not torch.isnan(torch.tensor(val)):
                    last_valid = val
                rain_values.append(last_valid)
            except ValueError:
                rain_values.append(last_valid)
        features[:, current_idx] = torch.tensor(rain_values)
        current_idx += 1
        
        features[:, current_idx] = torch.tensor([float(w) for w in item['windgeschwindigkeit_in_km_per_s']])
        current_idx += 1
        
        features[:, current_idx] = torch.tensor([float(p) for p in item['luftdruck_in_hpa']])
        current_idx += 1
        
        features[:, current_idx] = torch.tensor([float(h) for h in item['relative_feuchte_in_perc']])
        current_idx += 1
        
        features[:, current_idx] = torch.tensor([float(c.split('/')[0]) / 8 for c in item['bewölkungsgrad']])
        current_idx += 1
        
        # Wind directions (one-hot encoded)
        wind_features = torch.stack([self.one_hot_wind(w) for w in item['windrichtung']])
        features[:, current_idx:current_idx + len(self.wind_directions)] = wind_features
        current_idx += len(self.wind_directions)
        
        # Time features
        time_features = torch.stack([self._encode_time(t) for t in item['times']])
        features[:, current_idx:current_idx + 2] = time_features
        current_idx += 2
        
        # Sun features
        sun_features = torch.stack([
            self._encode_sun_info(
                item.get('sunrise', '-'), 
                item.get('sundown', '-'), 
                t
            ) for t in item['times']
        ])
        features[:, current_idx:current_idx + 3] = sun_features
        current_idx += 3
        
        # Sun hours feature
        sun_hours = torch.tensor([1.0 if "fast nicht zu sehen" in item.get('sunhours', '') else 0.0])
        features[:, current_idx] = sun_hours.expand(seq_len)
        
        # Process the text using the NER-based method
        processed_text = self.process_text(item['report_long'], item['city'])
        
        return {
            'features': features,
            'text': processed_text
        }
    
    def scan_dataset_for_named_entities(self, sample_size=None):
        """
        Scan the dataset to identify common named entities
        
        Args:
            sample_size: Number of samples to analyze, None for all
            
        Returns:
            Counter of named entities and their frequencies
        """
        if not self.has_ner:
            print("NER model not available. Cannot scan for named entities.")
            return None
            
        from collections import Counter
        
        # Get samples to analyze
        if sample_size is None:
            samples = self.data
        else:
            import random
            samples = random.sample(self.data, min(sample_size, len(self.data)))
            
        # Collect all entities
        all_entities = []
        
        for i, item in enumerate(samples):
            if i % 100 == 0:
                print(f"Processing sample {i}/{len(samples)}...")
                
            doc = self.nlp(item['report_long'])
            
            for ent in doc.ents:
                if ent.label_ in ["LOC", "GPE"]:
                    all_entities.append(ent.text)
        
        # Return counter of entities
        return Counter(all_entities)

In [6]:
if __name__=='__main__':
    
    # change directory if not on root
    if str(os.getcwd()).endswith('LLamas') == False:
        os.chdir('../..')
    
    # check if dict contains correct key
    # def check_file(path):
    #     with open(path, 'r', encoding='utf-8') as f:
    #         data: dict = json.load(f)
    #         if 'gpt_rewritten_v2' in data.keys():
    #             return True
    #         else:
    #             return False

    def check_file(path):
        """
        Checks if the file contains valid keys and values.
        Returns True if the file should be loaded, False otherwise.
        """
        with open(path, 'r', encoding='utf-8') as f:
            try:
                data: dict = json.load(f)
            except json.JSONDecodeError:
                print(f"Invalid JSON in file: {path}")
                return False

            # Ensure required keys are present
            if not {'report_long', 'city'}.issubset(data.keys()):
                return False

            # Check if 'city' and 'report_long' have valid non-empty values
            if not isinstance(data['city'], str) or not data['city'].strip():
                return False
            if not isinstance(data['report_long'], str) or not data['report_long'].strip():
                return False

            return True
                
    # transform data into weatherDataSet class format
    def load_data(path):
        with open(path, 'r', encoding='utf-8') as f:
            data: dict = json.load(f)
            return data
    
    # files for reading
    files = os.listdir(os.path.join(os.getcwd(), 'data', 'files_for_chatGPT', '2024-12-12'))
    files = {(file.split('-')[-1]).split('_')[0]:load_data(os.path.join(os.getcwd(), 'data', 'files_for_chatGPT', '2024-12-12', file)) for file in tqdm(files) if check_file(os.path.join(os.getcwd(), 'data', 'files_for_chatGPT', '2024-12-12', file))}

  0%|          | 0/28812 [00:00<?, ?it/s]

In [7]:
weather_data = list(files.values())
dataset = WeatherDataset(weather_data, max_length=100)
# Optionally analyze entities in the dataset
entity_counts = dataset.scan_dataset_for_named_entities(sample_size=1000)
print(entity_counts)
def validate_and_clean_weather_data(weather_data, dataset_class):
    """
    Validates the dataset and returns cleaned weather_data with problematic samples removed.
    
    Args:
        weather_data: List of weather data samples
        dataset_class: The dataset class constructor to use for validation
    
    Returns:
        tuple: (cleaned_weather_data, removed_indices, validation_summary)
    """
    # Create temporary dataset for validation
    temp_dataset = dataset_class(weather_data, max_length=100)
    
    summary = {
        'total_samples': len(temp_dataset),
        'invalid_samples': [],
        'statistics': {
            'nan_count': 0,
            'inf_count': 0,
            'extreme_values': 0
        }
    }
    
    invalid_indices = set()
    
    # Validate each sample
    for idx in range(len(temp_dataset)):
        try:
            sample = temp_dataset[idx]
            features = sample['features']
            
            has_issue = False
            
            # Check for NaN values
            nan_mask = torch.isnan(features)
            if nan_mask.any():
                summary['statistics']['nan_count'] += nan_mask.sum().item()
                has_issue = True
                
            # Check for infinity
            inf_mask = torch.isinf(features)
            if inf_mask.any():
                summary['statistics']['inf_count'] += inf_mask.sum().item()
                has_issue = True
                
            # Check for extreme values
            extreme_mask = (features.abs() > 1e6)
            if extreme_mask.any():
                summary['statistics']['extreme_values'] += extreme_mask.sum().item()
                has_issue = True
            
            if has_issue:
                invalid_indices.add(idx)
                summary['invalid_samples'].append({
                    'index': idx,
                    'text': sample['text']
                })
                
        except Exception as e:
            print(f"Error processing sample {idx}: {str(e)}")
            invalid_indices.add(idx)
            summary['invalid_samples'].append({
                'index': idx,
                'error_message': str(e)
            })
    
    # Create cleaned weather_data list
    cleaned_weather_data = [
        data for idx, data in enumerate(weather_data) 
        if idx not in invalid_indices
    ]
    
    # Print summary
    print("\nValidation Summary:")
    print(f"Total samples: {summary['total_samples']}")
    print(f"Samples with issues: {len(summary['invalid_samples'])}")
    print(f"Total NaN values: {summary['statistics']['nan_count']}")
    print(f"Total infinite values: {summary['statistics']['inf_count']}")
    print(f"Total extreme values: {summary['statistics']['extreme_values']}")
    print(f"\nRemoved {len(invalid_indices)} samples")
    print(f"Remaining samples: {len(cleaned_weather_data)}")
    
    return cleaned_weather_data, list(invalid_indices), summary

# Example usage function
def clean_and_create_dataset(weather_data, dataset_class):
    """
    Cleans the weather data and creates a new dataset.
    
    Args:
        weather_data: Original weather data list
        dataset_class: Dataset class to use
    
    Returns:
        tuple: (new_dataset, removed_indices, validation_summary)
    """
    cleaned_data, removed_indices, summary = validate_and_clean_weather_data(
        weather_data, dataset_class
    )
    
    # Create new dataset with cleaned data
    clean_dataset = dataset_class(cleaned_data, max_length=100)
    
    return clean_dataset, removed_indices, summary

# Clean the data and create new dataset
clean_dataset, removed_indices, summary = clean_and_create_dataset(weather_data, WeatherDataset)

Loading German NER model...
NER model loaded successfully
Processing sample 0/1000...
Processing sample 100/1000...
Processing sample 200/1000...
Processing sample 300/1000...
Processing sample 400/1000...
Processing sample 500/1000...
Processing sample 600/1000...
Processing sample 700/1000...
Processing sample 800/1000...
Processing sample 900/1000...
Counter({'Karibik': 153, 'Werten': 52, 'Temperaturen': 52, 'Nachts': 18, 'Kapverden': 7, 'Santa Fe': 6, 'California': 6, 'Catalina': 5, 'Trinidad': 5, 'Las Mercedes': 5, 'San Pedro': 4, 'La Rinconada': 4, 'Morgen Wolken': 4, 'San Pablo': 4, 'Italienische Adria': 4, 'San Agustín': 4, 'San Antonio': 4, 'Florida': 4, 'San Lorenzo': 4, 'Pueblo Nuevo': 4, 'Tierra Blanca': 4, 'La Esperanza': 4, 'Stetten SH': 3, 'Santa Bárbara': 3, '°C': 3, 'Montão': 3, 'Malveira': 3, 'Grande Tête Rouge': 3, 'Los Caballos': 3, 'Au Parc': 3, 'Baliares': 3, 'Preveza': 3, 'Golden Grove': 3, 'El Salado': 3, 'Mata de Felipito': 3, 'Saint Lucia': 3, 'La Source': 3, 

In [8]:
def reduce_vocabulary(tokenizer, dataset, batch_size=64):
    """Identify used tokens and create a reduced vocabulary mapping"""
    print("Analyzing vocabulary usage to reduce model size...")
    
    from collections import Counter
    from tqdm.notebook import tqdm
    import torch
    
    # Count tokens
    token_counter = Counter()
    
    # Process the entire dataset
    for idx in tqdm(range(len(dataset)), desc="Scanning token usage"):
        # Get the raw text directly
        if isinstance(dataset, torch.utils.data.Subset):
            sample = dataset.dataset[dataset.indices[idx]]
        else:
            sample = dataset[idx]
        
        text = sample['text']
        
        # Tokenize directly
        tokens = tokenizer.encode(text, add_special_tokens=True)
        token_counter.update(tokens)
    
    # Always keep special tokens
    for special_token in tokenizer.special_tokens_map.values():
        if isinstance(special_token, str):
            token_id = tokenizer.convert_tokens_to_ids(special_token)
            if token_id not in token_counter:
                token_counter[token_id] = 1
        elif isinstance(special_token, list):
            for token in special_token:
                token_id = tokenizer.convert_tokens_to_ids(token)
                if token_id not in token_counter:
                    token_counter[token_id] = 1
    
    # Sort by frequency for efficient token ID assignment
    used_token_ids = sorted(token_counter.keys())
    
    # Create token ID mapping (old ID -> new ID)
    token_id_map = {old_id: new_id for new_id, old_id in enumerate(used_token_ids)}
    
    # Create reverse mapping for inference
    reverse_token_id_map = {new_id: old_id for old_id, new_id in token_id_map.items()}
    
    # Store the mappings for later use
    token_mappings = {
        'token_id_map': token_id_map,
        'reverse_token_id_map': reverse_token_id_map,
        'used_token_ids': used_token_ids
    }
    
    # Update vocabulary size to reduced size
    reduced_vocab_size = len(used_token_ids)
    original_vocab_size = len(tokenizer.vocab)
    print(f"Reduced vocabulary from {original_vocab_size:,} to {reduced_vocab_size:,} tokens " 
          f"({reduced_vocab_size/original_vocab_size*100:.1f}%)")
    
    return token_mappings, reduced_vocab_size

# Map tokens function for data loaders
def map_tokens_fn(batch, token_id_map):
    """Map token IDs to new reduced IDs"""
    import torch
    
    # Map the token IDs to new IDs
    old_tokens = batch['text']
    new_tokens = torch.zeros_like(old_tokens)
    
    # Apply mapping
    for i in range(old_tokens.size(0)):
        for j in range(old_tokens.size(1)):
            old_id = old_tokens[i, j].item()
            new_tokens[i, j] = token_id_map.get(old_id, 0)  # Default to 0 if token not found
    
    batch['text'] = new_tokens
    return batch

# Modified prepare_batch function to use token mapping
def prepare_batch_with_mapping(batch_list, tokenizer, token_id_map=None):
    """Prepare batch with optional token mapping for reduced vocabulary"""
    import torch
    
    features = torch.stack([item['features'] for item in batch_list])
    texts = [item['text'] for item in batch_list]
    
    # Normalize features
    features = (features - features.mean()) / (features.std() + 1e-8)
    
    encoded = tokenizer(
        texts,
        padding=True,
        truncation=True,
        return_tensors='pt'
    )
    
    batch = {
        'features': features,
        'text': encoded['input_ids']
    }
    
    # Apply token mapping if provided
    if token_id_map is not None:
        batch = map_tokens_fn(batch, token_id_map)
    
    return batch

# Create DataLoader with token mapping support
def create_dataloader_with_mapping(dataset, batch_size, tokenizer, token_id_map=None):
    """Create a DataLoader with optional token mapping"""
    from torch.utils.data import DataLoader
    
    return DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=True,
        collate_fn=lambda batch: prepare_batch_with_mapping(batch, tokenizer, token_id_map)
    )

def count_model_parameters(model):
    """Count the number of trainable parameters in the model"""
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def adjust_model_size(feature_dim, vocab_size, embedding_dim, hidden_dim, max_params=30_000_000):
    """Adjust model dimensions to stay under parameter limit"""
    import math
    
    # Calculate WeatherGRU model parameter count
    # Main parameters:
    # - Feature encoder: feature_dim * embedding_dim * 2 (two linear layers)
    # - Embedding: vocab_size * embedding_dim
    # - GRU: embedding_dim * hidden_dim * 3 (input projection) + hidden_dim * hidden_dim * 3 (recurrent)
    # - Output layer: hidden_dim * vocab_size
    
    param_count = (
        feature_dim * embedding_dim * 2 +  # Feature encoder (two layers)
        vocab_size * embedding_dim +       # Embedding layer
        (embedding_dim * hidden_dim * 3) + # GRU input projection (3 gates)
        (hidden_dim * hidden_dim * 3) +    # GRU recurrent connections (3 gates)
        hidden_dim * vocab_size            # Output layer
    )
    
    print(f"Initial parameter count: {param_count:,}")
    
    if param_count <= max_params:
        return embedding_dim, hidden_dim
    
    # If we need to reduce params, try to find optimal dimensions
    # Strategy: reduce both embedding_dim and hidden_dim proportionally
    
    # Calculate reduction factor needed
    reduction_factor = math.sqrt(max_params / param_count)
    
    # Apply reduction
    new_embedding_dim = max(32, int(embedding_dim * reduction_factor))
    new_hidden_dim = max(64, int(hidden_dim * reduction_factor))
    
    # Recalculate to verify
    new_param_count = (
        feature_dim * new_embedding_dim * 2 +  
        vocab_size * new_embedding_dim +       
        (new_embedding_dim * new_hidden_dim * 3) + 
        (new_hidden_dim * new_hidden_dim * 3) +    
        new_hidden_dim * vocab_size            
    )
    
    print(f"Adjusted dimensions: embedding_dim={new_embedding_dim}, hidden_dim={new_hidden_dim}")
    print(f"Adjusted parameter count: {new_param_count:,}")
    
    return new_embedding_dim, new_hidden_dim

def save_model_with_metadata(model, path, token_mappings=None, config=None):
    """Save model with all required metadata for later use"""
    import torch
    import os
    
    # Create directory if it doesn't exist
    os.makedirs(os.path.dirname(path), exist_ok=True)
    
    # Prepare save data
    save_data = {
        'model_state_dict': model.state_dict(),
    }
    
    # Add token mappings if available
    if token_mappings is not None:
        save_data['token_mappings'] = token_mappings
    
    # Add model configuration if available
    if config is not None:
        save_data['config'] = config
    
    # Save to file
    torch.save(save_data, path)
    print(f"Model saved to {path}")
    
    return path

def load_model_with_metadata(path, device='cpu'):
    """Load model with all metadata"""
    import torch
    
    # Load saved data
    save_data = torch.load(path, map_location=device)
    
    return save_data

# ANOMALY DETECTION RUN

In [None]:
# Create DataLoader with token mapping support
def create_dataloader_with_mapping(dataset, batch_size, tokenizer, token_id_map=None):
    return DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=True,
        collate_fn=lambda batch: prepare_batch_with_mapping(batch, tokenizer, token_id_map)
    )

def train_epoch(model, dataloader, criterion, optimizer, device, teacher_forcing_ratio=1.0, epoch=0, total_epochs=1):
    model.train()
    total_loss = 0
    num_batches = 0
    pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{total_epochs}")
    
    for batch_idx, batch in enumerate(pbar):
        try:
            features = batch['features'].to(device)
            text = batch['text'].to(device)
            
            # Check for invalid values in inputs
            if torch.isnan(features).any() or torch.isinf(features).any():
                print(f"Warning: Invalid values in features at batch {batch_idx}")
                continue
                
            optimizer.zero_grad()
            
            # Forward pass with gradient checking
            with torch.autograd.detect_anomaly():
                outputs = model(features, text, teacher_forcing_ratio)
                outputs = outputs.view(-1, outputs.size(-1))
                targets = text[:, 1:].contiguous().view(-1)
                
                loss = criterion(outputs, targets)
                
                # Check if loss is valid
                if torch.isnan(loss) or torch.isinf(loss):
                    print(f"Warning: Invalid loss value {loss.item()} at batch {batch_idx}")
                    print("Last output values:", outputs[-5:])
                    print("Last target values:", targets[-5:])
                    raise ValueError("Invalid loss detected")
                
                loss.backward()
                
                # Clip gradients
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                
                # Check gradients
                for name, param in model.named_parameters():
                    if param.grad is not None:
                        grad_norm = param.grad.norm()
                        if torch.isnan(grad_norm) or torch.isinf(grad_norm):
                            print(f"Warning: Invalid gradient for {name}")
                            raise ValueError(f"Invalid gradient detected in {name}")
                
                optimizer.step()
                
                total_loss += loss.item()
                num_batches += 1
                avg_loss = total_loss / num_batches
                pbar.set_postfix({"Loss": f"{avg_loss:.4f}"})
                
        except ValueError as e:
            print(f"Error in batch {batch_idx}: {str(e)}")
            continue
            
    return total_loss / num_batches if num_batches > 0 else float('inf')

# Updated train_model function with vocabulary reduction and parameter management
def train_model(dataset, tokenizer, max_params=3_000_000, num_epochs=10, batch_size=64, learning_rate=1e-3):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    # Set seed for reproducibility
    torch.manual_seed(42)
    random.seed(42)
    np.random.seed(42)
    
    # First, reduce vocabulary
    token_mappings, reduced_vocab_size = reduce_vocabulary(tokenizer, dataset, batch_size)
    
    # Get feature dimension
    feature_dim = dataset.feature_dim
    
    # Determine optimal model dimensions based on parameter constraints
    embedding_dim, hidden_dim = adjust_model_size(
        feature_dim=feature_dim,
        vocab_size=reduced_vocab_size,
        embedding_dim=256,  # Starting point
        hidden_dim=512,    # Starting point
        max_params=max_params
    )
    
    # Create model with adjusted dimensions
    model = WeatherGRU(
        feature_dim=feature_dim,
        vocab_size=reduced_vocab_size,
        embedding_dim=embedding_dim,
        hidden_dim=hidden_dim,
        n_layers=1,
        dropout=0.2
    ).to(device)
    
    # Count parameters
    param_count = count_model_parameters(model)
    print(f"Model has {param_count:,} trainable parameters")
    
    # Create dataloader with token mapping
    dataloader = create_dataloader_with_mapping(
        dataset, 
        batch_size=batch_size, 
        tokenizer=tokenizer,
        token_id_map=token_mappings['token_id_map']
    )
    
    criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id, reduction='mean')
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=0.01)
    
    # Add learning rate scheduler
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=2, verbose=True
    )
    
    losses = []
    best_loss = float('inf')
    
    # Select a few examples for generation testing
    test_indices = [0, len(dataset)//2, len(dataset)-1]  # Beginning, middle, and end
    test_samples = [dataset[i] for i in test_indices]
    test_features = torch.stack([sample['features'] for sample in test_samples]).to(device)
    
    print("\nOriginal texts for test samples:")
    for idx, sample in zip(test_indices, test_samples):
        print(f"Sample {idx}: {sample['text']}")
    
    # Create directory for model checkpoints
    os.makedirs('models', exist_ok=True)
    
    for epoch in range(num_epochs):
        try:
            loss = train_epoch(
                model, dataloader, criterion, optimizer, device,
                teacher_forcing_ratio=0.9,
                epoch=epoch, total_epochs=num_epochs
            )
            losses.append(loss)
            
            scheduler.step(loss)
            print(f"\nEpoch {epoch + 1}, Loss: {loss:.4f}")
            
            # Generate examples
            print("\nGenerated examples:")
            model.eval()
            with torch.no_grad():
                generated_tokens = model.generate(
                    test_features,
                    token_mappings=token_mappings
                )
                for idx, tokens in enumerate(generated_tokens):
                    # Map tokens back to original vocabulary
                    original_tokens = [token_mappings['reverse_token_id_map'][t.item()] for t in tokens]
                    generated_text = tokenizer.decode(original_tokens, skip_special_tokens=False)
                    print(f"Sample {test_indices[idx]}: {generated_text}")
            model.train()
            
            # Save checkpoint if it's the best model so far
            if loss < best_loss:
                best_loss = loss
                config = {
                    'feature_dim': feature_dim,
                    'vocab_size': reduced_vocab_size,
                    'embedding_dim': embedding_dim,
                    'hidden_dim': hidden_dim,
                    'n_layers': 1,
                    'dropout': 0.2,
                    'epoch': epoch,
                    'loss': loss
                }
                
                save_model_with_metadata(
                    model=model,
                    path=os.path.join('models', 'best_gru_model.pt'),
                    token_mappings=token_mappings,
                    config=config
                )
                
        except Exception as e:
            print(f"Error during epoch {epoch + 1}: {str(e)}")
            continue
    
    return model, token_mappings, losses

# Usage example:
# Note that we now call train_model differently, passing dataset and tokenizer directly
model, token_mappings, losses = train_model(
    dataset=clean_dataset,
    tokenizer=tokenizer,
    max_params=30_000_000,  # 3 million parameter limit
    num_epochs=1, # tends to overfit after 1 epoch
    batch_size=128,
    learning_rate=1e-3
)

Using device: cuda
Analyzing vocabulary usage to reduce model size...


Scanning token usage:   0%|          | 0/28768 [00:00<?, ?it/s]

Reduced vocabulary from 30,007 to 444 tokens (1.5%)
Initial parameter count: 1,531,904
Model has 1,609,236 trainable parameters

Original texts for test samples:
Sample 0: Wetter heute, <date> In <city> stören am Morgen nur einzelne Wolken den sonst blauen Himmel bei Temperaturen von 23 <temp>. Mittags stören nur einzelne Wolken den sonst blauen Himmel und die Temperaturen erreichen 30 <temp>. Abends gibt es in <city> lockere Bewölkung bei Temperaturen von 24 bis 26 <temp>. In der Nacht bedecken einzelne Wolken den Himmel und die Werte gehen auf 22 <temp> zurück. Die gefühlten Temperaturen liegen bei 23 bis 33 <temp>. <city> liegt in der Region <ne>. Dort finden Sie eine Wettervorhersage für die gesamte Region.
Sample 14384: Wetter heute, <date> In <city> ist es morgens sonnig und die Temperatur liegt bei 21 <temp>. Im Laufe des Mittags scheint die Sonne bei blauem Himmel und das Thermometer klettert auf 27 <temp>. Abends gibt es in <city> einen wolkenlosen Himmel bei Temperaturen von 

Epoch 1/1:   0%|          | 0/225 [00:00<?, ?it/s]

  with torch.autograd.detect_anomaly():



Epoch 1, Loss: 1.1792

Generated examples:
Sample 0: Wetter heute, <date> In <city> regnet es Höchst erreicht 28 <temp>. Am Abend regnet es in <city> bei Werten von 26 bis zu 27 <temp>. Nachts gibt es Regen bei Tiefsttemperaturen von 22 <temp>. Mit einer Wahrscheinlichkeit von 15 <percentile>, ist über den Tag verteilt mit Niederschlagsmengen von 0. 01 bis 0. 07 <rainfall> zu rechnen. Gefühlt liegen die Temperaturen bei 26 bis 29 <temp>. <city> liegt in der Region <ne>. Dort finden Sie eine Wettervorhersage
Sample 14384: Wetter heute, <date> In <city> regnet es Höchst erreicht 37 <temp>. Am Abend gibt es in <city> Regen bei Werten von 21 bis zu 23 <temp>. Nachts gibt es Regen bei Tiefstwerten von 22 <temp>. Mit einer Wahrscheinlichkeit von 15 <percentile>, ist über den Tag verteilt mit Niederschlagsmengen von 0. 01 bis 0. 29 <rainfall> zu rechnen. Gefühlt liegen die Temperaturen bei 23 bis 30 <temp>. <city> liegt in der Region <ne>. Dort finden Sie eine Wettervorhersage für
Sample 287

### First Iteration
    embedding_dim=64,
    hidden_dim=128,

ca 9,806,800 parameter laut Claude

Results:


In [13]:
import torch
import re
from transformers import AutoTokenizer

def load_and_run_weather_model(model_path, dataset, sample_indices=None, num_samples=3, tokenizer=None, max_length=100):
    """
    Load the trained GRU weather model and generate text based on features from the provided dataset.
    
    Args:
        model_path (str): Path to the saved model file
        dataset: The cleaned dataset containing features
        sample_indices (list, optional): Specific indices to use from the dataset. If None, random samples are selected.
        num_samples (int): Number of random samples to generate if sample_indices is None
        tokenizer (AutoTokenizer, optional): Tokenizer to use for decoding. If None, will load BERT German tokenizer.
        max_length (int): Maximum length of the generated sequence
        
    Returns:
        dict: Dictionary mapping sample indices to original and generated text
    """
    # Load tokenizer if not provided
    if tokenizer is None:
        tokenizer = AutoTokenizer.from_pretrained('bert-base-german-cased')
        # Add special tokens that were used during training
        special_tokens = {
            'additional_special_tokens': ['<city>','<temp>','<date>','<velocity>','<percentile>','<rainfall>', '<ne>']
        }
        tokenizer.add_special_tokens(special_tokens)
    
    # Set device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    # Load model and metadata
    save_data = torch.load(model_path, map_location=device)
    token_mappings = save_data['token_mappings']
    config = save_data['config']
    
    # Create model with saved configuration
    
    # Create model with saved configuration
    model = WeatherGRU(
        feature_dim=config['feature_dim'],
        vocab_size=config['vocab_size'],
        embedding_dim=config['embedding_dim'],
        hidden_dim=config['hidden_dim'],
        n_layers=config['n_layers'],
        dropout=config['dropout']
    ).to(device)
    
    # Load state dict
    model.load_state_dict(save_data['model_state_dict'])
    model.eval()
    
    # Select samples from the dataset
    if sample_indices is None:
        import random
        # Choose random indices
        sample_indices = random.sample(range(len(dataset)), min(num_samples, len(dataset)))
    
    # Extract features and original texts
    samples = [dataset[i] for i in sample_indices]
    features = torch.stack([sample['features'] for sample in samples]).to(device)
    original_texts = [sample['text'] for sample in samples]
    
    # Generate text
    with torch.no_grad():
        generated_tokens = model.generate(
            features,
            max_length=max_length,
            token_mappings=token_mappings
        )
        
        # Convert tokens back to text
        results = {}
        for idx, (tokens, original) in enumerate(zip(generated_tokens, original_texts)):
            # Map tokens back to original vocabulary
            original_tokens = [token_mappings['reverse_token_id_map'][t.item()] for t in tokens]
            text = tokenizer.decode(original_tokens, skip_special_tokens=False)
            # Clean up text (remove extra tokens or artifacts if needed)
            text = re.sub(r'\[CLS\]|\[SEP\]', '', text).strip()
            
            sample_idx = sample_indices[idx]
            results[sample_idx] = {
                'original': original,
                'generated': text
            }
    
    return results

# Example usage:
results = load_and_run_weather_model("models/best_gru_model.pt", clean_dataset)
for idx, data in results.items():
    print(f"Sample {idx}:")
    print(f"Original: {data['original']}")
    print(f"Generated: {data['generated']}")
    print()

Using device: cuda


  save_data = torch.load(model_path, map_location=device)


Sample 24299:
Original: Wetter heute, <date> In <city> kann sich morgens die Sonne nicht durchsetzen und es bleibt bedeckt bei Werten von -1 <temp>. Mittags bleibt die Wolkendecke geschlossen und die Höchstwerte liegen bei 1 <temp>. Abends ist es in <city> bedeckt bei Temperaturen von -1 <temp>. Nachts ist es überwiegend dicht bewölkt bei Werten von -1 <temp>. Böen können Geschwindigkeiten zwischen 3 und 18 <velocity> erreichen. Die gefühlten Temperaturen liegen bei -3 bis 2 <temp>.
Generated: [PAD] Wetter heute, <date> In <city> kann sich morgens die Sonne nicht durchsetzen und es bleibt bedeckt bei Werten von 5 <temp>. Mittags gibt sich leichte Wolken und das Thermometer klettert auf - <temp>. Am Abend gibt es in <city> lockere Bewölkung bei Werten von 21 bis zu 27 <temp>. Nachts gibt es einen wolkenlosen Himmel bei Tiefstwerten von 23 <temp>. Die Wahrscheinlichkeit für Niederschläge liegt bei 10 <percentile> und es ist mit einer maximalen Niederschlagsmenge

Sample 9012:
Original: W