In [21]:
# ==============================================================================
# CELL 1: Load Trained Model from Checkpoint
# ==============================================================================

import os
import sys
import torch
import torch.nn as nn
import numpy as np
import wfdb

# Add ESI to path if needed
esi_path = "/kaggle/input/esi-repo/ESI"  # Adjust based on your ESI input dataset
if os.path.exists(esi_path) and esi_path not in sys.path:
    sys.path.append(esi_path)

from model.convnextv2 import convnextv2_base

# Set device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"üñ•Ô∏è Using device: {device}")

# ==============================================================================
# Model Architecture (MUST match training)
# ==============================================================================

class MultiScaleTemporalHead(nn.Module):
    def __init__(self, input_dim=1024, hidden_dim=256, lstm_layers=2, dropout=0.3):
        super().__init__()
        
        self.lstm = nn.LSTM(
            input_dim, hidden_dim,
            num_layers=2,
            batch_first=True,
            bidirectional=True,
            dropout=dropout if lstm_layers > 1 else 0
        )
        
        self.attention = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, 1),
            nn.Softmax(dim=1)
        )
        
        self.global_pool = nn.AdaptiveAvgPool1d(1)
        self.max_pool = nn.AdaptiveMaxPool1d(1)
        
        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim * 2 + input_dim * 2, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, 64),
            nn.LayerNorm(64),
            nn.ReLU(),
            nn.Dropout(dropout * 0.5),
            nn.Linear(64, 1)
        )
        
        self._init_weights()
    
    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight, gain=0.5)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.LSTM):
                for name, param in m.named_parameters():
                    if 'weight' in name:
                        nn.init.orthogonal_(param)
                    elif 'bias' in name:
                        nn.init.constant_(param, 0)
    
    def forward(self, x):
        x = x + 1e-8
        x_seq = x.unsqueeze(1)
        lstm_out, _ = self.lstm(x_seq)
        attn_weights = self.attention(lstm_out)
        attended = (lstm_out * attn_weights).sum(dim=1)
        
        x_unsqueezed = x.unsqueeze(2)
        avg_pooled = self.global_pool(x_unsqueezed).squeeze(2)
        max_pooled = self.max_pool(x_unsqueezed).squeeze(2)
        
        combined = torch.cat([attended, avg_pooled, max_pooled], dim=1)
        combined = torch.clamp(combined, -10, 10)
        logits = self.classifier(combined)
        
        return logits


class ESIApneaDetector(nn.Module):
    def __init__(self, checkpoint_path, config):
        super().__init__()
        
        self.backbone = convnextv2_base(
            in_chans=12,
            num_classes=5,
            return_embedding=True
        )
        
        if os.path.exists(checkpoint_path):
            try:
                from safetensors.torch import load_file
                state_dict = load_file(checkpoint_path, device="cpu")
                clean_weights = {
                    k.replace("img_encoder.", ""): v
                    for k, v in state_dict.items()
                    if "img_encoder." in k
                }
                self.backbone.load_state_dict(clean_weights, strict=False)
                print("‚úÖ ESI weights loaded")
            except Exception as e:
                print(f"‚ö†Ô∏è Could not load ESI weights: {e}")
        
        self.lead_projection = nn.Sequential(
            nn.Conv1d(1, 12, kernel_size=1),
            nn.BatchNorm1d(12)
        )
        
        nn.init.xavier_uniform_(self.lead_projection[0].weight, gain=0.5)
        
        for param in self.backbone.parameters():
            param.requires_grad = False
        
        self.head = MultiScaleTemporalHead(
            input_dim=config.ESI_EMBEDDING_DIM,
            hidden_dim=config.HIDDEN_DIM,
            lstm_layers=config.LSTM_LAYERS,
            dropout=config.DROPOUT
        )
    
    def forward(self, x):
        if torch.isnan(x).any():
            x = torch.nan_to_num(x, nan=0.0)
        
        x_multi = self.lead_projection(x)
        x_multi = torch.clamp(x_multi, -10, 10)
        x_multi = x_multi.permute(0, 2, 1)
        features = self.backbone(x_multi)
        
        if torch.isnan(features).any():
            features = torch.nan_to_num(features, nan=0.0)
        
        logits = self.head(features)
        return logits


# ==============================================================================
# Load Model Function
# ==============================================================================

def load_trained_model(checkpoint_path, esi_checkpoint_path, device):
    """
    Load trained model from checkpoint
    
    Args:
        checkpoint_path: Path to best_model.pt
        esi_checkpoint_path: Path to ESI model.safetensors
        device: torch device
    
    Returns:
        model, config, normalization_params
    """
    print("\n" + "="*70)
    print("üì• LOADING TRAINED MODEL")
    print("="*70)
    
    try:
        # Load checkpoint with weights_only=False for PyTorch 2.6+
        checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
        
        # Extract configuration
        model_config = checkpoint.get('config', {})
        normalization = checkpoint.get('normalization', {'mean': 0.0, 'std': 1.0})
        
        print(f"‚úÖ Checkpoint loaded from epoch {checkpoint.get('epoch', 'unknown')}")
        print(f"‚úÖ Validation loss: {checkpoint.get('val_loss', 'N/A')}")
        
        # Create config object
        class InferenceConfig:
            SEQ_LENGTH = model_config.get('seq_length', 6000)
            ESI_EMBEDDING_DIM = model_config.get('esi_embedding_dim', 1024)
            HIDDEN_DIM = model_config.get('hidden_dim', 256)
            LSTM_LAYERS = model_config.get('lstm_layers', 2)
            DROPOUT = model_config.get('dropout', 0.3)
            DEVICE = device
        
        config = InferenceConfig()
        
        # Initialize model
        print("\nüèóÔ∏è Initializing model architecture...")
        model = ESIApneaDetector(esi_checkpoint_path, config)
        
        # Load trained weights
        model.load_state_dict(checkpoint['model_state_dict'])
        model.to(device)
        model.eval()
        
        print(f"\nüìä Model Configuration:")
        print(f"   Sequence Length:    {config.SEQ_LENGTH}")
        print(f"   ESI Embedding Dim:  {config.ESI_EMBEDDING_DIM}")
        print(f"   Hidden Dimension:   {config.HIDDEN_DIM}")
        print(f"   LSTM Layers:        {config.LSTM_LAYERS}")
        print(f"\nüî¢ Normalization Parameters:")
        print(f"   Mean: {normalization['mean']:.6f}")
        print(f"   Std:  {normalization['std']:.6f}")
        print("="*70)
        
        return model, config, normalization
    
    except Exception as e:
        print(f"‚ùå Error loading model: {e}")
        import traceback
        traceback.print_exc()
        return None, None, None


# ==============================================================================
# Load Your Trained Model
# ==============================================================================

# Define paths
BEST_MODEL_PATH = "/kaggle/input/model-checkpoints-apnea/outputs/checkpoints/best_model.pt"
ESI_CHECKPOINT_PATH = "/kaggle/input/esi-model/model.safetensors"

# Load model
model, config, normalization_params = load_trained_model(
    BEST_MODEL_PATH,
    ESI_CHECKPOINT_PATH,
    device
)

if model is None:
    raise RuntimeError("Failed to load model!")

print("\n‚úÖ Model loaded successfully and ready for inference!")

üñ•Ô∏è Using device: cpu

üì• LOADING TRAINED MODEL
‚úÖ Checkpoint loaded from epoch 3
‚úÖ Validation loss: 0.41083619382345316

üèóÔ∏è Initializing model architecture...
‚úÖ ESI weights loaded

üìä Model Configuration:
   Sequence Length:    6000
   ESI Embedding Dim:  1024
   Hidden Dimension:   256
   LSTM Layers:        2

üî¢ Normalization Parameters:
   Mean: 0.000000
   Std:  1.000000

‚úÖ Model loaded successfully and ready for inference!


In [None]:
# ==============================================================================
# CELL 2: Flask API with Ngrok Deployment
# ==============================================================================

import os
import sys

# Install dependencies
print("üì¶ Installing dependencies...")
os.system("killall ngrok 2>/dev/null")
os.system('pip install -q flask flask-cors pyngrok')

from flask import Flask, request, jsonify
from flask_cors import CORS
from pyngrok import ngrok, conf
from kaggle_secrets import UserSecretsClient
import numpy as np
import torch
import wfdb
from werkzeug.utils import secure_filename
from datetime import datetime
import traceback

# ==============================================================================
# Ngrok Authentication
# ==============================================================================

print("\nüîê Setting up Ngrok...")
try:
    user_secrets = UserSecretsClient()
    NGROK_AUTH_TOKEN = user_secrets.get_secret("NGROK_AUTH_TOKEN")
    
    # Optional: Use static domain if you have one
    try:
        NGROK_STATIC_DOMAIN = user_secrets.get_secret("NGROK_STATIC_DOMAIN")
        USE_STATIC_DOMAIN = True
        print(f"‚úÖ Static domain found: {NGROK_STATIC_DOMAIN}")
    except:
        USE_STATIC_DOMAIN = False
        print("‚ÑπÔ∏è No static domain configured (using dynamic URL)")
    
    ngrok.set_auth_token(NGROK_AUTH_TOKEN)
    conf.get_default().region = "us"
    print("‚úÖ Ngrok authenticated")
    
except Exception as e:
    print(f"‚ùå Ngrok setup failed: {e}")
    print("\nüìù To fix this:")
    print("   1. Go to https://dashboard.ngrok.com/")
    print("   2. Copy your authtoken")
    print("   3. Add it as a Kaggle Secret:")
    print("      - Key: NGROK_AUTH_TOKEN")
    print("      - Value: <your-token>")
    raise

# ==============================================================================
# Flask App Setup
# ==============================================================================

app = Flask(__name__)
CORS(app)

UPLOAD_FOLDER = '/tmp/ecg_uploads'
ALLOWED_EXTENSIONS = {'hea', 'dat'}
os.makedirs(UPLOAD_FOLDER, exist_ok=True)

# ==============================================================================
# Helper Functions
# ==============================================================================

def allowed_file(filename):
    """Check if file extension is allowed"""
    return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS


def extract_demographics(record_path):
    """Extract patient demographics from WFDB header"""
    try:
        header = wfdb.rdheader(record_path)
        demographics = {
            'record_name': header.record_name,
            'sampling_frequency': f"{header.fs} Hz",
            'signal_length': f"{header.sig_len} samples",
            'duration': f"{header.sig_len / header.fs:.1f} seconds",
            'number_of_signals': header.n_sig,
        }
        
        # Extract additional info from comments
        if hasattr(header, 'comments') and header.comments:
            for comment in header.comments:
                comment_lower = comment.lower()
                if 'age' in comment_lower:
                    demographics['age'] = comment
                elif 'sex' in comment_lower or 'gender' in comment_lower:
                    demographics['gender'] = comment
        
        return demographics
    except Exception as e:
        return {'error': str(e)}


def generate_waveform_points(signal, target_points=800):
    """Generate SVG polyline points for waveform visualization"""
    try:
        # Downsample signal for visualization
        downsample_factor = max(1, len(signal) // target_points)
        signal_viz = signal[::downsample_factor]
        
        # Normalize to 0-100 range (inverted for SVG)
        signal_min, signal_max = signal_viz.min(), signal_viz.max()
        
        if abs(signal_max - signal_min) < 1e-6:
            signal_normalized = np.full_like(signal_viz, 50.0)
        else:
            signal_normalized = 100 - ((signal_viz - signal_min) / (signal_max - signal_min) * 80 + 10)
        
        # Create SVG points
        x_scale = 800 / len(signal_normalized)
        points = [f"{i * x_scale:.2f},{y:.2f}" for i, y in enumerate(signal_normalized)]
        
        return " ".join(points)
    except Exception as e:
        print(f"‚ö†Ô∏è Waveform generation error: {e}")
        return "0,50 800,50"


def preprocess_signal(signal, mean, std, seq_length=6000):
    """
    Preprocess ECG signal for model inference
    
    Args:
        signal: Raw ECG signal (numpy array)
        mean: Normalization mean
        std: Normalization std
        seq_length: Target sequence length
    
    Returns:
        Preprocessed signal tensor [1, 1, seq_length]
    """
    try:
        # Handle multi-channel signals (take first channel)
        if signal.ndim > 1:
            signal = signal[:, 0]
        
        # Pad or truncate to fixed length
        if len(signal) < seq_length:
            signal = np.pad(signal, (0, seq_length - len(signal)), mode='edge')
        elif len(signal) > seq_length:
            signal = signal[:seq_length]
        
        # Normalize using provided parameters
        signal_normalized = (signal - mean) / (std + 1e-8)
        
        # Convert to tensor with shape [batch, channels, length]
        signal_tensor = torch.tensor(signal_normalized, dtype=torch.float32).unsqueeze(0).unsqueeze(0)
        
        return signal_tensor
    
    except Exception as e:
        print(f"‚ùå Preprocessing error: {e}")
        return None


def run_inference(model, signal_tensor, device):
    """
    Run model inference and return prediction
    
    Args:
        model: Trained model
        signal_tensor: Preprocessed signal tensor
        device: torch device
    
    Returns:
        Dictionary with prediction results
    """
    try:
        model.eval()
        
        with torch.no_grad():
            output = model(signal_tensor.to(device))
            raw_logit = output.cpu().item()
            probability = torch.sigmoid(output).cpu().item()
        
        has_apnea = probability > 0.5
        
        # Calculate risk level
        if probability > 0.7:
            risk_level = "High Risk"
            risk_color = "danger"
        elif probability > 0.4:
            risk_level = "Moderate Risk"
            risk_color = "warning"
        else:
            risk_level = "Low Risk"
            risk_color = "success"
        
        result = {
            'has_apnea': bool(has_apnea),
            'probability': float(probability),
            'raw_logit': float(raw_logit),
            'risk_level': risk_level,
            'risk_color': risk_color,
            'diagnosis': 'Apnea Detected' if has_apnea else 'Normal Breathing'
        }
        
        print(f"\n{'='*50}")
        print(f"üîç INFERENCE RESULT:")
        print(f"{'='*50}")
        print(f"Diagnosis:    {result['diagnosis']}")
        print(f"Probability:  {probability:.4f}")
        print(f"Risk Level:   {risk_level}")
        print(f"{'='*50}\n")
        
        return result
    
    except Exception as e:
        print(f"‚ùå Inference error: {e}")
        traceback.print_exc()
        return None


# ==============================================================================
# API Endpoints
# ==============================================================================

@app.route('/health', methods=['GET'])
def health_check():
    """Health check endpoint"""
    return jsonify({
        'status': 'healthy',
        'model_loaded': model is not None,
        'device': str(device),
        'gpu_available': torch.cuda.is_available(),
        'normalization': normalization_params,
        'timestamp': datetime.utcnow().isoformat()
    })


@app.route('/predict', methods=['POST'])
def predict():
    """
    Main prediction endpoint
    
    Expected: POST request with files containing .hea and .dat files
    Returns: JSON with prediction results, demographics, and waveform data
    """
    try:
        # Validate model is loaded
        if model is None:
            return jsonify({
                'success': False,
                'error': 'Model not loaded'
            }), 500
        
        # Validate files in request
        if 'files' not in request.files:
            return jsonify({
                'success': False,
                'error': 'No files uploaded'
            }), 400
        
        files = request.files.getlist('files')
        
        # Find .hea and .dat files
        hea_file = next((f for f in files if f.filename.endswith('.hea')), None)
        dat_file = next((f for f in files if f.filename.endswith('.dat')), None)
        
        if not hea_file or not dat_file:
            return jsonify({
                'success': False,
                'error': 'Both .hea and .dat files required'
            }), 400
        
        # Save uploaded files
        base_name = secure_filename(hea_file.filename).rsplit('.', 1)[0]
        hea_path = os.path.join(UPLOAD_FOLDER, f"{base_name}.hea")
        dat_path = os.path.join(UPLOAD_FOLDER, f"{base_name}.dat")
        
        hea_file.save(hea_path)
        dat_file.save(dat_path)
        
        print(f"\nüìÅ Processing record: {base_name}")
        
        # Load ECG record
        record_path = os.path.join(UPLOAD_FOLDER, base_name)
        demographics = extract_demographics(record_path)
        
        # Read signal (first 6000 samples or configured length)
        signal, _ = wfdb.rdsamp(record_path, sampfrom=0, sampto=config.SEQ_LENGTH)
        
        # Generate waveform visualization
        raw_signal = signal[:, 0] if signal.ndim > 1 else signal
        waveform_points = generate_waveform_points(raw_signal)
        
        # Preprocess signal
        signal_tensor = preprocess_signal(
            signal,
            normalization_params['mean'],
            normalization_params['std'],
            config.SEQ_LENGTH
        )
        
        if signal_tensor is None:
            return jsonify({
                'success': False,
                'error': 'Signal preprocessing failed'
            }), 500
        
        # Run inference
        prediction_result = run_inference(model, signal_tensor, device)
        
        if prediction_result is None:
            return jsonify({
                'success': False,
                'error': 'Model inference failed'
            }), 500
        
        # Build response
        response = {
            'success': True,
            'prediction': prediction_result,
            'demographics': demographics,
            'waveform': waveform_points,
            'timestamp': datetime.utcnow().isoformat()
        }
        
        # Cleanup uploaded files
        try:
            os.remove(hea_path)
            os.remove(dat_path)
        except:
            pass
        
        print(f"‚úÖ Analysis complete: {prediction_result['diagnosis']}")
        
        return jsonify(response), 200
    
    except Exception as e:
        print(f"\n‚ùå ERROR in /predict:")
        traceback.print_exc()
        return jsonify({
            'success': False,
            'error': str(e)
        }), 500


@app.route('/', methods=['GET'])
def home():
    """Root endpoint with API information"""
    return jsonify({
        'message': 'Sleep Apnea Detection API',
        'version': '1.0',
        'endpoints': {
            '/health': 'GET - Check API health',
            '/predict': 'POST - Predict sleep apnea from ECG files'
        },
        'usage': 'Send .hea and .dat files to /predict endpoint'
    })


# ==============================================================================
# Start Server
# ==============================================================================

def start_server():
    """Start Flask server with Ngrok tunnel"""
    print("\n" + "="*70)
    print("üöÄ STARTING INFERENCE SERVER")
    print("="*70)
    
    if model is None:
        print("‚ùå ERROR: Model not loaded!")
        return
    
    try:
        port = 5000
        
        # Create ngrok tunnel
        if USE_STATIC_DOMAIN:
            # Use static domain (requires paid ngrok plan)
            public_url = ngrok.connect(
                port,
                bind_tls=True,
                hostname=NGROK_STATIC_DOMAIN
            )
            display_url = f"https://{NGROK_STATIC_DOMAIN}"
        else:
            # Use dynamic URL (free tier)
            public_url = ngrok.connect(port, bind_tls=True)
            display_url = public_url.public_url
        
        print("\n" + "="*70)
        print("‚úÖ SERVER READY")
        print("="*70)
        print(f"üåç Public URL:       {display_url}")
        print(f"üîç Health Check:     {display_url}/health")
        print(f"üì° Predict Endpoint: {display_url}/predict")
        print("="*70)
        print("\nüìù Example cURL command:")
        print(f"""
curl -X POST {display_url}/predict \\
  -F "files=@path/to/record.hea" \\
  -F "files=@path/to/record.dat"
        """)
        print("="*70 + "\n")
        
        # Start Flask app
        app.run(host='0.0.0.0', port=port, debug=False, use_reloader=False)
    
    except Exception as e:
        print(f"\n‚ùå Server start failed: {e}")
        traceback.print_exc()


# ==============================================================================
# Auto-start if model is loaded
# ==============================================================================

if 'model' in globals() and model is not None:
    print("\n‚úÖ Model detected - starting server...")
    start_server()
else:
    print("\n‚ö†Ô∏è Model not found. Please run the model loading cell first.")

üì¶ Installing dependencies...

üîê Setting up Ngrok...
‚úÖ Static domain found: merry-ewe-endlessly.ngrok-free.app
‚úÖ Ngrok authenticated

‚úÖ Model detected - starting server...

üöÄ STARTING INFERENCE SERVER

‚úÖ SERVER READY
üåç Public URL:       https://merry-ewe-endlessly.ngrok-free.app
üîç Health Check:     https://merry-ewe-endlessly.ngrok-free.app/health
üì° Predict Endpoint: https://merry-ewe-endlessly.ngrok-free.app/predict

üìù Example cURL command:

curl -X POST https://merry-ewe-endlessly.ngrok-free.app/predict \
  -F "files=@path/to/record.hea" \
  -F "files=@path/to/record.dat"
        

 * Serving Flask app '__main__'
 * Debug mode: off


 * Running on all addresses (0.0.0.0)
 * Running on http://127.0.0.1:5000
 * Running on http://172.19.2.2:5000
Press CTRL+C to quit



üìÅ Processing record: a01


127.0.0.1 - - [14/Dec/2025 16:10:35] "POST /predict HTTP/1.1" 200 -



üîç INFERENCE RESULT:
Diagnosis:    Apnea Detected
Probability:  0.5895
Risk Level:   Moderate Risk

‚úÖ Analysis complete: Apnea Detected

üìÅ Processing record: a03


127.0.0.1 - - [14/Dec/2025 16:10:51] "POST /predict HTTP/1.1" 200 -



üîç INFERENCE RESULT:
Diagnosis:    Normal Breathing
Probability:  0.1381
Risk Level:   Low Risk

‚úÖ Analysis complete: Normal Breathing

üìÅ Processing record: a07


127.0.0.1 - - [14/Dec/2025 16:11:12] "POST /predict HTTP/1.1" 200 -



üîç INFERENCE RESULT:
Diagnosis:    Apnea Detected
Probability:  0.6918
Risk Level:   Moderate Risk

‚úÖ Analysis complete: Apnea Detected
