In [None]:
# Install required packages for real dataset processing
!pip install -q xarray netcdf4 scipy einops kaggle
!pip install -q torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
from sklearn.preprocessing import StandardScaler, MinMaxScaler, LabelEncoder
import xarray as xr
import os
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"CUDA Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")


In [None]:
# Setup data directories
data_dir = Path('../data')
data_dir.mkdir(exist_ok=True)

# Download datasets using Kaggle API
# Note: You need to setup kaggle.json credentials first
# Place your kaggle.json in ~/.kaggle/ or setup environment variables

print("Setting up Kaggle datasets...")
print("Please ensure your Kaggle API credentials are configured.")
print("Instructions: https://github.com/Kaggle/kaggle-api#api-credentials")

# Dataset URLs for reference:
shoreline_dataset = "globalshorelines/data"
marine_dataset = "atharvasoundankar/shifting-seas-ocean-climate-and-marine-life-dataset"

print(f"Dataset 1: {shoreline_dataset}")
print(f"Dataset 2: {marine_dataset}")


In [None]:
class HybridCNNLSTM(nn.Module):
    """
    SOTA Hybrid CNN-LSTM model for coastal erosion prediction
    Combines spatial feature extraction with temporal sequence modeling
    """
    def __init__(self, n_features, n_classes, cnn_filters=[32, 64, 128], 
                 lstm_hidden=128, lstm_layers=2, dropout=0.1):
        super(HybridCNNLSTM, self).__init__()
        
        # CNN feature extractor for spatial patterns
        cnn_layers = []
        in_channels = n_features
        
        for filters in cnn_filters:
            cnn_layers.extend([
                nn.Conv1d(in_channels, filters, kernel_size=3, padding=1),
                nn.BatchNorm1d(filters),
                nn.ReLU(),
                nn.Dropout(dropout),
                nn.MaxPool1d(kernel_size=2, stride=1, padding=1)  # Mild downsampling
            ])
            in_channels = filters
            
        self.cnn_features = nn.Sequential(*cnn_layers)
        
        # Adaptive pooling to ensure consistent sequence length
        self.adaptive_pool = nn.AdaptiveAvgPool1d(84)  # Reduce sequence length
        
        # Bidirectional LSTM for temporal modeling
        self.lstm = nn.LSTM(
            input_size=cnn_filters[-1],
            hidden_size=lstm_hidden,
            num_layers=lstm_layers,
            batch_first=True,
            dropout=dropout if lstm_layers > 1 else 0,
            bidirectional=True
        )
        
        # Multi-head attention mechanism for LSTM outputs
        lstm_output_dim = lstm_hidden * 2  # bidirectional
        self.attention = nn.MultiheadAttention(
            embed_dim=lstm_output_dim,
            num_heads=8,
            dropout=dropout,
            batch_first=True
        )
        
        # Layer normalization for attention
        self.layer_norm = nn.LayerNorm(lstm_output_dim)
        
        # Advanced classification head
        self.classifier = nn.Sequential(
            nn.Linear(lstm_output_dim, lstm_output_dim // 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(lstm_output_dim // 2, lstm_output_dim // 4),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(lstm_output_dim // 4, n_classes)
        )
        
    def forward(self, x):
        # x shape: [batch_size, seq_len, n_features]
        batch_size, seq_len, n_features = x.shape
        
        # Transpose for CNN: [batch_size, n_features, seq_len]
        x = x.transpose(1, 2)
        
        # CNN feature extraction
        cnn_out = self.cnn_features(x)  # [batch_size, filters, seq_len']
        
        # Adaptive pooling to manage sequence length
        cnn_out = self.adaptive_pool(cnn_out)  # [batch_size, filters, pooled_len]
        
        # Transpose back for LSTM: [batch_size, seq_len', filters]
        cnn_out = cnn_out.transpose(1, 2)
        
        # LSTM processing
        lstm_out, (hidden, cell) = self.lstm(cnn_out)  # [batch_size, seq_len', lstm_hidden*2]
        
        # Multi-head attention
        attended_out, attention_weights = self.attention(lstm_out, lstm_out, lstm_out)
        
        # Residual connection and layer normalization
        attended_out = self.layer_norm(attended_out + lstm_out)
        
        # Global average pooling over time dimension
        pooled_output = attended_out.mean(dim=1)  # [batch_size, lstm_hidden*2]
        
        # Classification
        output = self.classifier(pooled_output)
        
        return output

print("✅ SOTA Hybrid CNN-LSTM model implemented")


In [None]:
class KaggleDatasetProcessor:
    """
    Unified processor for Global Shorelines and Shifting Seas datasets
    """
    def __init__(self, data_dir, sequence_length=48):
        self.data_dir = Path(data_dir)
        self.sequence_length = sequence_length
        self.shoreline_data = None
        self.marine_data = None
        self.scaler = StandardScaler()
        
    def load_datasets(self):
        """Load both Kaggle datasets or create synthetic alternatives"""
        print("🔄 Loading datasets...")
        
        # Try to load real datasets, fallback to synthetic
        self.shoreline_data = self._load_shoreline_data()
        self.marine_data = self._load_marine_data()
        
        print("✅ Datasets loaded successfully")
        return self.shoreline_data, self.marine_data
    
    def _load_shoreline_data(self):
        """Load Global Shorelines NetCDF data"""
        netcdf_files = list(self.data_dir.glob("*.nc"))
        
        if not netcdf_files:
            print("⚠️ No NetCDF files found. Creating synthetic shoreline data...")
            return self._create_synthetic_shoreline_data()
        
        try:
            nc_file = netcdf_files[0]
            print(f"📁 Loading {nc_file.name}...")
            
            data = xr.open_dataset(nc_file)
            print(f"✅ Loaded shoreline data: {dict(data.dims)}")
            return data
            
        except Exception as e:
            print(f"⚠️ Error loading NetCDF: {e}")
            return self._create_synthetic_shoreline_data()
    
    def _create_synthetic_shoreline_data(self):
        """Create realistic synthetic shoreline data"""
        print("🔧 Creating synthetic shoreline data...")
        
        # 14-year monthly data (2000-2013)
        time_range = pd.date_range('2000-01-01', '2013-12-31', freq='M')
        
        # Global coastal coordinates
        lat = np.linspace(-60, 70, 40)
        lon = np.linspace(-180, 180, 80)
        
        np.random.seed(42)
        
        # Realistic shoreline change patterns
        n_time, n_lat, n_lon = len(time_range), len(lat), len(lon)
        
        # Base shoreline position changes
        shoreline_change = np.zeros((n_time, n_lat, n_lon))
        
        for i, timestamp in enumerate(time_range):
            # Seasonal patterns
            seasonal = 2 * np.sin(2 * np.pi * timestamp.month / 12)
            
            # Long-term sea level rise
            trend = 0.15 * (timestamp.year - 2000)
            
            # Storm events (random)
            storm_intensity = np.random.exponential(0.1, (n_lat, n_lon))
            storm_mask = storm_intensity > np.percentile(storm_intensity, 95)
            
            # Combine effects
            change = (seasonal + trend + 
                     5 * storm_mask * np.random.normal(1, 0.3, (n_lat, n_lon)) +
                     np.random.normal(0, 1, (n_lat, n_lon)))
            
            shoreline_change[i] = change
        
        # Create comprehensive dataset
        data = xr.Dataset({
            'shoreline_position': (['time', 'lat', 'lon'], 
                                 np.cumsum(shoreline_change, axis=0)),
            'erosion_rate': (['time', 'lat', 'lon'], 
                           np.gradient(shoreline_change, axis=0)),
            'wave_energy': (['time', 'lat', 'lon'], 
                          np.abs(shoreline_change) * np.random.uniform(0.5, 2, shoreline_change.shape)),
            'sea_level': (['time', 'lat', 'lon'],
                        0.15 * np.arange(n_time).reshape(-1, 1, 1) + 
                        np.random.normal(0, 0.1, shoreline_change.shape))
        }, coords={
            'time': time_range,
            'lat': lat,
            'lon': lon
        })
        
        print(f"✅ Created synthetic shoreline data: {dict(data.dims)}")
        return data
    
    def _load_marine_data(self):
        """Load Shifting Seas marine dataset"""
        csv_files = list(self.data_dir.glob("*.csv"))
        
        if not csv_files:
            print("⚠️ No CSV files found. Creating synthetic marine data...")
            return self._create_synthetic_marine_data()
        
        try:
            # Look for marine/ocean related CSV files
            marine_files = [f for f in csv_files if any(keyword in f.name.lower() 
                          for keyword in ['ocean', 'marine', 'sea', 'climate'])]
            
            if marine_files:
                csv_file = marine_files[0]
                print(f"📁 Loading {csv_file.name}...")
                
                data = pd.read_csv(csv_file)
                print(f"✅ Loaded marine data: {data.shape}")
                return data
            else:
                return self._create_synthetic_marine_data()
                
        except Exception as e:
            print(f"⚠️ Error loading CSV: {e}")
            return self._create_synthetic_marine_data()
    
    def _create_synthetic_marine_data(self):
        """Create realistic synthetic marine/climate data"""
        print("🔧 Creating synthetic marine climate data...")
        
        # 14 years of daily data
        date_range = pd.date_range('2000-01-01', '2013-12-31', freq='D')
        n_days = len(date_range)
        
        np.random.seed(42)
        
        # Create realistic oceanographic time series
        data = pd.DataFrame({
            'date': date_range,
            'sea_surface_temp': 15 + 8 * np.sin(2 * np.pi * np.arange(n_days) / 365.25) + 
                               np.random.normal(0, 1, n_days),
            'salinity': 35 + 2 * np.sin(2 * np.pi * np.arange(n_days) / 365.25) + 
                       np.random.normal(0, 0.5, n_days),
            'ph_level': 8.1 + 0.3 * np.sin(2 * np.pi * np.arange(n_days) / 365.25) +
                       np.random.normal(0, 0.1, n_days),
            'dissolved_oxygen': 8 + 2 * np.sin(2 * np.pi * np.arange(n_days) / 365.25) +
                              np.random.normal(0, 0.5, n_days),
            'wave_height': 1.5 + 0.8 * np.sin(2 * np.pi * np.arange(n_days) / 365.25) +
                          np.random.exponential(0.5, n_days),
            'current_velocity': 0.3 + 0.2 * np.sin(2 * np.pi * np.arange(n_days) / 365.25) +
                              np.random.exponential(0.2, n_days),
            'wind_speed': 8 + 5 * np.sin(2 * np.pi * np.arange(n_days) / 365.25) +
                         np.random.exponential(3, n_days),
            'precipitation': np.random.exponential(2, n_days),
            'atmospheric_pressure': 1013 + 10 * np.sin(2 * np.pi * np.arange(n_days) / 365.25) +
                                  np.random.normal(0, 5, n_days)
        })
        
        # Add derived features
        data['year'] = data['date'].dt.year
        data['month'] = data['date'].dt.month
        data['day_of_year'] = data['date'].dt.dayofyear
        
        print(f"✅ Created synthetic marine data: {data.shape}")
        return data

# Initialize processor
processor = KaggleDatasetProcessor(data_dir, sequence_length=48)
shoreline_data, marine_data = processor.load_datasets()


In [None]:
class ErosionDataset(Dataset):
    """PyTorch Dataset for erosion prediction"""
    def __init__(self, sequences, targets):
        self.sequences = torch.FloatTensor(sequences)
        self.targets = torch.LongTensor(targets)
        
    def __len__(self):
        return len(self.sequences)
        
    def __getitem__(self, idx):
        return self.sequences[idx], self.targets[idx]

def create_training_sequences(processor):
    """Create training sequences from both datasets"""
    print("🔄 Creating training sequences...")
    
    # Process shoreline data for spatial-temporal patterns
    shoreline_sequences = []
    marine_sequences = []
    all_targets = []
    
    # Extract shoreline sequences
    if hasattr(processor.shoreline_data, 'erosion_rate'):
        shore_data = processor.shoreline_data
        time_len = len(shore_data.time)
        
        # Sample spatial locations
        lat_samples = np.linspace(0, len(shore_data.lat)-1, 15, dtype=int)
        lon_samples = np.linspace(0, len(shore_data.lon)-1, 30, dtype=int)
        
        for lat_idx in lat_samples[:5]:  # Reduced for memory
            for lon_idx in lon_samples[:10]:
                for start_idx in range(0, time_len - processor.sequence_length, 6):
                    end_idx = start_idx + processor.sequence_length
                    
                    # Extract features at this location
                    features = []
                    for t in range(start_idx, end_idx):
                        timestamp = shore_data.time.values[t]
                        month = pd.to_datetime(timestamp).month
                        
                        feature_vector = [
                            float(shore_data.shoreline_position[t, lat_idx, lon_idx]),
                            float(shore_data.erosion_rate[t, lat_idx, lon_idx]),
                            float(shore_data.wave_energy[t, lat_idx, lon_idx]),
                            float(shore_data.sea_level[t, lat_idx, lon_idx]),
                            lat_idx / len(shore_data.lat),  # Normalized coordinates
                            lon_idx / len(shore_data.lon),
                            np.sin(2 * np.pi * month / 12),
                            np.cos(2 * np.pi * month / 12)
                        ]
                        features.append(feature_vector)
                    
                    shoreline_sequences.append(features)
                    
                    # Create target based on future erosion risk
                    if end_idx < time_len - 3:
                        future_erosion = float(shore_data.erosion_rate[end_idx:end_idx+3, lat_idx, lon_idx].mean())
                        if future_erosion > 1.0:
                            target = 2  # High risk
                        elif future_erosion > 0.3:
                            target = 1  # Medium risk
                        else:
                            target = 0  # Low risk
                    else:
                        target = 0
                    
                    all_targets.append(target)
    
    # Process marine/climate data
    if isinstance(processor.marine_data, pd.DataFrame):
        marine_df = processor.marine_data.copy()
        
        # Resample to monthly for consistency
        marine_df['date'] = pd.to_datetime(marine_df['date'])
        marine_monthly = marine_df.set_index('date').resample('M').mean()
        
        # Align with shoreline data timeframe
        marine_monthly = marine_monthly.loc['2000-01':'2013-12']
        
        # Extract marine sequences
        marine_features = ['sea_surface_temp', 'salinity', 'ph_level', 'dissolved_oxygen',
                          'wave_height', 'current_velocity', 'wind_speed', 'precipitation',
                          'atmospheric_pressure']
        
        for start_idx in range(0, len(marine_monthly) - processor.sequence_length, 3):
            end_idx = start_idx + processor.sequence_length
            
            sequence_data = marine_monthly.iloc[start_idx:end_idx][marine_features].values
            marine_sequences.append(sequence_data.tolist())
    
    # Combine sequences (pad marine sequences to match shoreline)
    print(f"Shoreline sequences: {len(shoreline_sequences)}")
    print(f"Marine sequences: {len(marine_sequences)}")
    
    # Use shoreline sequences as primary, pad with marine features if available
    if shoreline_sequences:
        sequences = np.array(shoreline_sequences, dtype=np.float32)
        targets = np.array(all_targets, dtype=np.int64)
        
        # Add marine features if available
        if marine_sequences and len(marine_sequences) > 0:
            # Replicate marine patterns to match shoreline sequences
            marine_array = np.array(marine_sequences[:len(sequences)], dtype=np.float32)
            if marine_array.shape[0] == sequences.shape[0]:
                # Combine features
                combined_sequences = np.concatenate([sequences, marine_array], axis=-1)
                sequences = combined_sequences
    
    print(f"✅ Created {len(sequences)} sequences with shape {sequences.shape}")
    print(f"Target distribution: {np.bincount(targets)}")
    
    return sequences, targets

# Create sequences
sequences, targets = create_training_sequences(processor)

# Data normalization
scaler = StandardScaler()
n_samples, seq_len, n_features = sequences.shape
sequences_reshaped = sequences.reshape(-1, n_features)
sequences_normalized = scaler.fit_transform(sequences_reshaped)
sequences = sequences_normalized.reshape(n_samples, seq_len, n_features)

# Create datasets
dataset = ErosionDataset(sequences, targets)

# Split data
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

# Create data loaders
batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

print(f"📊 Training set: {len(train_dataset)} samples")
print(f"📊 Validation set: {len(val_dataset)} samples")
print(f"📊 Features per timestep: {n_features}")
print(f"📊 Sequence length: {seq_len}")


In [None]:
def train_hybrid_model(model, train_loader, val_loader, epochs=50, lr=1e-3):
    """Train the Hybrid CNN-LSTM model with comprehensive evaluation"""
    model = model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='max', patience=7, factor=0.5, verbose=True
    )
    
    # Training history
    history = {
        'train_loss': [], 'val_loss': [], 'val_accuracy': [], 'learning_rates': []
    }
    
    best_val_acc = 0.0
    patience_counter = 0
    patience = 10
    
    print("🚀 Starting Hybrid CNN-LSTM training...")
    print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
    
    for epoch in range(epochs):
        # Training phase
        model.train()
        train_loss = 0.0
        train_samples = 0
        
        for batch_idx, (data, targets) in enumerate(train_loader):
            data, targets = data.to(device), targets.to(device)
            
            optimizer.zero_grad()
            outputs = model(data)
            loss = criterion(outputs, targets)
            loss.backward()
            
            # Gradient clipping for stability
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            
            optimizer.step()
            
            train_loss += loss.item() * data.size(0)
            train_samples += data.size(0)
            
            if batch_idx % 10 == 0:
                print(f'Epoch {epoch+1}/{epochs}, Batch {batch_idx}/{len(train_loader)}, '
                      f'Loss: {loss.item():.4f}')
        
        # Validation phase
        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_samples = 0
        all_preds = []
        all_labels = []
        
        with torch.no_grad():
            for data, targets in val_loader:
                data, targets = data.to(device), targets.to(device)
                
                outputs = model(data)
                loss = criterion(outputs, targets)
                
                val_loss += loss.item() * data.size(0)
                val_samples += data.size(0)
                
                # Predictions
                pred = outputs.argmax(dim=1)
                val_correct += pred.eq(targets).sum().item()
                
                all_preds.extend(pred.cpu().numpy())
                all_labels.extend(targets.cpu().numpy())
        
        # Calculate metrics
        avg_train_loss = train_loss / train_samples
        avg_val_loss = val_loss / val_samples
        val_accuracy = val_correct / val_samples
        
        # Update learning rate
        scheduler.step(val_accuracy)
        current_lr = optimizer.param_groups[0]['lr']
        
        # Store history
        history['train_loss'].append(avg_train_loss)
        history['val_loss'].append(avg_val_loss)
        history['val_accuracy'].append(val_accuracy)
        history['learning_rates'].append(current_lr)
        
        print(f'Epoch {epoch+1}/{epochs}:')
        print(f'  Train Loss: {avg_train_loss:.4f}')
        print(f'  Val Loss: {avg_val_loss:.4f}')
        print(f'  Val Accuracy: {val_accuracy:.4f}')
        print(f'  Learning Rate: {current_lr:.6f}')
        
        # Early stopping
        if val_accuracy > best_val_acc:
            best_val_acc = val_accuracy
            patience_counter = 0
            # Save best model
            torch.save(model.state_dict(), 'best_hybrid_cnn_lstm.pth')
            print(f'  ✅ New best validation accuracy: {best_val_acc:.4f}')
        else:
            patience_counter += 1
            
        if patience_counter >= patience:
            print(f'Early stopping after {epoch+1} epochs')
            break
        
        print('-' * 60)
    
    # Load best model
    model.load_state_dict(torch.load('best_hybrid_cnn_lstm.pth'))
    
    # Final evaluation
    model.eval()
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for data, targets in val_loader:
            data, targets = data.to(device), targets.to(device)
            outputs = model(data)
            pred = outputs.argmax(dim=1)
            all_preds.extend(pred.cpu().numpy())
            all_labels.extend(targets.cpu().numpy())
    
    # Generate comprehensive evaluation report
    print("\n" + "="*60)
    print("🎯 FINAL MODEL EVALUATION")
    print("="*60)
    
    print(f"Best Validation Accuracy: {best_val_acc:.4f}")
    print(f"Final Validation Accuracy: {accuracy_score(all_labels, all_preds):.4f}")
    
    print("\nClassification Report:")
    print(classification_report(all_labels, all_preds, 
                              target_names=['Low Risk', 'Medium Risk', 'High Risk']))
    
    # Confusion Matrix
    cm = confusion_matrix(all_labels, all_preds)
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=['Low Risk', 'Medium Risk', 'High Risk'],
                yticklabels=['Low Risk', 'Medium Risk', 'High Risk'])
    plt.title('Confusion Matrix - Hybrid CNN-LSTM')
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    plt.show()
    
    # Training curves
    fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 10))
    
    # Loss curves
    ax1.plot(history['train_loss'], label='Training Loss')
    ax1.plot(history['val_loss'], label='Validation Loss')
    ax1.set_title('Loss Curves')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    ax1.legend()
    ax1.grid(True)
    
    # Accuracy curve
    ax2.plot(history['val_accuracy'], label='Validation Accuracy', color='green')
    ax2.set_title('Validation Accuracy')
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Accuracy')
    ax2.legend()
    ax2.grid(True)
    
    # Learning rate
    ax3.plot(history['learning_rates'], label='Learning Rate', color='red')
    ax3.set_title('Learning Rate Schedule')
    ax3.set_xlabel('Epoch')
    ax3.set_ylabel('Learning Rate')
    ax3.set_yscale('log')
    ax3.legend()
    ax3.grid(True)
    
    # Feature importance (attention visualization)
    ax4.text(0.5, 0.5, f'Model Architecture:\\n\\n' + 
             f'• CNN Filters: [32, 64, 128]\\n' +
             f'• LSTM Hidden: 128 (Bidirectional)\\n' +
             f'• Multi-head Attention: 8 heads\\n' +
             f'• Parameters: {sum(p.numel() for p in model.parameters()):,}\\n' +
             f'• Input Features: {n_features}\\n' +
             f'• Sequence Length: {seq_len}',
             transform=ax4.transAxes, fontsize=12,
             verticalalignment='center', horizontalalignment='center',
             bbox=dict(boxstyle='round', facecolor='lightblue', alpha=0.8))
    ax4.set_title('Model Summary')
    ax4.axis('off')
    
    plt.tight_layout()
    plt.show()
    
    return model, history

# Initialize and train model
print("🏗️ Initializing SOTA Hybrid CNN-LSTM model...")
model = HybridCNNLSTM(
    n_features=n_features,
    n_classes=3,  # Low, Medium, High erosion risk
    cnn_filters=[32, 64, 128],
    lstm_hidden=128,
    lstm_layers=2,
    dropout=0.15
)

print(f"Model created with {sum(p.numel() for p in model.parameters()):,} parameters")

# Train the model
trained_model, training_history = train_hybrid_model(
    model, train_loader, val_loader, epochs=50, lr=1e-3
)
