In [13]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import ipywidgets as widgets
from ipywidgets import Layout, VBox, HBox, Button, Label
from IPython.display import display, clear_output, HTML
from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import RBF, Matern, RationalQuadratic, DotProduct, WhiteKernel
from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor, ExtraTreesRegressor
from sklearn.neural_network import MLPRegressor
from sklearn.svm import SVR
from sklearn.linear_model import BayesianRidge, ARDRegression
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from sklearn.metrics import r2_score
from scipy.stats import norm
import xgboost as xgb
import lightgbm as lgb
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModel
import warnings
import io
import os
warnings.filterwarnings('ignore')


# For animations
from matplotlib import animation
from matplotlib.animation import FuncAnimation
from mpl_toolkits.mplot3d import Axes3D  # For 3D plots

# Set up the color theme
color_palette = ['#6A728A', '#6083B4', '#8F92C5', '#38D23D', '#F5A33C']
plt.rcParams['axes.facecolor'] = 'white'
plt.rcParams['axes.edgecolor'] = '#4D4D4D'
plt.rcParams['axes.labelcolor'] = '#4D4D4D'
plt.rcParams['xtick.color'] = '#4D4D4D'
plt.rcParams['ytick.color'] = '#4D4D4D'
plt.rcParams['text.color'] = '#4D4D4D'
plt.rcParams['figure.facecolor'] = 'white'

# List of amino acids
amino_acids = list('ACDEFGHIKLMNPQRSTVWY')

# Global variable to hold the ESM2 model and tokenizer
esm2_model = None
esm2_tokenizer = None

# Function to load ESM2 model when needed
def load_esm2_model():
    global esm2_model, esm2_tokenizer
    if esm2_model is None or esm2_tokenizer is None:
        print("Loading ESM2 model from Hugging Face...")
        esm2_tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")
        esm2_model = AutoModel.from_pretrained("facebook/esm2_t6_8M_UR50D")
        esm2_model.eval()  # Disable dropout for deterministic results
        print("ESM2 model loaded.")

# Function to encode sequences using ESM2
def esm2_encode_sequences(sequences):
    """
    Encode sequences using ESM2 embeddings from Hugging Face.
    """
    load_esm2_model()
    encoded_sequences = []
    for seq in sequences:
        # Prepare sequence for the model
        inputs = esm2_tokenizer(seq, return_tensors="pt", add_special_tokens=True)
        with torch.no_grad():
            outputs = esm2_model(**inputs)
        # Get the hidden states from the last layer
        last_hidden_states = outputs.last_hidden_state
        # Exclude special tokens ([CLS], [SEP]) and average embeddings
        sequence_representation = last_hidden_states[:, 1:-1, :].mean(dim=1).squeeze().numpy()
        encoded_sequences.append(sequence_representation)
    return np.array(encoded_sequences)

# Function to generate random sequences
def generate_sequences(num_sequences, sequence_length):
    sequences = [''.join(np.random.choice(amino_acids, sequence_length)) for _ in range(num_sequences)]
    return sequences

# Function to compute fitness (simulate fitness landscape)
def compute_fitness(sequences, target_sequence):
    """
    Compute fitness based on similarity to target sequence.
    Fitness is the number of matching amino acids.
    """
    fitness = []
    for seq in sequences:
        match_score = sum(a == b for a, b in zip(seq, target_sequence))
        fitness.append(match_score)
    return np.array(fitness)

# Function to one-hot encode sequences
def one_hot_encode_sequences(sequences):
    aa_to_int = {aa: i for i, aa in enumerate(amino_acids)}
    encoded_sequences = []
    for seq in sequences:
        seq_int = [aa_to_int[aa] for aa in seq]
        seq_one_hot = np.eye(len(amino_acids))[seq_int]
        seq_one_hot_flat = seq_one_hot.flatten()
        encoded_sequences.append(seq_one_hot_flat)
    return np.array(encoded_sequences)

# Function to encode sequences using AAIndex (simulated)
def aaindex_encode_sequences(sequences):
    np.random.seed(42)
    aa_properties = {aa: np.random.rand(5) for aa in amino_acids}  # 5 properties per amino acid
    encoded_sequences = []
    for seq in sequences:
        seq_encoded = np.concatenate([aa_properties[aa] for aa in seq])
        encoded_sequences.append(seq_encoded)
    return np.array(encoded_sequences)

# VAE model class
class VAE(nn.Module):
    def __init__(self, input_dim, latent_dim=10):
        super(VAE, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
        )
        self.fc_mu = nn.Linear(64, latent_dim)
        self.fc_logvar = nn.Linear(64, latent_dim)
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 128),
            nn.ReLU(),
            nn.Linear(128, input_dim),
            nn.Sigmoid(),
        )

    def encode(self, x):
        h = self.encoder(x)
        return self.fc_mu(h), self.fc_logvar(h)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        return self.decoder(z)

    def forward(self, x):
        mu, logvar = self.encode(x.view(-1, x.shape[1]))
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

# Function to train VAE
def train_vae(sequences, encoding_dim=10, epochs=20, batch_size=32):
    encoded_sequences = one_hot_encode_sequences(sequences)
    input_dim = encoded_sequences.shape[1]
    vae = VAE(input_dim, latent_dim=encoding_dim)
    optimizer = torch.optim.Adam(vae.parameters(), lr=0.001)
    criterion = nn.MSELoss(reduction='sum')

    dataset = torch.FloatTensor(encoded_sequences)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    vae.train()
    for epoch in range(epochs):
        total_loss = 0
        for data in dataloader:
            optimizer.zero_grad()
            recon_batch, mu, logvar = vae(data)
            loss = vae_loss_function(recon_batch, data, mu, logvar, criterion)
            loss.backward()
            total_loss += loss.item()
            optimizer.step()
    return vae

def vae_loss_function(recon_x, x, mu, logvar, criterion):
    MSE = criterion(recon_x, x)
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return MSE + KLD

# Function to encode sequences using trained VAE
def vae_encode_sequences(sequences, vae_model):
    vae_model.eval()
    with torch.no_grad():
        encoded_sequences = one_hot_encode_sequences(sequences)
        data = torch.FloatTensor(encoded_sequences)
        mu, logvar = vae_model.encode(data)
        z = vae_model.reparameterize(mu, logvar)
        return z.numpy()

# Function to select surrogate model
def select_surrogate_model(model_name, kernel_name='RBF'):
    if model_name == 'Gaussian Process':
        if kernel_name == 'RBF':
            kernel = RBF()
        elif kernel_name == 'Matern':
            kernel = Matern()
        elif kernel_name == 'RationalQuadratic':
            kernel = RationalQuadratic()
        elif kernel_name == 'DotProduct':
            kernel = DotProduct()
        else:
            kernel = RBF()
        kernel += WhiteKernel()
        model = GaussianProcessRegressor(kernel=kernel, alpha=1e-6)
    elif model_name == 'Random Forest':
        model = RandomForestRegressor(n_estimators=100)
    elif model_name == 'Extra Trees':
        model = ExtraTreesRegressor(n_estimators=100)
    elif model_name == 'Neural Network':
        model = MLPRegressor(hidden_layer_sizes=(100,100), max_iter=1000)
    elif model_name == 'Deep Neural Network':
        model = MLPRegressor(hidden_layer_sizes=(200,200,200), max_iter=1000)
    elif model_name == '1D CNN with MC Dropout':
        model = CNNRegressor(input_dim=input_dim, dropout_rate=0.1)
    elif model_name == 'XGBoost':
        model = xgb.XGBRegressor(n_estimators=100)
    elif model_name == 'LightGBM':
        model = lgb.LGBMRegressor(n_estimators=100)
    elif model_name == 'SVR':
        model = SVR()
    elif model_name == 'Bayesian Ridge':
        model = BayesianRidge()
    elif model_name == 'ARD Regression':
        model = ARDRegression()
    elif model_name == 'Gradient Boosting':
        model = GradientBoostingRegressor(n_estimators=100)
    else:
        model = GaussianProcessRegressor()
    return model

# Function to estimate uncertainty
def estimate_uncertainty(model, X, model_name):
    """
    Estimates the mean and uncertainty (standard deviation) of the model's predictions.
    """
    if model_name == 'Gaussian Process':
        mu, sigma = model.predict(X, return_std=True)
    elif model_name == 'Random Forest' or model_name == 'Extra Trees':
        # Get predictions from all trees
        all_preds = np.array([tree.predict(X) for tree in model.estimators_])
        mu = np.mean(all_preds, axis=0)
        sigma = np.std(all_preds, axis=0)
    elif model_name in ['Neural Network', 'Deep Neural Network']:
        # Use Monte Carlo Dropout for uncertainty estimation
        T = 10  # Number of forward passes
        model.set_params(alpha=0.0001)  # To simulate dropout
        preds = []
        for _ in range(T):
            preds.append(model.predict(X))
        mu = np.mean(preds, axis=0)
        sigma = np.std(preds, axis=0)
    elif model_name == '1D CNN with MC Dropout':
        # MC Dropout in PyTorch model
        model.train()  # Enable dropout
        T = 10
        preds = []
        X_tensor = torch.FloatTensor(X)
        for _ in range(T):
            preds.append(model(X_tensor).detach().numpy().squeeze())
        mu = np.mean(preds, axis=0)
        sigma = np.std(preds, axis=0)
    elif model_name in ['XGBoost', 'LightGBM', 'Gradient Boosting']:
        # Approximate uncertainty using predictions from multiple models
        # For simplicity, set sigma to a small constant value
        mu = model.predict(X)
        sigma = np.full_like(mu, 1e-6)
    elif model_name in ['Bayesian Ridge', 'ARD Regression']:
        mu, sigma = model.predict(X, return_std=True)
    elif model_name == 'SVR':
        mu = model.predict(X)
        sigma = np.full_like(mu, 1e-6)
    else:
        mu = model.predict(X)
        sigma = np.full_like(mu, 1e-6)
    return mu, sigma

# Expected Improvement
def expected_improvement(mu, sigma, Y_best, xi=0.01):
    with np.errstate(divide='warn'):
        imp = mu - Y_best - xi
        Z = imp / sigma
        ei = imp * norm.cdf(Z) + sigma * norm.pdf(Z)
        ei[sigma == 0.0] = 0.0
    return ei

# Upper Confidence Bound
def upper_confidence_bound(mu, sigma, kappa=2.576):
    return mu + kappa * sigma

# Probability of Improvement
def probability_of_improvement(mu, sigma, Y_best, xi=0.01):
    with np.errstate(divide='warn'):
        Z = (mu - Y_best - xi) / sigma
        pi = norm.cdf(Z)
    return pi

# Thompson Sampling
def thompson_sampling(mu, sigma):
    return np.random.normal(mu, sigma)

# Greedy Acquisition
def greedy_acquisition(mu):
    return mu

# eUCB Acquisition Function
def calculate_eUCB(mu, sigma):
    """
    eUCB (Exploration-weighted UCB) acquisition function.
    Balances exploration and exploitation by incorporating uncertainty.
    """
    eucb = mu + 2 * sigma
    return eucb

# Function to optimize sequences using various acquisition functions
def optimize_sequences(sequences, fitness, encoding_method, surrogate_model_name,
                       acquisition_function_name, kernel_name='RBF', num_mutations = 5,
                       iterations=10, batch_size=3, target_sequence=None,
                       user_sequences=None, user_fitness=None, sampling_strategy='Random',
                       train_vae_flag=False, kappa=2.576, xi=0.01, exploration_type='Fixed'):
    """
    Optimize sequences using the selected acquisition function.
    """
    # Initialize observed data
    if user_sequences is not None and user_fitness is not None:
        observed_sequences = user_sequences
        observed_fitness = user_fitness
    else:
        observed_sequences = sequences.copy()
        observed_fitness = fitness.copy()

    # Encoding method
    if encoding_method == 'One-Hot':
        encode_sequences = one_hot_encode_sequences
    elif encoding_method == 'AAIndex':
        encode_sequences = aaindex_encode_sequences
    elif encoding_method == 'ESM2':
        encode_sequences = esm2_encode_sequences
    elif encoding_method == 'VAE':
        # Train VAE on observed sequences
        if train_vae_flag:
            print("Training VAE...")
            global vae_model
            vae_model = train_vae(observed_sequences, encoding_dim=latent_dim_widget.value)
        encode_sequences = lambda seqs: vae_encode_sequences(seqs, vae_model)
    else:
        encode_sequences = one_hot_encode_sequences

    # Initialize cumulative data for plotting
    cumulative_embeddings = []
    cumulative_fitness = []
    cumulative_iterations = []

    # Encode initial observed sequences
    X_observed = encode_sequences(observed_sequences)
    Y_observed = observed_fitness

    # Fit PCA on initial data
    try:
        pca = PCA(n_components=2)
        reduced_embeddings = pca.fit_transform(X_observed)
    except:
        reduced_embeddings = X_observed[:, :2]

    # Initialize cumulative data
    cumulative_embeddings.extend(reduced_embeddings)
    cumulative_fitness.extend(Y_observed)
    cumulative_iterations.extend([0] * len(Y_observed))

    # Prepare for animation
    embedding_history = []
    fitness_history = []
    iteration_history = []

    # Main optimization loop
    max_fitness_over_iterations = []
    max_fitness_over_iterations.append(np.max(observed_fitness))

    for iteration in range(iterations):
        # Train surrogate model
        regressor = select_surrogate_model(surrogate_model_name, kernel_name)
        regressor_model_name = surrogate_model_name
        if surrogate_model_name == '1D CNN with MC Dropout':
            # Custom training for CNN in PyTorch
            regressor = train_cnn_regressor(regressor, X_observed, Y_observed)
        else:
            regressor.fit(X_observed, Y_observed)

        # Generate candidate sequences   
        if user_sequences is None or user_fitness is None:
            candidate_sequences = generate_candidate_sequences(observed_sequences, 10, num_mutations)
            candidate_sequences = list(set(candidate_sequences) - set(observed_sequences))
        else:
            candidate_sequences = list(set(sequences) - set(observed_sequences))
        
        # Encode candidate sequences
        X_candidates = encode_sequences(candidate_sequences)

        # Predict fitness and uncertainty for candidates
        mu, sigma = estimate_uncertainty(regressor, X_candidates, surrogate_model_name)

        # Adjust explore-exploit parameters
        if exploration_type == 'Adaptive':
            xi = xi / (iteration + 1)
            kappa = kappa / (iteration + 1)

        # Calculate acquisition scores
        if acquisition_function_name == 'Expected Improvement':
            Y_best = np.max(Y_observed)
            acquisition_values = expected_improvement(mu, sigma, Y_best, xi=xi)
        elif acquisition_function_name == 'Probability of Improvement':
            Y_best = np.max(Y_observed)
            acquisition_values = probability_of_improvement(mu, sigma, Y_best, xi=xi)
        elif acquisition_function_name == 'Upper Confidence Bound':
            acquisition_values = upper_confidence_bound(mu, sigma, kappa=kappa)
        elif acquisition_function_name == 'Thompson Sampling':
            acquisition_values = thompson_sampling(mu, sigma)
        elif acquisition_function_name == 'Greedy':
            acquisition_values = greedy_acquisition(mu)
        elif acquisition_function_name == 'eUCB':
            acquisition_values = calculate_eUCB(mu, sigma)
        else:
            acquisition_values = expected_improvement(mu, sigma, np.max(Y_observed), xi=xi)

        # Select sequences with highest acquisition scores
        idx_top = np.argsort(-acquisition_values)[:batch_size]
        sequences_to_evaluate = [candidate_sequences[i] for i in idx_top]

        # Evaluate fitness of selected sequences
        if target_sequence is not None:
            fitness_evaluated = compute_fitness(sequences_to_evaluate, target_sequence)
        elif user_sequences is not None and user_fitness is not None:
            # Here, we match with the ground-truth data
            fitness_evaluated = np.array([fitness[i] for i in idx_top])
        else:
            fitness_evaluated = np.array([fitness[i] for i in idx_top])


        # Add selected sequences to observed data
        observed_sequences.extend(sequences_to_evaluate)
        observed_fitness = np.concatenate([observed_fitness, fitness_evaluated])

        # Update embeddings and apply PCA transformation
        X_new = encode_sequences(sequences_to_evaluate)
        reduced_new_embeddings = pca.transform(X_new)

        # Update cumulative data
        cumulative_embeddings.extend(reduced_new_embeddings)
        cumulative_fitness.extend(fitness_evaluated)
        cumulative_iterations.extend([iteration + 1] * len(sequences_to_evaluate))

        # Update X_observed and Y_observed for next iteration
        X_observed = np.vstack((X_observed, X_new))
        Y_observed = observed_fitness

        # Update max fitness
        max_fitness_over_iterations.append(np.max(observed_fitness))

        # Append data for animation
        embedding_history.append(np.array(cumulative_embeddings.copy()))
        fitness_history.append(np.array(cumulative_fitness.copy()))
        iteration_history.append(np.array(cumulative_iterations.copy()))

    # Create animation of optimization progress
    create_sequence_animation(embedding_history, fitness_history, iteration_history, cumulative_iterations)

    # Plot starting seeds vs acquired sequences
    plot_seed_vs_acquired(cumulative_embeddings, cumulative_fitness, cumulative_iterations)

    # Correlation between true and predicted fitness with uncertainty
    plot_fitness_correlation(observed_fitness, observed_sequences, regressor, encode_sequences, surrogate_model_name)

    # Visualize Fitness Landscape if possible
    if input_dim <= 2:
        visualize_fitness_landscape(regressor, encode_sequences)

    return observed_sequences, observed_fitness, max_fitness_over_iterations

# Function to create sequence animation
def create_sequence_animation(embedding_history, fitness_history, iteration_history, cumulative_iterations):
    fig = plt.figure(figsize=(10,6))
    ax = fig.add_subplot(111, projection='3d')

    # Create colorbar outside the animate function
    norm = plt.Normalize(min(cumulative_iterations), max(cumulative_iterations))
    mappable = plt.cm.ScalarMappable(cmap='viridis', norm=norm)
    mappable.set_array([])  # Required for ScalarMappable
    cbar = fig.colorbar(mappable, ax=ax, pad=0.1)
    cbar.set_label('Iteration Acquired')

    def animate(i):
        ax.clear()
        ax.set_title(f'Sequence Optimization Progress - Iteration {i+1}')
        ax.set_xlabel('Embedding Dimension 1')
        ax.set_ylabel('Embedding Dimension 2')
        ax.set_zlabel('Fitness')
        embeddings = embedding_history[i]
        fitness = fitness_history[i]
        iterations = iteration_history[i]
        sc = ax.scatter(embeddings[:,0], embeddings[:,1], fitness, c=iterations, cmap='viridis', edgecolor='k', s=50)
        ax.set_xlim(np.min(embeddings[:,0])-1, np.max(embeddings[:,0])+1)
        ax.set_ylim(np.min(embeddings[:,1])-1, np.max(embeddings[:,1])+1)
        ax.set_zlim(np.min(fitness)-1, np.max(fitness)+1)
        return sc,

    ani = FuncAnimation(fig, animate, frames=len(embedding_history), interval=1000, blit=False)
    plt.close(fig)
    display(HTML(ani.to_jshtml()))

# Function to plot starting seeds vs acquired sequences
def plot_seed_vs_acquired(embeddings, fitness, iterations):
    fig = plt.figure(figsize=(10,6))
    ax = fig.add_subplot(111, projection='3d')
    sc = ax.scatter(np.array(embeddings)[:,0], np.array(embeddings)[:,1], fitness, c=iterations, cmap='viridis', edgecolor='k', s=50)
    cbar = plt.colorbar(sc, ax=ax, pad=0.1)
    cbar.set_label('Iteration Acquired')
    ax.set_title('Sequences in Embedding Space with Fitness')
    ax.set_xlabel('Embedding Dimension 1')
    ax.set_ylabel('Embedding Dimension 2')
    ax.set_zlabel('Fitness')
    plt.show()

# Function to plot correlation between true and predicted fitness with uncertainty
def plot_fitness_correlation(true_fitness, observed_sequences, regressor, encode_sequences, model_name):
    X_observed = encode_sequences(observed_sequences)
    mu, sigma = estimate_uncertainty(regressor, X_observed, model_name)
    # Ensure lengths match
    min_length = min(len(true_fitness), len(mu))
    true_fitness = true_fitness[:min_length]
    mu = mu[:min_length]
    sigma = sigma[:min_length]
    # Compute R^2 score
    r2 = r2_score(true_fitness, mu)
    plt.figure(figsize=(8,6))
    plt.errorbar(true_fitness, mu, yerr=sigma, fmt='o', ecolor='#F5A33C', color='#6083B4', alpha=0.7, capsize=3)
    plt.plot([min(true_fitness), max(true_fitness)], [min(true_fitness), max(true_fitness)], 'k--', lw=2)
    plt.xlabel('True Fitness')
    plt.ylabel('Predicted Fitness')
    plt.title(f'True vs Predicted Fitness with Uncertainty\n$R^2$ = {r2:.2f}')
    plt.grid(True)
    plt.show()

# Function to visualize fitness landscape (only for low-dimensional cases)
def visualize_fitness_landscape(regressor, encode_sequences):
    if input_dim > 2:
        print("Fitness landscape visualization not available for high-dimensional input.")
        return
    # Generate grid
    x = np.linspace(0, 1, 50)
    y = np.linspace(0, 1, 50)
    xx, yy = np.meshgrid(x, y)
    grid = np.c_[xx.ravel(), yy.ravel()]
    mu, _ = estimate_uncertainty(regressor, grid, surrogate_model_widget.value)
    mu = mu.reshape(xx.shape)
    plt.figure(figsize=(8,6))
    plt.contourf(xx, yy, mu, cmap='viridis')
    plt.colorbar(label='Predicted Fitness')
    plt.title('Fitness Landscape')
    plt.xlabel('Feature 1')
    plt.ylabel('Feature 2')
    plt.show()

# Additional functions for the optimization process
def mutate_sequence(seq, num_mutations=1):
    seq_list = list(seq)
    for _ in range(num_mutations):
        idx = np.random.randint(len(seq_list))
        aa = np.random.choice(amino_acids)
        seq_list[idx] = aa
    return ''.join(seq_list)

def generate_candidate_sequences(sequences, num_candidates_per_seq, num_mutations=1):
    candidate_sequences = []
    for seq in sequences:
        for _ in range(num_candidates_per_seq):
            mutated_seq = mutate_sequence(seq, num_mutations)
            candidate_sequences.append(mutated_seq)
    return candidate_sequences

# Function to train CNN Regressor
def train_cnn_regressor(model, X_train, y_train):
    model.train()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    criterion = nn.MSELoss()
    epochs = 10
    batch_size = 32
    dataset = torch.utils.data.TensorDataset(torch.FloatTensor(X_train), torch.FloatTensor(y_train))
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    for epoch in range(epochs):
        for data in dataloader:
            inputs, targets = data
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs.squeeze(), targets)
            loss.backward()
            optimizer.step()
    return model

# Define CNN Regressor with MC Dropout
class CNNRegressor(nn.Module):
    def __init__(self, input_dim, dropout_rate=0.1):
        super(CNNRegressor, self).__init__()
        self.conv1 = nn.Conv1d(in_channels=1, out_channels=16, kernel_size=3)
        self.conv2 = nn.Conv1d(in_channels=16, out_channels=32, kernel_size=3)
        self.fc1 = nn.Linear((input_dim - 4) * 32, 100)
        self.dropout = nn.Dropout(dropout_rate)
        self.fc2 = nn.Linear(100, 1)
    
    def forward(self, x):
        x = x.unsqueeze(1)  # Add channel dimension
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = x.view(x.size(0), -1)  # Flatten
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x

# Widgets for parameters
style = {'description_width': 'initial'}
layout = Layout(width='70%')

def create_bold_label(text):
    return widgets.HTML(value=f"<b>{text}</b>")

# Section 1: Dataset and Sequences
dataset_label = create_bold_label("Dataset and Sequences")
upload_button = widgets.FileUpload(
    accept='.csv',
    multiple=False,
    description='Upload CSV',
    style=style,
    layout=layout,
    tooltip='Upload your dataset as a CSV file with "sequence" and "fitness" columns.'
)
upload_expl = Label("Upload your CSV file with 'sequence' and 'fitness' columns.")

num_sequences_widget = widgets.IntSlider(
    min=1, max=100, step=1, value=10,
    description='Number of Initial Sequences:',
    style=style,
    layout=layout,
    tooltip='Set the number of initial sequences from the user-provided dataset.'
)

sequence_length_widget = widgets.IntSlider(
    min=5, max=50, step=1, value=10,
    description='Sequence Length:',
    style=style,
    layout=layout,
    tooltip='Set the length of the protein sequences.'
)

starting_sequences_widget = widgets.Dropdown(
    options=['Random', 'Low Fitness Similar', 'Diverse Across Fitness', 'User-Provided'],
    value='User-Provided',
    description='Starting Sequences:',
    style=style,
    layout=layout,
    tooltip='Choose the strategy for starting sequences.'
)

starting_sequences_expl = Label("Choose the strategy for starting sequences.")

# Section 2: Encoding and Models
encoding_label = create_bold_label("Encoding and Models")
encoding_method_widget = widgets.Dropdown(
    options=['One-Hot', 'AAIndex', 'ESM2', 'VAE'],
    value='One-Hot',
    description='Encoding Method:',
    style=style,
    layout=layout,
    tooltip='Select the sequence encoding method.'
)
encoding_expl = Label("Select the sequence encoding method.")

latent_dim_widget = widgets.IntSlider(
    min=2, max=20, step=1, value=10,
    description='VAE Latent Dimension:',
    style=style,
    layout=layout,
    tooltip='Set the latent dimension for VAE.'
)
latent_dim_expl = Label("Set the latent dimension for VAE.")

surrogate_model_widget = widgets.Dropdown(
    options=['Gaussian Process', 'Random Forest', 'Extra Trees', 'Neural Network', 'Deep Neural Network', '1D CNN with MC Dropout', 'XGBoost', 'LightGBM', 'SVR', 'Bayesian Ridge', 'ARD Regression', 'Gradient Boosting'],
    value='Gaussian Process',
    description='Regressor Model:',
    style=style,
    layout=layout,
    tooltip='Choose the surrogate regressor model.'
)
surrogate_expl = Label("Choose the surrogate regressor model.")

kernel_widget = widgets.Dropdown(
    options=['RBF', 'Matern', 'RationalQuadratic', 'DotProduct'],
    value='RBF',
    description='Kernel Function:',
    style=style,
    layout=layout,
    tooltip='Select the kernel function for Gaussian Process.'
)
kernel_expl = Label("Select the kernel function for Gaussian Process.")

# Section 3: Optimization Settings
optimization_label = create_bold_label("Optimization Settings")
acquisition_function_widget = widgets.Dropdown(
    options=['Expected Improvement', 'Probability of Improvement', 'Upper Confidence Bound', 'Thompson Sampling', 'Greedy', 'eUCB'],
    value='Expected Improvement',
    description='Acquisition Function:',
    style=style,
    layout=layout,
    tooltip='Select the acquisition function for optimization.'
)
acquisition_expl = Label("Select the acquisition function for optimization.")

iterations_widget = widgets.IntSlider(
    min=1, max=50, step=1, value=10,
    description='Number of Iterations:',
    style=style,
    layout=layout,
    tooltip='Set the number of optimization iterations.'
)
iterations_expl = Label("Set the number of optimization iterations.")

batch_size_widget = widgets.IntSlider(
    min=1, max=10, step=1, value=3,
    description='Batch Size:',
    style=style,
    layout=layout,
    tooltip='Number of sequences selected per iteration.'
)
batch_size_expl = Label("Number of sequences selected per iteration.")

num_mutations_widget = widgets.IntSlider(
    min=1, max=5, step=1, value=1,
    description='Number of Mutations:',
    style=style,
    layout=layout,
    tooltip='Number of mutations applied to generate new sequences.'
)
mutations_expl = Label("Number of mutations applied to generate new sequences.")

# Explore-Exploit Parameters
exploration_label = create_bold_label("Explore-Exploit Parameters")
kappa_widget = widgets.FloatSlider(
    min=0.0, max=10.0, step=0.1, value=2.576,
    description='Kappa (UCB):',
    style=style,
    layout=layout,
    tooltip='Controls exploration-exploitation trade-off in UCB.'
)
kappa_expl = Label("Controls exploration-exploitation trade-off in UCB.")

xi_widget = widgets.FloatSlider(
    min=0.0, max=1.0, step=0.01, value=0.01,
    description='Xi (EI/PI):',
    style=style,
    layout=layout,
    tooltip='Controls exploration in EI and PI.'
)
xi_expl = Label("Controls exploration in EI and PI.")

exploration_type_widget = widgets.Dropdown(
    options=['Fixed', 'Adaptive'],
    value='Fixed',
    description='Exploration Type:',
    style=style,
    layout=layout,
    tooltip='Choose whether exploration parameters are fixed or adaptive.'
)
exploration_type_expl = Label("Choose whether exploration parameters are fixed or adaptive.")

# Run Button
run_button = Button(description="Run Optimization", button_style='success', layout=Layout(width='30%', height='40px'))

# Output Area
output_area = widgets.Output()

# Function to handle file upload
global uploaded_sequences, uploaded_fitness

uploaded_sequences = None
uploaded_fitness = None
def handle_upload(change):
    global uploaded_sequences, uploaded_fitness
    if change.new:  # This checks if a new file has been uploaded
        try:
            uploaded_file = change.new[0]  # Get the first (and usually only) uploaded file
            content = uploaded_file['content']
            
            # Create a StringIO object from the content
            string_io = io.StringIO(content.decode('utf-8'))
            
            # Read the CSV from the StringIO object
            df = pd.read_csv(string_io)
            
            if 'sequence' in df.columns and 'fitness' in df.columns:
                uploaded_sequences = df['sequence'].tolist()
                uploaded_fitness = df['fitness'].values
                print("CSV file loaded successfully.")
            else:
                print("CSV file must contain 'sequence' and 'fitness' columns.")
        except Exception as e:
            print(f"Error loading CSV file: {str(e)}")
    else:
        print("No file uploaded.")
upload_button.observe(handle_upload, names='value')
print(uploaded_sequences)

# Function to run optimization
def run_optimization(b):
    with output_area:
        clear_output()
        # Prepare sequences and fitness
        if starting_sequences_widget.value == 'Random':
            sequences = generate_sequences(num_sequences_widget.value, sequence_length_widget.value)
        elif starting_sequences_widget.value == 'Low Fitness Similar':
            target_sequence = generate_sequences(1, sequence_length_widget.value)[0]
            sequences = generate_similar_sequences(target_sequence, num_sequences_widget.value, low_fitness=True)
        elif starting_sequences_widget.value == 'Diverse Across Fitness':
            sequences = generate_diverse_sequences(num_sequences_widget.value, sequence_length_widget.value)
        elif starting_sequences_widget.value == 'User-Provided':
            uploaded_file = upload_button.value[0]
            name = uploaded_file["name"]
            with open(name, "wb") as fp:
                fp.write(uploaded_file.content)
            df = pd.read_csv(os.path.join(os.getcwd(), name))
        
                
            if 'sequence' in df.columns and 'fitness' in df.columns:
                uploaded_sequences = df['sequence'].tolist()
                uploaded_fitness = df['fitness'].values
            
            if uploaded_sequences is not None and uploaded_fitness is not None:
                sequences = uploaded_sequences
                fitness = uploaded_fitness
                target_sequence = None
                indices = np.random.choice(len(uploaded_sequences), num_sequences_widget.value, replace=False)
                user_sequences = [uploaded_sequences[i] for i in indices]
                user_fitness = [uploaded_fitness[i] for i in indices]
                print(f"Max fitness found in the starting sequences {np.max(user_fitness)}")
                print(f"Max fitness in the dataset {np.max(uploaded_fitness)}")
            else:
                print("Please upload a valid CSV file.")
                return
        else:
            print("Invalid starting sequence option.")
            return

        if starting_sequences_widget.value != 'User-Provided':
            target_sequence = generate_sequences(1, sequence_length_widget.value)[0]
            fitness = compute_fitness(sequences, target_sequence)
            print(f"Target sequence (for simulation): {target_sequence}")
            print(f"Fitness of target sequence: {sequence_length_widget.value}")
            user_sequences = None
            user_fitness = None
        # Set input_dim for CNN and VAE
        global input_dim
        if encoding_method_widget.value == 'One-Hot':
            input_dim = len(amino_acids) * sequence_length_widget.value
        elif encoding_method_widget.value == 'AAIndex':
            input_dim = 5 * sequence_length_widget.value  # 5 properties per amino acid
        elif encoding_method_widget.value == 'ESM2':
            load_esm2_model()
            input_dim = esm2_model.config.hidden_size
        elif encoding_method_widget.value == 'VAE':
            input_dim = len(amino_acids) * sequence_length_widget.value  # Input dimension for VAE

        # Train VAE if selected
        train_vae_flag = encoding_method_widget.value == 'VAE'

        # Run optimization using selected acquisition function
        observed_sequences, observed_fitness, max_fitness_over_iterations = optimize_sequences(
            sequences, fitness, encoding_method_widget.value, surrogate_model_widget.value,
            acquisition_function_widget.value, kernel_name=kernel_widget.value, num_mutations = num_mutations_widget.value,
            iterations=iterations_widget.value, batch_size=batch_size_widget.value, target_sequence=target_sequence,
            user_sequences=user_sequences, user_fitness=user_fitness, sampling_strategy=starting_sequences_widget.value,
            train_vae_flag=train_vae_flag, kappa=kappa_widget.value, xi=xi_widget.value,
            exploration_type=exploration_type_widget.value)

        # Plotting optimization progress
        plt.figure(figsize=(10,6))
        plt.plot(range(len(max_fitness_over_iterations)), max_fitness_over_iterations, marker='o', color=color_palette[1])
        plt.xlabel('Iteration')
        plt.ylabel('Maximum Fitness Observed')
        plt.title('Optimization Progress')
        plt.grid(True)
        plt.show()

        # Show top sequences
        top_indices = np.argsort(-observed_fitness)[:5]
        print("Top sequences found:")
        for idx in top_indices:
            print(f"Sequence: {observed_sequences[idx]}, Fitness: {observed_fitness[idx]}")

        # Visualize ground-truth fitness distribution (if simulation)
        if target_sequence is not None:
            sequence_space = generate_sequences(500, sequence_length_widget.value)
            fitness_space = compute_fitness(sequence_space, target_sequence)
            plt.figure(figsize=(10,6))
            plt.hist(fitness_space, bins=20, color=color_palette[0], alpha=0.7)
            plt.xlabel('Fitness')
            plt.ylabel('Frequency')
            plt.title('Ground-Truth Fitness Distribution')
            plt.grid(True)
            plt.show()

# Functions for generating starting sequences
def generate_similar_sequences(target_sequence, num_sequences, low_fitness=False):
    sequences = []
    for _ in range(num_sequences):
        seq = mutate_sequence(target_sequence, num_mutations=sequence_length_widget.value//2 if low_fitness else 1)
        sequences.append(seq)
    return sequences

def generate_diverse_sequences(num_sequences, sequence_length):
    sequences = set()
    while len(sequences) < num_sequences:
        seq = ''.join(np.random.choice(amino_acids, sequence_length))
        sequences.add(seq)
    return list(sequences)

run_button.on_click(run_optimization)

# Organize UI
ui = VBox([
    dataset_label,
    HBox([upload_button, upload_expl]),
    HBox([starting_sequences_widget, starting_sequences_expl]),
    HBox([num_sequences_widget, sequence_length_widget]),
    encoding_label,
    HBox([encoding_method_widget, encoding_expl]),
    HBox([latent_dim_widget, latent_dim_expl]),
    HBox([surrogate_model_widget, surrogate_expl]),
    HBox([kernel_widget, kernel_expl]),
    optimization_label,
    HBox([acquisition_function_widget, acquisition_expl]),
    HBox([iterations_widget, iterations_expl]),
    HBox([batch_size_widget, batch_size_expl]),
    HBox([num_mutations_widget, mutations_expl]),
    exploration_label,
    HBox([exploration_type_widget, exploration_type_expl]),
    HBox([kappa_widget, kappa_expl]),
    HBox([xi_widget, xi_expl]),
    run_button,
    output_area
])

display(ui)


None


VBox(children=(HTML(value='<b>Dataset and Sequences</b>'), HBox(children=(FileUpload(value=(), accept='.csv', …