# ECG 12-Lead Reconstruction - Google Colab Training

**Paper**: [AI-enhanced reconstruction of the 12-lead ECG via 3-leads](https://www.nature.com/articles/s41746-024-01193-7)

This notebook uses the **exact original codebase** with **PTB-XL dataset**.

## Steps:
1. Check GPU
2. Mount Google Drive & load PTB-XL dataset
3. Get project code
4. Load/Convert PTB-XL data
5. Train model
6. Test & Visualize

## 0. Check GPU

In [None]:
# Check GPU availability
!nvidia-smi

import torch
print(f"\nPyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    DEVICE = 'cuda:0'
else:
    print("WARNING: No GPU! Go to Runtime -> Change runtime type -> GPU")
    DEVICE = 'cpu'

## 1. Mount Google Drive & Load PTB-XL Dataset

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# Extract PTB-XL dataset from Google Drive
import os
from pathlib import Path

# Your dataset ZIP path on Google Drive
DRIVE_ZIP = '/content/drive/MyDrive/IMLE-Net-Project/data/ptb-xl-a-large-publicly-available-electrocardiography-dataset-1.0.1.zip'

# Extract to /content/data
!mkdir -p /content/data
!unzip -q -o "{DRIVE_ZIP}" -d /content/data/

# Set PTBXL_PATH
extracted = list(Path('/content/data').glob('ptb-xl-*'))
PTBXL_PATH = str(extracted[0]) if extracted else None
print(f"PTBXL_PATH = {PTBXL_PATH}")
!ls {PTBXL_PATH}

## 2. Get Project Code & Install Dependencies

In [None]:
# Clone from GitHub and install dependencies
import os
PROJECT_DIR = '/content/ecg_reconstruction'

if not os.path.exists(PROJECT_DIR):
    print("Cloning from GitHub...")
    !git clone https://github.com/scripps-research/ecg_reconstruction.git {PROJECT_DIR}
else:
    print(f"Project already exists at {PROJECT_DIR}")

# Install wfdb with compatible pandas version
!pip install -q "pandas<3.0" wfdb tqdm

# Fix NumPy 2.0 compatibility (np.infty -> np.inf)
import numpy as np
if not hasattr(np, 'infty'):
    np.infty = np.inf
    print("Applied NumPy 2.0 compatibility fix (np.infty = np.inf)")

%cd {PROJECT_DIR}
!ls -la

In [None]:
# Create general_ptbxl.py (adapts original code to use pickle instead of MongoDB)
# This version properly handles MongoDB $in operator for batch queries

general_ptbxl_code = '''
import json
import shutil
import pickle
import os
import pandas as pd
import numpy as np

class PickleCollection:
    """Mimics MongoDB collection interface using pickle files"""
    def __init__(self, data_list):
        # Store as list and create index for fast lookup
        self.data = data_list
        self.index = {item['ElementID']: item for item in data_list}
    
    def find_one(self, query):
        """Find a single document matching the query"""
        if query is None:
            return self.data[0] if self.data else None
        
        # Handle ElementID lookup
        if 'ElementID' in query:
            element_id = query['ElementID']
            item = self.index.get(element_id)
            if item:
                return self._convert_item(item)
            return None
        
        # Handle _id lookup (treat as ElementID)
        if '_id' in query:
            element_id = query['_id']
            item = self.index.get(element_id)
            if item:
                return self._convert_item(item)
            return None
        
        return None
    
    def find(self, query=None):
        """Find documents matching the query - supports MongoDB $in operator"""
        if query is None:
            return [self._convert_item(item) for item in self.data]
        
        results = []
        
        # Handle ElementID with $in operator (batch query)
        if 'ElementID' in query:
            value = query['ElementID']
            if isinstance(value, dict) and '$in' in value:
                # Batch lookup using $in operator
                element_ids = value['$in']
                for eid in element_ids:
                    item = self.index.get(eid)
                    if item:
                        results.append(self._convert_item(item))
                return results
            else:
                # Single lookup
                item = self.index.get(value)
                if item:
                    return [self._convert_item(item)]
                return []
        
        # Handle _id with $in operator
        if '_id' in query:
            value = query['_id']
            if isinstance(value, dict) and '$in' in value:
                element_ids = value['$in']
                for eid in element_ids:
                    item = self.index.get(eid)
                    if item:
                        results.append(self._convert_item(item))
                return results
            else:
                item = self.index.get(value)
                if item:
                    return [self._convert_item(item)]
                return []
        
        # Handle $or operator
        if '$or' in query:
            for sub_query in query['$or']:
                results.extend(self.find(sub_query))
            return results
        
        # Generic field matching
        for item in self.data:
            match = True
            for key, value in query.items():
                if item.get(key) != value:
                    match = False
                    break
            if match:
                results.append(self._convert_item(item))
        
        return results
    
    def _convert_item(self, item):
        """Convert numpy arrays to pd.Series for compatibility"""
        item_copy = dict(item)
        if 'lead' in item_copy:
            item_copy['lead'] = {
                k: pd.Series(v) if isinstance(v, np.ndarray) else v 
                for k, v in item_copy['lead'].items()
            }
        return item_copy
    
    def count_documents(self, query=None):
        return len(self.find(query))

_data_collection = None

def get_collection(database_params_file=None):
    """Load data from pickle instead of MongoDB"""
    global _data_collection
    if _data_collection is None:
        pickle_path = os.path.join(get_parent_folder(), 'Feature_map', 'Dataset', 'data_collection.pkl')
        print(f"Loading data collection from {pickle_path}...")
        with open(pickle_path, 'rb') as f:
            _data_collection = pickle.load(f)
        print(f"Loaded {len(_data_collection)} records")
    return PickleCollection(_data_collection)

def get_parent_folder():
    return "/content/Data/"

def remove_dir(folder: str):
    try:
        shutil.rmtree(folder)        
    except:
        pass

def get_twelve_keys():
    return ['I', 'II', 'III', 'aVL', 'aVR', 'aVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6']

def get_lead_keys(leads: str):
    if leads == 'limb':
        keys = ['I', 'II']
    elif leads == 'limb+comb(v3+v4)':
        keys = ['I', 'II', ['V3', 'V4']]
    elif leads == 'limb+v2+v4':
        keys = ['I', 'II', 'V2', 'V4']
    elif leads == 'full_limb':
        keys = ['I', 'II', 'III', 'aVL', 'aVR', 'aVF']
    elif leads == 'limb+v1':
        keys = ['I', 'II', 'V1']
    elif leads == 'limb+v2':
        keys = ['I', 'II', 'V2']
    elif leads == 'limb+v3':
        keys = ['I', 'II', 'V3']
    elif leads == 'limb+v4':
        keys = ['I', 'II', 'V4']
    elif leads == 'limb+v5':
        keys = ['I', 'II', 'V5']
    elif leads == 'limb+v6':
        keys = ['I', 'II', 'V6']
    elif leads == 'precordial':
        keys = ['V1', 'V2', 'V3', 'V4', 'V5', 'V6']
    elif leads == 'full':
        keys = ['I', 'II', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6']
    else:
        raise ValueError(f"Unknown leads: {leads}")
    return keys

def get_data_classes(dataset: str):
    # For PTB-XL, use 'other' to load ALL data without class filtering
    if dataset == 'other' or dataset == 'all':
        data_classes = ['other']
    elif dataset == 'infarct+other':
        data_classes = ['st_elevation_or_infarct', 'other']
    elif dataset == 'infarct+noninfarct':
        data_classes = ['st_elevation_or_infarct', 'non_st_elevation_or_infarct']
    else:
        data_classes = ['other']  # Default to 'other' for PTB-XL
    return data_classes

def get_detect_classes(detect_class: str):
    detect_classes = [detect_class]
    return detect_classes

def get_value_range():
    min_value = -2.5
    amplitude = 5.0
    wave_sample = 2500
    return min_value, amplitude, wave_sample
'''

with open('util_functions/general_ptbxl.py', 'w') as f:
    f.write(general_ptbxl_code)
print("Created util_functions/general_ptbxl.py")

## 3. Load or Convert PTB-XL Data

If data is already converted and saved to Google Drive, it will be loaded directly. Otherwise, conversion will run.

In [None]:
# Check if converted data exists on Drive
import os
import shutil

DATA_DIR = '/content/Data'
DRIVE_DATA_PATH = '/content/drive/MyDrive/ECG_Reconstruction/Data'

# Remove any broken symlinks
if os.path.islink(DATA_DIR):
    os.remove(DATA_DIR)

# Check if already converted and saved to Drive
if os.path.exists(f'{DRIVE_DATA_PATH}/Feature_map/Dataset/data_collection.pkl'):
    print("Found converted data on Drive! Copying...")
    !rm -rf {DATA_DIR}
    !cp -r '{DRIVE_DATA_PATH}' {DATA_DIR}
    print("Loaded from Drive (skipping conversion)")
    SKIP_CONVERSION = True
else:
    print("No converted data on Drive. Will convert from scratch.")
    !rm -rf {DATA_DIR}
    !mkdir -p {DATA_DIR}
    SKIP_CONVERSION = False

print(f"\nData directory: {DATA_DIR}")

In [None]:
# Convert PTB-XL data (only runs if not loaded from Drive)
if not SKIP_CONVERSION:
    import pickle
    import numpy as np
    import pandas as pd
    import wfdb
    from tqdm import tqdm
    from sklearn.model_selection import train_test_split
    import gc

    # Load database
    df = pd.read_csv(f'{PTBXL_PATH}/ptbxl_database.csv', index_col='ecg_id')
    df.scp_codes = df.scp_codes.apply(eval)
    print(f"Found {len(df)} records")

    lead_map = {'AVR': 'aVR', 'AVL': 'aVL', 'AVF': 'aVF'}
    os.makedirs(f'{DATA_DIR}/Feature_map/Dataset', exist_ok=True)

    # Process in chunks to avoid memory issues
    CHUNK_SIZE = 2000
    chunks = []

    for chunk_idx, start in enumerate(range(0, len(df), CHUNK_SIZE)):
        chunk_df = df.iloc[start:start+CHUNK_SIZE]
        chunk_data = []
        
        for ecg_id, row in tqdm(chunk_df.iterrows(), total=len(chunk_df), desc=f"Chunk {chunk_idx+1}"):
            try:
                record_path = f"{PTBXL_PATH}/{row['filename_hr']}"
                record = wfdb.rdrecord(record_path)
                signal = record.p_signal
                
                lead_dict = {}
                for i, name in enumerate(record.sig_name):
                    mapped = lead_map.get(name, name)
                    lead_dict[mapped] = (signal[:, i] * 1000).astype(np.float32)
                
                element = {
                    'ElementID': f'ptbxl_{ecg_id}',
                    'lead': lead_dict,
                    'patient_id': row['patient_id']
                }
                chunk_data.append(element)
            except:
                pass
        
        # Save chunk
        chunk_path = f'{DATA_DIR}/Feature_map/Dataset/chunk_{chunk_idx}.pkl'
        with open(chunk_path, 'wb') as f:
            pickle.dump(chunk_data, f, protocol=4)
        print(f"  Saved chunk {chunk_idx+1}: {len(chunk_data)} records")
        chunks.append(chunk_path)
        del chunk_data
        gc.collect()

    # Combine chunks
    print("Combining chunks...")
    all_data = []
    for chunk_path in chunks:
        with open(chunk_path, 'rb') as f:
            all_data.extend(pickle.load(f))
        os.remove(chunk_path)
        gc.collect()

    print(f"Total records: {len(all_data)}")

    # Save final file
    pkl_path = f'{DATA_DIR}/Feature_map/Dataset/data_collection.pkl'
    print("Saving final file...")
    with open(pkl_path, 'wb') as f:
        pickle.dump(all_data, f, protocol=4)
    
    size_mb = os.path.getsize(pkl_path) / 1024 / 1024
    print(f"Done! File size: {size_mb:.1f} MB")

    # Create train/valid/test splits
    element_ids = [d['ElementID'] for d in all_data]
    patient_map_dict = {d['ElementID']: d['patient_id'] for d in all_data}
    patients = list(set(patient_map_dict.values()))
    
    train_p, temp_p = train_test_split(patients, test_size=0.3, random_state=42)
    valid_p, test_p = train_test_split(temp_p, test_size=0.5, random_state=42)
    
    train_p_set, valid_p_set, test_p_set = set(train_p), set(valid_p), set(test_p)
    train_ids = [e for e in element_ids if patient_map_dict[e] in train_p_set]
    valid_ids = [e for e in element_ids if patient_map_dict[e] in valid_p_set]
    test_ids = [e for e in element_ids if patient_map_dict[e] in test_p_set]
    
    print(f"Split: Train={len(train_ids)}, Valid={len(valid_ids)}, Test={len(test_ids)}")
    
    # Save Dataset maps
    dataset_path = f'{DATA_DIR}/Feature_map/Dataset'
    all_patients = list(patient_map_dict.values())
    train_patients = [patient_map_dict[e] for e in train_ids]
    valid_patients = [patient_map_dict[e] for e in valid_ids]
    test_patients = [patient_map_dict[e] for e in test_ids]
    
    for name, data in [('map', element_ids), ('clean_map', element_ids), ('corrupted_map', []),
                       ('train_map', train_ids), ('valid_map', valid_ids), ('test_map', test_ids),
                       ('patient_map', all_patients), ('clean_patient_map', all_patients), ('corrupted_patient_map', []),
                       ('train_patient_map', train_patients), ('valid_patient_map', valid_patients), ('test_patient_map', test_patients)]:
        with open(f'{dataset_path}/{name}.pkl', 'wb') as f:
            pickle.dump(data, f)
    
    # Create 'other' Dataclass (contains ALL data - required by the codebase)
    other_path = f'{DATA_DIR}/Feature_map/Dataclass/other'
    os.makedirs(other_path, exist_ok=True)
    
    for name, data in [('map', element_ids), ('clean_map', element_ids), ('corrupted_map', []),
                       ('train_map', train_ids), ('valid_map', valid_ids), ('test_map', test_ids),
                       ('patient_map', all_patients), ('clean_patient_map', all_patients), ('corrupted_patient_map', []),
                       ('train_patient_map', train_patients), ('valid_patient_map', valid_patients), ('test_patient_map', test_patients)]:
        with open(f'{other_path}/{name}.pkl', 'wb') as f:
            pickle.dump(data, f)
    
    print(f"Created 'other' Dataclass with {len(element_ids)} records")
    
    del all_data
    gc.collect()
    print("Conversion complete!")
else:
    print("Skipped conversion - data already loaded from Drive")

In [None]:
# Create required map files if missing (needed when loading from Drive)
import pickle
import os
from sklearn.model_selection import train_test_split

DATA_DIR = '/content/Data'
dataset_path = f'{DATA_DIR}/Feature_map/Dataset'
other_path = f'{DATA_DIR}/Feature_map/Dataclass/other'

# Check if train_map.pkl exists in Dataset folder
if not os.path.exists(f'{dataset_path}/train_map.pkl'):
    print("Creating required map files...")
    
    # Load data collection
    with open(f'{dataset_path}/data_collection.pkl', 'rb') as f:
        data_list = pickle.load(f)
    print(f"Loaded {len(data_list)} records")
    
    # Extract IDs and patient info
    element_ids = [d['ElementID'] for d in data_list]
    patient_map_dict = {d['ElementID']: d['patient_id'] for d in data_list}
    patients = list(set(patient_map_dict.values()))
    
    # Create train/valid/test splits
    train_p, temp_p = train_test_split(patients, test_size=0.3, random_state=42)
    valid_p, test_p = train_test_split(temp_p, test_size=0.5, random_state=42)
    
    train_p_set, valid_p_set, test_p_set = set(train_p), set(valid_p), set(test_p)
    train_ids = [e for e in element_ids if patient_map_dict[e] in train_p_set]
    valid_ids = [e for e in element_ids if patient_map_dict[e] in valid_p_set]
    test_ids = [e for e in element_ids if patient_map_dict[e] in test_p_set]
    
    print(f"Split: Train={len(train_ids)}, Valid={len(valid_ids)}, Test={len(test_ids)}")
    
    # Patient lists
    all_patients = list(patient_map_dict.values())
    train_patients = [patient_map_dict[e] for e in train_ids]
    valid_patients = [patient_map_dict[e] for e in valid_ids]
    test_patients = [patient_map_dict[e] for e in test_ids]
    
    # All files to create
    all_maps = [
        ('map', element_ids), ('clean_map', element_ids), ('corrupted_map', []),
        ('train_map', train_ids), ('valid_map', valid_ids), ('test_map', test_ids),
        ('patient_map', all_patients), ('clean_patient_map', all_patients), ('corrupted_patient_map', []),
        ('train_patient_map', train_patients), ('valid_patient_map', valid_patients), ('test_patient_map', test_patients)
    ]
    
    # Save to Dataset folder
    for name, data in all_maps:
        with open(f'{dataset_path}/{name}.pkl', 'wb') as f:
            pickle.dump(data, f)
    print(f"Saved {len(all_maps)} files to Dataset/")
    
    # Create 'other' Dataclass folder
    os.makedirs(other_path, exist_ok=True)
    for name, data in all_maps:
        with open(f'{other_path}/{name}.pkl', 'wb') as f:
            pickle.dump(data, f)
    print(f"Created 'other' Dataclass")
    
    del data_list
    import gc
    gc.collect()
else:
    print("Map files already exist")
    # Still ensure 'other' Dataclass exists
    if not os.path.exists(f'{other_path}/train_map.pkl'):
        print("Creating 'other' Dataclass...")
        os.makedirs(other_path, exist_ok=True)
        for filename in os.listdir(dataset_path):
            if filename.endswith('.pkl') and filename != 'data_collection.pkl':
                with open(f'{dataset_path}/{filename}', 'rb') as f:
                    data = pickle.load(f)
                with open(f'{other_path}/{filename}', 'wb') as f:
                    pickle.dump(data, f)
        print("Created 'other' Dataclass")

In [None]:
# Save converted data to Google Drive (run after first conversion)
import shutil
import os

DRIVE_DATA_PATH = '/content/drive/MyDrive/ECG_Reconstruction/Data'

# Check if Drive data needs updating
needs_update = os.path.exists(DATA_DIR) and (
    not os.path.exists(DRIVE_DATA_PATH) or 
    not os.path.exists(f'{DRIVE_DATA_PATH}/Feature_map/Dataset/train_map.pkl') or
    not os.path.exists(f'{DRIVE_DATA_PATH}/Feature_map/Dataclass/other/train_map.pkl')
)

if needs_update:
    print(f"Saving/updating Google Drive: {DRIVE_DATA_PATH}")
    print("This may take a few minutes...")
    if os.path.exists(DRIVE_DATA_PATH):
        shutil.rmtree(DRIVE_DATA_PATH)
    shutil.copytree(DATA_DIR, DRIVE_DATA_PATH)
    print("Saved to Drive!")
else:
    print("Data on Drive is up to date")

In [None]:
# Verify data setup
import os
import pickle

print("Dataset contents:")
!ls -la {DATA_DIR}/Feature_map/Dataset/ | head -15

print("\nDataclass 'other' contents:")
!ls -la {DATA_DIR}/Feature_map/Dataclass/other/

# Verify pickle file is valid
pkl_path = f'{DATA_DIR}/Feature_map/Dataset/data_collection.pkl'
try:
    with open(pkl_path, 'rb') as f:
        data = pickle.load(f)
    print(f"\ndata_collection.pkl is valid! Contains {len(data)} records")
    print(f"Sample ElementID: {data[0]['ElementID']}")
    del data
except Exception as e:
    print(f"\nERROR: {e}")
    print("You need to re-run the conversion cell")

In [None]:
# Patch original code to use PTB-XL adapter
!cp util_functions/general.py util_functions/general_backup.py
!cp util_functions/general_ptbxl.py util_functions/general.py

print("Code patched to use PTB-XL data!")

## 4. Configuration (Paper Settings)

In [None]:
# Configuration based on Paper (Nature s41746-024-01193-7)

CONFIG = {
    'device': DEVICE,
    
    # INPUT/OUTPUT LEADS (Paper: I + II + V3 -> V1-V6)
    'input_leads': 'limb+v3',     # I, II, V3 (3 leads)
    'output_leads': 'precordial', # V1-V6 (6 leads)
    
    # Dataset - use 'other' to load ALL PTB-XL data
    'dataset': 'other',
    'data_size': 'max',
    
    # Network architecture (Paper: ResCNN blocks)
    'input_channel': 32,
    'middle_channel': 32,
    'output_channel': 32,
    'input_depth': 3,
    'middle_depth': 2,
    'output_depth': 3,
    'input_kernel': 17,
    'middle_kernel': 17,
    'output_kernel': 17,
    'use_residual': 'true',
    
    # Training parameters
    'epochs': 200,
    'batch_size': 16,
    'optimizer_algorithm': 'adam',
    'learning_rate': 0.000003,
    'weight_decay': 0.001,
    'momentum': 0.9,
    'nesterov': True,
    'prioritize_percent': 0,
    'prioritize_size': 0,
}

print("="*60)
print("CONFIGURATION (Paper Settings)")
print("="*60)
print(f"Input:  {CONFIG['input_leads']} (I, II, V3)")
print(f"Output: {CONFIG['output_leads']} (V1-V6)")
print(f"Device: {CONFIG['device']}")
print(f"Dataset: {CONFIG['dataset']} (all PTB-XL data)")
print(f"Epochs: {CONFIG['epochs']}")
print("="*60)

## 5. Initialize Model

In [None]:
import sys
sys.path.insert(0, '.')

# Fix NumPy 2.0 compatibility before importing
import numpy as np
if not hasattr(np, 'infty'):
    np.infty = np.inf

from util_functions.general import get_parent_folder, get_data_classes, get_lead_keys
from training_functions.single_reconstruction_manager import ReconstructionManager

# Get settings
parent_folder = get_parent_folder()
data_classes = get_data_classes(CONFIG['dataset'])  # Returns ['other']
sub_classes = []

# Show lead configuration
input_keys = get_lead_keys(CONFIG['input_leads'])
output_keys = get_lead_keys(CONFIG['output_leads'])
print(f"Input leads ({len(input_keys)}): {input_keys}")
print(f"Output leads ({len(output_keys)}): {output_keys}")
print(f"Data classes: {data_classes}")
print(f"Data folder: {parent_folder}")

In [None]:
# Create the Reconstruction Manager
manager = ReconstructionManager(
    parent_folder=parent_folder,
    device=CONFIG['device'],
    sub_classes=sub_classes,
    input_leads=CONFIG['input_leads'],
    output_leads=CONFIG['output_leads'],
    data_classes=data_classes,
    data_size=CONFIG['data_size'],
    input_channel=CONFIG['input_channel'],
    middle_channel=CONFIG['middle_channel'],
    output_channel=CONFIG['output_channel'],
    input_depth=CONFIG['input_depth'],
    middle_depth=CONFIG['middle_depth'],
    output_depth=CONFIG['output_depth'],
    input_kernel=CONFIG['input_kernel'],
    middle_kernel=CONFIG['middle_kernel'],
    output_kernel=CONFIG['output_kernel'],
    use_residual=CONFIG['use_residual'],
    epochs=CONFIG['epochs'],
    batch_size=CONFIG['batch_size'],
    prioritize_percent=CONFIG['prioritize_percent'],
    prioritize_size=CONFIG['prioritize_size'],
    optimizer_algorithm=CONFIG['optimizer_algorithm'],
    learning_rate=CONFIG['learning_rate'],
    weight_decay=CONFIG['weight_decay'],
    momentum=CONFIG['momentum'],
    nesterov=CONFIG['nesterov']
)

print("Reconstruction Manager initialized!")

## 6. Training

In [None]:
# Initialize model
print("Initializing model...")
manager.reset_model()
print("Model initialized!")

In [None]:
# Load datasets
print("Loading training and validation datasets...")
manager.load_dataset(train=True, valid=True)
print("Datasets loaded!")

In [None]:
# Train!
print(f"Starting training for {CONFIG['epochs']} epochs on {DEVICE}...")
print("="*50)
manager.train()
print("="*50)
print("Training completed!")

In [None]:
# Release memory and plot
manager.release_dataset()
manager.plot_train_stats()
manager.plot_valid_stats()

## 7. Testing

In [None]:
# Load model and test
print("Loading model for testing...")
manager.load_model()
manager.load_dataset(test=True)

print("Running tests...")
manager.test()

manager.release_dataset()
manager.plot_test_stats()
print("Testing completed!")

## 8. Visualization

In [None]:
# Plot examples
manager.load_model()
manager.plot_random_example(plot_format='png')
print("Random examples plotted!")

In [None]:
# Plot error examples
manager.load_test_stats()
manager.plot_error_example(plot_format='png')
print("Error examples plotted!")

## 9. Save Results to Google Drive

In [None]:
import shutil
import os

drive_path = '/content/drive/MyDrive/ECG_Reconstruction_Results'
os.makedirs(drive_path, exist_ok=True)

# Copy output folder
output_folder = f'{DATA_DIR}/Analysis'
if os.path.exists(output_folder):
    shutil.copytree(output_folder, f'{drive_path}/Analysis', dirs_exist_ok=True)
    print(f"Results saved to: {drive_path}")

print("\nDone! Results saved to Google Drive.")