In [8]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from tqdm import tqdm
import numpy as np
import csv

# RNA sequence vocabulary (0 reserved for padding)
seq_vocab = {
    'A': 1, 'U': 2, 'G': 3, 'C': 4,
    'R': 5, 'Y': 6, 'N': 7
}

# Structure vocabulary (0 reserved for padding)
struct_vocab = {
    "(": 1,
    ")": 2,
    "[": 3,
    "]": 4,
    ".": 5
}
# Inverse mapping for converting indices back to characters
inv_struct_vocab = {v: k for k, v in struct_vocab.items()}

class RNADataset(Dataset):
    """
    Custom Dataset for RNA data with normalized thermodynamic features.
    Each sample includes:
      - RNA sequence (from a .dbn file)
      - Dot-bracket structure (from the same .dbn file)
      - Thermodynamic features (from a CSV file)

    The CSV is assumed to have columns (check your file for exact order):
      [minimum_free_energy, ..., partition_function, ensemble_diversity, ..., rna_id, sequence_length, ...]

    Samples with sequence_length >= 400 are skipped.
    """
    def __init__(self, dbn_dir, thermo_csv, seq_vocab, struct_vocab, max_len=400):
        self.dbn_dir = dbn_dir
        self.seq_vocab = seq_vocab
        self.struct_vocab = struct_vocab
        self.max_len = max_len

        all_mfe = []
        all_pf = []
        all_ed = []
        all_lengths = []

        self.thermo_dict = {}
        print(f"Reading thermodynamic features from: {thermo_csv}")
        try:
            with open(thermo_csv, 'r', newline='') as f:
                reader = csv.reader(f)
                header = next(reader)
                print(f"CSV Header: {header}")

                try:
                    mfe_idx = header.index('minimum_free_energy')
                    pf_idx = header.index('partition_function')
                    ed_idx = header.index('ensemble_diversity')
                    id_idx = header.index('rna_id')
                    len_idx = header.index('sequence_length')
                    print("Found column indices by name.")
                except ValueError as e:
                    print(f"Warning: Could not find column names in header: {e}. Falling back to hardcoded indices.")
                 
                    mfe_idx = 0
                    pf_idx = 3
                    ed_idx = 4
                    id_idx = 7
                    len_idx = 8
                    print(f"Using hardcoded indices: MFE={mfe_idx}, PF={pf_idx}, ED={ed_idx}, ID={id_idx}, Len={len_idx}")

                processed_rows = 0
                skipped_long = 0
                skipped_parsing = 0
                for i, parts in enumerate(reader):
                    if not parts:
                        continue
                    try:
                        max_idx_needed = max(mfe_idx, pf_idx, ed_idx, id_idx, len_idx)
                        if len(parts) <= max_idx_needed:
                            skipped_parsing += 1
                            continue

                        sequence_length = int(parts[len_idx])

                        if sequence_length >= self.max_len:
                            skipped_long += 1
                            continue

                        minimum_free_energy = float(parts[mfe_idx])
                        partition_function = float(parts[pf_idx])
                        ensemble_diversity = float(parts[ed_idx])
                        rna_id = parts[id_idx]

                        all_mfe.append(minimum_free_energy)
                        all_pf.append(partition_function)
                        all_ed.append(ensemble_diversity)
                        all_lengths.append(sequence_length)

                        self.thermo_dict[rna_id] = {
                            'minimum_free_energy': minimum_free_energy,
                            'partition_function': partition_function,
                            'ensemble_diversity': ensemble_diversity,
                            'sequence_length': sequence_length
                        }
                        processed_rows += 1
                    except (ValueError, IndexError) as e:
                        skipped_parsing += 1
                        continue

            print(f"Finished reading CSV: Processed {processed_rows} rows.")
            if skipped_long > 0:
                print(f"Skipped {skipped_long} sequences with length >= {self.max_len}.")
            if skipped_parsing > 0:
                print(f"Skipped {skipped_parsing} rows due to parsing errors or insufficient columns.")

            if not self.thermo_dict:
                 raise ValueError("No valid thermodynamic data loaded. Check CSV path and format.")

        except FileNotFoundError:
            print(f"Error: Thermodynamic features file not found at {thermo_csv}")
            raise
        except Exception as e:
            print(f"An error occurred while reading the CSV file: {e}")
            raise

        self.mfe_mean = np.mean(all_mfe, dtype=np.float64)
        self.mfe_std = np.std(all_mfe, dtype=np.float64)
        self.pf_mean = np.mean(all_pf, dtype=np.float64)
        self.pf_std = np.std(all_pf, dtype=np.float64)
        self.ed_mean = np.mean(all_ed, dtype=np.float64)
        self.ed_std = np.std(all_ed, dtype=np.float64)
        self.len_mean = np.mean(all_lengths, dtype=np.float64)
        self.len_std = np.std(all_lengths, dtype=np.float64)

        self.mfe_std = self.mfe_std if self.mfe_std > 1e-9 else 1.0
        self.pf_std = self.pf_std if self.pf_std > 1e-9 else 1.0
        self.ed_std = self.ed_std if self.ed_std > 1e-9 else 1.0
        self.len_std = self.len_std if self.len_std > 1e-9 else 1.0

        print(f"\nThermodynamic feature statistics (used for normalization):")
        print(f"MFE:              mean={self.mfe_mean:.4f}, std={self.mfe_std:.4f}")
        print(f"Partition Function: mean={self.pf_mean:.4f}, std={self.pf_std:.4f}")
        print(f"Ensemble Diversity: mean={self.ed_mean:.4f}, std={self.ed_std:.4f}")
        print(f"Sequence Length:    mean={self.len_mean:.4f}, std={self.len_std:.4f}\n")

        self.samples = []
        skipped_dbn = 0
        print(f"Checking for corresponding DBN files in: {dbn_dir}")
        for rna_id in self.thermo_dict.keys():
            dbn_path = os.path.join(self.dbn_dir, f"{rna_id}.dbn")
            if os.path.isfile(dbn_path):
                self.samples.append(rna_id)
            else:
                skipped_dbn += 1

        if skipped_dbn > 0:
            print(f"Skipped {skipped_dbn} samples because corresponding DBN file was not found.")
        print(f"Total valid samples found: {len(self.samples)}")
        if not self.samples:
            raise ValueError("No valid samples found. Check DBN directory path and ensure filenames match RNA IDs in the CSV.")


    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        rna_id = self.samples[idx]
        dbn_path = os.path.join(self.dbn_dir, f"{rna_id}.dbn")

        try:
            with open(dbn_path, 'r') as f:
                lines = [line.strip() for line in f if line.strip()]

            if len(lines) < 5:
                print(f"Warning: DBN file {dbn_path} for RNA ID {rna_id} has fewer than 5 lines. Skipping sample.")
                raise ValueError(f"DBN file {dbn_path} has insufficient lines.")

            seq_str = lines[3]
            struct_str = lines[4]

            if len(seq_str) != len(struct_str):
                 raise ValueError(f"Sequence and structure length mismatch in {dbn_path} for RNA ID {rna_id}. Seq: {len(seq_str)}, Struct: {len(struct_str)}")

            seq_idx = [self.seq_vocab.get(nuc.upper(), 0) for nuc in seq_str]
            struct_idx = [self.struct_vocab.get(ch, 0) for ch in struct_str]

            seq_tensor = torch.tensor(seq_idx, dtype=torch.long)
            struct_tensor = torch.tensor(struct_idx, dtype=torch.long)

            thermo_data = self.thermo_dict[rna_id]

            normalized_thermo_tensor = torch.tensor([
                (thermo_data['minimum_free_energy'] - self.mfe_mean) / self.mfe_std,
                (thermo_data['partition_function'] - self.pf_mean) / self.pf_std,
                (thermo_data['ensemble_diversity'] - self.ed_mean) / self.ed_std,
                (thermo_data['sequence_length'] - self.len_mean) / self.len_std
            ], dtype=torch.float)

            raw_mfe = thermo_data['minimum_free_energy']

            return seq_tensor, struct_tensor, normalized_thermo_tensor, raw_mfe, rna_id

        except FileNotFoundError:
            print(f"Error: DBN file {dbn_path} not found during __getitem__ for RNA ID {rna_id}.")
            raise FileNotFoundError(f"DBN file not found: {dbn_path}")
        except Exception as e:
            print(f"Error processing sample for RNA ID {rna_id} from file {dbn_path}: {e}")
            raise RuntimeError(f"Failed to process sample {rna_id}") from e


class RNAPred(nn.Module):
    """
    Improved RNA structure prediction model combining CNN and LSTM
    to capture both local patterns and long-range dependencies.
    Integrates thermodynamic features and predicts pseudoknot probability.
    """
    def __init__(self, seq_vocab_size, struct_vocab_size, thermo_feature_size=4,
                 embed_dim=64, num_filters=64, kernel_sizes=[3,5,7],
                 lstm_hidden=128, num_lstm_layers=2, num_attn_heads=8, dropout=0.2):
        super(RNAPred, self).__init__()
        self.embed_dim = embed_dim

        self.embedding = nn.Embedding(seq_vocab_size, embed_dim, padding_idx=0)

        self.thermo_fc = nn.Sequential(
            nn.Linear(thermo_feature_size, lstm_hidden),
            nn.BatchNorm1d(lstm_hidden),
            nn.ReLU(),
            nn.Dropout(dropout)
        )

        self.conv_layers = nn.ModuleList([
            nn.Sequential(
                nn.Conv1d(in_channels=embed_dim, out_channels=num_filters, kernel_size=k, padding=k//2),
                nn.BatchNorm1d(num_filters),
                nn.ReLU(),
                nn.Dropout(dropout)
            ) for k in kernel_sizes
        ])
        cnn_output_dim = num_filters * len(kernel_sizes)

        self.lstm = nn.LSTM(
            input_size=cnn_output_dim,
            hidden_size=lstm_hidden,
            num_layers=num_lstm_layers,
            bidirectional=True,
            dropout=dropout if num_lstm_layers > 1 else 0,
            batch_first=True
        )
        lstm_output_dim = lstm_hidden * 2

        self.self_attn = nn.MultiheadAttention(
            embed_dim=lstm_output_dim,
            num_heads=num_attn_heads,
            dropout=dropout,
            batch_first=True
        )

        combined_hidden_dim = lstm_output_dim + lstm_hidden

        self.fc_structure = nn.Linear(combined_hidden_dim, struct_vocab_size)

        self.fc_energy = nn.Sequential(
            nn.Linear(combined_hidden_dim, lstm_hidden),
            nn.BatchNorm1d(lstm_hidden),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(lstm_hidden, 1)
        )

        self.pseudoknot_feature_extractor = nn.Sequential(
            nn.Linear(combined_hidden_dim, lstm_hidden),
            nn.ReLU(),
            nn.Dropout(dropout)
        )
        self.pseudoknot_classifier = nn.Linear(lstm_hidden, 1)

    def forward(self, x, thermo_features):
        """
        Forward pass of the RNAPred model.

        Args:
            x (torch.Tensor): Input RNA sequence tensor, shape (batch, seq_len).
                                Contains indices from seq_vocab, with 0 for padding.
            thermo_features (torch.Tensor): Input thermodynamic features tensor,
                                            shape (batch, thermo_feature_size).

        Returns:
            tuple: Contains:
                - struct_logits (torch.Tensor): Logits for structure prediction for each base.
                                                Shape (batch, seq_len, struct_vocab_size).
                - energy_pred (torch.Tensor): Predicted Minimum Free Energy (MFE) value.
                                              Shape (batch).
                - pk_logits (torch.Tensor): Logits for pseudoknot presence prediction.
                                            Shape (batch). Use with BCEWithLogitsLoss.
        """
        batch_size, seq_len = x.size()

        mask = (x != 0)

        emb = self.embedding(x)
        emb = emb.transpose(1, 2)

        thermo_encoded = self.thermo_fc(thermo_features)

        conv_outs = [conv(emb) for conv in self.conv_layers]
        cnn_features = torch.cat(conv_outs, dim=1)
        cnn_features = cnn_features.transpose(1, 2)

        lengths = mask.sum(dim=1).cpu()
        lengths = torch.clamp(lengths, min=1)

        packed_features = nn.utils.rnn.pack_padded_sequence(
            cnn_features,
            lengths=lengths,
            batch_first=True,
            enforce_sorted=False
        )
        lstm_out_packed, _ = self.lstm(packed_features)
        lstm_out, output_lengths = nn.utils.rnn.pad_packed_sequence(
            lstm_out_packed,
            batch_first=True,
            total_length=seq_len
        )

        attn_mask = ~mask

        attended_out, _ = self.self_attn(
            query=lstm_out,
            key=lstm_out,
            value=lstm_out,
            key_padding_mask=attn_mask
        )
        attended_out = attended_out * mask.unsqueeze(-1).float()

        thermo_expanded = thermo_encoded.unsqueeze(1).expand(-1, seq_len, -1)

        combined_features = torch.cat([attended_out, thermo_expanded], dim=-1)

        struct_logits = self.fc_structure(combined_features)

        masked_combined_features = combined_features * mask.unsqueeze(-1).float()

        valid_lengths = mask.sum(dim=1, keepdim=True).float().clamp(min=1e-9)
        pooled_features = masked_combined_features.sum(dim=1) / valid_lengths

        energy_pred = self.fc_energy(pooled_features).squeeze(-1)

        pk_extracted_features = self.pseudoknot_feature_extractor(pooled_features)
        pk_logits = self.pseudoknot_classifier(pk_extracted_features).squeeze(-1)

        return struct_logits, energy_pred, pk_logits


def pad_collate(batch):
    """
    Collates data samples into batches with padding.

    Filters out None samples that might result from errors in __getitem__.

    Args:
        batch (list): A list of tuples, where each tuple is the output of
                      RNADataset.__getitem__:
                      (seq_tensor, struct_tensor, thermo_tensor, raw_mfe, rna_id)
                      or None if an error occurred.

    Returns:
        tuple: A tuple containing padded batches:
               (padded_seq, padded_struct, thermo_batch, raw_mfe_batch, rna_ids)
               Returns None if the filtered batch is empty.
    """
    batch = [item for item in batch if item is not None]

    if not batch:
        return None

    seq_tensors, struct_tensors, thermo_tensors, raw_mfes, rna_ids = zip(*batch)

    max_seq_len = max(s.size(0) for s in seq_tensors)

    padded_seqs = []
    padded_structs = []
    for s, st in zip(seq_tensors, struct_tensors):
        pad_size = max_seq_len - s.size(0)
        padded_seqs.append(F.pad(s, (0, pad_size), value=0))
        padded_structs.append(F.pad(st, (0, pad_size), value=0))

    padded_seq = torch.stack(padded_seqs, dim=0)
    padded_struct = torch.stack(padded_structs, dim=0)
    thermo_batch = torch.stack(thermo_tensors, dim=0)
    raw_mfe_batch = torch.tensor(raw_mfes, dtype=torch.float)

    return padded_seq, padded_struct, thermo_batch, raw_mfe_batch, list(rna_ids)


def make_valid_structure(predicted_structure):
    """
    Attempts to convert a predicted structure string into a valid dot-bracket
    notation by matching brackets greedily. Handles one level of pseudoknots ('[]').

    Args:
        predicted_structure (str): The raw predicted structure string from the model.

    Returns:
        str: An adjusted structure string with matched brackets or '.' for unmatched positions.
    """
    n = len(predicted_structure)
    valid_structure = ['.'] * n
    used_positions = set()

    stack_round = []
    for i, char in enumerate(predicted_structure):
        if char == '(':
            stack_round.append(i)
        elif char == ')' and stack_round:
            open_pos = stack_round.pop()
            valid_structure[open_pos] = '('
            valid_structure[i] = ')'
            used_positions.add(open_pos)
            used_positions.add(i)

    stack_square = []
    for i, char in enumerate(predicted_structure):
        if i in used_positions:
            continue

        if char == '[':
            stack_square.append(i)
        elif char == ']' and stack_square:
            open_pos = stack_square.pop()
            if open_pos not in used_positions:
                 valid_structure[open_pos] = '['
                 valid_structure[i] = ']'
                 used_positions.add(open_pos)
                 used_positions.add(i)

    return "".join(valid_structure)


def parse_dot_bracket(dot_bracket):
    """
    Parses a dot-bracket string and returns a set of base pairs.
    Handles both standard `()` and pseudoknot `[]` brackets.

    Args:
        dot_bracket (str): The dot-bracket structure string.

    Returns:
        set: A set of tuples, where each tuple represents a base pair (i, j) with i < j.
    """
    pairs = set()
    stack_round = []
    stack_square = []

    for i, char in enumerate(dot_bracket):
        if char == '(':
            stack_round.append(i)
        elif char == ')':
            if stack_round:
                j = stack_round.pop()
                pairs.add(tuple(sorted((j, i))))
        elif char == '[':
            stack_square.append(i)
        elif char == ']':
            if stack_square:
                j = stack_square.pop()
                pairs.add(tuple(sorted((j, i))))

    return pairs

def evaluate_structure(predicted, ground_truth):
    """
    Evaluates the predicted RNA structure against the ground truth using
    Sensitivity, Positive Predictive Value (PPV), and F1-score based on base pairs.

    Args:
        predicted (str): Predicted RNA structure string (dot-bracket).
        ground_truth (str): Ground truth RNA structure string (dot-bracket).

    Returns:
        tuple: (sensitivity, ppv, f1)
    """
    min_len = min(len(predicted), len(ground_truth))
    predicted = predicted[:min_len]
    ground_truth = ground_truth[:min_len]

    pred_pairs = parse_dot_bracket(predicted)
    true_pairs = parse_dot_bracket(ground_truth)

    TP = len(pred_pairs.intersection(true_pairs))
    FP = len(pred_pairs - true_pairs)
    FN = len(true_pairs - pred_pairs)

    sensitivity = TP / (TP + FN) if (TP + FN) > 0 else 0.0
    ppv = TP / (TP + FP) if (TP + FP) > 0 else 0.0
    f1 = 2 * sensitivity * ppv / (sensitivity + ppv) if (sensitivity + ppv) > 0 else 0.0

    return sensitivity, ppv, f1


def train_model(train_loader, valid_loader, model, device, num_epochs=50, lr=1e-4,
                weight_decay=1e-5, grad_clip_norm=1.0,
                enable_energy_loss=True, energy_weight=0.001,
                pseudoknot_weight=0.1):
    """
    Trains the RNA structure prediction model.

    Args:
        train_loader: DataLoader for the training set.
        valid_loader: DataLoader for the validation set.
        model: The RNAPred model instance.
        device: The device (CPU or CUDA) to train on.
        num_epochs: Number of training epochs.
        lr: Learning rate.
        weight_decay: Weight decay for AdamW optimizer.
        grad_clip_norm: Maximum norm for gradient clipping.
        enable_energy_loss: Whether to include energy prediction loss.
        energy_weight: Weighting factor for the energy loss component.
        pseudoknot_weight: Weighting factor for the pseudoknot loss component.

    Returns:
        The trained model (best model loaded from checkpoint).
    """
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=3, verbose=True
    )

    best_valid_loss = float('inf')
    best_model_path = './best_rna_structure_model.pt'

    print("\n--- Starting Training ---")
    for epoch in range(num_epochs):
        model.train()
        total_loss = 0.0
        total_struct_loss = 0.0
        total_energy_loss = 0.0
        total_pk_loss = 0.0

        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Train]", leave=False)
        for batch in progress_bar:
            if batch is None:
                print("Warning: Skipping an empty batch.")
                continue

            seq_batch, struct_batch, thermo_batch, raw_mfe_batch, _ = batch

            seq_batch = seq_batch.to(device)
            struct_batch = struct_batch.to(device)
            thermo_batch = thermo_batch.to(device)
            raw_mfe_batch = raw_mfe_batch.to(device)

            optimizer.zero_grad()

            struct_logits, energy_pred, pk_logits = model(seq_batch, thermo_batch)

            struct_logits_flat = struct_logits.view(-1, struct_logits.size(-1))
            struct_batch_flat = struct_batch.view(-1)
            struct_loss = F.cross_entropy(struct_logits_flat, struct_batch_flat, ignore_index=0)

            has_pseudoknot = torch.zeros_like(raw_mfe_batch)
            for i, struct_tensor in enumerate(struct_batch):
                if torch.any((struct_tensor == 3) | (struct_tensor == 4)):
                    has_pseudoknot[i] = 1.0
            pk_loss = F.binary_cross_entropy_with_logits(pk_logits, has_pseudoknot)

            if enable_energy_loss:
                energy_loss = F.mse_loss(energy_pred, raw_mfe_batch)
            else:
                energy_loss = torch.tensor(0.0, device=device)

            loss = struct_loss + (pseudoknot_weight * pk_loss)
            if enable_energy_loss:
                 loss += (energy_weight * energy_loss)

            loss.backward()

            if grad_clip_norm > 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=grad_clip_norm)

            optimizer.step()

            total_loss += loss.item()
            total_struct_loss += struct_loss.item()
            total_pk_loss += pk_loss.item()
            if enable_energy_loss:
                total_energy_loss += energy_loss.item()

            progress_bar.set_postfix({
                'Loss': f"{loss.item():.4f}",
                'Struct': f"{struct_loss.item():.4f}",
                'PK': f"{pk_loss.item():.4f}",
                'Energy': f"{energy_loss.item():.4f}" if enable_energy_loss else "N/A"
            })

        avg_loss = total_loss / len(train_loader) if len(train_loader) > 0 else 0
        avg_struct_loss = total_struct_loss / len(train_loader) if len(train_loader) > 0 else 0
        avg_pk_loss = total_pk_loss / len(train_loader) if len(train_loader) > 0 else 0
        avg_energy_loss = total_energy_loss / len(train_loader) if enable_energy_loss and len(train_loader) > 0 else 0

        model.eval()
        valid_loss = 0.0
        valid_struct_loss = 0.0
        valid_energy_loss = 0.0
        valid_pk_loss = 0.0
        struct_correct = 0
        struct_total = 0
        pk_correct = 0
        pk_total = 0

        with torch.no_grad():
            for batch in tqdm(valid_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Valid]", leave=False):
                if batch is None: continue

                seq_batch, struct_batch, thermo_batch, raw_mfe_batch, _ = batch

                seq_batch = seq_batch.to(device)
                struct_batch = struct_batch.to(device)
                thermo_batch = thermo_batch.to(device)
                raw_mfe_batch = raw_mfe_batch.to(device)

                struct_logits, energy_pred, pk_logits = model(seq_batch, thermo_batch)

                struct_logits_flat = struct_logits.view(-1, struct_logits.size(-1))
                struct_batch_flat = struct_batch.view(-1)
                struct_loss = F.cross_entropy(struct_logits_flat, struct_batch_flat, ignore_index=0)

                has_pseudoknot = torch.zeros_like(raw_mfe_batch)
                for i, struct_tensor in enumerate(struct_batch):
                    if torch.any((struct_tensor == 3) | (struct_tensor == 4)):
                        has_pseudoknot[i] = 1.0
                pk_loss = F.binary_cross_entropy_with_logits(pk_logits, has_pseudoknot)

                if enable_energy_loss:
                    energy_loss = F.mse_loss(energy_pred, raw_mfe_batch)
                    loss = struct_loss + (pseudoknot_weight * pk_loss) + (energy_weight * energy_loss)
                else:
                    energy_loss = torch.tensor(0.0, device=device)
                    loss = struct_loss + (pseudoknot_weight * pk_loss)

                valid_loss += loss.item()
                valid_struct_loss += struct_loss.item()
                valid_pk_loss += pk_loss.item()
                if enable_energy_loss:
                    valid_energy_loss += energy_loss.item()

                mask = (struct_batch != 0)
                pred_struct_indices = torch.argmax(struct_logits, dim=-1)
                struct_correct += (pred_struct_indices == struct_batch)[mask].sum().item()
                struct_total += mask.sum().item()

                pk_pred = (torch.sigmoid(pk_logits) > 0.5).float()
                pk_correct += (pk_pred == has_pseudoknot).sum().item()
                pk_total += len(seq_batch)

        avg_valid_loss = valid_loss / len(valid_loader) if len(valid_loader) > 0 else 0
        avg_valid_struct_loss = valid_struct_loss / len(valid_loader) if len(valid_loader) > 0 else 0
        avg_valid_pk_loss = valid_pk_loss / len(valid_loader) if len(valid_loader) > 0 else 0
        avg_valid_energy_loss = valid_energy_loss / len(valid_loader) if enable_energy_loss and len(valid_loader) > 0 else 0
        struct_accuracy = struct_correct / struct_total if struct_total > 0 else 0
        pk_accuracy = pk_correct / pk_total if pk_total > 0 else 0

        print(f"\nEpoch {epoch+1}/{num_epochs} Summary:")
        print(f"  Train Loss: {avg_loss:.4f} | Struct: {avg_struct_loss:.4f} | PK: {avg_pk_loss:.4f} | Energy: {avg_energy_loss:.4f}")
        print(f"  Valid Loss: {avg_valid_loss:.4f} | Struct: {avg_valid_struct_loss:.4f} | PK: {avg_valid_pk_loss:.4f} | Energy: {avg_valid_energy_loss:.4f}")
        print(f"  Valid Metrics -> Struct Acc (pos): {struct_accuracy:.4f} | PK Acc (seq): {pk_accuracy:.4f}")

        scheduler.step(avg_valid_loss)

        if avg_valid_loss < best_valid_loss:
            best_valid_loss = avg_valid_loss
            torch.save({
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'valid_loss': best_valid_loss,
                'seq_vocab_size': model.embedding.num_embeddings,
                'struct_vocab_size': model.fc_structure.out_features,
                'thermo_feature_size': model.thermo_fc[0].in_features,
            }, best_model_path)
            print(f"  ** New best model saved to {best_model_path} (Valid Loss: {best_valid_loss:.4f}) **")
        print("-" * 60)

    print("\n--- Training Finished ---")
    try:
        print(f"Loading best model from {best_model_path}...")
        checkpoint = torch.load(best_model_path, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        print(f"Loaded best model from epoch {checkpoint.get('epoch', 'N/A')} with validation loss {checkpoint.get('valid_loss', float('inf')):.4f}")
    except FileNotFoundError:
        print(f"Warning: Best model file {best_model_path} not found. Returning the model from the last epoch.")
    except Exception as e:
        print(f"Error loading best model checkpoint: {e}. Returning the model from the last epoch.")

    return model
    
def evaluate_predictions(model, test_loader, device, sample_indices=None, use_structure_validation=True):
    """
    Evaluates the trained model on test data, prints detailed examples,
    and calculates average metrics.

    Args:
        model: The trained RNAPred model.
        test_loader: DataLoader for the test set.
        device: The device (CPU or CUDA) to run evaluation on.
        sample_indices (list, optional): List of indices of test samples to print
                                         detailed results for. Defaults to first 5.
        use_structure_validation (bool): Whether to apply `make_valid_structure`
                                         post-processing to the predictions.

    Returns:
        list: A list of dictionaries, each containing detailed results for the
              evaluated samples (those specified by sample_indices).
    """
    model.eval() # Set model to evaluation mode

    # --- Gather all test samples first (needed for specific index selection) ---
    all_samples = []
    print("Gathering all test samples...")
    for batch in tqdm(test_loader, desc="Loading Test Samples"):
        if batch is None: continue
        seq_batch, struct_batch, thermo_batch, raw_mfe_batch, rna_ids = batch
        for i in range(len(seq_batch)):
            all_samples.append((
                seq_batch[i], struct_batch[i], thermo_batch[i], raw_mfe_batch[i], rna_ids[i]
            ))

    if not all_samples:
        print("Error: No samples found in the test loader.")
        return []

    # If no specific indices are given, use the first few samples
    if sample_indices is None:
        num_default_samples = min(5, len(all_samples)) # Show up to 5 samples
        sample_indices = list(range(num_default_samples))
        print(f"No specific sample indices provided. Evaluating first {num_default_samples} test samples.")
    else:
        print(f"Evaluating specific test sample indices: {sample_indices}")

    # Validate indices
    valid_indices = [i for i in sample_indices if 0 <= i < len(all_samples)]
    if len(valid_indices) < len(sample_indices):
        ignored_count = len(sample_indices) - len(valid_indices)
        print(f"Warning: {ignored_count} provided indices were out of range (0 to {len(all_samples)-1}) and ignored.")
    if not valid_indices:
        print("Error: No valid sample indices to evaluate.")
        return []

    results = []
    total_sensitivity = 0.0
    total_ppv = 0.0
    total_f1 = 0.0
    total_pk_correct = 0
    total_energy_mae = 0.0 

    print("\n--- Evaluating Selected Test Samples ---")
    with torch.no_grad(): # Disable gradients for evaluation
        for idx in valid_indices:
            seq_tensor, struct_tensor, thermo_tensor, raw_mfe, rna_id = all_samples[idx]

            seq_tensor_batch = seq_tensor.unsqueeze(0).to(device)  # (1, seq_len)
            thermo_tensor_batch = thermo_tensor.unsqueeze(0).to(device) # (1, thermo_dim)

            # Model prediction
            struct_logits, energy_pred, pk_logits = model(seq_tensor_batch, thermo_tensor_batch)

            predicted_indices = torch.argmax(struct_logits, dim=-1).squeeze(0).cpu() # (seq_len)
            # Get predicted pseudoknot probability
            pk_probability = torch.sigmoid(pk_logits).squeeze(0).cpu().item() # Single probability value
            # Get predicted energy
            predicted_energy = energy_pred.squeeze(0).cpu().item() # Single energy value

            true_len = (seq_tensor != 0).sum().item()

            predicted_structure_raw = "".join([inv_struct_vocab.get(token.item(), "") for token in predicted_indices[:true_len]])

            if use_structure_validation:
                predicted_structure_final = make_valid_structure(predicted_structure_raw)
            else:
                predicted_structure_final = predicted_structure_raw

            # Convert original structure indices to string
            original_structure = "".join([inv_struct_vocab.get(token.item(), "") for token in struct_tensor[:true_len]])

            # Convert sequence indices to string
            inv_seq_vocab = {v: k for k, v in seq_vocab.items()}
            rna_sequence = "".join([inv_seq_vocab.get(token.item(), "?") for token in seq_tensor[:true_len]])

            sensitivity, ppv, f1 = evaluate_structure(predicted_structure_final, original_structure)
            total_sensitivity += sensitivity
            total_ppv += ppv
            total_f1 += f1

            # Positional accuracy
            matches = sum(1 for p, o in zip(predicted_structure_final, original_structure) if p == o)
            accuracy = matches / true_len if true_len > 0 else 0

            # Pseudoknot detection correctness
            has_true_pk = '[' in original_structure or ']' in original_structure
            # Use the predicted probability with a threshold (e.g., 0.5)
            predicted_pk_presence = pk_probability > 0.5
            pk_correct = (has_true_pk == predicted_pk_presence)
            if pk_correct:
                total_pk_correct += 1

            # Energy prediction error (MAE)
            energy_mae = abs(predicted_energy - raw_mfe.item())
            total_energy_mae += energy_mae

            # Create difference markers string
            diff_marks = "".join(["✓" if p == o else "✗" for p, o in zip(predicted_structure_final, original_structure)])

            # Store results for this sample
    sample_result = {
                "rna_id": rna_id,
                "length": true_len,
                "sequence": rna_sequence,
                "predicted_structure": predicted_structure_final,
                "original_structure": original_structure,
                "difference_markers": diff_marks,
                "positional_accuracy": accuracy,
                "sensitivity": sensitivity,
                "ppv": ppv,
                "f1_score": f1,
                "predicted_energy": predicted_energy,
                "true_energy": raw_mfe.item(),
                "energy_mae": energy_mae,
                "has_true_pseudoknot": has_true_pk,
                "predicted_pseudoknot_presence": predicted_pk_presence,
                "pseudoknot_prediction_correct": pk_correct,
                "pseudoknot_probability": pk_probability,
            }
            results.append(sample_result)

            print(f"\n{'='*40} Sample Analysis: RNA ID {rna_id} {'='*40}")
            print(f"Length: {true_len}")
            print(f"Sequence:              {rna_sequence}")
            print(f"Predicted Structure:   {predicted_structure_final}")
            print(f"Original Structure:    {original_structure}")
            print(f"Difference Markers:    {diff_marks}")
            print(f"\nMetrics:")
            print(f"  Positional Accuracy: {accuracy:.4f} ({matches}/{true_len})")
            print(f"  Sensitivity (Pairs): {sensitivity:.4f}")
            print(f"  PPV (Pairs):         {ppv:.4f}")
            print(f"  F1 Score (Pairs):    {f1:.4f}")
            print(f"\nEnergy:")
            print(f"  Predicted MFE: {predicted_energy:.4f}")
            print(f"  True MFE:      {raw_mfe.item():.4f}")
            print(f"  Absolute Error:{energy_mae:.4f}")
            print(f"\nPseudoknot:")
            print(f"  Original Has PK: {has_true_pk}")
            print(f"  Predicted Has PK:{predicted_pk_presence} (Prob: {pk_probability:.4f})")
            print(f"  PK Prediction Correct: {pk_correct}")
            print(f"{'='*98}")

    num_evaluated = len(results)
    if num_evaluated > 0:
        avg_sensitivity = total_sensitivity / num_evaluated
        avg_ppv = total_ppv / num_evaluated
        avg_f1 = total_f1 / num_evaluated
        avg_pk_accuracy = total_pk_correct / num_evaluated
        avg_energy_mae = total_energy_mae / num_evaluated


        print(f"\n--- Average Metrics Over {num_evaluated} Evaluated Samples ---")
        print(f"Average Sensitivity (Pairs): {avg_sensitivity:.4f}")
        print(f"Average PPV (Pairs):         {avg_ppv:.4f}")
        print(f"Average F1 Score (Pairs):    {avg_f1:.4f}")
        print(f"Average PK Detection Acc:    {avg_pk_accuracy:.4f}")
        print(f"Average Energy MAE:          {avg_energy_mae:.4f}")
        print("-" * 60)
    else:
        print("No samples were evaluated.")

    return results 

def save_datasets(train_dataset, valid_dataset, test_dataset, output_dir="./saved_datasets"):
    """
    Saves the train, validation, and test dataset objects (torch.utils.data.Subset)
    to disk using torch.save. These objects retain the ability to access RNA IDs.

    Args:
        train_dataset: The training dataset subset.
        valid_dataset: The validation dataset subset.
        test_dataset: The test dataset subset.
        output_dir (str): The directory where dataset files will be saved.
    """
    import os # Make sure os is imported

    os.makedirs(output_dir, exist_ok=True)
    print(f"\nSaving datasets to directory: {output_dir}")

    # Define file paths
    train_path = os.path.join(output_dir, "train_dataset.pt")
    valid_path = os.path.join(output_dir, "valid_dataset.pt")
    test_path = os.path.join(output_dir, "test_dataset.pt")

    # Save datasets using torch.save
    try:
        torch.save(train_dataset, train_path)
        print(f"  Saved train dataset ({len(train_dataset)} samples) to {train_path}")

        torch.save(valid_dataset, valid_path)
        print(f"  Saved validation dataset ({len(valid_dataset)} samples) to {valid_path}")

        torch.save(test_dataset, test_path)
        print(f"  Saved test dataset ({len(test_dataset)} samples) to {test_path}")

    except Exception as e:
        print(f"Error saving datasets: {e}")
def main():
    """Main function to run the RNA structure prediction workflow."""
    seed = 42
    torch.manual_seed(seed)
    np.random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

    dbn_dir = "/kaggle/input/bprna-dbn/bpRNA_1m_90_DBNFILES"  # Directory containing .dbn files
    thermo_csv = "/kaggle/input/thermo-feature/rna_global_thermo_features.csv"  # Path to the CSV with thermodynamic features

    # Dataset parameters
    max_sequence_length = 400 # Sequences longer than this will be skipped

    # Data splitting ratios
    train_split = 0.8
    valid_split = 0.1

    # DataLoader parameters
    batch_size = 32 
    num_workers = 0 

    embed_dim = 64
    num_filters = 64
    kernel_sizes = [3, 5, 7]
    lstm_hidden = 128
    num_lstm_layers = 2
    num_attn_heads = 8
    dropout = 0.2

    # Training parameters
    num_epochs = 50      
    learning_rate = 1e-4
    weight_decay = 1e-5
    grad_clip_norm = 1.0 
    enable_energy_loss = True 
    energy_loss_weight = 0.001 
    pseudoknot_loss_weight = 0.1 

    # Evaluation parameters
    evaluate_on_test_set = True
    test_sample_indices = [0, 10, 20, 30, 40]
    use_structure_validation_postprocessing = True 

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # 1. Load dataset
    print("\n--- Loading Dataset ---")
    try:
   dataset = RNADataset(
            dbn_dir=dbn_dir,
            thermo_csv=thermo_csv,
            seq_vocab=seq_vocab,
            struct_vocab=struct_vocab,
            max_len=max_sequence_length
        )
    except (FileNotFoundError, ValueError, RuntimeError) as e:
        print(f"Fatal Error: Failed to load dataset. Please check paths and file formats. Error: {e}")
        return # Exit if dataset loading fails

    # 2. Split dataset
    print("\n--- Splitting Dataset ---")
    total_samples = len(dataset)
    if total_samples == 0:
        print("Fatal Error: Dataset is empty after loading. Cannot proceed.")
        return

    train_size = int(train_split * total_samples)
    valid_size = int(valid_split * total_samples)
    test_size = total_samples - train_size - valid_size

    if train_size == 0 or valid_size == 0 or test_size == 0:
         print(f"Warning: Dataset size ({total_samples}) is too small for the specified splits, resulting in empty sets.")
    generator = torch.Generator().manual_seed(seed)
    train_dataset, valid_dataset, test_dataset = random_split(
        dataset, [train_size, valid_size, test_size],
        generator=generator
    )
    print(f"Dataset split -> Train: {len(train_dataset)}, Validation: {len(valid_dataset)}, Test: {len(test_dataset)}")

    # 3. Save the split datasets
    save_datasets(train_dataset, valid_dataset, test_dataset, output_dir="./saved_datasets")

    # 4. Create DataLoaders
    print("\n--- Creating DataLoaders ---")
    # Use pin_memory=True if using GPU for potentially faster data transfer
    pin_memory = torch.cuda.is_available()
    train_loader = DataLoader(
        train_dataset, batch_size=batch_size, shuffle=True,
        collate_fn=pad_collate, num_workers=num_workers, pin_memory=pin_memory
    )
    valid_loader = DataLoader(
        valid_dataset, batch_size=batch_size, shuffle=False, # No need to shuffle validation
        collate_fn=pad_collate, num_workers=num_workers, pin_memory=pin_memory
    )
    test_loader = DataLoader(
        test_dataset, batch_size=batch_size, shuffle=False, # No need to shuffle test
        collate_fn=pad_collate, num_workers=num_workers, pin_memory=pin_memory
    )
    print(f"Batch size: {batch_size}, Num workers: {num_workers}")

    # 5. Initialize model
    print("\n--- Initializing Model ---")
    seq_vocab_size = len(seq_vocab) + 1
    struct_vocab_size = len(struct_vocab) + 1

    thermo_feature_size = 4

    model = RNAPred(
        seq_vocab_size=seq_vocab_size,
        struct_vocab_size=struct_vocab_size,
        thermo_feature_size=thermo_feature_size,
        embed_dim=embed_dim,
        num_filters=num_filters,
        kernel_sizes=kernel_sizes,
        lstm_hidden=lstm_hidden,
        num_lstm_layers=num_lstm_layers,
        num_attn_heads=num_attn_heads,
        dropout=dropout
    )
    model.to(device) 
    # 6. Train model
    print("\n--- Starting Model Training ---")
    trained_model = train_model(
        train_loader=train_loader,
        valid_loader=valid_loader,
        model=model,
        device=device,
        num_epochs=num_epochs,
        lr=learning_rate,
        weight_decay=weight_decay,
        grad_clip_norm=grad_clip_norm,
        enable_energy_loss=enable_energy_loss,
        energy_weight=energy_loss_weight,
        pseudoknot_weight=pseudoknot_loss_weight
    )

    # 7. Evaluate model on the test set
    if evaluate_on_test_set and len(test_dataset) > 0:
        print("\n--- Evaluating Best Model on Test Set ---")
        test_results = evaluate_predictions(
            model=trained_model, # Use the best model returned by train_model
            test_loader=test_loader,
            device=device,
            sample_indices=None,
            use_structure_validation=use_structure_validation_postprocessing
        )
        import json
        with open("test_evaluation_results.json", "w") as f:
             json.dump(test_results, f, indent=2)
 print("Saved detailed test evaluation results to test_evaluation_results.json")
    elif not evaluate_on_test_set:
         print("\nSkipping evaluation on the test set.")
    else:
         print("\nTest set is empty. Skipping evaluation.")


    print("\n--- Workflow Complete ---")


In [35]:
def evaluate_predictions(model, test_loader, device, use_structure_validation=True, print_first_batch_details=False):
    """
    Evaluates the trained model on the entire test set.
    Calculates average metrics including positional accuracy.
    Returns average metrics AND a list of all predictions.
    """
    model.eval()
    total_TP, total_FP, total_FN = 0, 0, 0
    total_pk_correct, num_pk_samples = 0, 0
    total_energy_mae = 0.0
    total_positional_matches = 0
    total_positions_compared = 0
    num_samples = 0
    all_predictions = [] # List to store all prediction dicts

    print("\n--- Evaluating Model on Entire Test Set ---")
    with torch.no_grad():
        for batch_idx, batch in enumerate(tqdm(test_loader, desc="Evaluating Test Set")):
            if batch is None: continue
            try:
                 seq_batch, struct_batch, thermo_batch, raw_mfe_batch, rna_ids = batch
            except (ValueError, TypeError) as e:
                 print(f"Error unpacking batch {batch_idx}: {e}. Skipping batch.")
                 continue

            seq_batch, struct_batch, thermo_batch, raw_mfe_batch = \
                seq_batch.to(device), struct_batch.to(device), thermo_batch.to(device), raw_mfe_batch.to(device)

            try:
                struct_logits, energy_pred, pk_logits = model(seq_batch, thermo_batch)
            except Exception as e:
                print(f"Error during model forward pass on batch {batch_idx}: {e}. Skipping batch.")
                continue

            for i in range(len(seq_batch)):
                rna_id = rna_ids[i]
                seq_tensor, struct_tensor, raw_mfe = seq_batch[i], struct_batch[i], raw_mfe_batch[i]
                s_logits, e_pred, p_logits = struct_logits[i], energy_pred[i], pk_logits[i]

                predicted_indices = torch.argmax(s_logits, dim=-1).cpu()
                pk_probability = torch.sigmoid(p_logits).cpu().item()
                predicted_energy = e_pred.cpu().item()

                true_len = (seq_tensor != 0).sum().item()
                if true_len == 0: continue

                predicted_structure_raw = "".join([inv_struct_vocab.get(token.item(), "") for token in predicted_indices[:true_len]])
                predicted_structure_final = make_valid_structure(predicted_structure_raw) if use_structure_validation else predicted_structure_raw
                original_structure = "".join([inv_struct_vocab.get(token.item(), "") for token in struct_tensor[:true_len]])
                # rna_sequence = "".join([inv_seq_vocab.get(token.item(), "?") for token in seq_tensor[:true_len]]) # Sequence if needed

                # Base Pair Metrics
                pred_pairs = parse_dot_bracket(predicted_structure_final)
                true_pairs = parse_dot_bracket(original_structure)
                batch_TP = len(pred_pairs.intersection(true_pairs))
                batch_FP = len(pred_pairs - true_pairs)
                batch_FN = len(true_pairs - pred_pairs)
                total_TP += batch_TP; total_FP += batch_FP; total_FN += batch_FN

                # Pseudoknot Metrics
                has_true_pk = '[' in original_structure or ']' in original_structure
                predicted_pk_presence = pk_probability > 0.5
                pk_correct = (has_true_pk == predicted_pk_presence)
                if pk_correct: total_pk_correct += 1
                if has_true_pk: num_pk_samples += 1

                # Energy Metrics
                energy_mae = abs(predicted_energy - raw_mfe.item())
                total_energy_mae += energy_mae

                # Positional Accuracy Metrics
                matches = sum(1 for p, o in zip(predicted_structure_final, original_structure) if p == o)
                total_positional_matches += matches
                total_positions_compared += true_len

                num_samples += 1

                # Store prediction for this sample
                prediction_data = {
                    'rna_id': rna_id,
                    'predicted_structure': predicted_structure_final
                    # Optionally add original structure or sequence if needed
                    # 'original_structure': original_structure,
                    # 'sequence': rna_sequence
                }
                all_predictions.append(prediction_data)


                # Store/Print first batch details if needed
                if batch_idx == 0 and print_first_batch_details:
                     pass 

    # Calculate Final Average Metrics
    avg_sensitivity = total_TP / (total_TP + total_FN) if (total_TP + total_FN) > 0 else 0.0
    avg_ppv = total_TP / (total_TP + total_FP) if (total_TP + total_FP) > 0 else 0.0
    avg_f1 = 2 * avg_sensitivity * avg_ppv / (avg_sensitivity + avg_ppv) if (avg_sensitivity + avg_ppv) > 0 else 0.0
    avg_pk_accuracy_overall = total_pk_correct / num_samples if num_samples > 0 else 0.0
    avg_energy_mae = total_energy_mae / num_samples if num_samples > 0 else 0.0
    avg_pos_accuracy = total_positional_matches / total_positions_compared if total_positions_compared > 0 else 0.0

    print(f"\n--- Overall Test Set Metrics ({num_samples} samples) ---")
    print(f"Average Sensitivity (Pairs): {avg_sensitivity:.4f}")
    print(f"Average PPV (Pairs):         {avg_ppv:.4f}")
    print(f"Average F1 Score (Pairs):    {avg_f1:.4f}")
    print(f"Average Positional Acc:      {avg_pos_accuracy:.4f}")
    print(f"Average PK Detection Acc:    {avg_pk_accuracy_overall:.4f} ({total_pk_correct}/{num_samples})")
    print(f"Average Energy MAE:          {avg_energy_mae:.4f}")
    print("-" * 60)

    avg_metrics = {
        'avg_sensitivity': avg_sensitivity, 'avg_ppv': avg_ppv, 'avg_f1': avg_f1,
        'avg_positional_accuracy': avg_pos_accuracy,
        'avg_pk_accuracy': avg_pk_accuracy_overall,
        'avg_energy_mae': avg_energy_mae,
        'total_samples': num_samples
    }
    # Return both average metrics and the list of all predictions
    return avg_metrics, all_predictions
# 5. Run Evaluation if model loaded successfully
if loaded_model:
        try:
            # Call evaluation function, now returns two values
            average_test_metrics, all_prediction_results = evaluate_predictions(
                model=loaded_model,
                test_loader=test_loader,
                device=device,
                use_structure_validation=use_structure_validation_postprocessing,
                print_first_batch_details=print_first_batch_eval_details
            )

            # 6. Save Average Metrics
            try:
                with open("full_model_test.json", "w") as f:
                     json.dump(average_test_metrics, f, indent=4)
                print(f"Saved average test metrics to")
            except Exception as e:
                print(f"Error saving test metrics to JSON: {e}")

            # 7. Save All Predictions
            try:
                with open("full_model_pred_structure.json", "w") as f:
                     json.dump(all_prediction_results, f, indent=2) # Use indent=2 for list of structures
                print(f"Saved all test predictions ({len(all_prediction_results)} samples) ")
            except Exception as e:
                print(f"Error saving all test predictions to JSON: {e}")

        except Exception as e:
            print(f"An error occurred during evaluation: {e}")
else:
        print("\nModel loading failed. Cannot run evaluation.")


--- Evaluating Model on Entire Test Set ---


Evaluating Test Set: 100%|██████████| 83/83 [00:08<00:00,  9.43it/s]


--- Overall Test Set Metrics (2633 samples) ---
Average Sensitivity (Pairs): 0.1605
Average PPV (Pairs):         0.1850
Average F1 Score (Pairs):    0.1719
Average Positional Acc:      0.7532
Average PK Detection Acc:    0.9886 (2603/2633)
Average Energy MAE:          1.7920
------------------------------------------------------------
Saved average test metrics to
Saved all test predictions (2633 samples) 





In [45]:
# --- Model 1: Sequence-Only LSTM Baseline ---
class RNAPred_SeqOnlyLSTM(nn.Module):
    """Simplified RNA structure prediction model using only sequence input and a Bidirectional LSTM."""
    def __init__(self, seq_vocab_size, struct_vocab_size,
                 embed_dim=64, lstm_hidden=128, num_lstm_layers=2, dropout=0.2):
        super(RNAPred_SeqOnlyLSTM, self).__init__()
        self.embed_dim = embed_dim
        self.lstm_hidden = lstm_hidden
        lstm_output_dim = lstm_hidden * 2  # Bidirectional output dim

        # Embedding layer for RNA sequence
        self.embedding = nn.Embedding(seq_vocab_size, embed_dim, padding_idx=0)

        # Bidirectional LSTM
        self.lstm = nn.LSTM(
            input_size=embed_dim,
            hidden_size=lstm_hidden,
            num_layers=num_lstm_layers,
            bidirectional=True,
            dropout=dropout if num_lstm_layers > 1 else 0,
            batch_first=True
        )

        # Structure prediction head
        self.fc_structure = nn.Linear(lstm_output_dim, struct_vocab_size)

        # Energy prediction head
        self.fc_energy = nn.Sequential(
            nn.Linear(lstm_output_dim, lstm_hidden),
            nn.LayerNorm(lstm_hidden),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(lstm_hidden, 1)
        )

        # Pseudoknot feature extractor & classifier
        self.pseudoknot_feature_extractor = nn.Sequential(
            nn.Linear(lstm_output_dim, lstm_hidden),
            nn.ReLU(),
            nn.Dropout(dropout)
        )
        self.pseudoknot_classifier = nn.Linear(lstm_hidden, 1)

        # COMPATIBILITY PATCH: Dummy thermo_fc for checkpoint saving
        self.thermo_fc = nn.Sequential(
            nn.Linear(0, self.lstm_hidden)
        )

    def forward(self, x, thermo_features=None):
        """
        Args:
            x: LongTensor of shape (batch, seq_len)
            thermo_features: ignored (None or tensor)
        Returns:
            struct_logits: (batch, seq_len, struct_vocab_size)
            energy_pred:   (batch,)
            pk_logits:     (batch,)
        """
        batch_size, seq_len = x.size()
        mask = (x != 0)

        # Embedding
        emb = self.embedding(x)  # (batch, seq_len, embed_dim)

        # Pack & LSTM
        lengths = mask.sum(dim=1).cpu().clamp(min=1)
        packed = nn.utils.rnn.pack_padded_sequence(
            emb, lengths, batch_first=True, enforce_sorted=False
        )
        packed_out, _ = self.lstm(packed)
        lstm_out, _ = nn.utils.rnn.pad_packed_sequence(
            packed_out, batch_first=True, total_length=seq_len
        )
        lstm_out = lstm_out * mask.unsqueeze(-1).float()

        # Structure logits
        struct_logits = self.fc_structure(lstm_out)

        # Global pooling for energy & pseudoknot
        summed = (lstm_out * mask.unsqueeze(-1).float()).sum(dim=1)
        lengths_f = mask.sum(dim=1, keepdim=True).float().clamp(min=1e-9)
        pooled = summed / lengths_f  # (batch, lstm_output_dim)

        # Energy
        energy_pred = self.fc_energy(pooled).squeeze(-1)

        # Pseudoknot
        pk_feat = self.pseudoknot_feature_extractor(pooled)
        pk_logits = self.pseudoknot_classifier(pk_feat).squeeze(-1)

        return struct_logits, energy_pred, pk_logits

# --- Model 2: Simplified CNN+LSTM with Thermo Features (No Attention) ---
class RNAPred_SimpleCNN_LSTM(nn.Module):
    """RNA structure prediction using CNN, LSTM, and Thermodynamic features, but without Self-Attention."""
    def __init__(self, seq_vocab_size, struct_vocab_size, thermo_feature_size=4,
                 embed_dim=64, num_filters=64, kernel_sizes=[3,5,7],
                 lstm_hidden=128, num_lstm_layers=2, dropout=0.2):
        super(RNAPred_SimpleCNN_LSTM, self).__init__()
        self.embed_dim = embed_dim
        self.lstm_hidden = lstm_hidden
        cnn_output_dim = num_filters * len(kernel_sizes)
        lstm_output_dim = lstm_hidden * 2 # Bidirectional

        # Embedding
        self.embedding = nn.Embedding(seq_vocab_size, embed_dim, padding_idx=0)

        # Thermo FC
        self.thermo_fc = nn.Sequential(
            nn.Linear(thermo_feature_size, lstm_hidden),
            nn.BatchNorm1d(lstm_hidden),
            nn.ReLU(),
            nn.Dropout(dropout)
        )

        # CNN Layers
        self.conv_layers = nn.ModuleList([
            nn.Sequential(
                nn.Conv1d(in_channels=embed_dim, out_channels=num_filters, kernel_size=k, padding=k//2),
                nn.BatchNorm1d(num_filters),
                nn.ReLU(),
                nn.Dropout(dropout)
            ) for k in kernel_sizes
        ])

        # BiLSTM Layer
        self.lstm = nn.LSTM(
            input_size=cnn_output_dim,
            hidden_size=lstm_hidden,
            num_layers=num_lstm_layers,
            bidirectional=True,
            dropout=dropout if num_lstm_layers > 1 else 0,
            batch_first=True
        )

        # Combined dimension: LSTM output + Thermo features
        combined_hidden_dim = lstm_output_dim + lstm_hidden

        # Structure prediction layer
        self.fc_structure = nn.Linear(combined_hidden_dim, struct_vocab_size)

        # Energy prediction layer
        self.fc_energy = nn.Sequential(
            nn.Linear(combined_hidden_dim, lstm_hidden),
            nn.BatchNorm1d(lstm_hidden),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(lstm_hidden, 1)
        )

        # Pseudoknot feature extraction and classification
        self.pseudoknot_feature_extractor = nn.Sequential(
            nn.Linear(combined_hidden_dim, lstm_hidden),
            nn.ReLU(),
            nn.Dropout(dropout)
        )
        self.pseudoknot_classifier = nn.Linear(lstm_hidden, 1)

    def forward(self, x, thermo_features):
        """
        Forward pass for the simplified CNN+LSTM model (no Attention).

        Args:
            x (torch.Tensor): Input RNA sequence tensor, shape (batch, seq_len).
            thermo_features (torch.Tensor): Input thermodynamic features tensor,
                                            shape (batch, thermo_feature_size).

        Returns:
            tuple: (struct_logits, energy_pred, pk_logits)
        """
        batch_size, seq_len = x.size()
        mask = (x != 0) # Padding mask

        # Embedding
        emb = self.embedding(x)
        emb = emb.transpose(1, 2)

        # Process Thermodynamic Features
        thermo_encoded = self.thermo_fc(thermo_features)

        # Apply CNN Layers
        conv_outs = [conv(emb) for conv in self.conv_layers]
        cnn_features = torch.cat(conv_outs, dim=1)
        cnn_features = cnn_features.transpose(1, 2)

        # Apply Bidirectional LSTM
        lengths = mask.sum(dim=1).cpu().clamp(min=1)
        packed_cnn_features = nn.utils.rnn.pack_padded_sequence(
            cnn_features, lengths=lengths, batch_first=True, enforce_sorted=False
        )
        lstm_out_packed, _ = self.lstm(packed_cnn_features)
        lstm_out, _ = nn.utils.rnn.pad_packed_sequence(
            lstm_out_packed, batch_first=True, total_length=seq_len
        )
        lstm_out = lstm_out * mask.unsqueeze(-1).float()

        # Combine LSTM Output and Thermodynamic Features
        thermo_expanded = thermo_encoded.unsqueeze(1).expand(-1, seq_len, -1)
        combined_features = torch.cat([lstm_out, thermo_expanded], dim=-1)

        # Structure Prediction Head
        struct_logits = self.fc_structure(combined_features)

        # Global Pooling (using combined features)
        masked_combined_features = combined_features * mask.unsqueeze(-1).float()
        valid_lengths = mask.sum(dim=1, keepdim=True).float().clamp(min=1e-9)
        pooled_features = masked_combined_features.sum(dim=1) / valid_lengths

        # Energy Prediction Head
        energy_pred = self.fc_energy(pooled_features).squeeze(-1)

        # Pseudoknot Prediction Head
        pk_extracted_features = self.pseudoknot_feature_extractor(pooled_features)
        pk_logits = self.pseudoknot_classifier(pk_extracted_features).squeeze(-1)

        return struct_logits, energy_pred, pk_logits

In [46]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Model hyperparameters (ensure consistency with RNAPred definition)
embed_dim = 64
num_filters = 64
kernel_sizes = [3, 5, 7] # Corrected to match RNAPred definition
lstm_hidden = 128
num_lstm_layers = 2
num_attn_heads = 8
dropout = 0.2

# Training parameters
num_epochs = 50      
learning_rate = 1e-4
weight_decay = 1e-5
grad_clip_norm = 1.0 # Max norm for gradient clipping (0 to disable)
enable_energy_loss = True # Include energy prediction in loss
energy_loss_weight = 0.001 # Weight for energy loss term
pseudoknot_loss_weight = 0.1 # Weight for pseudoknot loss term
print("\n\n=== TRAINING MODEL 2: SEQUENCE-ONLY LSTM ===")
    
# Initialize Sequence-Only LSTM model
print("\n--- Initializing Sequence-Only LSTM Model ---")
model2 = RNAPred_SeqOnlyLSTM(
    seq_vocab_size=seq_vocab_size,
    struct_vocab_size=struct_vocab_size,
    embed_dim=embed_dim,
    lstm_hidden=lstm_hidden,
    num_lstm_layers=num_lstm_layers,
    dropout=dropout
)
model2.to(device)
    
# Train Sequence-Only model
print("\n--- Starting Sequence-Only LSTM Model Training ---")
trained_model2 = train_model(
        train_loader=train_loader,
        valid_loader=valid_loader,
        model=model2,
        device=device,
        num_epochs=num_epochs,
        lr=learning_rate,
        weight_decay=weight_decay,
        grad_clip_norm=grad_clip_norm,
        enable_energy_loss=False,  # Disable energy loss for sequence-only model
        energy_weight=0,  # Set to 0 as energy prediction is less meaningful without thermo features
        pseudoknot_weight=pseudoknot_loss_weight
)
    
# Evaluate Sequence-Only model
if evaluate_on_test_set and len(test_dataset) > 0:
    print("\n--- Evaluating Sequence-Only LSTM Model on Test Set ---")
    test_results2, m2_pred = evaluate_predictions(
        model=trained_model2,
        test_loader=test_loader,
        device=device,
        use_structure_validation=use_structure_validation_postprocessing,
        print_first_batch_details= False
    )
     # Save sequence-only model evaluation results
    try:
                with open("model2_test.json", "w") as f:
                     json.dump(test_results2, f, indent=4)
                print(f"Saved average test metrics to")
    except Exception as e:
                print(f"Error saving test metrics to JSON: {e}")

            # 7. Save All Predictions
    try:
                with open("model2_pred_structure.json", "w") as f:
                     json.dump(m2_pred, f, indent=2) # Use indent=2 for list of structures
                print(f"Saved all test predictions ({len(all_prediction_results)} samples) ")
    except Exception as e:
                print(f"Error saving all test predictions to JSON: {e}")
    print("Saved detailed test evaluation results for Sequence-Only Model")

# Save sequence-only model
torch.save({
    'model_state_dict': trained_model2.state_dict(),
    'seq_vocab_size': seq_vocab_size,
    'struct_vocab_size': struct_vocab_size,
    'embed_dim': embed_dim,
    'lstm_hidden': lstm_hidden,
    'num_lstm_layers': num_lstm_layers,
    'dropout': dropout
}, "seq_only_model.pt")
print("Sequence-Only model saved to seq_only_model.pt")





=== TRAINING MODEL 2: SEQUENCE-ONLY LSTM ===

--- Initializing Sequence-Only LSTM Model ---

--- Starting Sequence-Only LSTM Model Training ---

--- Starting Training ---


                                                                                                                        


Epoch 1/50 Summary:
  Train Loss: 1.0161 | Struct: 0.9866 | PK: 0.2950 | Energy: 0.0000
  Valid Loss: 0.9271 | Struct: 0.9031 | PK: 0.2405 | Energy: 0.0000
  Valid Metrics -> Struct Acc (pos): 0.5776 | PK Acc (seq): 0.9304
  ** New best model saved to ./best_rna_structure_model.pt (Valid Loss: 0.9271) **
------------------------------------------------------------


                                                                                                                        


Epoch 2/50 Summary:
  Train Loss: 0.9302 | Struct: 0.9061 | PK: 0.2414 | Energy: 0.0000
  Valid Loss: 0.9050 | Struct: 0.8851 | PK: 0.1984 | Energy: 0.0000
  Valid Metrics -> Struct Acc (pos): 0.5851 | PK Acc (seq): 0.9312
  ** New best model saved to ./best_rna_structure_model.pt (Valid Loss: 0.9050) **
------------------------------------------------------------


                                                                                                                        


Epoch 3/50 Summary:
  Train Loss: 0.9069 | Struct: 0.8857 | PK: 0.2120 | Energy: 0.0000
  Valid Loss: 0.8817 | Struct: 0.8633 | PK: 0.1843 | Energy: 0.0000
  Valid Metrics -> Struct Acc (pos): 0.5956 | PK Acc (seq): 0.9502
  ** New best model saved to ./best_rna_structure_model.pt (Valid Loss: 0.8817) **
------------------------------------------------------------


                                                                                                                        


Epoch 4/50 Summary:
  Train Loss: 0.8858 | Struct: 0.8662 | PK: 0.1967 | Energy: 0.0000
  Valid Loss: 0.8620 | Struct: 0.8441 | PK: 0.1789 | Energy: 0.0000
  Valid Metrics -> Struct Acc (pos): 0.6083 | PK Acc (seq): 0.9456
  ** New best model saved to ./best_rna_structure_model.pt (Valid Loss: 0.8620) **
------------------------------------------------------------


                                                                                                                        


Epoch 5/50 Summary:
  Train Loss: 0.8667 | Struct: 0.8480 | PK: 0.1871 | Energy: 0.0000
  Valid Loss: 0.8441 | Struct: 0.8278 | PK: 0.1626 | Energy: 0.0000
  Valid Metrics -> Struct Acc (pos): 0.6166 | PK Acc (seq): 0.9536
  ** New best model saved to ./best_rna_structure_model.pt (Valid Loss: 0.8441) **
------------------------------------------------------------


                                                                                                                        


Epoch 6/50 Summary:
  Train Loss: 0.8506 | Struct: 0.8325 | PK: 0.1807 | Energy: 0.0000
  Valid Loss: 0.8279 | Struct: 0.8119 | PK: 0.1606 | Energy: 0.0000
  Valid Metrics -> Struct Acc (pos): 0.6241 | PK Acc (seq): 0.9529
  ** New best model saved to ./best_rna_structure_model.pt (Valid Loss: 0.8279) **
------------------------------------------------------------


                                                                                                                        


Epoch 7/50 Summary:
  Train Loss: 0.8358 | Struct: 0.8182 | PK: 0.1754 | Energy: 0.0000
  Valid Loss: 0.8137 | Struct: 0.7979 | PK: 0.1582 | Energy: 0.0000
  Valid Metrics -> Struct Acc (pos): 0.6305 | PK Acc (seq): 0.9548
  ** New best model saved to ./best_rna_structure_model.pt (Valid Loss: 0.8137) **
------------------------------------------------------------


                                                                                                                        


Epoch 8/50 Summary:
  Train Loss: 0.8218 | Struct: 0.8047 | PK: 0.1710 | Energy: 0.0000
  Valid Loss: 0.8015 | Struct: 0.7856 | PK: 0.1587 | Energy: 0.0000
  Valid Metrics -> Struct Acc (pos): 0.6361 | PK Acc (seq): 0.9517
  ** New best model saved to ./best_rna_structure_model.pt (Valid Loss: 0.8015) **
------------------------------------------------------------


                                                                                                                        


Epoch 9/50 Summary:
  Train Loss: 0.8090 | Struct: 0.7923 | PK: 0.1665 | Energy: 0.0000
  Valid Loss: 0.7898 | Struct: 0.7747 | PK: 0.1502 | Energy: 0.0000
  Valid Metrics -> Struct Acc (pos): 0.6432 | PK Acc (seq): 0.9532
  ** New best model saved to ./best_rna_structure_model.pt (Valid Loss: 0.7898) **
------------------------------------------------------------


                                                                                                                         


Epoch 10/50 Summary:
  Train Loss: 0.7961 | Struct: 0.7802 | PK: 0.1595 | Energy: 0.0000
  Valid Loss: 0.7819 | Struct: 0.7675 | PK: 0.1443 | Energy: 0.0000
  Valid Metrics -> Struct Acc (pos): 0.6455 | PK Acc (seq): 0.9552
  ** New best model saved to ./best_rna_structure_model.pt (Valid Loss: 0.7819) **
------------------------------------------------------------


                                                                                                                         


Epoch 11/50 Summary:
  Train Loss: 0.7849 | Struct: 0.7695 | PK: 0.1540 | Energy: 0.0000
  Valid Loss: 0.7663 | Struct: 0.7525 | PK: 0.1384 | Energy: 0.0000
  Valid Metrics -> Struct Acc (pos): 0.6536 | PK Acc (seq): 0.9571
  ** New best model saved to ./best_rna_structure_model.pt (Valid Loss: 0.7663) **
------------------------------------------------------------


                                                                                                                         


Epoch 12/50 Summary:
  Train Loss: 0.7747 | Struct: 0.7597 | PK: 0.1500 | Energy: 0.0000
  Valid Loss: 0.7595 | Struct: 0.7454 | PK: 0.1410 | Energy: 0.0000
  Valid Metrics -> Struct Acc (pos): 0.6561 | PK Acc (seq): 0.9532
  ** New best model saved to ./best_rna_structure_model.pt (Valid Loss: 0.7595) **
------------------------------------------------------------


                                                                                                                         


Epoch 13/50 Summary:
  Train Loss: 0.7643 | Struct: 0.7499 | PK: 0.1446 | Energy: 0.0000
  Valid Loss: 0.7473 | Struct: 0.7341 | PK: 0.1321 | Energy: 0.0000
  Valid Metrics -> Struct Acc (pos): 0.6618 | PK Acc (seq): 0.9559
  ** New best model saved to ./best_rna_structure_model.pt (Valid Loss: 0.7473) **
------------------------------------------------------------


                                                                                                                         


Epoch 14/50 Summary:
  Train Loss: 0.7548 | Struct: 0.7410 | PK: 0.1380 | Energy: 0.0000
  Valid Loss: 0.7375 | Struct: 0.7253 | PK: 0.1219 | Energy: 0.0000
  Valid Metrics -> Struct Acc (pos): 0.6661 | PK Acc (seq): 0.9593
  ** New best model saved to ./best_rna_structure_model.pt (Valid Loss: 0.7375) **
------------------------------------------------------------


                                                                                                                         


Epoch 15/50 Summary:
  Train Loss: 0.7447 | Struct: 0.7315 | PK: 0.1320 | Energy: 0.0000
  Valid Loss: 0.7291 | Struct: 0.7176 | PK: 0.1153 | Energy: 0.0000
  Valid Metrics -> Struct Acc (pos): 0.6694 | PK Acc (seq): 0.9620
  ** New best model saved to ./best_rna_structure_model.pt (Valid Loss: 0.7291) **
------------------------------------------------------------


                                                                                                                         


Epoch 16/50 Summary:
  Train Loss: 0.7359 | Struct: 0.7232 | PK: 0.1269 | Energy: 0.0000
  Valid Loss: 0.7207 | Struct: 0.7102 | PK: 0.1051 | Energy: 0.0000
  Valid Metrics -> Struct Acc (pos): 0.6712 | PK Acc (seq): 0.9639
  ** New best model saved to ./best_rna_structure_model.pt (Valid Loss: 0.7207) **
------------------------------------------------------------


                                                                                                                         


Epoch 17/50 Summary:
  Train Loss: 0.7274 | Struct: 0.7154 | PK: 0.1209 | Energy: 0.0000
  Valid Loss: 0.7128 | Struct: 0.7025 | PK: 0.1032 | Energy: 0.0000
  Valid Metrics -> Struct Acc (pos): 0.6756 | PK Acc (seq): 0.9647
  ** New best model saved to ./best_rna_structure_model.pt (Valid Loss: 0.7128) **
------------------------------------------------------------


                                                                                                                         


Epoch 18/50 Summary:
  Train Loss: 0.7184 | Struct: 0.7069 | PK: 0.1152 | Energy: 0.0000
  Valid Loss: 0.7073 | Struct: 0.6978 | PK: 0.0956 | Energy: 0.0000
  Valid Metrics -> Struct Acc (pos): 0.6796 | PK Acc (seq): 0.9707
  ** New best model saved to ./best_rna_structure_model.pt (Valid Loss: 0.7073) **
------------------------------------------------------------


                                                                                                                         


Epoch 19/50 Summary:
  Train Loss: 0.7105 | Struct: 0.6996 | PK: 0.1100 | Energy: 0.0000
  Valid Loss: 0.6996 | Struct: 0.6904 | PK: 0.0929 | Energy: 0.0000
  Valid Metrics -> Struct Acc (pos): 0.6810 | PK Acc (seq): 0.9719
  ** New best model saved to ./best_rna_structure_model.pt (Valid Loss: 0.6996) **
------------------------------------------------------------


                                                                                                                         


Epoch 20/50 Summary:
  Train Loss: 0.7032 | Struct: 0.6925 | PK: 0.1062 | Energy: 0.0000
  Valid Loss: 0.6933 | Struct: 0.6844 | PK: 0.0890 | Energy: 0.0000
  Valid Metrics -> Struct Acc (pos): 0.6856 | PK Acc (seq): 0.9711
  ** New best model saved to ./best_rna_structure_model.pt (Valid Loss: 0.6933) **
------------------------------------------------------------


                                                                                                                         


Epoch 21/50 Summary:
  Train Loss: 0.6945 | Struct: 0.6841 | PK: 0.1034 | Energy: 0.0000
  Valid Loss: 0.6910 | Struct: 0.6817 | PK: 0.0929 | Energy: 0.0000
  Valid Metrics -> Struct Acc (pos): 0.6850 | PK Acc (seq): 0.9685
  ** New best model saved to ./best_rna_structure_model.pt (Valid Loss: 0.6910) **
------------------------------------------------------------


                                                                                                                         


Epoch 22/50 Summary:
  Train Loss: 0.6878 | Struct: 0.6781 | PK: 0.0967 | Energy: 0.0000
  Valid Loss: 0.6795 | Struct: 0.6711 | PK: 0.0839 | Energy: 0.0000
  Valid Metrics -> Struct Acc (pos): 0.6910 | PK Acc (seq): 0.9734
  ** New best model saved to ./best_rna_structure_model.pt (Valid Loss: 0.6795) **
------------------------------------------------------------


                                                                                                                         


Epoch 23/50 Summary:
  Train Loss: 0.6809 | Struct: 0.6715 | PK: 0.0934 | Energy: 0.0000
  Valid Loss: 0.6731 | Struct: 0.6650 | PK: 0.0807 | Energy: 0.0000
  Valid Metrics -> Struct Acc (pos): 0.6942 | PK Acc (seq): 0.9742
  ** New best model saved to ./best_rna_structure_model.pt (Valid Loss: 0.6731) **
------------------------------------------------------------


                                                                                                                         


Epoch 24/50 Summary:
  Train Loss: 0.6747 | Struct: 0.6657 | PK: 0.0904 | Energy: 0.0000
  Valid Loss: 0.6753 | Struct: 0.6671 | PK: 0.0821 | Energy: 0.0000
  Valid Metrics -> Struct Acc (pos): 0.6943 | PK Acc (seq): 0.9753
------------------------------------------------------------


                                                                                                                         


Epoch 25/50 Summary:
  Train Loss: 0.6674 | Struct: 0.6587 | PK: 0.0866 | Energy: 0.0000
  Valid Loss: 0.6663 | Struct: 0.6584 | PK: 0.0788 | Energy: 0.0000
  Valid Metrics -> Struct Acc (pos): 0.6969 | PK Acc (seq): 0.9749
  ** New best model saved to ./best_rna_structure_model.pt (Valid Loss: 0.6663) **
------------------------------------------------------------


                                                                                                                         


Epoch 26/50 Summary:
  Train Loss: 0.6614 | Struct: 0.6531 | PK: 0.0835 | Energy: 0.0000
  Valid Loss: 0.6606 | Struct: 0.6533 | PK: 0.0729 | Energy: 0.0000
  Valid Metrics -> Struct Acc (pos): 0.6999 | PK Acc (seq): 0.9757
  ** New best model saved to ./best_rna_structure_model.pt (Valid Loss: 0.6606) **
------------------------------------------------------------


                                                                                                                         


Epoch 27/50 Summary:
  Train Loss: 0.6549 | Struct: 0.6468 | PK: 0.0801 | Energy: 0.0000
  Valid Loss: 0.6563 | Struct: 0.6495 | PK: 0.0682 | Energy: 0.0000
  Valid Metrics -> Struct Acc (pos): 0.7023 | PK Acc (seq): 0.9810
  ** New best model saved to ./best_rna_structure_model.pt (Valid Loss: 0.6563) **
------------------------------------------------------------


                                                                                                                         


Epoch 28/50 Summary:
  Train Loss: 0.6493 | Struct: 0.6414 | PK: 0.0787 | Energy: 0.0000
  Valid Loss: 0.6506 | Struct: 0.6434 | PK: 0.0718 | Energy: 0.0000
  Valid Metrics -> Struct Acc (pos): 0.7050 | PK Acc (seq): 0.9764
  ** New best model saved to ./best_rna_structure_model.pt (Valid Loss: 0.6506) **
------------------------------------------------------------


                                                                                                                         


Epoch 29/50 Summary:
  Train Loss: 0.6434 | Struct: 0.6357 | PK: 0.0766 | Energy: 0.0000
  Valid Loss: 0.6484 | Struct: 0.6407 | PK: 0.0765 | Energy: 0.0000
  Valid Metrics -> Struct Acc (pos): 0.7065 | PK Acc (seq): 0.9753
  ** New best model saved to ./best_rna_structure_model.pt (Valid Loss: 0.6484) **
------------------------------------------------------------


                                                                                                                         


Epoch 30/50 Summary:
  Train Loss: 0.6376 | Struct: 0.6303 | PK: 0.0727 | Energy: 0.0000
  Valid Loss: 0.6429 | Struct: 0.6363 | PK: 0.0663 | Energy: 0.0000
  Valid Metrics -> Struct Acc (pos): 0.7084 | PK Acc (seq): 0.9787
  ** New best model saved to ./best_rna_structure_model.pt (Valid Loss: 0.6429) **
------------------------------------------------------------


                                                                                                                         


Epoch 31/50 Summary:
  Train Loss: 0.6318 | Struct: 0.6249 | PK: 0.0689 | Energy: 0.0000
  Valid Loss: 0.6474 | Struct: 0.6407 | PK: 0.0667 | Energy: 0.0000
  Valid Metrics -> Struct Acc (pos): 0.7056 | PK Acc (seq): 0.9814
------------------------------------------------------------


                                                                                                                         


Epoch 32/50 Summary:
  Train Loss: 0.6263 | Struct: 0.6195 | PK: 0.0671 | Energy: 0.0000
  Valid Loss: 0.6430 | Struct: 0.6359 | PK: 0.0715 | Energy: 0.0000
  Valid Metrics -> Struct Acc (pos): 0.7080 | PK Acc (seq): 0.9776
------------------------------------------------------------


                                                                                                                         


Epoch 33/50 Summary:
  Train Loss: 0.6211 | Struct: 0.6146 | PK: 0.0657 | Energy: 0.0000
  Valid Loss: 0.6327 | Struct: 0.6262 | PK: 0.0654 | Energy: 0.0000
  Valid Metrics -> Struct Acc (pos): 0.7130 | PK Acc (seq): 0.9795
  ** New best model saved to ./best_rna_structure_model.pt (Valid Loss: 0.6327) **
------------------------------------------------------------


                                                                                                                         


Epoch 34/50 Summary:
  Train Loss: 0.6161 | Struct: 0.6098 | PK: 0.0634 | Energy: 0.0000
  Valid Loss: 0.6337 | Struct: 0.6273 | PK: 0.0634 | Energy: 0.0000
  Valid Metrics -> Struct Acc (pos): 0.7143 | PK Acc (seq): 0.9795
------------------------------------------------------------


                                                                                                                         


Epoch 35/50 Summary:
  Train Loss: 0.6113 | Struct: 0.6051 | PK: 0.0620 | Energy: 0.0000
  Valid Loss: 0.6268 | Struct: 0.6190 | PK: 0.0787 | Energy: 0.0000
  Valid Metrics -> Struct Acc (pos): 0.7176 | PK Acc (seq): 0.9745
  ** New best model saved to ./best_rna_structure_model.pt (Valid Loss: 0.6268) **
------------------------------------------------------------


                                                                                                                         


Epoch 36/50 Summary:
  Train Loss: 0.6058 | Struct: 0.6000 | PK: 0.0587 | Energy: 0.0000
  Valid Loss: 0.6275 | Struct: 0.6209 | PK: 0.0660 | Energy: 0.0000
  Valid Metrics -> Struct Acc (pos): 0.7155 | PK Acc (seq): 0.9787
------------------------------------------------------------


                                                                                                                         


Epoch 37/50 Summary:
  Train Loss: 0.6015 | Struct: 0.5958 | PK: 0.0564 | Energy: 0.0000
  Valid Loss: 0.6260 | Struct: 0.6187 | PK: 0.0734 | Energy: 0.0000
  Valid Metrics -> Struct Acc (pos): 0.7160 | PK Acc (seq): 0.9753
  ** New best model saved to ./best_rna_structure_model.pt (Valid Loss: 0.6260) **
------------------------------------------------------------


                                                                                                                         


Epoch 38/50 Summary:
  Train Loss: 0.5974 | Struct: 0.5917 | PK: 0.0566 | Energy: 0.0000
  Valid Loss: 0.6226 | Struct: 0.6159 | PK: 0.0667 | Energy: 0.0000
  Valid Metrics -> Struct Acc (pos): 0.7198 | PK Acc (seq): 0.9783
  ** New best model saved to ./best_rna_structure_model.pt (Valid Loss: 0.6226) **
------------------------------------------------------------


                                                                                                                         


Epoch 39/50 Summary:
  Train Loss: 0.5921 | Struct: 0.5866 | PK: 0.0552 | Energy: 0.0000
  Valid Loss: 0.6190 | Struct: 0.6117 | PK: 0.0723 | Energy: 0.0000
  Valid Metrics -> Struct Acc (pos): 0.7215 | PK Acc (seq): 0.9787
  ** New best model saved to ./best_rna_structure_model.pt (Valid Loss: 0.6190) **
------------------------------------------------------------


                                                                                                                         


Epoch 40/50 Summary:
  Train Loss: 0.5876 | Struct: 0.5825 | PK: 0.0509 | Energy: 0.0000
  Valid Loss: 0.6179 | Struct: 0.6110 | PK: 0.0687 | Energy: 0.0000
  Valid Metrics -> Struct Acc (pos): 0.7195 | PK Acc (seq): 0.9806
  ** New best model saved to ./best_rna_structure_model.pt (Valid Loss: 0.6179) **
------------------------------------------------------------


                                                                                                                         


Epoch 41/50 Summary:
  Train Loss: 0.5830 | Struct: 0.5779 | PK: 0.0511 | Energy: 0.0000
  Valid Loss: 0.6110 | Struct: 0.6048 | PK: 0.0614 | Energy: 0.0000
  Valid Metrics -> Struct Acc (pos): 0.7256 | PK Acc (seq): 0.9810
  ** New best model saved to ./best_rna_structure_model.pt (Valid Loss: 0.6110) **
------------------------------------------------------------


                                                                                                                         


Epoch 42/50 Summary:
  Train Loss: 0.5792 | Struct: 0.5740 | PK: 0.0520 | Energy: 0.0000
  Valid Loss: 0.6101 | Struct: 0.6039 | PK: 0.0625 | Energy: 0.0000
  Valid Metrics -> Struct Acc (pos): 0.7261 | PK Acc (seq): 0.9818
  ** New best model saved to ./best_rna_structure_model.pt (Valid Loss: 0.6101) **
------------------------------------------------------------


                                                                                                                         


Epoch 43/50 Summary:
  Train Loss: 0.5750 | Struct: 0.5700 | PK: 0.0496 | Energy: 0.0000
  Valid Loss: 0.6054 | Struct: 0.5985 | PK: 0.0694 | Energy: 0.0000
  Valid Metrics -> Struct Acc (pos): 0.7273 | PK Acc (seq): 0.9791
  ** New best model saved to ./best_rna_structure_model.pt (Valid Loss: 0.6054) **
------------------------------------------------------------


                                                                                                                         


Epoch 44/50 Summary:
  Train Loss: 0.5705 | Struct: 0.5658 | PK: 0.0465 | Energy: 0.0000
  Valid Loss: 0.6065 | Struct: 0.6002 | PK: 0.0630 | Energy: 0.0000
  Valid Metrics -> Struct Acc (pos): 0.7282 | PK Acc (seq): 0.9783
------------------------------------------------------------


                                                                                                                         


Epoch 45/50 Summary:
  Train Loss: 0.5669 | Struct: 0.5623 | PK: 0.0460 | Energy: 0.0000
  Valid Loss: 0.6001 | Struct: 0.5931 | PK: 0.0701 | Energy: 0.0000
  Valid Metrics -> Struct Acc (pos): 0.7318 | PK Acc (seq): 0.9768
  ** New best model saved to ./best_rna_structure_model.pt (Valid Loss: 0.6001) **
------------------------------------------------------------


                                                                                                                         


Epoch 46/50 Summary:
  Train Loss: 0.5630 | Struct: 0.5587 | PK: 0.0433 | Energy: 0.0000
  Valid Loss: 0.5988 | Struct: 0.5926 | PK: 0.0614 | Energy: 0.0000
  Valid Metrics -> Struct Acc (pos): 0.7311 | PK Acc (seq): 0.9814
  ** New best model saved to ./best_rna_structure_model.pt (Valid Loss: 0.5988) **
------------------------------------------------------------


                                                                                                                         


Epoch 47/50 Summary:
  Train Loss: 0.5592 | Struct: 0.5547 | PK: 0.0445 | Energy: 0.0000
  Valid Loss: 0.5981 | Struct: 0.5918 | PK: 0.0632 | Energy: 0.0000
  Valid Metrics -> Struct Acc (pos): 0.7311 | PK Acc (seq): 0.9802
  ** New best model saved to ./best_rna_structure_model.pt (Valid Loss: 0.5981) **
------------------------------------------------------------


                                                                                                                         


Epoch 48/50 Summary:
  Train Loss: 0.5544 | Struct: 0.5502 | PK: 0.0418 | Energy: 0.0000
  Valid Loss: 0.5949 | Struct: 0.5883 | PK: 0.0655 | Energy: 0.0000
  Valid Metrics -> Struct Acc (pos): 0.7334 | PK Acc (seq): 0.9810
  ** New best model saved to ./best_rna_structure_model.pt (Valid Loss: 0.5949) **
------------------------------------------------------------


                                                                                                                         


Epoch 49/50 Summary:
  Train Loss: 0.5508 | Struct: 0.5467 | PK: 0.0409 | Energy: 0.0000
  Valid Loss: 0.5923 | Struct: 0.5860 | PK: 0.0626 | Energy: 0.0000
  Valid Metrics -> Struct Acc (pos): 0.7360 | PK Acc (seq): 0.9810
  ** New best model saved to ./best_rna_structure_model.pt (Valid Loss: 0.5923) **
------------------------------------------------------------


                                                                                                                         


Epoch 50/50 Summary:
  Train Loss: 0.5471 | Struct: 0.5431 | PK: 0.0400 | Energy: 0.0000
  Valid Loss: 0.5926 | Struct: 0.5863 | PK: 0.0623 | Energy: 0.0000
  Valid Metrics -> Struct Acc (pos): 0.7355 | PK Acc (seq): 0.9821
------------------------------------------------------------

--- Training Finished ---
Loading best model from ./best_rna_structure_model.pt...
Loaded best model from epoch 49 with validation loss 0.5923


  checkpoint = torch.load(best_model_path, map_location=device) # Load to target device


NameError: name 'evaluate_on_test_set' is not defined

In [47]:
# Initialize CNN-LSTM without Attention model
print("\n--- Initializing CNN-LSTM without Attention Model ---")
model3 = RNAPred_SimpleCNN_LSTM(
    seq_vocab_size=seq_vocab_size,
    struct_vocab_size=struct_vocab_size,
    thermo_feature_size=thermo_feature_size,
    embed_dim=embed_dim,
    num_filters=num_filters,
    kernel_sizes=kernel_sizes,
    lstm_hidden=lstm_hidden,
    num_lstm_layers=num_lstm_layers,
    dropout=dropout
)
model3.to(device)
    
# Train CNN-LSTM without Attention model
print("\n--- Starting CNN-LSTM without Attention Model Training ---")
trained_model3 = train_model(
    train_loader=train_loader,
    valid_loader=valid_loader,
    model=model3,
    device=device,
    num_epochs=num_epochs,
    lr=learning_rate,
    weight_decay=weight_decay,
    grad_clip_norm=grad_clip_norm,
    enable_energy_loss=enable_energy_loss,
    energy_weight=energy_loss_weight,
    pseudoknot_weight=pseudoknot_loss_weight
)


--- Initializing CNN-LSTM without Attention Model ---

--- Starting CNN-LSTM without Attention Model Training ---

--- Starting Training ---


                                                                                                                              


Epoch 1/50 Summary:
  Train Loss: 3.0265 | Struct: 0.9603 | PK: 0.2437 | Energy: 2041.8182
  Valid Loss: 2.7126 | Struct: 0.8761 | PK: 0.1787 | Energy: 1818.5612
  Valid Metrics -> Struct Acc (pos): 0.5951 | PK Acc (seq): 0.9498
  ** New best model saved to ./best_rna_structure_model.pt (Valid Loss: 2.7126) **
------------------------------------------------------------


                                                                                                                              


Epoch 2/50 Summary:
  Train Loss: 2.4285 | Struct: 0.8824 | PK: 0.2071 | Energy: 1525.3883
  Valid Loss: 2.1684 | Struct: 0.8510 | PK: 0.1766 | Energy: 1299.6735
  Valid Metrics -> Struct Acc (pos): 0.6104 | PK Acc (seq): 0.9483
  ** New best model saved to ./best_rna_structure_model.pt (Valid Loss: 2.1684) **
------------------------------------------------------------


                                                                                                                              


Epoch 3/50 Summary:
  Train Loss: 1.8639 | Struct: 0.8559 | PK: 0.2033 | Energy: 987.6536
  Valid Loss: 1.6169 | Struct: 0.8226 | PK: 0.1708 | Energy: 777.2049
  Valid Metrics -> Struct Acc (pos): 0.6231 | PK Acc (seq): 0.9487
  ** New best model saved to ./best_rna_structure_model.pt (Valid Loss: 1.6169) **
------------------------------------------------------------


                                                                                                                              


Epoch 4/50 Summary:
  Train Loss: 1.4503 | Struct: 0.8317 | PK: 0.1994 | Energy: 598.6960
  Valid Loss: 1.2796 | Struct: 0.7993 | PK: 0.1658 | Energy: 463.7207
  Valid Metrics -> Struct Acc (pos): 0.6351 | PK Acc (seq): 0.9494
  ** New best model saved to ./best_rna_structure_model.pt (Valid Loss: 1.2796) **
------------------------------------------------------------


                                                                                                                              


Epoch 5/50 Summary:
  Train Loss: 1.1837 | Struct: 0.8110 | PK: 0.1910 | Energy: 353.5507
  Valid Loss: 1.0252 | Struct: 0.7783 | PK: 0.1581 | Energy: 231.0581
  Valid Metrics -> Struct Acc (pos): 0.6446 | PK Acc (seq): 0.9529
  ** New best model saved to ./best_rna_structure_model.pt (Valid Loss: 1.0252) **
------------------------------------------------------------


                                                                                                                             


Epoch 6/50 Summary:
  Train Loss: 1.0116 | Struct: 0.7941 | PK: 0.1866 | Energy: 198.8380
  Valid Loss: 0.8673 | Struct: 0.7624 | PK: 0.1523 | Energy: 89.6224
  Valid Metrics -> Struct Acc (pos): 0.6512 | PK Acc (seq): 0.9540
  ** New best model saved to ./best_rna_structure_model.pt (Valid Loss: 0.8673) **
------------------------------------------------------------


                                                                                                                             


Epoch 7/50 Summary:
  Train Loss: 0.9134 | Struct: 0.7800 | PK: 0.1756 | Energy: 115.7958
  Valid Loss: 0.7885 | Struct: 0.7462 | PK: 0.1467 | Energy: 27.6839
  Valid Metrics -> Struct Acc (pos): 0.6573 | PK Acc (seq): 0.9544
  ** New best model saved to ./best_rna_structure_model.pt (Valid Loss: 0.7885) **
------------------------------------------------------------


                                                                                                                             


Epoch 8/50 Summary:
  Train Loss: 0.8666 | Struct: 0.7668 | PK: 0.1712 | Energy: 82.6297
  Valid Loss: 0.7725 | Struct: 0.7374 | PK: 0.1396 | Energy: 21.1613
  Valid Metrics -> Struct Acc (pos): 0.6610 | PK Acc (seq): 0.9525
  ** New best model saved to ./best_rna_structure_model.pt (Valid Loss: 0.7725) **
------------------------------------------------------------


                                                                                                                             


Epoch 9/50 Summary:
  Train Loss: 0.8361 | Struct: 0.7551 | PK: 0.1588 | Energy: 65.1178
  Valid Loss: 0.7612 | Struct: 0.7305 | PK: 0.1292 | Energy: 17.7024
  Valid Metrics -> Struct Acc (pos): 0.6648 | PK Acc (seq): 0.9574
  ** New best model saved to ./best_rna_structure_model.pt (Valid Loss: 0.7612) **
------------------------------------------------------------


                                                                                                                              


Epoch 10/50 Summary:
  Train Loss: 0.8212 | Struct: 0.7434 | PK: 0.1479 | Energy: 63.0113
  Valid Loss: 0.7410 | Struct: 0.7184 | PK: 0.1213 | Energy: 10.4670
  Valid Metrics -> Struct Acc (pos): 0.6701 | PK Acc (seq): 0.9571
  ** New best model saved to ./best_rna_structure_model.pt (Valid Loss: 0.7410) **
------------------------------------------------------------


                                                                                                                              


Epoch 11/50 Summary:
  Train Loss: 0.8008 | Struct: 0.7323 | PK: 0.1388 | Energy: 54.6214
  Valid Loss: 0.7398 | Struct: 0.7092 | PK: 0.1125 | Energy: 19.4083
  Valid Metrics -> Struct Acc (pos): 0.6740 | PK Acc (seq): 0.9582
  ** New best model saved to ./best_rna_structure_model.pt (Valid Loss: 0.7398) **
------------------------------------------------------------


                                                                                                                              


Epoch 12/50 Summary:
  Train Loss: 0.7870 | Struct: 0.7223 | PK: 0.1269 | Energy: 52.0680
  Valid Loss: 0.7217 | Struct: 0.7018 | PK: 0.1068 | Energy: 9.1687
  Valid Metrics -> Struct Acc (pos): 0.6787 | PK Acc (seq): 0.9639
  ** New best model saved to ./best_rna_structure_model.pt (Valid Loss: 0.7217) **
------------------------------------------------------------


                                                                                                                              


Epoch 13/50 Summary:
  Train Loss: 0.7769 | Struct: 0.7125 | PK: 0.1213 | Energy: 52.2823
  Valid Loss: 0.7208 | Struct: 0.6941 | PK: 0.1093 | Energy: 15.7976
  Valid Metrics -> Struct Acc (pos): 0.6811 | PK Acc (seq): 0.9647
  ** New best model saved to ./best_rna_structure_model.pt (Valid Loss: 0.7208) **
------------------------------------------------------------


                                                                                                                              


Epoch 14/50 Summary:
  Train Loss: 0.7643 | Struct: 0.7027 | PK: 0.1135 | Energy: 50.2589
  Valid Loss: 0.7012 | Struct: 0.6813 | PK: 0.0935 | Energy: 10.5768
  Valid Metrics -> Struct Acc (pos): 0.6875 | PK Acc (seq): 0.9669
  ** New best model saved to ./best_rna_structure_model.pt (Valid Loss: 0.7012) **
------------------------------------------------------------


                                                                                                                              


Epoch 15/50 Summary:
  Train Loss: 0.7538 | Struct: 0.6939 | PK: 0.1054 | Energy: 49.3269
  Valid Loss: 0.7019 | Struct: 0.6786 | PK: 0.0903 | Energy: 14.2911
  Valid Metrics -> Struct Acc (pos): 0.6883 | PK Acc (seq): 0.9681
------------------------------------------------------------


                                                                                                                              


Epoch 16/50 Summary:
  Train Loss: 0.7412 | Struct: 0.6849 | PK: 0.1004 | Energy: 46.2090
  Valid Loss: 0.6911 | Struct: 0.6696 | PK: 0.0859 | Energy: 12.9251
  Valid Metrics -> Struct Acc (pos): 0.6903 | PK Acc (seq): 0.9745
  ** New best model saved to ./best_rna_structure_model.pt (Valid Loss: 0.6911) **
------------------------------------------------------------


                                                                                                                              


Epoch 17/50 Summary:
  Train Loss: 0.7339 | Struct: 0.6777 | PK: 0.0971 | Energy: 46.4444
  Valid Loss: 0.6799 | Struct: 0.6631 | PK: 0.0780 | Energy: 8.9975
  Valid Metrics -> Struct Acc (pos): 0.6964 | PK Acc (seq): 0.9738
  ** New best model saved to ./best_rna_structure_model.pt (Valid Loss: 0.6799) **
------------------------------------------------------------


                                                                                                                              


Epoch 18/50 Summary:
  Train Loss: 0.7239 | Struct: 0.6690 | PK: 0.0909 | Energy: 45.8439
  Valid Loss: 0.6702 | Struct: 0.6571 | PK: 0.0790 | Energy: 5.2493
  Valid Metrics -> Struct Acc (pos): 0.6970 | PK Acc (seq): 0.9749
  ** New best model saved to ./best_rna_structure_model.pt (Valid Loss: 0.6702) **
------------------------------------------------------------


                                                                                                                              


Epoch 19/50 Summary:
  Train Loss: 0.7104 | Struct: 0.6616 | PK: 0.0879 | Energy: 39.9424
  Valid Loss: 0.6780 | Struct: 0.6516 | PK: 0.0767 | Energy: 18.6967
  Valid Metrics -> Struct Acc (pos): 0.6997 | PK Acc (seq): 0.9757
------------------------------------------------------------


                                                                                                                              


Epoch 20/50 Summary:
  Train Loss: 0.7074 | Struct: 0.6537 | PK: 0.0837 | Energy: 45.3628
  Valid Loss: 0.6672 | Struct: 0.6470 | PK: 0.0717 | Energy: 12.9694
  Valid Metrics -> Struct Acc (pos): 0.7015 | PK Acc (seq): 0.9780
  ** New best model saved to ./best_rna_structure_model.pt (Valid Loss: 0.6672) **
------------------------------------------------------------


                                                                                                                              


Epoch 21/50 Summary:
  Train Loss: 0.6969 | Struct: 0.6468 | PK: 0.0787 | Energy: 42.2305
  Valid Loss: 0.6532 | Struct: 0.6376 | PK: 0.0779 | Energy: 7.8219
  Valid Metrics -> Struct Acc (pos): 0.7084 | PK Acc (seq): 0.9757
  ** New best model saved to ./best_rna_structure_model.pt (Valid Loss: 0.6532) **
------------------------------------------------------------


                                                                                                                              


Epoch 22/50 Summary:
  Train Loss: 0.6894 | Struct: 0.6397 | PK: 0.0771 | Energy: 42.0438
  Valid Loss: 0.6537 | Struct: 0.6378 | PK: 0.0745 | Energy: 8.4361
  Valid Metrics -> Struct Acc (pos): 0.7070 | PK Acc (seq): 0.9753
------------------------------------------------------------


                                                                                                                              


Epoch 23/50 Summary:
  Train Loss: 0.6811 | Struct: 0.6334 | PK: 0.0745 | Energy: 40.3319
  Valid Loss: 0.6395 | Struct: 0.6228 | PK: 0.0644 | Energy: 10.2865
  Valid Metrics -> Struct Acc (pos): 0.7162 | PK Acc (seq): 0.9818
  ** New best model saved to ./best_rna_structure_model.pt (Valid Loss: 0.6395) **
------------------------------------------------------------


                                                                                                                              


Epoch 24/50 Summary:
  Train Loss: 0.6751 | Struct: 0.6268 | PK: 0.0681 | Energy: 41.4560
  Valid Loss: 0.6373 | Struct: 0.6218 | PK: 0.0616 | Energy: 9.3644
  Valid Metrics -> Struct Acc (pos): 0.7163 | PK Acc (seq): 0.9833
  ** New best model saved to ./best_rna_structure_model.pt (Valid Loss: 0.6373) **
------------------------------------------------------------


                                                                                                                              


Epoch 25/50 Summary:
  Train Loss: 0.6700 | Struct: 0.6210 | PK: 0.0689 | Energy: 42.1229
  Valid Loss: 0.6401 | Struct: 0.6199 | PK: 0.0608 | Energy: 14.1365
  Valid Metrics -> Struct Acc (pos): 0.7177 | PK Acc (seq): 0.9799
------------------------------------------------------------


                                                                                                                              


Epoch 26/50 Summary:
  Train Loss: 0.6598 | Struct: 0.6156 | PK: 0.0627 | Energy: 37.8519
  Valid Loss: 0.6221 | Struct: 0.6085 | PK: 0.0672 | Energy: 6.8440
  Valid Metrics -> Struct Acc (pos): 0.7230 | PK Acc (seq): 0.9791
  ** New best model saved to ./best_rna_structure_model.pt (Valid Loss: 0.6221) **
------------------------------------------------------------


                                                                                                                              


Epoch 27/50 Summary:
  Train Loss: 0.6548 | Struct: 0.6084 | PK: 0.0646 | Energy: 39.9082
  Valid Loss: 0.6176 | Struct: 0.6059 | PK: 0.0581 | Energy: 5.8504
  Valid Metrics -> Struct Acc (pos): 0.7248 | PK Acc (seq): 0.9848
  ** New best model saved to ./best_rna_structure_model.pt (Valid Loss: 0.6176) **
------------------------------------------------------------


                                                                                                                              


Epoch 28/50 Summary:
  Train Loss: 0.6496 | Struct: 0.6046 | PK: 0.0614 | Energy: 38.8808
  Valid Loss: 0.6192 | Struct: 0.6045 | PK: 0.0580 | Energy: 8.8226
  Valid Metrics -> Struct Acc (pos): 0.7241 | PK Acc (seq): 0.9825
------------------------------------------------------------


                                                                                                                              


Epoch 29/50 Summary:
  Train Loss: 0.6442 | Struct: 0.5986 | PK: 0.0580 | Energy: 39.7989
  Valid Loss: 0.6171 | Struct: 0.6044 | PK: 0.0599 | Energy: 6.7650
  Valid Metrics -> Struct Acc (pos): 0.7239 | PK Acc (seq): 0.9825
  ** New best model saved to ./best_rna_structure_model.pt (Valid Loss: 0.6171) **
------------------------------------------------------------


                                                                                                                              


Epoch 30/50 Summary:
  Train Loss: 0.6361 | Struct: 0.5933 | PK: 0.0574 | Energy: 37.0874
  Valid Loss: 0.6093 | Struct: 0.5957 | PK: 0.0568 | Energy: 7.9302
  Valid Metrics -> Struct Acc (pos): 0.7291 | PK Acc (seq): 0.9840
  ** New best model saved to ./best_rna_structure_model.pt (Valid Loss: 0.6093) **
------------------------------------------------------------


                                                                                                                              


Epoch 31/50 Summary:
  Train Loss: 0.6309 | Struct: 0.5878 | PK: 0.0558 | Energy: 37.5265
  Valid Loss: 0.6002 | Struct: 0.5880 | PK: 0.0605 | Energy: 6.0872
  Valid Metrics -> Struct Acc (pos): 0.7325 | PK Acc (seq): 0.9833
  ** New best model saved to ./best_rna_structure_model.pt (Valid Loss: 0.6002) **
------------------------------------------------------------


                                                                                                                              


Epoch 32/50 Summary:
  Train Loss: 0.6233 | Struct: 0.5826 | PK: 0.0542 | Energy: 35.2701
  Valid Loss: 0.6032 | Struct: 0.5887 | PK: 0.0572 | Energy: 8.8317
  Valid Metrics -> Struct Acc (pos): 0.7329 | PK Acc (seq): 0.9852
------------------------------------------------------------


                                                                                                                              


Epoch 33/50 Summary:
  Train Loss: 0.6191 | Struct: 0.5773 | PK: 0.0523 | Energy: 36.5198
  Valid Loss: 0.6060 | Struct: 0.5909 | PK: 0.0620 | Energy: 8.9005
  Valid Metrics -> Struct Acc (pos): 0.7312 | PK Acc (seq): 0.9806
------------------------------------------------------------


                                                                                                                              


Epoch 34/50 Summary:
  Train Loss: 0.6154 | Struct: 0.5729 | PK: 0.0529 | Energy: 37.1388
  Valid Loss: 0.5928 | Struct: 0.5793 | PK: 0.0585 | Energy: 7.6523
  Valid Metrics -> Struct Acc (pos): 0.7375 | PK Acc (seq): 0.9848
  ** New best model saved to ./best_rna_structure_model.pt (Valid Loss: 0.5928) **
------------------------------------------------------------


                                                                                                                              


Epoch 35/50 Summary:
  Train Loss: 0.6107 | Struct: 0.5685 | PK: 0.0498 | Energy: 37.2307
  Valid Loss: 0.5933 | Struct: 0.5798 | PK: 0.0566 | Energy: 7.7734
  Valid Metrics -> Struct Acc (pos): 0.7364 | PK Acc (seq): 0.9837
------------------------------------------------------------


                                                                                                                              


Epoch 36/50 Summary:
  Train Loss: 0.6084 | Struct: 0.5637 | PK: 0.0477 | Energy: 39.9541
  Valid Loss: 0.5999 | Struct: 0.5779 | PK: 0.0625 | Energy: 15.7699
  Valid Metrics -> Struct Acc (pos): 0.7384 | PK Acc (seq): 0.9810
------------------------------------------------------------


                                                                                                                              


Epoch 37/50 Summary:
  Train Loss: 0.6010 | Struct: 0.5598 | PK: 0.0466 | Energy: 36.4987
  Valid Loss: 0.5931 | Struct: 0.5753 | PK: 0.0762 | Energy: 10.1414
  Valid Metrics -> Struct Acc (pos): 0.7388 | PK Acc (seq): 0.9749
------------------------------------------------------------


                                                                                                                              


Epoch 38/50 Summary:
  Train Loss: 0.5931 | Struct: 0.5553 | PK: 0.0439 | Energy: 33.4229
  Valid Loss: 0.5872 | Struct: 0.5723 | PK: 0.0678 | Energy: 8.0853
  Valid Metrics -> Struct Acc (pos): 0.7419 | PK Acc (seq): 0.9776
  ** New best model saved to ./best_rna_structure_model.pt (Valid Loss: 0.5872) **
------------------------------------------------------------


                                                                                                                              


Epoch 39/50 Summary:
  Train Loss: 0.5901 | Struct: 0.5505 | PK: 0.0418 | Energy: 35.4173
  Valid Loss: 0.5725 | Struct: 0.5621 | PK: 0.0600 | Energy: 4.3909
  Valid Metrics -> Struct Acc (pos): 0.7464 | PK Acc (seq): 0.9833
  ** New best model saved to ./best_rna_structure_model.pt (Valid Loss: 0.5725) **
------------------------------------------------------------


                                                                                                                              


Epoch 40/50 Summary:
  Train Loss: 0.5839 | Struct: 0.5462 | PK: 0.0436 | Energy: 33.2864
  Valid Loss: 0.5788 | Struct: 0.5660 | PK: 0.0577 | Energy: 6.9808
  Valid Metrics -> Struct Acc (pos): 0.7447 | PK Acc (seq): 0.9844
------------------------------------------------------------


                                                                                                                              


Epoch 41/50 Summary:
  Train Loss: 0.5818 | Struct: 0.5420 | PK: 0.0409 | Energy: 35.6779
  Valid Loss: 0.5759 | Struct: 0.5606 | PK: 0.0581 | Energy: 9.5295
  Valid Metrics -> Struct Acc (pos): 0.7470 | PK Acc (seq): 0.9852
------------------------------------------------------------


                                                                                                                              


Epoch 42/50 Summary:
  Train Loss: 0.5776 | Struct: 0.5384 | PK: 0.0382 | Energy: 35.4188
  Valid Loss: 0.5691 | Struct: 0.5568 | PK: 0.0638 | Energy: 5.9659
  Valid Metrics -> Struct Acc (pos): 0.7496 | PK Acc (seq): 0.9810
  ** New best model saved to ./best_rna_structure_model.pt (Valid Loss: 0.5691) **
------------------------------------------------------------


                                                                                                                              


Epoch 43/50 Summary:
  Train Loss: 0.5746 | Struct: 0.5349 | PK: 0.0410 | Energy: 35.6540
  Valid Loss: 0.5730 | Struct: 0.5593 | PK: 0.0583 | Energy: 7.8132
  Valid Metrics -> Struct Acc (pos): 0.7481 | PK Acc (seq): 0.9848
------------------------------------------------------------


                                                                                                                              


Epoch 44/50 Summary:
  Train Loss: 0.5697 | Struct: 0.5301 | PK: 0.0381 | Energy: 35.7301
  Valid Loss: 0.5636 | Struct: 0.5506 | PK: 0.0583 | Energy: 7.1828
  Valid Metrics -> Struct Acc (pos): 0.7527 | PK Acc (seq): 0.9844
  ** New best model saved to ./best_rna_structure_model.pt (Valid Loss: 0.5636) **
------------------------------------------------------------


                                                                                                                              


Epoch 45/50 Summary:
  Train Loss: 0.5666 | Struct: 0.5267 | PK: 0.0375 | Energy: 36.1937
  Valid Loss: 0.5723 | Struct: 0.5508 | PK: 0.0535 | Energy: 16.1637
  Valid Metrics -> Struct Acc (pos): 0.7527 | PK Acc (seq): 0.9852
------------------------------------------------------------


                                                                                                                              


Epoch 46/50 Summary:
  Train Loss: 0.5625 | Struct: 0.5224 | PK: 0.0360 | Energy: 36.4505
  Valid Loss: 0.5672 | Struct: 0.5509 | PK: 0.0561 | Energy: 10.7452
  Valid Metrics -> Struct Acc (pos): 0.7526 | PK Acc (seq): 0.9867
------------------------------------------------------------


                                                                                                                              


Epoch 47/50 Summary:
  Train Loss: 0.5564 | Struct: 0.5197 | PK: 0.0355 | Energy: 33.1301
  Valid Loss: 0.5693 | Struct: 0.5555 | PK: 0.0754 | Energy: 6.2585
  Valid Metrics -> Struct Acc (pos): 0.7491 | PK Acc (seq): 0.9787
------------------------------------------------------------


                                                                                                                              


Epoch 48/50 Summary:
  Train Loss: 0.5534 | Struct: 0.5148 | PK: 0.0322 | Energy: 35.3554
  Valid Loss: 0.5645 | Struct: 0.5481 | PK: 0.0641 | Energy: 10.0172
  Valid Metrics -> Struct Acc (pos): 0.7556 | PK Acc (seq): 0.9844
------------------------------------------------------------


                                                                                                                              


Epoch 49/50 Summary:
  Train Loss: 0.5416 | Struct: 0.5058 | PK: 0.0311 | Energy: 32.6954
  Valid Loss: 0.5550 | Struct: 0.5418 | PK: 0.0568 | Energy: 7.5266
  Valid Metrics -> Struct Acc (pos): 0.7583 | PK Acc (seq): 0.9863
  ** New best model saved to ./best_rna_structure_model.pt (Valid Loss: 0.5550) **
------------------------------------------------------------


                                                                                                                              


Epoch 50/50 Summary:
  Train Loss: 0.5398 | Struct: 0.5029 | PK: 0.0291 | Energy: 33.9481
  Valid Loss: 0.5528 | Struct: 0.5392 | PK: 0.0575 | Energy: 7.8518
  Valid Metrics -> Struct Acc (pos): 0.7585 | PK Acc (seq): 0.9859
  ** New best model saved to ./best_rna_structure_model.pt (Valid Loss: 0.5528) **
------------------------------------------------------------

--- Training Finished ---
Loading best model from ./best_rna_structure_model.pt...
Loaded best model from epoch 50 with validation loss 0.5528


  checkpoint = torch.load(best_model_path, map_location=device) # Load to target device
