# Sequence Model Research

The scope of this notebook is to assess and train different sequence models given the training data generated.

Training data is generated based on financial time series data labeled with potential profits using a buy-sell system.

The goal is to create a sequence model that can choose favourable stock charts equal to or better than a human can via traditional technical analysis.

# Import Libraries and Data

In [None]:
import os
import numpy as np

# Define the data directory relative to the script location
data_dir = 'data'

# Define the file paths
sequences_path = os.path.join(data_dir, 'sequences.npy')
labels_path = os.path.join(data_dir, 'labels.npy')
metadata_path = os.path.join(data_dir, 'metadata.npy')

# List of feature names
feature_names = [
    'Consol_Len_Bars', 'Consol_Depth_Percent',
    'Distance_to_21EMA', 'Distance_to_50SMA', 'Distance_to_200SMA', 
    'RSL_NH_Count', 'RSL_Slope', 'Up_Down_Days', 
    'Stage 2', 'UpDownVolumeRatio', 'ATR', '%B'
]

# Load the data
try:
    data_sequences = np.load(sequences_path)
    data_labels = np.load(labels_path)
    data_metadata = np.load(metadata_path)

    # Number of examples to select
    num_examples = 115000

    # Generate a random permutation of indices
    indices = np.random.permutation(len(data_sequences))

    # Select the first `num_examples` indices
    selected_indices = indices[:num_examples]

    # Use the selected indices to create the random subset
    data_sequences = data_sequences[selected_indices, :, :]
    data_labels = data_labels[selected_indices]
    data_metadata = data_metadata[selected_indices]

    # Inspect the shape and size of the loaded data before slicing
    print(f'Loaded sequences shape: {data_sequences.shape}')
    print(f'Loaded sequences size: {data_sequences.size}')
    print(f'Loaded labels shape: {data_labels.shape}')
    print(f'Loaded metadata shape: {data_metadata.shape}')

except FileNotFoundError as e:
    print(f"Error loading files: {e}")
except ValueError as e:
    print(f"Value error: {e}")


## Data Preprocessing

### NaN anf INF Removal

In [None]:
import numpy as np

# Dictionary to map variable names to their corresponding data arrays
data_dict = {
    'data_sequences': data_sequences,
    'data_labels': data_labels,
}

# Using a dictionary to iterate over variables
for var_name, data in data_dict.items():
    num_nans = np.sum(np.isnan(data))
    num_infs = np.sum(np.isinf(data))
    print(f"NaNs in {var_name}: {num_nans}")
    print(f"Infs in {var_name}: {num_infs}")

    # Remove NaNs and Infs
    if num_nans > 0 or num_infs > 0:
        data_dict[var_name][:] = np.nan_to_num(data, nan=0.0, posinf=0.0, neginf=0.0)
        num_nans_after = np.sum(np.isnan(data))
        num_infs_after = np.sum(np.isinf(data))
        print(f"NaNs remaining in {var_name} after removal: {num_nans_after}")
        print(f"Infs remaining in {var_name} after removal: {num_infs_after}")

print("NaN and Inf removal completed.")

### Corrupted sequence removal

99% of stocks I buy will be below 1000, with a few above 1000, although they are important.

I also noticed quite a few training examples have weird price data, which I filter out below.

I noticed with thresholds above 3e3, the max is the threshold, which is very suspect.

The loss of training examples is insignificant, and the result is better normalization of the data and obviously no corrupted sequences.

#### Feature Stats

In [None]:
import numpy as np
import pandas as pd

def print_feature_stats(data_sequences, feature_names):
    print("Feature Statistics:")
    print("-" * 50)

    for i, feature_name in enumerate(feature_names):
        feature_data = data_sequences[:, :, i].flatten()
        
        stats = {
            "Mean": np.mean(feature_data),
            "Median": np.median(feature_data),
            "Std Dev": np.std(feature_data),
            "Min": np.min(feature_data),
            "Max": np.max(feature_data),
            "25th Percentile": np.percentile(feature_data, 25),
            "75th Percentile": np.percentile(feature_data, 75),
            "Skewness": pd.Series(feature_data).skew(),
            "Kurtosis": pd.Series(feature_data).kurtosis(),
            "Zero Count": np.sum(feature_data == 0),
            "Zero Percentage": np.mean(feature_data == 0) * 100
        }
        
        print(f"Feature: {feature_name}")
        for stat_name, stat_value in stats.items():
            print(f"  {stat_name}: {stat_value:.4f}")
        print("-" * 50)

# Call the function to print statistics
print_feature_stats(data_sequences, feature_names)

# Additional overall statistics
print("Overall Dataset Statistics:")
print(f"Total number of sequences: {data_sequences.shape[0]}")
print(f"Sequence length: {data_sequences.shape[1]}")
print(f"Number of features: {data_sequences.shape[2]}")
print(f"Total number of data points: {data_sequences.size}")
print(f"Memory usage: {data_sequences.nbytes / (1024 * 1024):.2f} MB")

### Normalization of Training Data

In [None]:
import numpy as np
import pandas as pd
from sklearn.preprocessing import StandardScaler

# Define the indices based on the provided feature names
feature_names = [
    'Consol_Len_Bars', 'Consol_Depth_Percent',
    'Distance_to_21EMA', 'Distance_to_50SMA', 'Distance_to_200SMA', 
    'RSL_NH_Count', 'RSL_Slope', 'Up_Down_Days', 
    'Stage 2', 'UpDownVolumeRatio', 'ATR', '%B'
]

feature_indices = {name: idx for idx, name in enumerate(feature_names)}

# Function to remove outliers and cap values
def preprocess_data(sequences, labels):
    # Reshape sequences to 2D array for easier processing (flatten the timesteps)
    num_sequences, num_timesteps, num_features = sequences.shape
    sequences_reshaped = sequences.reshape(-1, num_features)
    
    # Create a mask to filter out invalid sequences
    valid_mask = (
        (sequences_reshaped[:, feature_indices['Distance_to_21EMA']] <= 100) &
        (sequences_reshaped[:, feature_indices['Distance_to_50SMA']] <= 200) &
        (sequences_reshaped[:, feature_indices['Distance_to_200SMA']] <= 500)
    )
    
    # Reshape the valid_mask to match the original sequence shape
    valid_mask_reshaped = valid_mask.reshape(num_sequences, num_timesteps)
    
    # Filter out sequences with any invalid timesteps
    valid_sequences_mask = valid_mask_reshaped.all(axis=1)
    filtered_sequences = sequences[valid_sequences_mask]
    filtered_labels = labels[valid_sequences_mask]
    
    # Cap 'UpDownVolumeRatio' at 10
    filtered_sequences[:, :, feature_indices['UpDownVolumeRatio']] = np.minimum(
        filtered_sequences[:, :, feature_indices['UpDownVolumeRatio']], 10
    )
    
    # Normalize the features using Z-score normalization
    scaler = StandardScaler()
    
    # Flatten the sequence again for normalization
    filtered_sequences_reshaped = filtered_sequences.reshape(-1, num_features)
    
    # Normalize
    normalized_data_reshaped = scaler.fit_transform(filtered_sequences_reshaped)
    
    # Reshape back to the original 3D shape
    normalized_data = normalized_data_reshaped.reshape(filtered_sequences.shape)
    
    return normalized_data, filtered_labels

# Function to print feature statistics
def print_feature_stats(data_sequences, feature_names):
    print("Feature Statistics:")
    print("-" * 50)

    for i, feature_name in enumerate(feature_names):
        feature_data = data_sequences[:, :, i].flatten()
        
        stats = {
            "Mean": np.mean(feature_data),
            "Median": np.median(feature_data),
            "Std Dev": np.std(feature_data),
            "Min": np.min(feature_data),
            "Max": np.max(feature_data),
            "25th Percentile": np.percentile(feature_data, 25),
            "75th Percentile": np.percentile(feature_data, 75),
            "Skewness": pd.Series(feature_data).skew(),
            "Kurtosis": pd.Series(feature_data).kurtosis(),
            "Zero Count": np.sum(feature_data == 0),
            "Zero Percentage": np.mean(feature_data == 0) * 100
        }
        
        print(f"Feature: {feature_name}")
        for stat_name, stat_value in stats.items():
            print(f"  {stat_name}: {stat_value:.4f}")
        print("-" * 50)

# Example usage with data_sequences and data_labels
# Assuming data_sequences is loaded and has shape (115000, 63, 12)

# Process the data
normalized_data, processed_labels = preprocess_data(data_sequences, data_labels)




# Train Candidate Model

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.metrics import precision_score, recall_score, f1_score, confusion_matrix
from sklearn.utils.class_weight import compute_class_weight
import mlflow
import mlflow.pytorch
from tqdm.notebook import tqdm
from IPython.display import display
import ipywidgets as widgets
import psutil
import GPUtil
import time
import threading
import matplotlib.pyplot as plt
import seaborn as sns
import os

# Set random seed for reproducibility
RANDOM_SEED = 42
torch.manual_seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)

# Check for CUDA availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

def get_system_metrics():
    """
    Get system metrics including CPU usage, memory usage, and GPU utilization.

    Returns:
        dict: Dictionary containing system metrics.
    """
    metrics = {}
    metrics['cpu_percent'] = psutil.cpu_percent(interval=1)
    metrics['memory_percent'] = psutil.virtual_memory().percent
    
    gpus = GPUtil.getGPUs()
    if gpus:
        metrics['gpu_utilization'] = gpus[0].load * 100
        metrics['gpu_memory_used'] = gpus[0].memoryUsed
        metrics['gpu_memory_total'] = gpus[0].memoryTotal
    
    return metrics

def log_system_metrics():
    """
    Continuously log system metrics to MLflow.
    """
    while True:
        metrics = get_system_metrics()
        for key, value in metrics.items():
            mlflow.log_metric(key, value)
        time.sleep(5)  # Log every 5 seconds

def display_model_details(model, params):
    """
    Display the model architecture and hyperparameters.

    Args:
        model (torch.nn.Module): The PyTorch model.
        params (dict): The hyperparameters of the model.
    """
    print("\nModel Architecture:")
    print(model)
    
    print("\nModel Hyperparameters:")
    for param, value in params.items():
        print(f"{param}: {value}")

def process_labels(labels, threshold):
    """
    Process labels based on a given threshold.

    Args:
        labels (numpy array): The array of labels.
        threshold (float): The threshold value to binarize the labels.

    Returns:
        numpy array: Binarized labels.
    """
    return (labels > threshold).astype(int)

def split_data(normalized_data, processed_labels):
    """
    Split data into training, validation, and test sets using sklearn's train_test_split.

    Args:
        normalized_data (numpy array): Normalized input data.
        processed_labels (numpy array): Corresponding processed labels.

    Returns:
        tuple: Train, validation, and test data and labels.
    """
    print("Splitting data into train, validation, and test sets...")
    
    # First split: separate test set
    X_train_val, X_test, y_train_val, y_test = train_test_split(
        normalized_data, processed_labels, test_size=0.2, random_state=RANDOM_SEED, stratify=processed_labels
    )
    
    # Second split: separate train and validation sets
    X_train, X_val, y_train, y_val = train_test_split(
        X_train_val, y_train_val, test_size=0.25, random_state=RANDOM_SEED, stratify=y_train_val
    )  # 0.25 x 0.8 = 0.2, so validation set is 20% of original data
    
    print(f"Train set size: {len(X_train)}")
    print(f"Validation set size: {len(X_val)}")
    print(f"Test set size: {len(X_test)}")
    
    return X_train, X_val, X_test, y_train, y_val, y_test

def list_mlflow_runs(experiment_name):
    """
    List all MLflow runs for a given experiment.

    Args:
        experiment_name (str): Name of the MLflow experiment.

    Returns:
        list: List of MLflow runs.
    """
    mlflow.set_experiment(experiment_name)
    client = mlflow.tracking.MlflowClient()
    experiment = client.get_experiment_by_name(experiment_name)
    runs = client.search_runs(experiment_ids=[experiment.experiment_id])
    return runs

def select_model_ui(runs):
    """
    Create a UI for selecting a model from available MLflow runs.

    Args:
        runs (list): List of MLflow runs.

    Returns:
        widgets.Dropdown: A dropdown widget for model selection.
    """
    options = []
    for run in runs:
        run_name = run.data.tags.get('mlflow.runName', 'N/A')
        test_f1 = run.data.metrics.get('test_f1', 'N/A')
        try:
            test_f1 = float(test_f1)
            test_f1 = f"{test_f1:.4f}"
        except (ValueError, TypeError):
            pass  # Keep test_f1 as 'N/A' or any non-numeric value
        
        options.append(f"{run_name} (F1: {test_f1})")
    
    if not options:
        print("No available models to select.")
        return None
    
    dropdown = widgets.Dropdown(
        options=options,
        description='Select Model:',
        disabled=False,
    )
    display(dropdown)

    def on_change(change):
        if change['type'] == 'change' and change['name'] == 'value':
            selected_index = dropdown.index
            selected_run = runs[selected_index]
            print(f"\nYou selected:")
            print(f"Run Name: {selected_run.data.tags.get('mlflow.runName', 'N/A')}")
            print(f"Run ID: {selected_run.info.run_id}")
            print(f"F1 Score: {selected_run.data.metrics.get('test_f1', 'N/A')}")
    
    dropdown.observe(on_change, names='value')
    return dropdown

def load_model(run):
    """
    Load a PyTorch model from a given MLflow run.

    Args:
        run (mlflow.entities.Run): The MLflow run containing the model.

    Returns:
        torch.nn.Module: The loaded PyTorch model.
    """
    try:
        model_uri = f"runs:/{run.info.run_id}/model"
        model = mlflow.pytorch.load_model(model_uri)
        return model
    except Exception as e:
        print(f"Error loading model: {e}")
        raise

class EarlyStopping:
    """
    Early stopping utility to stop the training when the monitored metric has stopped improving.
    """
    def __init__(self, patience=5, verbose=False):
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_precision_max = 0

    def __call__(self, val_precision, model, epoch):
        score = val_precision

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_precision, model)
        elif score < self.best_score:
            self.counter += 1
            if self.verbose:
                print(f"EarlyStopping counter: {self.counter} out of {self.patience}")
            mlflow.log_metric("early_stopping_counter", self.counter, step=epoch)
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_precision, model)
            self.counter = 0
            mlflow.log_metric("early_stopping_counter", self.counter, step=epoch)

    def save_checkpoint(self, val_precision, model):
        """Saves model when validation precision increases."""
        if self.verbose:
            print(f"Validation precision increased ({self.val_precision_max:.6f} --> {val_precision:.6f}).  Saving model ...")
        torch.save(model.state_dict(), 'checkpoints/checkpoint.pt')
        self.val_precision_max = val_precision

def train_model(model, train_loader, val_loader, params, n_epochs=10):
    """
    Train a PyTorch model.

    Args:
        model (torch.nn.Module): The model to train.
        train_loader (DataLoader): DataLoader for training data.
        val_loader (DataLoader): DataLoader for validation data.
        params (dict): Training parameters.
        n_epochs (int): Number of epochs to train.

    Returns:
        torch.nn.Module: The trained model.
    """
    print("\nPreparing for training...")
    
    # Compute class weights
    all_labels = torch.cat([y for _, y in train_loader])
    class_weights = compute_class_weight('balanced', classes=np.unique(all_labels.numpy()), y=all_labels.numpy())
    class_weights = torch.FloatTensor(class_weights).to(device)
    
    criterion = nn.BCEWithLogitsLoss(pos_weight=class_weights[1])  # Assuming binary classification
    
    lr = float(params.get('LEARNING_RATE', 0.001))
    weight_decay = float(params.get('L2_REGULARIZATION', 0.0))
    
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)

    best_val_precision = 0
    best_model = None

    print(f"\nStarting training for {n_epochs} epochs...")
    
    # Create a progress bar for epochs
    epoch_pbar = tqdm(range(n_epochs), desc="Training Progress")
    
    early_stopping = EarlyStopping(patience=250, verbose=True)
    
    for epoch in epoch_pbar:
        model.train()
        total_loss = 0
        batch_pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{n_epochs}", leave=False)
        for batch_X, batch_y in batch_pbar:
            batch_X, batch_y = batch_X.to(device), batch_y.to(device)
            
            optimizer.zero_grad()
            outputs = model(batch_X)
            
            outputs = outputs.view(-1)
            batch_y = batch_y.float()
            
            loss = criterion(outputs, batch_y)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            batch_pbar.set_postfix({'loss': total_loss / len(train_loader)})

        # Evaluate model after each epoch
        val_precision, val_recall, val_f1, _, _ = evaluate_model(model, val_loader, params)
        
        # Calculate validation loss
        val_loss = 0
        model.eval()
        with torch.no_grad():
            for val_X, val_y in val_loader:
                val_X, val_y = val_X.to(device), val_y.to(device)
                val_outputs = model(val_X)
                val_outputs = val_outputs.view(-1)
                val_y = val_y.float()
                val_loss += criterion(val_outputs, val_y).item()
        
        val_loss /= len(val_loader)
        
        # Update epoch progress bar with validation precision
        epoch_pbar.set_postfix({'Val Precision': f'{val_precision:.4f}'})
        
        # Log metrics to MLflow
        mlflow.log_metric("train_loss", total_loss / len(train_loader), step=epoch)
        mlflow.log_metric("val_loss", val_loss, step=epoch)
        mlflow.log_metric("val_precision", val_precision, step=epoch)
        mlflow.log_metric("val_recall", val_recall, step=epoch)
        mlflow.log_metric("val_f1", val_f1, step=epoch)
        
        print(f"Epoch {epoch+1}/{n_epochs}: Loss: {total_loss/len(train_loader):.4f}, "
              f"Val Loss: {val_loss:.4f}, Val Precision: {val_precision:.4f}, Val Recall: {val_recall:.4f}, Val F1: {val_f1:.4f}")

        early_stopping(val_precision, model, epoch)
        
        if early_stopping.early_stop:
            print(f"Early stopping at epoch {epoch+1}")
            break

        if val_precision > best_val_precision:
            best_val_precision = val_precision
            best_model = model.state_dict()
            print(f"New best model found at epoch {epoch+1}")

    print("\nTraining complete. Loading best model...")
    model.load_state_dict(best_model)
    return model

def evaluate_model(model, data_loader, params):
    """
    Evaluate a PyTorch model.

    Args:
        model (torch.nn.Module): The model to evaluate.
        data_loader (DataLoader): DataLoader for evaluation data.
        params (dict): Evaluation parameters.

    Returns:
        tuple: Precision, recall, f1 score, all_labels, all_preds.
    """
    model.eval()
    all_preds = []
    all_labels = []
    
    # Convert prediction threshold to float
    prediction_threshold = float(params.get('PREDICTION_THRESHOLD', 0.5))
    
    with torch.no_grad():
        for batch_X, batch_y in data_loader:
            batch_X, batch_y = batch_X.to(device), batch_y.to(device)
            outputs = model(batch_X)
            preds = (outputs > prediction_threshold).float()
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(batch_y.cpu().numpy())

    precision = precision_score(all_labels, all_preds, zero_division=0)
    recall = recall_score(all_labels, all_preds, zero_division=0)
    f1 = f1_score(all_labels, all_preds, zero_division=0)
    
    return precision, recall, f1, all_labels, all_preds

def plot_confusion_matrix(cm, class_names):
    """
    Plot confusion matrix and save it as an image.

    Args:
        cm (ndarray): Confusion matrix.
        class_names (list): List of class names.
    """
    fig, ax = plt.subplots(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names)
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.title('Confusion Matrix')
    plt.close(fig)
    return fig

def main(normalized_data, processed_labels):
    """
    Main function to handle the entire training and evaluation pipeline.

    Args:
        normalized_data (numpy array): Normalized input data.
        processed_labels (numpy array): Corresponding processed labels.
    """
    try:
        print("Loading data...")

        print("Listing available models...")
        runs = list_mlflow_runs("Stock Analysis Model Tuning 2")

        print("Please select a model:")
        model_dropdown = select_model_ui(runs)
        if model_dropdown is None:
            print("No models available for selection.")
            return
        
        # Number of epochs widget
        n_epochs = widgets.IntText(value=1000, description='Epochs:', min=100, max=10000)
        display(n_epochs)
        
        # Start training button
        start_button = widgets.Button(description="Start Training")
        display(start_button)
        
        def on_button_click(b):
            try:
                if model_dropdown.value is None:
                    print("No model selected.")
                    return
                
                selected_index = model_dropdown.index
                selected_run = runs[selected_index]

                print("\nLoading the selected model...")
                model = load_model(selected_run)
                model = model.to(device)
                print(f"Loaded model from run: {selected_run.info.run_id}")
                print(f"Run name: {selected_run.data.tags.get('mlflow.runName', 'N/A')}")

                print("\nRetrieving model parameters...")
                params = selected_run.data.params

                # Display model architecture and hyperparameters
                display_model_details(model, params)

                # Process labels using the PROFIT_THRESH from the selected model
                profit_thresh = float(params['PROFIT_THRESH'])
                processed_labels_thresholded = process_labels(processed_labels, profit_thresh)

                print(f"Processing labels with PROFIT_THRESH: {profit_thresh}")
                print(f"Processed labels shape: {processed_labels_thresholded.shape}")
                print(f"Unique values in processed labels: {np.unique(processed_labels_thresholded)}")

                print("Splitting data...")
                X_train, X_val, X_test, y_train, y_val, y_test = split_data(normalized_data, processed_labels_thresholded)

                # Create data loaders
                batch_size = 64  # You can adjust this
                train_dataset = TensorDataset(torch.FloatTensor(X_train), torch.FloatTensor(y_train))
                val_dataset = TensorDataset(torch.FloatTensor(X_val), torch.FloatTensor(y_val))
                test_dataset = TensorDataset(torch.FloatTensor(X_test), torch.FloatTensor(y_test))
                
                train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
                val_loader = DataLoader(val_dataset, batch_size=batch_size)
                test_loader = DataLoader(test_dataset, batch_size=batch_size)

                # Check a batch from the train_loader
                sample_batch = next(iter(train_loader))
                print(f"Sample batch X shape: {sample_batch[0].shape}")
                print(f"Sample batch y shape: {sample_batch[1].shape}")
                print(f"Sample batch X type: {sample_batch[0].dtype}")
                print(f"Sample batch y type: {sample_batch[1].dtype}")
                print(f"Unique values in sample batch y: {torch.unique(sample_batch[1])}")

                print("\nStarting MLflow run for continued training...")
                with mlflow.start_run(run_name=f"continued_{selected_run.data.tags.get('mlflow.runName', 'unnamed')}"):
                    # Log parameters
                    mlflow.log_params(params)
                    mlflow.log_param("continued_epochs", n_epochs.value)
                    
                    print(f"\nStarting training for {n_epochs.value} epochs...")
                    trained_model = train_model(model, train_loader, val_loader, params, n_epochs.value)
                    print("Training completed.")

                    print("\nEvaluating model on test set...")
                    test_precision, test_recall, test_f1, all_labels, all_preds = evaluate_model(trained_model, test_loader, params)
                    print(f"Test set performance - Precision: {test_precision:.4f}, Recall: {test_recall:.4f}, F1: {test_f1:.4f}")

                    # Log final metrics
                    mlflow.log_metric("test_precision", test_precision)
                    mlflow.log_metric("test_recall", test_recall)
                    mlflow.log_metric("test_f1", test_f1)

                    # Confusion matrix
                    cm = confusion_matrix(all_labels, all_preds)
                    class_names = ['Class 0', 'Class 1']  # Adjust based on your classes
                    cm_fig = plot_confusion_matrix(cm, class_names)
                    cm_img_path = os.path.join("confusion_matrix.png")
                    cm_fig.savefig(cm_img_path)
                    mlflow.log_artifact(cm_img_path)
                    
                    # Log the model
                    print("Logging model...")
                    mlflow.pytorch.log_model(trained_model, "continued_model")
                    print("Saved continued model in MLflow")

                print("\nTraining and evaluation complete!")
                print(f"Final test set performance:")
                print(f"Precision: {test_precision:.4f}")
                print(f"Recall: {test_recall:.4f}")
                print(f"F1 Score: {test_f1:.4f}")
                print(f"Model saved in MLflow with run name: continued_{selected_run.data.tags.get('mlflow.runName', 'unnamed')}")

            except Exception as e:
                print(f"An error occurred in on_button_click: {str(e)}")
                print("Full traceback:", e)

        start_button.on_click(on_button_click)

    except Exception as e:
        print(f"An error occurred in main function: {str(e)}")
        print("Full traceback:", e)

# Assuming normalized_data and processed_labels are already defined in the notebook
# Example call to the main function
if __name__ == "__main__":
    main(normalized_data, processed_labels)