# Lyra Colab instructions

<details>
  <summary>General Usage</summary>


**We note to the user that Lyra runs significantly faster on a CUDA-enabled GPU version of colab, which can be accessed in "Runtime" → "Change Runtime Type" → "T4 GPU"**



## General usage

### 1. Data Format
Your input CSV files should contain at minimum:
1. A 'seq' column with the sequences
2. A target column with your prediction values

### 2. Configure your model and task
Use the form to select:
1. Sequence Type (RNA/DNA/Protein)
2. Task Type (Regression/Classification)
3. Data Files (train.csv, val.csv, test.csv)
4. Label Column Name

We note that by default this Colab notebook uses a train-val-test split, wherein
the validation dataset is used after every epoch, and the best-performing model
on the validation set will be saved and used on the test split.

### 3. Press "Play" to train Lyra!

### 4. (Optional) Use the saved Lyra model to make predictions on new data
Ultimately, we want to use the trained Lyra model to make predictions on new data. The previous step saves the best trained Lyra model to "best_model.pt"; this next cell takes in an input file with a column 'seq' and automatically uses the best model to make predictions on this new data. All you have to do is enter the input file name and press "Play".
</details>


<details>
  <summary>Advanced usage</summary>


## Advanced usage
### 1. For protein sequences with non-natural amino acids
The default code operates on the 20 natural amino acids. To use a longer code,
you must modify the encoding:
1. Update the mapping array in one_hot_encode_protein()
2. Add new rows to the mapping array for each additional amino acid
3. Update char_to_int dictionary with new characters
4. Update d_input in the Lyra model initialization accordingly

### 2. Hyperparameter tuning
The default version of Lyra works well in most use cases. Two two most common
hyperparameters we find ourselves tuning are (1) epochs (2) dropout. The default
model setup is for 50 epochs, which is sufficient for most tasks, and a dropout
rate of 0.2. While internal model dimensions can be changed, and different PGC
and S4D dimensions and counts can be used, in most cases we do not see significant improvement as compared to our default Lyra settings.
</details>

In [4]:
#@title Load example data (optional)
import gdown
import shutil
import os

# Data for APA Isoform prediction task from Bogard’s dataset
# https://pubmed.ncbi.nlm.nih.gov/31178116/
# Data collected and organized in the BEACON dataset paper:
# https://arxiv.org/abs/2406.10391

folder_url = "https://drive.google.com/drive/folders/1NcEJZ_9C22R-XZT1ZMkTtYOiI0ik0jOj"
gdown.download_folder(folder_url, quiet=True, use_cookies=False)

# Move downloaded files one level up
for file_name in ['test.csv', 'train.csv', 'val.csv']:
    shutil.move(os.path.join('Isoform/', file_name), os.path.join('', file_name))


if False: # enable this if you want to test classification data
  # Load the train, validation, and test files
  train_df = pd.read_csv('train.csv')
  val_df = pd.read_csv('val.csv')
  test_df = pd.read_csv('test.csv')


  # Calculate the 33% and 66% percentiles from the training data
  percentile_33 = train_df['proximal_isoform_proportion'].quantile(0.33)
  percentile_66 = train_df['proximal_isoform_proportion'].quantile(0.66)


  # Function to add classification column based on the percentiles
  def add_classification(df, p33, p66):
      conditions = [
          (df['proximal_isoform_proportion'] < p33),
          (df['proximal_isoform_proportion'] >= p33) & (df['proximal_isoform_proportion'] < p66),
          (df['proximal_isoform_proportion'] >= p66)
      ]
      values = [0, 1, 2]
      df['class'] = np.select(conditions, values)
      return df

  # Add classification column to all datasets using the same cutoffs
  train_df = add_classification(train_df, percentile_33, percentile_66)
  val_df = add_classification(val_df, percentile_33, percentile_66)
  test_df = add_classification(test_df, percentile_33, percentile_66)

  # Save the modified datasets
  train_df.to_csv('train_classified.csv', index=False)
  val_df.to_csv('val_classified.csv', index=False)
  test_df.to_csv('test_classified.csv', index=False)
  # print("\nSaved classified datasets as train_classified.csv, val_classified.csv, and test_classified.csv")



In [6]:
#@title Configure Run (press play to the left, then fill form and press "Initialize task")
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
import math
from einops import rearrange, repeat
from scipy.stats import spearmanr
torch.manual_seed(42)
import pandas as pd
from tqdm.auto import tqdm
device = torch.device('cuda' if torch.cuda.is_available() else "cpu")

import torch
import torch.nn as nn
import torch.nn.functional as F

initialized = False

class PGC(nn.Module):
    def __init__(self,d_model,expansion_factor = 1.0,dropout = 0.0):
        super().__init__()

        self.d_model = d_model
        self.expansion_factor = expansion_factor
        self.dropout = dropout
        expanded_dim = int(d_model * expansion_factor)

        self.conv = nn.Conv1d(expanded_dim,
                              expanded_dim,
                              kernel_size=3,
                              padding=1,
                              groups=expanded_dim)

        self.in_proj = nn.Linear(d_model, int(d_model * expansion_factor * 2))
        self.out_norm = nn.RMSNorm(int(d_model), eps=1e-8)
        self.in_norm = nn.RMSNorm(expanded_dim * 2, eps=1e-8)
        self.out_proj = nn.Linear(expanded_dim, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, u):
        # Input projection and normalization
        xv = self.in_norm(self.in_proj(u))

        # Split projected input into two parts: x and v
        x, v = xv.chunk(2, dim=-1)

        # Depthwise convolution on x
        x_conv = self.conv(x.transpose(-1, -2)).transpose(-1, -2)

        # Gating mechanism
        gate = v * x_conv

        # Output projection and normalization
        x_out = self.out_norm(self.out_proj(gate))

        return x_out


class Lyra(nn.Module):
    def __init__(self, d_input, d_output,d_model, d_state=64, dropout=0.2, transposed=False, **kernel_args):
        super().__init__()
        self.encoder = nn.Linear(d_input, d_model)
        self.pgc1 = PGC(d_model, expansion_factor=0.25, dropout=dropout)
        self.pgc2 = PGC(d_model, expansion_factor=2, dropout=dropout)
        self.s4d = S4D(d_model, d_state=d_state, dropout=dropout, transposed=transposed, **kernel_args)
        self.norm = nn.RMSNorm(d_model)
        self.decoder = nn.Linear(d_model, d_output)
        self.dropout = nn.Dropout(dropout)

    def forward(self, u):
        x = self.encoder(u)
        x = self.pgc1(x)
        x = self.pgc2(x)
        z = x
        z = self.norm(z)
        x = self.dropout(self.s4d(z)) + x
        x = x.mean(dim=1)
        #x = self.dropout(x)
        x = self.decoder(x)
        return x

class DropoutNd(nn.Module):
    def __init__(self, p: float = 0.5, tie=True, transposed=True):
        """
        tie: tie dropout mask across sequence lengths (Dropout1d/2d/3d)
        """
        super().__init__()
        if p < 0 or p >= 1:
            raise ValueError("dropout probability has to be in [0, 1), " "but got {}".format(p))
        self.p = p
        self.tie = tie
        self.transposed = transposed
        self.binomial = torch.distributions.binomial.Binomial(probs=1-self.p)

    def forward(self, X):
        """X: (batch, dim, lengths...)."""
        if self.training:
            if not self.transposed: X = rearrange(X, 'b ... d -> b d ...')
            # binomial = torch.distributions.binomial.Binomial(probs=1-self.p) # This is incredibly slow because of CPU -> GPU copying
            mask_shape = X.shape[:2] + (1,)*(X.ndim-2) if self.tie else X.shape
            # mask = self.binomial.sample(mask_shape)
            mask = torch.rand(*mask_shape, device=X.device) < 1.-self.p
            X = X * mask * (1.0/(1-self.p))
            if not self.transposed: X = rearrange(X, 'b d ... -> b ... d')
            return X
        return X

class S4DKernel(nn.Module):
    """Generate convolution kernel from diagonal SSM parameters."""

    def __init__(self, d_model, N=64, dt_min=0.001, dt_max=0.1, lr=None):
        super().__init__()
        # Generate dt
        H = d_model
        log_dt = torch.rand(H) * (
            math.log(dt_max) - math.log(dt_min)
        ) + math.log(dt_min)

        C = torch.randn(H, N // 2, dtype=torch.cfloat)
        self.C = nn.Parameter(torch.view_as_real(C))
        self.register("log_dt", log_dt, lr)

        log_A_real = torch.log(0.5 * torch.ones(H, N//2))
        A_imag = math.pi * repeat(torch.arange(N//2), 'n -> h n', h=H)
        self.register("log_A_real", log_A_real, lr)
        self.register("A_imag", A_imag, lr)

    def forward(self, L):
        """
        returns: (..., c, L) where c is number of channels (default 1)
        """

        # Materialize parameters
        dt = torch.exp(self.log_dt) # (H)
        C = torch.view_as_complex(self.C) # (H N)
        A = -torch.exp(self.log_A_real) + 1j * self.A_imag # (H N)

        # Vandermonde multiplication
        dtA = A * dt.unsqueeze(-1)  # (H N)
        K = dtA.unsqueeze(-1) * torch.arange(L, device=A.device) # (H N L)
        C = C * (torch.exp(dtA)-1.) / A
        K = 2 * torch.einsum('hn, hnl -> hl', C, torch.exp(K)).real

        return K

    def register(self, name, tensor, lr=None):
        """Register a tensor with a configurable learning rate and 0 weight decay"""

        if lr == 0.0:
            self.register_buffer(name, tensor)
        else:
            self.register_parameter(name, nn.Parameter(tensor))

            optim = {"weight_decay": 0.0}
            if lr is not None: optim["lr"] = lr
            setattr(getattr(self, name), "_optim", optim)


class S4D(nn.Module):
    def __init__(self, d_model, d_state=64, dropout=0.0, transposed=True, **kernel_args):
        super().__init__()

        self.h = d_model
        self.n = d_state
        self.d_output = self.h
        self.transposed = transposed

        self.D = nn.Parameter(torch.randn(self.h))
        # SSM Kernel
        self.kernel = S4DKernel(self.h, N=self.n, **kernel_args)
        # Pointwise
        self.activation = nn.GELU()
        dropout_fn = DropoutNd
        self.dropout = dropout_fn(dropout) if dropout > 0.0 else nn.Identity()

        # position-wise output transform to mix features
        self.output_linear = nn.Sequential(
            nn.Conv1d(self.h, 2*self.h, kernel_size=1),
            nn.GLU(dim=-2),
        )

    def forward(self, u, **kwargs): # absorbs return_output and transformer src mask
        """ Input and output shape (B, H, L) """
        if not self.transposed: u = u.transpose(-1, -2)
        L = u.size(-1)
        # Compute SSM Kernel
        k = self.kernel(L=L) # (H L)

        # Convolution
        k_f = torch.fft.rfft(k, n=2*L)  # (H L)
        u_f = torch.fft.rfft(u, n=2*L) # (B H L)
        y = torch.fft.irfft(u_f*k_f, n=2*L)[..., :L] # (B H L)

        # Compute D term in state space equation - essentially a skip connection
        y = y + u * self.D.unsqueeze(-1)

        y = self.dropout(self.activation(y))
        y = self.output_linear(y)
        if not self.transposed: y = y.transpose(-1, -2)
        return y

def one_hot_encode_sequences_RNA(sequences):
    """
    Vectorized one-hot encoding of sequences with padding.

    Args:
    - sequences (list of str): List of sequences (e.g., sgrna or target).

    Returns:
    - Tensor: The one-hot encoded sequences as a batch.
    """
    # Define the one-hot encoding for RNA/DNA bases
    mapping = np.array([
        [1, 0, 0, 0],  # A
        [0, 1, 0, 0],  # T
        [0, 0, 1, 0],  # C
        [0, 0, 0, 1],  # G
        [0.0, 0.0, 0.0, 0.0],  # N (unknown base or padding)
        [0.0, 0.0, 0.0, 0.0]  # X (unknown base or padding)

    ])
    char_to_int = {c: i for i, c in enumerate('ATCGNX')}  # Map each base to an index

    # Find the maximum sequence length
    max_length = max(len(seq) for seq in sequences)

    # Pad sequences with 'N' to the max length
    padded_sequences = [seq.ljust(max_length, 'N') for seq in sequences]

    # Vectorized conversion of sequences to indices
    seq_indices = [[char_to_int[char] for char in seq] for seq in padded_sequences]
    encoded = np.array([mapping[seq] for seq in seq_indices])

    return torch.tensor(encoded, dtype=torch.float32)

# One hot encoding for DNA Datasets

def one_hot_encode_dna(sequences):
    """
    Vectorized one-hot encoding of DNA sequences.

    Args:
    - sequences (list of str): List of DNA sequences.

    Returns:
    - Tensor: The one-hot encoded DNA sequences as a batch.
    """
    # Define the mapping in a vectorized form
    mapping = np.array([
        [1, 0, 0, 0],  # A
        [0, 1, 0, 0],  # C
        [0, 0, 1, 0],  # G
        [0, 0, 0, 1],  # T
        [0.25, 0.25, 0.25, 0.25],  # N (unknown base)
    ])
    char_to_int = {c: i for i, c in enumerate('ACGTN')}

    # Vectorized conversion of sequences to indices
    seq_indices = [[char_to_int.get(char.upper(), 4) for char in seq] for seq in sequences]
    encoded = np.array([mapping[seq] for seq in seq_indices])

    return torch.tensor(encoded, dtype=torch.float32)



# One hot encoding for Protein Datasets

def one_hot_encode_protein(sequences):
    """
    Vectorized one-hot encoding of protein sequences.

    Args:
    - sequences (list of str): List of protein sequences.

    Returns:
    - Tensor: The one-hot encoded DNA sequences as a batch.
    """
    # Define the mapping in a vectorized form
    # NB: depending on the dataset (i.e. if there are non-natural amino acids
    # you may have to change the one-hot encoding mapping)
    mapping = np.array([
        [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],  # I
        [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],  # L
        [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],  # V
        [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],  # F
        [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],  # M
        [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],  # C
        [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],  # A
        [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],  # G
        [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],  # P
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],  # T
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],  # S
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0],  # Y
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0],  # W
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0],  # Q
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0],  # N
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0],  # H
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0],  # E
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0],  # D
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0],  # K
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1],  # R
        [0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05],  # X
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]   # J
    ])
    char_to_int = {c: i for i, c in enumerate('ILVFMCAGPTSYWQNHEDKRXJ')}

    # Vectorized conversion of sequences to indices
    seq_indices = [[char_to_int[char] for char in seq] for seq in sequences]
    encoded = np.array([mapping[seq] for seq in seq_indices])

    return torch.tensor(encoded, dtype=torch.float32)

class RNADataset(Dataset):
    def __init__(self, encoded_sequences, labels):
        """
        Initialize the dataset with preprocessed RNA data.

        Args:
        - encoded_sequences (Tensor): Encoded RNA sequences.
        - labels (Tensor): Corresponding labels.
        """
        self.encoded_sequences = encoded_sequences
        self.labels = labels

    def __len__(self):
        return len(self.encoded_sequences)

    def __getitem__(self, index):
        return self.encoded_sequences[index], torch.tensor(self.labels[index], dtype=torch.float32)


# One hot encoding for DNA Datasets
class DNADataset(Dataset):
    def __init__(self, encoded_sequences, labels):
        """
        Initialize the dataset with preprocessed data.

        Args:
        - encoded_sequences (Tensor): Encoded sequences.
        - labels (Tensor): Corresponding labels.
        """
        self.encoded_sequences = encoded_sequences
        self.labels = labels

    def __len__(self):
        return len(self.encoded_sequences)

    def __getitem__(self, index):
        return self.encoded_sequences[index], torch.tensor(self.labels[index], dtype=torch.float32)


class ProteinDataset(Dataset):
    def __init__(self, encoded_sequences, labels):
        """
        Initialize the dataset with preprocessed data.

        Args:
        - encoded_sequences (Tensor): Encoded sequences.
        - labels (Tensor): Corresponding labels.
        """
        self.encoded_sequences = encoded_sequences
        self.labels = labels

    def __len__(self):
        return len(self.encoded_sequences)

    def __getitem__(self, index):
        return self.encoded_sequences[index], torch.tensor(self.labels[index], dtype=torch.float32)
import torch
from torch.utils.data import Dataset, DataLoader
import pandas as pd
from ipywidgets import widgets
from IPython.display import display, clear_output

def setup_sequence_model(sequence_type):
    """
    Configure the appropriate dataset class and encoder based on sequence type.

    Args:
    - sequence_type (str): One of 'RNA', 'DNA', or 'Protein'

    Returns:
    - tuple: (dataset_class, encoder_function, d_input)
    """
    if sequence_type.upper() == 'RNA':
        return RNADataset, one_hot_encode_sequences_RNA, 4
    elif sequence_type.upper() == 'DNA':
        return DNADataset, one_hot_encode_dna, 4
    elif sequence_type.upper() == 'PROTEIN':
        return ProteinDataset, one_hot_encode_protein, 20
    else:
        raise ValueError("sequence_type must be one of: 'RNA', 'DNA', 'Protein'")

def load_and_prepare_data(train_file, val_file, test_file, label_name, Dataset, encoder):
    """
    Load and prepare data for training.

    Args:
    - train_file (str): Path to training data CSV
    - val_file (str): Path to validation data CSV
    - test_file (str): Path to test data CSV
    - label_name (str): Name of the label column in CSVs
    - Dataset (class): Dataset class to use
    - encoder (function): Encoding function to use

    Returns:
    - tuple: (train_loader, val_loader, test_loader)
    """
    # Load data
    try:
        train_df = pd.read_csv(train_file)
        val_df = pd.read_csv(val_file)
        test_df = pd.read_csv(test_file)
    except FileNotFoundError as e:
        raise FileNotFoundError(f"Could not find one or more data files. Please check the file paths: {str(e)}")
    except Exception as e:
        raise Exception(f"Error loading data files: {str(e)}")

    # Verify label column exists
    if label_name not in train_df.columns:
        raise ValueError(f"Label column '{label_name}' not found in data files. Available columns: {train_df.columns.tolist()}")

    # Encode sequences and prepare labels
    train_seqs = encoder(train_df['seq'].values)
    val_seqs = encoder(val_df['seq'].values)
    test_seqs = encoder(test_df['seq'].values)

    train_labels = train_df[label_name].values
    val_labels = val_df[label_name].values
    test_labels = test_df[label_name].values

    # Create datasets
    train_dataset = Dataset(train_seqs, train_labels)
    val_dataset = Dataset(val_seqs, val_labels)
    test_dataset = Dataset(test_seqs, test_labels)

    # Create dataloaders
    train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

    print(f"\nData loaded successfully:")
    print(f"Training samples: {len(train_dataset)}")
    print(f"Validation samples: {len(val_dataset)}")
    print(f"Test samples: {len(test_dataset)}")

    return train_loader, val_loader, test_loader

def create_setup_form():
    """Creates and returns the setup form widgets"""
    # Create widgets
    sequence_type_dropdown = widgets.Dropdown(
        options=['RNA', 'DNA', 'Protein'],
        value='RNA',
        description='Sequence Type:',
        style={'description_width': 'initial'}
    )

    task_type_dropdown = widgets.Dropdown(
        options=['Regression', 'Classification'],
        value='Regression',
        description='Task Type:',
        style={'description_width': 'initial'}
    )

    num_classes_input = widgets.IntText(
        value=2,
        min=2,
        description='Number of Classes:',
        style={'description_width': 'initial'},
        layout={'visibility': 'hidden'}  # Hidden by default for regression
    )

    train_file_input = widgets.Text(
        value='train.csv',
        description='Training Data File:',
        style={'description_width': 'initial'}
    )

    val_file_input = widgets.Text(
        value='val.csv',
        description='Validation Data File:',
        style={'description_width': 'initial'}
    )

    test_file_input = widgets.Text(
        value='test.csv',
        description='Testing Data File:',
        style={'description_width': 'initial'}
    )

    label_name_input = widgets.Text(
        value='proximal_isoform_proportion',
        description='Label Column Name:',
        style={'description_width': 'initial'}
    )

    # Function to handle task type changes
    def on_task_change(change):
        if change['new'] == 'Classification':
            num_classes_input.layout.visibility = 'visible'
        else:
            num_classes_input.layout.visibility = 'hidden'

    task_type_dropdown.observe(on_task_change, names='value')

    # Button to confirm selections
    confirm_button = widgets.Button(description="Initialize Task")

    def on_button_click(b):
        clear_output()
        # Re-display the form
        display(widgets.VBox([
            sequence_type_dropdown,
            task_type_dropdown,
            num_classes_input,
            train_file_input,
            val_file_input,
            test_file_input,
            label_name_input,
            confirm_button
        ]))

        try:
            # Get configuration
            global sequence_type, task_type
            sequence_type = sequence_type_dropdown.value
            task_type = task_type_dropdown.value
            train_file = train_file_input.value
            val_file = val_file_input.value
            test_file = test_file_input.value
            label_name = label_name_input.value

            # Set up model configuration
            Dataset, encoder, d_input = setup_sequence_model(sequence_type)
            global d_output
            d_output = num_classes_input.value if task_type == 'Classification' else 1

            # Print configuration
            print(f"\nSelected Configuration:")
            print(f"Sequence Type: {sequence_type} (d_input = {d_input})")
            print(f"Task Type: {task_type} (d_output = {d_output})")
            print(f"\nData Files:")
            print(f"Training: {train_file}")
            print(f"Validation: {val_file}")
            print(f"Testing: {test_file}")
            print(f"Label Column: {label_name}")

            # Load and prepare data
            global train_loader, val_loader, test_loader
            print(f"\nLoading data, this may take a while - " +
              "do not re-press \"Initialize Task\"")

            train_loader, val_loader, test_loader = load_and_prepare_data(
                train_file, val_file, test_file, label_name, Dataset, encoder
            )

            # Calculate class weights for classification tasks
            global class_weights
            if task_type == 'Classification':
                # Get all labels from the training set
                all_labels = []
                for _, labels in train_loader:
                    all_labels.append(labels)
                all_labels = torch.cat(all_labels)

                # Count occurrences of each class
                train_counts = torch.bincount(all_labels.long())
                total_samples = len(all_labels)

                # Calculate weights (inverse of frequency)
                class_weights = torch.tensor([total_samples / count for count in train_counts], device=device)

                print(f"\nClass distribution: {train_counts.tolist()}")
                print(f"Class weights: {class_weights.tolist()}")
            else:
                class_weights = None


            # Initialize model
            global model  # Make model accessible outside this function
            model = Lyra(d_input=d_input, d_output=d_output, d_model=64).to(device)
            print("\nModel initialized successfully!")
            global initialized
            initialized = True
            # print(model)
            # print(f"Total parameters: {sum(p.numel() for p in model.parameters())}")

        except Exception as e:
            print(f"\nError: {str(e)}")
            print("\nPlease check your configuration and try again.")

    confirm_button.on_click(on_button_click)

    # Return the form
    return widgets.VBox([
        sequence_type_dropdown,
        task_type_dropdown,
        num_classes_input,
        train_file_input,
        val_file_input,
        test_file_input,
        label_name_input,
        confirm_button
    ])

form = create_setup_form()
display(form)

VBox(children=(Dropdown(description='Sequence Type:', options=('RNA', 'DNA', 'Protein'), style=DescriptionStyl…

In [None]:
#@title Check if you're using a CUDA-enabled GPU (you can do this by selecting "Runtime" → "Change Runtime Type" → "T4 GPU")
import torch
if torch.cuda.is_available():
    print("✅ CUDA is available! You're using a GPU-enabled setup.")
else:
    print("❌ CUDA is NOT available. You're running on CPU - this will be much slower")

In [None]:
#@title Train Lyra!
print_results_from_every_epoch = False #@param {type:"boolean"}
num_epochs = 50 #@param {type:"integer"}



class ModelNotInitializedError(Exception): pass
if not initialized: raise ModelNotInitializedError(
    "You never initialized the model in the previous cell")

# Initialize lists to store performance metrics for different data volumes
lyra_metrics = []

from scipy.stats import spearmanr
from sklearn.metrics import r2_score, accuracy_score, precision_recall_fscore_support, roc_auc_score, confusion_matrix

def calculate_regression_metrics(y_true, y_pred):
    # Calculate R^2 score and Spearman's rank correlation coefficient
    r2 = r2_score(y_true, y_pred)
    spearman = spearmanr(y_true, y_pred).correlation
    return r2, spearman

def calculate_classification_metrics(y_true, y_pred, y_scores):
    accuracy = accuracy_score(y_true, y_pred)

    # Check if binary or multiclass classification
    num_classes = len(np.unique(y_true))

    if num_classes == 2:
        # Binary classification metrics
        precision, recall, f1, _ = precision_recall_fscore_support(y_true, y_pred, average='binary')
        tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
        specificity = tn / (tn + fp)
        auc_roc = roc_auc_score(y_true, y_scores)
        true_positive_rate = tp / (tp + fn)
        return accuracy, specificity, f1, auc_roc, recall, true_positive_rate
    else:
        # Multiclass classification metrics
        precision, recall, f1, _ = precision_recall_fscore_support(y_true, y_pred, average='macro')
        # For multiclass, calculate a macro-averaged specificity
        cm = confusion_matrix(y_true, y_pred)
        specificities = []
        true_positive_rates = []

        # Calculate specificity for each class
        for i in range(num_classes):
            true_neg = np.sum(cm) - np.sum(cm[i, :]) - np.sum(cm[:, i]) + cm[i, i]
            false_pos = np.sum(cm[:, i]) - cm[i, i]
            false_neg = np.sum(cm[i, :]) - cm[i, i]
            true_pos = cm[i, i]

            # Avoid division by zero
            if (true_neg + false_pos) > 0:
                specificities.append(true_neg / (true_neg + false_pos))
            else:
                specificities.append(0)

            if (true_pos + false_neg) > 0:
                true_positive_rates.append(true_pos / (true_pos + false_neg))
            else:
                true_positive_rates.append(0)

        # Macro-average the metrics
        specificity = np.mean(specificities)
        true_positive_rate = np.mean(true_positive_rates)

        # For multiclass ROC AUC, use one-vs-rest approach if y_scores has probabilities for each class
        if isinstance(y_scores, np.ndarray) and y_scores.ndim == 2 and y_scores.shape[1] == num_classes:
            try:
                auc_roc = roc_auc_score(y_true, y_scores, multi_class='ovr')
            except:
                auc_roc = 0  # Fallback if ROC AUC calculation fails
        else:
            auc_roc = 0  # Not applicable if scores aren't available for all classes

        return accuracy, specificity, f1, auc_roc, recall, true_positive_rate

# Set criterion based on task type
if task_type == 'Classification':
    criterion = nn.CrossEntropyLoss(weight=class_weights.to(device))
else:
    criterion = nn.MSELoss()

# Train models with different data volumes
# Train lyra model
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.01)
best_lyra_loss = float('inf')
best_lyra_r2 = float('-inf')  # Track the best R^2 score
best_model_state = None

tqdm_bar = tqdm(range(num_epochs), desc='Training Progress')

for epoch in tqdm_bar:
    model.train()
    lyra_train_loss = 0
    lyra_train_true = []
    lyra_train_pred = []
    lyra_train_scores = []

    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)

        lyra_outputs = model(inputs)
        if task_type == 'Classification':
            labels = labels.long()
            lyra_loss = criterion(lyra_outputs, labels)
        else:
            lyra_loss = criterion(lyra_outputs.squeeze(), labels)

        optimizer.zero_grad()
        lyra_loss.backward()
        optimizer.step()

        lyra_train_loss += lyra_loss.item()
        if task_type == 'Classification':
            _, predicted = torch.max(lyra_outputs.data, 1)
            lyra_train_true.extend(labels.cpu().numpy())
            lyra_train_pred.extend(predicted.cpu().numpy())
            # Store all class probabilities for multiclass
            lyra_train_scores.extend(torch.softmax(lyra_outputs, dim=1).cpu().detach().numpy())
        else:
            lyra_train_true.append(labels.detach().cpu().numpy())
            lyra_train_pred.append(lyra_outputs.squeeze().detach().cpu().numpy())

    # Evaluate on validation set
    model.eval()
    lyra_loss = 0
    lyra_true = []
    lyra_pred = []
    lyra_scores = []

    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            lyra_outputs = model(inputs)
            if task_type == 'Classification':
                labels = labels.long()
                lyra_loss = criterion(lyra_outputs, labels)
            else:
                lyra_loss = criterion(lyra_outputs.squeeze(), labels)

            lyra_loss += lyra_loss.item()
            if task_type == 'Classification':
                _, predicted = torch.max(lyra_outputs.data, 1)
                lyra_true.extend(labels.cpu().numpy())
                lyra_pred.extend(predicted.cpu().numpy())
                # Store all class probabilities for multiclass
                lyra_scores.extend(torch.softmax(lyra_outputs, dim=1).cpu().detach().numpy())
            else:
                lyra_true.append(labels.detach().cpu().numpy())
                lyra_pred.append(lyra_outputs.squeeze().detach().cpu().numpy())

        if task_type == 'Classification':
            lyra_true = np.array(lyra_true)
            lyra_pred = np.array(lyra_pred)
            lyra_scores = np.array(lyra_scores)
        else:
            lyra_true = np.concatenate(lyra_true)
            lyra_pred = np.concatenate(lyra_pred)

    # Calculate and print validation statistics for this epoch
    if task_type == 'Classification':
        epoch_accuracy, epoch_specificity, epoch_f1, epoch_auc_roc, epoch_recall, epoch_true_positive_rate = calculate_classification_metrics(lyra_true, lyra_pred, lyra_scores)
        if print_results_from_every_epoch:
            tqdm.write(f"\nEpoch {epoch+1} Validation Statistics:")
            tqdm.write(f"Validation Loss: {lyra_loss/len(val_loader):.8f}")
            tqdm.write(f"Accuracy: {epoch_accuracy:.4f}, Specificity: {epoch_specificity:.4f}, F1 Score: {epoch_f1:.4f}, AUC-ROC: {epoch_auc_roc:.4f}, Recall: {epoch_recall:.4f}, True Positive Rate: {epoch_true_positive_rate:.4f}")
    else:
        epoch_r2, epoch_spearman = calculate_regression_metrics(lyra_true, lyra_pred)
        if print_results_from_every_epoch:
            tqdm.write(f"\nEpoch {epoch+1} Validation Statistics:")
            tqdm.write(f"Validation Loss: {lyra_loss/len(val_loader):.8f}")
            tqdm.write(f"R^2 Score: {epoch_r2:.4f}")
            tqdm.write(f"Spearman's Rank Correlation: {epoch_spearman:.4f}")

    # Update tqdm description with current epoch metrics
    if task_type == 'Classification':
        tqdm_bar.set_description(f'Epoch {epoch+1}/{num_epochs} - Train Loss: {lyra_train_loss/len(train_loader):.8f}, Val Loss: {lyra_loss/len(val_loader):.8f}, Val Accuracy: {epoch_accuracy:.4f}, Val F1: {epoch_f1:.4f}')
    else:
        tqdm_bar.set_description(f'Epoch {epoch+1}/{num_epochs} - Train Loss: {lyra_train_loss/len(train_loader):.8f}, Val Loss: {lyra_loss/len(val_loader):.8f}, Val R^2: {epoch_r2:.4f}, Val Spearman: {epoch_spearman:.4f}')

    # Save the best model based on Validation performance
    if task_type == 'Classification':
        # For multiclass, use accuracy as the metric to determine best model
        # For binary, continue using TPR
        num_classes = len(np.unique(lyra_true))
        if num_classes == 2:
            current_metric = epoch_true_positive_rate
        else:
            current_metric = epoch_accuracy

        if current_metric > best_lyra_r2:
            best_lyra_r2 = current_metric
            best_model_state = model.state_dict()
    else:
        if epoch_r2 > best_lyra_r2:
            best_lyra_r2 = epoch_r2
            best_model_state = model.state_dict()

    if lyra_loss < best_lyra_loss:
        best_lyra_loss = lyra_loss

# Load the best model and evaluate on test set
model.load_state_dict(best_model_state)
model.eval()
lyra_test_loss = 0
lyra_test_true = []
lyra_test_pred = []
lyra_test_scores = []

with torch.no_grad():
    for inputs, labels in tqdm(test_loader, desc='Final Test Evaluation'):
        inputs, labels = inputs.to(device), labels.to(device)
        lyra_outputs = model(inputs)
        if task_type == 'Classification':
            labels = labels.long()
            lyra_loss = criterion(lyra_outputs, labels)
        else:
            lyra_loss = criterion(lyra_outputs.squeeze(), labels)

        lyra_test_loss += lyra_loss.item()
        if task_type == 'Classification':
            _, predicted = torch.max(lyra_outputs.data, 1)
            lyra_test_true.extend(labels.cpu().numpy())
            lyra_test_pred.extend(predicted.cpu().numpy())
            # Store all class probabilities for multiclass
            lyra_test_scores.extend(torch.softmax(lyra_outputs, dim=1).cpu().detach().numpy())
        else:
            lyra_test_true.append(labels.detach().cpu().numpy())
            lyra_test_pred.append(lyra_outputs.squeeze().detach().cpu().numpy())

    if task_type == 'Classification':
        lyra_test_true = np.array(lyra_test_true)
        lyra_test_pred = np.array(lyra_test_pred)
        lyra_test_scores = np.array(lyra_test_scores)
    else:
        lyra_test_true = np.concatenate(lyra_test_true)
        lyra_test_pred = np.concatenate(lyra_test_pred)

# Calculate final test metrics
if task_type == 'Classification':
    final_test_accuracy, final_test_specificity, final_test_f1, final_test_auc_roc, final_test_recall, final_test_true_positive_rate = calculate_classification_metrics(lyra_test_true, lyra_test_pred, lyra_test_scores)
    final_test_loss = lyra_test_loss / len(test_loader)

    # Check if binary or multiclass for final metrics
    num_classes = len(np.unique(lyra_test_true))
    if num_classes == 2:
        # Binary classification metrics
        lyra_metrics.append({
            '_val_loss': best_lyra_loss,
            'best_val_metric': best_lyra_r2,
            'test_loss': final_test_loss,
            'test_accuracy': final_test_accuracy,
            'test_specificity': final_test_specificity,
            'test_f1': final_test_f1,
            'test_auc_roc': final_test_auc_roc,
            'test_recall': final_test_recall,
            'test_true_positive_rate': final_test_true_positive_rate
        })
    else:
        # Multiclass classification metrics
        # Include class-specific metrics
        class_precision, class_recall, class_f1, class_support = precision_recall_fscore_support(
            lyra_test_true, lyra_test_pred, average=None
        )

        # Create a dictionary with class-specific metrics
        class_metrics = {}
        for i in range(num_classes):
            class_metrics[f'class_{i}_precision'] = class_precision[i]
            class_metrics[f'class_{i}_recall'] = class_recall[i]
            class_metrics[f'class_{i}_f1'] = class_f1[i]
            class_metrics[f'class_{i}_support'] = class_support[i]

        # Add confusion matrix
        cm = confusion_matrix(lyra_test_true, lyra_test_pred)

        # Combine all metrics
        lyra_metrics.append({
            '_val_loss': best_lyra_loss,
            'best_val_metric': best_lyra_r2,
            'test_loss': final_test_loss,
            'test_accuracy': final_test_accuracy,
            'test_macro_f1': final_test_f1,
            'test_macro_recall': final_test_recall,
            'test_macro_specificity': final_test_specificity,
            'test_macro_true_positive_rate': final_test_true_positive_rate,
            'test_auc_roc': final_test_auc_roc,
            'confusion_matrix': cm.tolist(),
            **class_metrics  # Include all class-specific metrics
        })
else:
    final_test_r2, final_test_spearman = calculate_regression_metrics(lyra_test_true, lyra_test_pred)
    final_test_loss = lyra_test_loss / len(test_loader)
    lyra_metrics.append({
        '_val_loss': best_lyra_loss,
        'best_val_r2': best_lyra_r2,
        'test_loss': final_test_loss,
        'test_r2': final_test_r2,
        'test_spearman': final_test_spearman
    })

# Save the best model
import datetime
current_time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
global model_name
model_name = f"best_lyra_model_{task_type}_{sequence_type}_{current_time}.pt"
torch.save({'model_state_dict': model.state_dict(), 'sequence_type': sequence_type}, model_name)
print(f"Model saved as: {model_name}")
display(lyra_metrics)

In [14]:
#@title Use trained model to predict new data
from google.colab import files

class PredictionDataset(Dataset):
    def __init__(self, sequences):
        self.sequences = sequences

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        return self.sequences[idx]

def load_model(model_path):
    """Load the saved model and get its configuration"""
    global task_type  # Use the global task_type variable

    checkpoint = torch.load(model_path)
    sequence_type = checkpoint['sequence_type']

    # Get encoder and initialize model
    _, encoder, d_input = setup_sequence_model(sequence_type)
    global d_output
    model = Lyra(d_input=d_input, d_output=d_output, d_model=64).to(device)
    model.load_state_dict(checkpoint['model_state_dict'])

    print(f"Model loaded successfully! Sequence type: {sequence_type}, Task type: {task_type}")
    return model, encoder

def predict_sequences(model_path, input_path, output_path, batch_size=64):
    """Make predictions and save to output file"""
    global task_type  # Use the global task_type variable

    # Load model
    model, encoder = load_model(model_path)
    model.eval()

    # Load input data
    df = pd.read_csv(input_path)

    # Encode sequences
    encoded_seqs = encoder(df['seq'].values)

    # Create dataset and dataloader
    dataset = PredictionDataset(encoded_seqs)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

    # Make predictions in batches
    predictions = []
    scores = []  # For classification probabilities
    with torch.no_grad():
        for batch in tqdm(dataloader, desc='Making predictions'):
            batch = batch.to(device)
            outputs = model(batch)

            if task_type == 'Classification':
                # Handle both binary and multiclass classification
                if d_output == 1:  # Binary classification
                    # Get probabilities for binary classification
                    probs = torch.sigmoid(outputs)
                    scores.append(probs.cpu().numpy())
                    # Convert to binary predictions
                    preds = (probs > 0.5).float()
                else:  # Multiclass classification
                    # Get probabilities using softmax for multiclass
                    probs = torch.softmax(outputs, dim=1)
                    scores.append(probs.cpu().numpy())
                    # Get class with highest probability
                    preds = torch.argmax(probs, dim=1)
                predictions.append(preds.cpu().numpy())
            else:
                # For regression, use outputs directly
                predictions.append(outputs.cpu().numpy())

    # Concatenate all predictions
    predictions = np.concatenate(predictions)

    # Save predictions
    df['prediction'] = predictions

    # Add probability scores for classification
    if task_type == 'Classification':
        scores = np.concatenate(scores)
        if d_output == 1:  # Binary classification
            df['probability'] = scores
        else:  # Multiclass classification
            # Add probability for each class
            for i in range(d_output):
                df[f'probability_class_{i}'] = scores[:, i]

    df.to_csv(output_path, index=False)

    # Trigger download
    files.download(output_path)

    return predictions

# Create and display the form
model_file = widgets.Text(
    value=model_name,
    description='Model File:',
    style={'description_width': 'initial'}
)

input_file = widgets.Text(
    value='test.csv',
    description='Input File:',
    style={'description_width': 'initial'}
)

output_file = widgets.Text(
    value='predictions.csv',
    description='Output File:',
    style={'description_width': 'initial'}
)

predict_button = widgets.Button(description="Make Predictions")

def on_button_click(b):
    try:
        predictions = predict_sequences(model_file.value, input_file.value, output_file.value)
        print(f"\nPredictions completed successfully!")
        print(f"Results saved to: {output_file.value}")
        print(f"Number of predictions: {len(predictions)}")
    except Exception as e:
        print(f"\nError: {str(e)}")

predict_button.on_click(on_button_click)

form = widgets.VBox([model_file, input_file, output_file, predict_button])
display(form)

VBox(children=(Text(value='best_lyra_model_Regression_RNA_20250423_043613.pt', description='Model File:', styl…