# HELLO DEAR READER!
This notebook is for analysis of different rewards. Read the original paper at LINK for a quick understanding of what's going on. 

### Original configuration file

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import random
import os
import json
import subprocess
from collections import deque
from typing import Dict, List
from Bio.Align import substitution_matrices
import itertools, numpy as np

class Config:
    # INDELible Configuration (Section 5.1)
    INDELIBLE_PATH = "indelible_1.03_Windows.exe"
    SYNTHETIC_DIR = "synthetic_data"
    
    # RL Configuration (Section 5.3.1)
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    MEMORY_CAPACITY = 50000  # Prioritized replay buffer size
    BATCH_SIZE = 512        
    GAMMA = 0.99            # Discount factor
    LR = 0.0001             # Learning rate
    TARGET_UPDATE = 500     # Target network update frequency
    MAX_SEQ_LENGTH = 100    # Maximum sequence length (level 50)
    
    # Evaluation Parameters
    EVAL_EPISODES = 20    # Increased for robust evaluation
    
    # Curriculum Parameters (Section 5.1)
    CURRICULUM_LEVELS = 50  # Full curriculum range
    MIN_SEQ_LENGTH = 10     # Level 1 starting length
    MAX_SEQ_LENGTH = 100    # Level 50 maximum length
    MIN_SEQUENCES = 2       # Starting number of sequences
    MAX_SEQUENCES = 9       # Maximum number of sequences
    NUM_OF_SEQUENCES = 100  # Replicates per level
    
    # Biological Parameters (Section 5.1)
    AMINO_ACIDS = ['-'] + list("ACDEFGHIKLMNPQRSTVWY")
    SUBSTITUTION_MODEL = "WAG"
    INDEL_MODEL_PARAMS = {
        'pow_shape': 1.7,    # Power law shape parameter
        'pow_scale': 100,    # Power law scale parameter
        'min_indel_rate': 0.03,  # Starting indel rate
        'max_indel_rate': 0.05   # Maximum indel rate
    }
    
    # Scoring Parameters (Section 5.2)
    BLOSUM_MATRIX = substitution_matrices.load("BLOSUM62")
    GAP_PENALTY = -6         # For contiguous gaps
    GAP_OPEN_PENALTY = -10   # For opening new gaps
    GAP_EXTEND_PENALTY = -1  # For extending existing gaps


### Curriculum creation

In [2]:
from typing import Dict, List, Optional, Tuple
class INDELibleWrapper:
    """Wrapper for INDELible sequence simulation (Section 5.1)"""
    @staticmethod
    def generate_control_file(output_dir: str, params: Dict):
        """Generate control file with evolutionary parameters"""
        # Calculate level-dependent indel rate
        level = int(params['name'].split('_')[1])
        indel_rate = Config.INDEL_MODEL_PARAMS['min_indel_rate'] + \
                   (level/Config.CURRICULUM_LEVELS) * \
                   (Config.INDEL_MODEL_PARAMS['max_indel_rate'] - Config.INDEL_MODEL_PARAMS['min_indel_rate'])
        
        control_content = f"""[TYPE] AMINOACID 2
        
[MODEL] model_{params['name']}
[submodel] {Config.SUBSTITUTION_MODEL}
[statefreq] {' '.join(map(str, params['state_freqs']))}
[indelmodel] POW {Config.INDEL_MODEL_PARAMS['pow_shape']} {Config.INDEL_MODEL_PARAMS['pow_scale']}
[indelrate] {indel_rate}

[TREE] tree_{params['name']} {params['newick_tree']}

[PARTITIONS] partition_{params['name']}
[tree_{params['name']} model_{params['name']} {params['length']}]

[EVOLVE] partition_{params['name']} {params['num_simulations']} {params['name']}
"""
        with open(os.path.join(output_dir, "control.txt"), 'w') as f:
            f.write(control_content)

    @staticmethod
    def run(output_dir: str) -> bool:
        """Execute INDELible simulation"""
        try:
            result = subprocess.run(
                [Config.INDELIBLE_PATH],
                cwd=output_dir,
                capture_output=True,
                text=True,
                check=True
            )
            
            if "ERROR" in result.stdout:
                error_lines = [line for line in result.stdout.split('\n') if "ERROR" in line]
                raise RuntimeError("\n".join(error_lines))
            
            return True
        except subprocess.CalledProcessError as e:
            print(f"INDELible failed with exit code {e.returncode}")
            print(f"Output:\n{e.stdout}")
            print(f"Error:\n{e.stderr}")
            return False
        except Exception as e:
            print(f"Error running INDELible: {str(e)}")
            return False

def parse_phy_file(file_path: str) -> List[List[str]]:
    """Parse INDELible output file to extract alignments"""
    alignments = []
    current_block = []
    
    with open(file_path, 'r') as f:
        for line in f:
            line = line.strip()
            if not line:
                if current_block:
                    alignments.append(current_block)
                    current_block = []
                continue
                
            if line[0].isdigit():
                continue
                
            parts = line.split()
            if len(parts) >= 2:
                current_block.append(parts[1])
    
    if current_block:
        alignments.append(current_block)
    
    return alignments

def generate_evolutionary_parameters(level: int):
    """Generate parameters for evolutionary simulation (Section 5.1)"""
    # Calculate number of sequences for this level
    num_seqs = Config.MIN_SEQUENCES + \
              (level % (Config.MAX_SEQUENCES - Config.MIN_SEQUENCES + 1))
    
    # Calculate sequence length with curriculum progression
    seq_length = Config.MIN_SEQ_LENGTH + \
               int((level / Config.CURRICULUM_LEVELS) * \
               (Config.MAX_SEQ_LENGTH - Config.MIN_SEQ_LENGTH))
    
    # Create Newick tree with equal branch lengths
    branches = ",".join(f"Seq{i+1}:0.1" for i in range(num_seqs))
    newick_tree = f"({branches});"
    
    return {
        'name': f"level_{level}",
        'matrix': Config.SUBSTITUTION_MODEL,
        'state_freqs': [0.05]*20,  # Uniform state frequencies
        'pow_shape': Config.INDEL_MODEL_PARAMS['pow_shape'],
        'pow_scale': Config.INDEL_MODEL_PARAMS['pow_scale'],
        'newick_tree': newick_tree,
        'length': seq_length,
        'num_simulations': Config.NUM_OF_SEQUENCES
    }

def process_indelible_output(output_dir: str, params: Dict) -> Optional[List[Dict]]:
    """Process INDELible output into multiple training samples (one per alignment block)"""
    try:
        phy_file = os.path.join(output_dir, f"{params['name']}_TRUE.phy")
        if not os.path.exists(phy_file):
            return None
            
        alignment_blocks = parse_phy_file(phy_file)
        if not alignment_blocks:
            return None
            
        samples = []
        for block_idx, true_alignment in enumerate(alignment_blocks):
            sequences = [seq.replace('-', '') for seq in true_alignment]
            samples.append({
                'sequences': sequences,
                'true_alignment': true_alignment,
                'length': params['length'],
                'num_sequences': len(sequences),
                'block_idx': block_idx
            })
            
        return samples
    except Exception as e:
        print(f"Error processing output: {str(e)}")
        return None

def generate_synthetic_curriculum():
    """Generate synthetic alignment curriculum using ALL alignment blocks"""
    os.makedirs(Config.SYNTHETIC_DIR, exist_ok=True)
    curriculum = []
    
    for level in range(1, Config.CURRICULUM_LEVELS + 1):
        level_dir = os.path.join(Config.SYNTHETIC_DIR, f"level_{level}")
        os.makedirs(level_dir, exist_ok=True)
        
        params = generate_evolutionary_parameters(level)
        INDELibleWrapper.generate_control_file(level_dir, params)
        
        if INDELibleWrapper.run(level_dir):
            all_samples = process_indelible_output(level_dir, params)
            if all_samples:
                for sample in all_samples:
                    curriculum.append({
                        'level': level,
                        'sample': sample,
                        'params': params,
                        'name': f"Level {level}-{sample['block_idx']}",
                        'episodes': 300 + level * 150,
                        'epsilon_start': max(0.2, 1.0 - (level * 0.15)),
                        'epsilon_end': max(0.05, 0.15 - (level * 0.02))
                    })
                print(f"Generated level {level} with {len(all_samples)} alignment blocks")
            else:
                print(f"Failed to process output for level {level}")
        else:
            print(f"INDELible failed for level {level}")
    
    return curriculum

curriculum = generate_synthetic_curriculum()


Generated level 1 with 100 alignment blocks
Generated level 2 with 100 alignment blocks
Generated level 3 with 100 alignment blocks
Generated level 4 with 100 alignment blocks
Generated level 5 with 100 alignment blocks
Generated level 6 with 100 alignment blocks
Generated level 7 with 100 alignment blocks
Generated level 8 with 100 alignment blocks
Generated level 9 with 100 alignment blocks
Generated level 10 with 100 alignment blocks
Generated level 11 with 100 alignment blocks
Generated level 12 with 100 alignment blocks
Generated level 13 with 100 alignment blocks
Generated level 14 with 100 alignment blocks
Generated level 15 with 100 alignment blocks
Generated level 16 with 100 alignment blocks
Generated level 17 with 100 alignment blocks
Generated level 18 with 100 alignment blocks
Generated level 19 with 100 alignment blocks
Generated level 20 with 100 alignment blocks
Generated level 21 with 100 alignment blocks
Generated level 22 with 100 alignment blocks
Generated level 23 

### The Model

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import random
import os
import json
import subprocess
from collections import deque
from typing import Dict, List
from Bio.Align import substitution_matrices

class BaseProteinAlignmentEnv:
    """Shared base class implementing the MDP formulation (Section 5.2)"""
    def __init__(self, sequences, true_alignment=None):
        # Validate input sequences
        if not sequences:
            raise ValueError("Empty sequence list provided")
        if any(len(seq) == 0 for seq in sequences):
            raise ValueError("One or more sequences are empty")
            
        self.sequences = sequences
        self.true_alignment = true_alignment
        self.n_seqs = len(sequences)
        self.max_len = max(len(s) for s in sequences)
        
        # Action space construction (2^n + n possible actions)
        self.action_space = self._build_action_space()
        self.action_descriptions = self._build_action_descriptions()
        self.aa_to_idx = {aa: i for i, aa in enumerate(Config.AMINO_ACIDS)}
        
        # Hybrid state representation (Section 5.2)
        self.state_size = self._calculate_state_size()
        self.reset()

    def _build_action_space(self):
        """Construct action space with all valid combinations (Section 5.2)"""
        actions = []
        # All possible advancement combinations (2^n - 1)
        for i in range(1, 2**self.n_seqs):
            action = tuple(int(b) for b in f"{i:0{self.n_seqs}b}")
            actions.append(action)
        
        # Gap insertion actions (n additional actions)
        for seq_idx in range(self.n_seqs):
            gap_action = [0] * self.n_seqs
            gap_action[seq_idx] = 1
            actions.append(tuple(gap_action))
            
        return actions

    def _build_action_descriptions(self):
        """Create human-readable action descriptions"""
        desc = []
        for action in self.action_space:
            parts = []
            for i, a in enumerate(action):
                if a:
                    parts.append(f"Advance seq{i+1}" if a == 1 else f"Gap seq{i+1}")
            desc.append(" + ".join(parts) if parts else "No-op")
        return desc

    def _calculate_state_size(self):
        """Calculate state vector size (Section 5.2)"""
        # Local features: current column AA composition
        local_features = len(Config.AMINO_ACIDS)
        
        # Lookahead features: next 3 AAs for each sequence
        lookahead_features = 3 * len(Config.AMINO_ACIDS) * self.n_seqs
        
        # Global features: alignment progress
        progress_features = 2 * self.n_seqs
        
        # Gap context features
        gap_features = 4
        
        return local_features + lookahead_features + progress_features + gap_features

    def reset(self):
        """Initialize a new alignment episode"""
        self.aligned = [[] for _ in range(self.n_seqs)]
        self.remaining = [list(seq) for seq in self.sequences]
        self.original_lengths = [len(seq) for seq in self.sequences]
        self.gap_status = [False] * self.n_seqs
        return self._get_state()

    def _get_gap_context(self):
        """Calculate gap-related features (Section 5.2)"""
        if not self.aligned[0]:
            return np.zeros(4)

        gap_features = np.zeros(4)
        last_col = [seq[-1] if seq else '-' for seq in self.aligned]

        # 1. Current gap ratio in last column
        gap_features[0] = last_col.count('-') / self.n_seqs

        # 2. Contiguous gap ratio
        if len(self.aligned[0]) > 1:
            prev_col = [seq[-2] if len(seq)>1 else '-' for seq in self.aligned]
            gap_features[1] = sum(1 for a,b in zip(last_col, prev_col) if a==b=='-') / self.n_seqs

        # 3. True alignment gap match (if available)
        if self.true_alignment and len(self.aligned[0]) <= len(self.true_alignment[0]):
            target_col = [seq[len(self.aligned[0])-1] for seq in self.true_alignment]
            gap_features[2] = sum(1 for a,b in zip(last_col, target_col) if a==b=='-') / self.n_seqs

        # 4. Total gap ratio in alignment
        total_gaps = sum(seq.count('-') for seq in self.aligned)
        gap_features[3] = total_gaps / (self.n_seqs * len(self.aligned[0])) if self.aligned[0] else 0

        return gap_features

    def _get_state(self):
        """Construct the hybrid state vector (Section 5.2)"""
        state = np.zeros(self.state_size, dtype=np.float32)
        offset = 0
        
        # 1. Current column amino acid composition
        if self.aligned and len(self.aligned[0]) > 0:
            last_col = [seq[-1] for seq in self.aligned]
            for aa in last_col:
                if aa in self.aa_to_idx:
                    state[offset + self.aa_to_idx[aa]] = 1
        offset += len(Config.AMINO_ACIDS)
        
        # 2. Lookahead features (next 3 amino acids)
        for i in range(self.n_seqs):
            if self.remaining[i]:
                for lookahead in range(3):
                    if len(self.remaining[i]) > lookahead:
                        aa = self.remaining[i][lookahead]
                        if aa in self.aa_to_idx:
                            state[offset + lookahead * len(Config.AMINO_ACIDS) + self.aa_to_idx[aa]] = 1.0 / (lookahead + 1)
            offset += 3 * len(Config.AMINO_ACIDS)
        
        # 3. Alignment progress features
        for i in range(self.n_seqs):
            # Safe division for alignment progress
            state[offset + i] = len(self.aligned[i]) / max(1, self.max_len)
            
            # Safe division for remaining ratio
            seq_len = len(self.sequences[i])
            remaining_len = len(self.remaining[i])
            if seq_len > 0:
                state[offset + self.n_seqs + i] = remaining_len / seq_len
            else:
                state[offset + self.n_seqs + i] = 0
        offset += 2 * self.n_seqs
        
        # 4. Gap context features
        state[offset:offset+4] = self._get_gap_context()
        
        return state

    def _verify_alignment_integrity(self):
        """Validate that no amino acids were lost during alignment"""
        for i in range(self.n_seqs):
            aligned_chars = len([c for c in self.aligned[i] if c != '-'])
            remaining_chars = len(self.remaining[i])
            assert aligned_chars + remaining_chars == self.original_lengths[i], \
                f"Sequence {i} lost amino acids! Original: {self.original_lengths[i]}, Aligned: {aligned_chars}, Remaining: {remaining_chars}"

    def step(self, action_idx):
        """Execute one alignment step (Section 5.2)"""
        action = self.action_space[action_idx]
        new_column = []
        new_gap_status = [False] * self.n_seqs
        
        # Process each sequence according to action
        for i in range(self.n_seqs):
            if action[i] and self.remaining[i]:
                new_column.append(self.remaining[i].pop(0))
                new_gap_status[i] = False
            else:
                new_column.append('-')
                new_gap_status[i] = True
        
        prev_gap_status = self.gap_status
        self.gap_status = new_gap_status
        
        # Add the new column to the alignment
        for i in range(self.n_seqs):
            self.aligned[i].append(new_column[i])
        
        # Check termination conditions
        done = all(len(seq) == 0 for seq in self.remaining)
        if len(self.aligned[0]) >= self.max_len * 2:
            if not done:
                reward = -100 * sum(len(seq) for seq in self.remaining)
                self._verify_alignment_integrity()
                return reward, self._get_state(), True
            done = True

        current_alignment = [''.join(seq) for seq in self.aligned]
        self._verify_alignment_integrity()
        
        reward = self._calculate_reward(current_alignment, done)
        return reward, self._get_state(), done

    def _calculate_reward(self, current_alignment, done):
        raise NotImplementedError("Implemented in child classes")

class AlignmentQNetwork(nn.Module):
    """Q-network architecture (Section 5.3.3)"""
    def __init__(self, input_dim, output_dim):
        super().__init__()
        # 4-layer MLP with layer normalization
        self.fc1 = nn.Linear(input_dim, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 64)
        self.fc4 = nn.Linear(64, output_dim)
        
        self.layer_norm1 = nn.LayerNorm(256)
        self.layer_norm2 = nn.LayerNorm(128)
        self.layer_norm3 = nn.LayerNorm(64)
        self.dropout = nn.Dropout(0.2)  # 20% dropout

    def forward(self, x):
        x = F.relu(self.layer_norm1(self.fc1(x)))
        x = self.dropout(x)
        x = F.relu(self.layer_norm2(self.fc2(x)))
        x = self.dropout(x)
        x = F.relu(self.layer_norm3(self.fc3(x)))
        return self.fc4(x)

class DQNAgent:
    """DQN agent with prioritized experience replay (Section 5.3)"""
    def __init__(self, state_dim, action_dim, action_space):
        self.policy_net = AlignmentQNetwork(state_dim, action_dim).to(Config.DEVICE)
        self.target_net = AlignmentQNetwork(state_dim, action_dim).to(Config.DEVICE)
        self.target_net.load_state_dict(self.policy_net.state_dict())
        
        # Optimizer with weight decay and learning rate scheduling
        self.optimizer = torch.optim.AdamW(self.policy_net.parameters(), 
                                          lr=Config.LR, 
                                          weight_decay=1e-4)
        self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            self.optimizer, 'min', patience=50, factor=0.5)
        
        self.memory = deque(maxlen=Config.MEMORY_CAPACITY)
        self.steps = 0
        self.action_space = action_space
        self.action_dim = action_dim
        self.state_dim = state_dim

    def select_action(self, state, epsilon):
        """Epsilon-greedy action selection"""
        if random.random() < epsilon:
            return random.randint(0, self.action_dim-1)
        
        with torch.no_grad():
            state_tensor = torch.FloatTensor(state).unsqueeze(0).to(Config.DEVICE)
            return self.policy_net(state_tensor).argmax().item()

    def store_transition(self, state, action, next_state, reward, done):
        """Store transition in replay buffer"""
        self.memory.append((state, action, next_state, reward, done))

    def update(self):
        """Update policy network with prioritized experience replay"""
        if len(self.memory) < Config.BATCH_SIZE:
            return None
        
        batch = random.sample(self.memory, Config.BATCH_SIZE)
        states, actions, next_states, rewards, dones = zip(*batch)
        
        states = torch.FloatTensor(np.array(states)).to(Config.DEVICE)
        actions = torch.LongTensor(actions).to(Config.DEVICE)
        next_states = torch.FloatTensor(np.array(next_states)).to(Config.DEVICE)
        rewards = torch.FloatTensor(rewards).to(Config.DEVICE)
        dones = torch.FloatTensor(dones).to(Config.DEVICE)
        
        # Current Q values
        current_q = self.policy_net(states).gather(1, actions.unsqueeze(1))
        
        # Double DQN target calculation
        with torch.no_grad():
            next_actions = self.policy_net(next_states).argmax(1)
            next_q = self.target_net(next_states).gather(1, next_actions.unsqueeze(1))
            target_q = rewards + (1 - dones) * Config.GAMMA * next_q.squeeze()
        
        # Huber loss for more stable training
        loss = F.smooth_l1_loss(current_q.squeeze(), target_q)
        
        # Optimize with gradient clipping
        self.optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), 1.0)
        self.optimizer.step()
        self.scheduler.step(loss)

        # Update target network periodically
        if self.steps % Config.TARGET_UPDATE == 0:
            self.target_net.load_state_dict(self.policy_net.state_dict())
        
        self.steps += 1
        return loss.item()

def evaluate_agent(agent, sequences, true_alignment=None, scoring_method='sp'):
    """Comprehensive agent evaluation with intermediate state tracking"""
    # Skip evaluation if we have invalid sequences
    if any(len(seq) == 0 for seq in sequences):
        print("Skipping evaluation due to empty sequences")
        return {
            "avg_reward": 0,
            "avg_sp_score": 0,
            "perfect_match_rate": 0,
            "column_accuracy": 0,
            "avg_alignment_length": 0,
            "reward_std": 0,
            "sp_score_std": 0,
            "all_rewards": [],
            "all_lengths": [],
            "alignments": [],
            "state_sp_correlations": [],
            "state_feature_importance": None
        }
    
    # Initialize environment based on scoring method
    if scoring_method == 'true':
        env = TrueAlignmentProteinEnv(sequences, true_alignment)
    elif scoring_method == 'blosum':
        env = TrueAlignmentBlosumProteinEnv(sequences, true_alignment)
    elif scoring_method == 'rlalign':
        env = RLALIGNProteinEnv(sequences)
    elif scoring_method == 'msadrl':
        env = MSADRLProteinEnv(sequences)
    elif scoring_method == 'dqnalign':
        env = DQNAlignProteinEnv(sequences)
    elif scoring_method == 'dnpmsa':
        env = DNPMSAProteinEnv(sequences)
    elif scoring_method == 'edgealign':
        env = EdgeAlignProteinEnv(sequences)
    elif scoring_method == 'dpamsa':
        env = DPAMSAProteinEnv(sequences)
    elif scoring_method == 'intellialign':
        env = IntelliAlignProteinEnv(sequences)

    
    # Evaluation metrics
    total_rewards = []
    perfect_matches = 0
    alignment_lengths = []
    produced_alignments = []
    sp_scores = []
    column_accuracies = []
    state_sp_correlations = []
    all_state_features = []
    all_sp_scores = []

    for _ in range(Config.EVAL_EPISODES):
        state = env.reset()
        episode_reward = 0
        done = False
        episode_states = []
        episode_sp_scores = []

        while not done:
            action = agent.select_action(state, epsilon=0.01)  # Minimal exploration
            reward, state, done = env.step(action)
            episode_reward += reward
            
            # Track intermediate state and SP score
            current_alignment = [''.join(seq) for seq in env.aligned]
            sp_score = SPProteinAlignmentEnv._calculate_sp_score(env, current_alignment)
            
            episode_states.append(state)
            episode_sp_scores.append(sp_score)

        # Store episode results
        total_rewards.append(episode_reward)
        current_alignment = [''.join(seq) for seq in env.aligned]
        produced_alignments.append(current_alignment)
        
        # Calculate final SP score
        final_sp_score = SPProteinAlignmentEnv._calculate_sp_score(env, current_alignment)
        sp_scores.append(final_sp_score)
        
        # Calculate column accuracy if true alignment available
        if true_alignment:
            min_len = min(len(current_alignment[0]), len(true_alignment[0]))
            correct_cols = 0
            for col in range(min_len):
                if tuple(seq[col] for seq in current_alignment) == tuple(seq[col] for seq in true_alignment):
                    correct_cols += 1
            column_accuracies.append(correct_cols / min_len if min_len > 0 else 0)

        if true_alignment and current_alignment == true_alignment:
            perfect_matches += 1
        alignment_lengths.append(len(env.aligned[0]))
        
        # Calculate state-SP correlation for this episode
        if len(episode_states) > 1:
            # Convert to numpy arrays
            states_array = np.array(episode_states)
            sp_array = np.array(episode_sp_scores)
            
            # Calculate correlation for each state feature
            feature_correlations = []
            for i in range(states_array.shape[1]):
                feature_correlations.append(np.corrcoef(states_array[:, i], sp_array)[0,1])
            
            state_sp_correlations.append(feature_correlations)
            all_state_features.extend(episode_states)
            all_sp_scores.extend(episode_sp_scores)

    # Calculate overall feature importance
    feature_importance = None
    if len(all_state_features) > 0:
        states_array = np.array(all_state_features)
        sp_array = np.array(all_sp_scores)
        
        # Calculate correlation for each feature across all episodes
        feature_importance = []
        for i in range(states_array.shape[1]):
            corr = np.corrcoef(states_array[:, i], sp_array)[0,1]
            feature_importance.append(abs(corr))  # Use absolute value of correlation

    # Print example alignments for inspection
    print("\nExample produced alignments:")
    for i in range(min(3, len(produced_alignments))):
        print(f"  {produced_alignments[i]}")

    return {
        "avg_reward": np.mean(total_rewards),
        "avg_sp_score": np.mean(sp_scores),
        "perfect_match_rate": perfect_matches / Config.EVAL_EPISODES if true_alignment else 0,
        "column_accuracy": np.mean(column_accuracies) if column_accuracies else 0,
        "avg_alignment_length": np.mean(alignment_lengths),
        "reward_std": np.std(total_rewards),
        "sp_score_std": np.std(sp_scores),
        "all_rewards": total_rewards,
        "all_lengths": alignment_lengths,
        "alignments": produced_alignments,
        "state_sp_correlations": np.mean(state_sp_correlations, axis=0) if state_sp_correlations else [],
        "state_feature_importance": feature_importance
    }

def train_agent(curriculum, scoring_method='sp'):
    """Training procedure (Section 5.3)"""
    agent = None
    results = {}
    prev_state_size = None

    for stage_idx, stage in enumerate(curriculum):
        print(f"\n=== Starting {stage['name']} ({scoring_method.upper()} scoring) ===")
        
        try:
            # Skip if we have invalid sequences
            if any(len(seq) == 0 for seq in stage['sample']['sequences']):
                print(f"Skipping {stage['name']} due to empty sequences")
                continue
                
            # Create appropriate environment based on pathway
            if scoring_method == 'true':
                env = TrueAlignmentProteinEnv(stage['sample']['sequences'], stage['sample']['true_alignment'])
            elif scoring_method == 'blosum':
                env = TrueAlignmentBlosumProteinEnv(stage['sample']['sequences'], stage['sample']['true_alignment'])
            elif scoring_method == 'rlalign':
                env = RLALIGNProteinEnv(stage['sample']['sequences'])
            elif scoring_method == 'msadrl':
                env = MSADRLProteinEnv(stage['sample']['sequences'])
            elif scoring_method == 'dqnalign':
                env = DQNAlignProteinEnv(stage['sample']['sequences'])
            elif scoring_method == 'dnpmsa':
                env = DNPMSAProteinEnv(stage['sample']['sequences'])
            elif scoring_method == 'edgealign':
                env = EdgeAlignProteinEnv(stage['sample']['sequences'])
            elif scoring_method == 'dpamsa':
                env = DPAMSAProteinEnv(stage['sample']['sequences'])
            elif scoring_method == 'intellialign':
                env = IntelliAlignProteinEnv(stage['sample']['sequences'])
              
            # Reinitialize agent if state size changed (different number of sequences)
            if agent is None or env.state_size != prev_state_size:
                if agent is not None:
                    print(f"Reinitializing agent due to state size change ({prev_state_size} -> {env.state_size})")
                agent = DQNAgent(env.state_size, len(env.action_space), env.action_space)
                prev_state_size = env.state_size

            epsilon = stage['epsilon_start']
            epsilon_decay = (stage['epsilon_end'] / stage['epsilon_start']) ** (1/stage['episodes'])

            training_rewards = []
            losses = []
            best_reward = -float('inf')
            no_improvement = 0

            for episode in range(stage['episodes']):
                state = env.reset()
                total_reward = 0
                done = False

                while not done:
                    action = agent.select_action(state, epsilon)
                    reward, next_state, done = env.step(action)
                    agent.store_transition(state, action, next_state, reward, done)
                    total_reward += reward
                    state = next_state

                loss = agent.update()
                if loss:
                    losses.append(loss)

                # Decay epsilon
                epsilon = max(stage['epsilon_end'], epsilon * epsilon_decay)
                training_rewards.append(total_reward)

                # Early stopping check (Section 5.3.2)
                if total_reward > best_reward:
                    best_reward = total_reward
                    no_improvement = 0
                else:
                    no_improvement += 1

                if no_improvement > 100:
                    print(f"Early stopping at episode {episode} due to no improvement")
                    break

                # Progress reporting
                if episode % 50 == 0 or episode == stage['episodes']-1:
                    avg_reward = np.mean(training_rewards[-50:]) if len(training_rewards) >= 50 else np.mean(training_rewards)
                    avg_loss = np.mean(losses[-50:]) if losses else 0
                    print(f"Episode {episode}, Avg Reward: {avg_reward:.2f}, Avg Loss: {avg_loss:.4f}, Epsilon: {epsilon:.3f}")

            # Evaluate with both metrics
            eval_results = evaluate_agent(
                agent, 
                stage['sample']['sequences'], 
                stage['sample'].get('true_alignment'),
                scoring_method=scoring_method
            )
            
            results[stage['name']] = {
                "training_rewards": training_rewards,
                "losses": losses,
                "eval_results": eval_results,
                "agent": agent,
                "level": stage['level'],
                "scoring_method": scoring_method
            }

            print(f"\nEvaluation for {stage['name']}:")
            print(f"  Average Reward: {eval_results['avg_reward']:.2f} ± {eval_results['reward_std']:.2f}")
            print(f"  Average SP Score: {eval_results['avg_sp_score']:.2f} ± {eval_results['sp_score_std']:.2f}")
            if 'true_alignment' in stage['sample']:
                print(f"  Perfect Match Rate: {eval_results['perfect_match_rate']:.2f}")
                print(f"  Column Accuracy: {eval_results['column_accuracy']:.2f}")
            print(f"  Avg Alignment Length: {eval_results['avg_alignment_length']:.1f}")

        except Exception as e:
            print(f"Error processing {stage['name']}: {str(e)}")
            continue
            
    return agent, results

def save_results(results, filename):
    """Save training results (without agent objects)"""
    serializable_results = {}
    for name, data in results.items():
        serializable_results[name] = {
            'level': data['level'],
            'scoring_method': data['scoring_method'],
            'training_rewards': data['training_rewards'],
            'losses': data['losses'],
            'eval_results': data['eval_results'],
        }
    
    with open(filename, 'w') as f:
        json.dump(serializable_results, f)

def load_results(filename):
    """Load saved results"""
    with open(filename, 'r') as f:
        return json.load(f)

### True reward function

In [4]:

class TrueAlignmentBlosumProteinEnv(BaseProteinAlignmentEnv):
    def _calculate_reward(self, current_alignment, done):
        reward = 0
        
        # 1. Perfect alignment bonus
        if self.true_alignment and current_alignment == self.true_alignment:
            return 100.0  # Standardized perfect alignment reward
        
        # 2. Column-wise matching (normalized by sequence count)
        if self.true_alignment:
            min_len = min(len(current_alignment[0]), len(self.true_alignment[0]))
            for col in range(min_len):
                current_col = tuple(seq[col] for seq in current_alignment)
                true_col = tuple(seq[col] for seq in self.true_alignment)
                
                if current_col == true_col:
                    reward += 2.0  # Perfect column match
                else:
                    for a, b in zip(current_col, true_col):
                        if a == b and a != '-':
                            reward += 0.5  # AA match
                        elif a == '-' and b == '-':
                            reward += 0.1  # Correct gap
                        elif a != b and (a == '-' or b == '-'):
                            reward -= 0.3  # Incorrect gap
        
        # 3. BLOSUM score (normalized)
        if len(self.aligned[0]) > 0:
            last_col = [seq[-1] for seq in self.aligned]
            pairs = 0
            for i in range(self.n_seqs):
                for j in range(i+1, self.n_seqs):
                    a, b = last_col[i], last_col[j]
                    if a != '-' and b != '-':
                        try:
                            reward += Config.BLOSUM_MATRIX[a, b] * 0.05
                        except KeyError:
                            pass
                    elif a == '-' or b == '-':
                        reward += Config.GAP_PENALTY * 0.05
                    pairs += 1
            if pairs > 0:
                reward /= pairs  # Normalize by number of pairs
        
        # 4. Completion bonus (standardized)
        if done and all(len(seq) == 0 for seq in self.remaining):
            reward += 100.0
            
        return float(reward)

class SPProteinAlignmentEnv(BaseProteinAlignmentEnv):
    def _calculate_sp_score(self, current_alignment):
        """Calculate normalized SP score (Section 5.2)"""
        total_score = 0.0
        n_cols = len(current_alignment[0]) if current_alignment else 0
        
        for col_idx in range(n_cols):
            column = [seq[col_idx] for seq in current_alignment]
            
            for i in range(len(column)):
                for j in range(i+1, len(column)):
                    a, b = column[i], column[j]
                    
                    if a == '-' and b == '-':
                        pair_score = 0
                    elif a == '-' or b == '-':
                        # Apply appropriate gap penalty
                        if (a == '-' and self.gap_status[i]) or (b == '-' and self.gap_status[j]):
                            pair_score = Config.GAP_EXTEND_PENALTY
                        else:
                            pair_score = Config.GAP_OPEN_PENALTY
                    else:
                        try:
                            pair_score = Config.BLOSUM_MATRIX[a, b]
                        except KeyError:
                            pair_score = -4  # Default penalty for rare AA pairs
                    
                    total_score += pair_score
        
        # Normalize by number of columns and pairs
        if n_cols > 0:
            total_score /= (n_cols * (self.n_seqs * (self.n_seqs - 1)) / 2)
        
        return total_score


    def _calculate_reward(self, current_alignment, done):
        """Standardized SP score reward"""
        if not current_alignment:
            return 0.0
            
        # Calculate normalized SP score
        sp_score = self._calculate_sp_score(current_alignment)
        
        # Scale to target range (0-100 for perfect alignment)
        reward = sp_score * 50  # Adjusted scaling
        
        # Standard completion bonus
        if done and all(len(seq) == 0 for seq in self.remaining):
            reward += 100.0
            
        return float(reward)


    """RQ1 Pathway: Biological alignment rewards (Equation 2)"""
    def _calculate_reward(self, current_alignment, done):
        reward = 0
        
        # 1. Perfect alignment bonus (R_complete)
        if self.true_alignment and current_alignment == self.true_alignment:
            return 500.0
        
        # 2. Column-wise matching reward (5R_match + 2R_consensus + 0.5R_gap)
        if self.true_alignment:
            min_len = min(len(current_alignment[0]), len(self.true_alignment[0]))
            for col in range(min_len):
                current_col = tuple(seq[col] for seq in current_alignment)
                true_col = tuple(seq[col] for seq in self.true_alignment)
                
                if current_col == true_col:
                    reward += 5.0 * self.n_seqs  # Perfect column match
                else:
                    for a, b in zip(current_col, true_col):
                        if a == b and a != '-':
                            reward += 2.0  # AA match
                        elif a == '-' and b == '-':
                            reward += 0.5  # Correct gap
                        elif a != b and (a == '-' or b == '-'):
                            reward -= 1.0  # Incorrect gap
        
        # 3. BLOSUM score for the last column (70% of reward)
        if len(self.aligned[0]) > 0:
            last_col = [seq[-1] for seq in self.aligned]
            for i in range(self.n_seqs):
                for j in range(i+1, self.n_seqs):
                    a, b = last_col[i], last_col[j]
                    if a != '-' and b != '-':
                        try:
                            reward += Config.BLOSUM_MATRIX[a, b] * 0.7  # Weighted BLOSUM
                        except KeyError:
                            pass
                    elif a == '-' or b == '-':
                        reward += Config.GAP_PENALTY * 0.7  # Weighted gap penalty
        
        # 4. Length penalty (-0.5L_penalty)
        reward -= 0.5 * len(self.aligned[0])
        
        # 5. Completion bonus (100R_complete)
        if done and all(len(seq) == 0 for seq in self.remaining):
            reward += 100.0
            
        return float(reward)

class TrueAlignmentProteinEnv(BaseProteinAlignmentEnv):
    def __init__(self, sequences, true_alignment=None):
        super().__init__(sequences, true_alignment)
        # Dynamic penalty scaling based on sequence length
        self.base_gap_penalty = -2 * (1 + np.log10(max(10, self.max_len)))
        self.progressive_weight = 0.3  # Starts low, increases with progress

    def _calculate_reward(self, current_alignment, done):
        if not current_alignment or not self.true_alignment:
            return 0.0

        reward = 0
        current_progress = len(current_alignment[0])/self.max_len if self.max_len > 0 else 0
        self.progressive_weight = min(1.0, 0.3 + current_progress*0.7)  # Progressive weighting

        # 1. Column-wise alignment scoring (weighted by progress)
        min_cols = min(len(current_alignment[0]), len(self.true_alignment[0]))
        perfect_cols = 0
        
        for col in range(min_cols):
            current_col = [seq[col] for seq in current_alignment]
            true_col = [seq[col] for seq in self.true_alignment]
            
            # Perfect column match (high reward)
            if current_col == true_col:
                perfect_cols += 1
                reward += 8.0 * self.n_seqs * self.progressive_weight
            else:
                # Partial matches
                matches = sum(1 for c,t in zip(current_col, true_col) if c == t and c != '-')
                reward += 2.5 * matches * self.progressive_weight
                
                # Conservative mismatches penalty
                mismatches = sum(1 for c,t in zip(current_col, true_col) if c != t)
                reward -= 0.1 * mismatches

        # 2. Gap pattern scoring (less penalizing early on)
        gap_score = 0
        for seq_idx in range(self.n_seqs):
            aligned = current_alignment[seq_idx]
            true = self.true_alignment[seq_idx]
            
            # Compare gap patterns
            for c, t in zip(aligned, true):
                if c == '-' and t == '-':
                    gap_score += 0.7  # Reward correct gaps
                elif c == '-' or t == '-':
                    gap_score -= 0.3 * (1 - current_progress)  # Reduced early penalty
        
        reward += gap_score * self.n_seqs

        # 3. Completion and length normalization
        if done:
            completion_ratio = sum(1 for seq in self.remaining if len(seq) == 0)/self.n_seqs
            reward += 150.0 * completion_ratio * self.progressive_weight
            
        # Normalize by length and sequence count
        return reward / (min_cols * self.n_seqs) if min_cols > 0 else 0


### Alternative reward functions


These functions are taken from the following papers and have been modified to function as part of our base model implementation, with various scaling factors applied to balance with the original ground truth reward function. 

1. RLALIGN: Kinattinkara Ramakrishnan, R., Singh, J. & Blanchette, M. RLALIGN: A Reinforcement Learning Approach for Multiple Sequence Alignment. in 2018 IEEE 18th International Conference on Bioinformatics and Bioengineering (BIBE) 61–66 (2018). doi:10.1109/BIBE.2018.00019.
2. MSADRL: Joeres, R. Multiple Sequence Alignment using Deep Reinforcement Learning. in 101–112 (Gesellschaft für Informatik, Bonn, 2021).
3. DQNAlign: Song, Y.-J., Ji, D. J., Seo, H., Han, G.-B. & Cho, D.-H. Pairwise Heuristic Sequence Alignment Algorithm Based on Deep Reinforcement Learning. IEEE Open J Eng Med Biol 2, 36–43 (2021).
4. DNPMSA: Zhang, Y., Zhang, Q., Liu, Y., Lin, M. & Ding, C. Multiple Sequence Alignment based on deep Q network with negative feedback policy. Computational Biology and Chemistry 101, 107780 (2022).
5. EdgeAlign: Lall, A. & Tallur, S. Deep reinforcement learning-based pairwise DNA sequence alignment method compatible with embedded edge devices. Sci Rep 13, 2773 (2023).
6. DPAMSA: Liu, Y. et al. Multiple sequence alignment based on deep reinforcement learning with self-attention and positional encoding. Bioinformatics 39, btad636 (2023).
7. IntelliAlign: Κοτζιά, Ε. Solving multiple sequence alignment using deep reinforcement learning. Επίλυση ευθυγράμμισης πολλαπλών ακολουθιών χρησιμοποιώντας βαθιά ενισχυτική μάθηση (2024) doi:10.26265/polynoe-5855.

In [5]:

class RLALIGNProteinEnv(BaseProteinAlignmentEnv):
    def _calculate_reward(self, current_alignment, done):
        """Standardized RLALIGN reward"""
        if not current_alignment:
            return 0.0
            
        n_cols = len(current_alignment[0])
        n_pairs = (self.n_seqs * (self.n_seqs - 1)) // 2
        reward = 0.0
        
        # Calculate column scores
        for col in range(n_cols):
            column = [seq[col] for seq in current_alignment]
            for i in range(self.n_seqs):
                for j in range(i+1, self.n_seqs):
                    a, b = column[i], column[j]
                    if a == b and a != '-':
                        reward += 0.5  # Match
                    elif a != '-' and b != '-' and a != b:
                        reward -= 0.3  # Mismatch
                    elif a == '-' or b == '-':
                        reward -= 0.2  # Gap penalty
        
        # Normalize and scale
        if n_cols > 0 and n_pairs > 0:
            reward = (reward / (n_cols * n_pairs)) * 50
            
        # Standard completion bonus
        if done:
            reward += 100.0
            
        return float(reward)

class MSADRLProteinEnv(BaseProteinAlignmentEnv):
    def __init__(self, sequences, true_alignment=None, score_type='SP'):
        super().__init__(sequences, true_alignment)
        # MSADRL parameters
        self.score_type = score_type  # 'SP' or 'C'
        self.match_reward = 2.0       # Reward for matching amino acids
        self.mismatch_penalty = -1.0  # Penalty for mismatched amino acids
        self.gap_penalty = -5.0       # Penalty for gaps (protein-specific)
        
    def _calculate_sp_score(self, current_alignment):
        """Calculate SP score according to MSADRL paper"""
        total_score = 0.0
        n_cols = len(current_alignment[0]) if current_alignment else 0
        
        for col_idx in range(n_cols):
            column = [seq[col_idx] for seq in current_alignment]
            
            for i in range(len(column)):
                for j in range(i+1, len(column)):
                    a, b = column[i], column[j]
                    
                    if a == '-' and b == '-':
                        continue  # No score for gap-gap
                    elif a == '-' or b == '-':
                        total_score += self.gap_penalty
                    elif a == b:
                        total_score += self.match_reward
                    else:
                        total_score += self.mismatch_penalty
        
        # Normalize by number of columns and pairs
        if n_cols > 0:
            total_score /= (n_cols * (self.n_seqs * (self.n_seqs - 1)) / 2)
        
        return total_score
    
    def _calculate_c_score(self, current_alignment):
        """Calculate column score (fraction of perfectly aligned columns)"""
        if not current_alignment:
            return 0.0
            
        n_cols = len(current_alignment[0])
        perfect_cols = 0
        
        for col in range(n_cols):
            column = [seq[col] for seq in current_alignment]
            first = column[0]
            if all(c == first for c in column) and first != '-':
                perfect_cols += 1
                
        return perfect_cols / n_cols
      
    def _calculate_reward(self, current_alignment, done):
        """Standardized MSADRL reward"""
        if not current_alignment:
            return 0.0
            
        if self.score_type == 'SP':
            score = self._calculate_sp_score(current_alignment)
            reward = score * 50  # Adjusted scaling
        else:
            score = self._calculate_c_score(current_alignment)
            reward = score * 100  # C-score already 0-1
            
        # Standard completion bonus
        if done:
            reward += 100.0
            
        return float(reward)

class DQNAlignProteinEnv(BaseProteinAlignmentEnv):
    def __init__(self, sequences, true_alignment=None):
        super().__init__(sequences, true_alignment)
        # DQNALIGN parameters
        self.match_reward = 1.0      # Reward for matching amino acids
        self.mismatch_penalty = -1.0 # Penalty for mismatched amino acids
        self.gap_penalty = -1.0      # Penalty for gaps (no distinction between open/extend)
        
    def _calculate_column_score(self, column):
        """Calculate score for a single column"""
        score = 0
        for i in range(len(column)):
            for j in range(i+1, len(column)):
                a, b = column[i], column[j]
                
                if a == '-' and b == '-':
                    continue  # No score for gap-gap
                elif a == '-' or b == '-':
                    score += self.gap_penalty
                elif a == b:
                    score += self.match_reward
                else:
                    score += self.mismatch_penalty
        return score
    
    
    def _calculate_reward(self, current_alignment, done):
        """Standardized DQNAlign reward"""
        if not current_alignment:
            return 0.0
            
        # Score only last column
        last_column = [seq[-1] for seq in current_alignment]
        column_score = self._calculate_column_score(last_column)
        
        # Normalize by max possible column score
        max_pairs = (self.n_seqs * (self.n_seqs - 1)) // 2
        normalized_score = column_score / (max_pairs * self.match_reward) if max_pairs > 0 else 0
        
        # Scale and add completion bonus
        reward = normalized_score * 50
        if done:
            reward += 100.0
            
        return float(reward)

class DNPMSAProteinEnv(BaseProteinAlignmentEnv):
    def __init__(self, sequences, true_alignment=None):
        super().__init__(sequences, true_alignment)
        # DNPMSA-specific parameters
        self.MATCH_REWARD = 5.0       # Reward for matching residues
        self.MISMATCH_PENALTY = -3.0  # Penalty for mismatches
        self.GAP_PENALTY = -4.0       # Penalty for gaps
        self._max_reward = self._calculate_max_reward()
        
    def _calculate_max_reward(self):
        """Calculate maximum possible reward for normalization"""
        max_cols = self.max_len * 2  # Worst case: all gaps
        max_pairs = (self.n_seqs * (self.n_seqs - 1)) // 2
        return max_cols * max_pairs * self.MATCH_REWARD
      
    
    def _calculate_reward(self, current_alignment, done):
        """Standardized DNPMSA reward"""
        if not current_alignment:
            return 0.0
            
        # Calculate raw score
        reward = 0.0
        n_cols = len(current_alignment[0])
        for col in range(n_cols):
            for i in range(self.n_seqs):
                for j in range(i+1, self.n_seqs):
                    a, b = current_alignment[i][col], current_alignment[j][col]
                    if a == '-' or b == '-':
                        reward += self.GAP_PENALTY
                    elif a == b:
                        reward += self.MATCH_REWARD
                    else:
                        reward += self.MISMATCH_PENALTY
        
        # Normalize and scale
        normalized = reward / max(1, self._max_reward)
        reward = normalized * 100
        
        # Standard completion bonus
        if done:
            reward += 100.0
            
        return float(reward)

class EdgeAlignProteinEnv(BaseProteinAlignmentEnv):
    def __init__(self, sequences, true_alignment=None):
        super().__init__(sequences, true_alignment)
        # EDGEALIGN parameters (normalized values from paper)
        self.match_reward = 1.0         # Reward for matching amino acids
        self.mismatch_penalty = -0.6    # Penalty for mismatched amino acids
        self.gap_open_penalty = -1.0    # Penalty for opening a new gap
        self.gap_extend_penalty = -0.4  # Penalty for extending an existing gap
        
    def _calculate_column_score(self, column):
        """Calculate score for a single column with gap tracking"""
        score = 0
        for i in range(len(column)):
            for j in range(i+1, len(column)):
                a, b = column[i], column[j]
                
                if a == '-' and b == '-':
                    continue  # No score for gap-gap
                elif a == '-' or b == '-':
                    # Check if this is a gap opening or extension
                    if (a == '-' and not self.gap_status[i]) or (b == '-' and not self.gap_status[j]):
                        score += self.gap_open_penalty
                    else:
                        score += self.gap_extend_penalty
                elif a == b:
                    score += self.match_reward
                else:
                    score += self.mismatch_penalty
        return score
    def _calculate_reward(self, current_alignment, done):
        """Standardized EdgeAlign reward"""
        if not current_alignment:
            return 0.0
            
        # Score only last column
        last_column = [seq[-1] for seq in current_alignment]
        column_score = self._calculate_column_score(last_column)
        
        # Normalize by max possible
        max_pairs = (self.n_seqs * (self.n_seqs - 1)) // 2
        max_score = max_pairs * self.match_reward
        normalized = column_score / max_score if max_score > 0 else 0
        
        # Scale and add completion bonus
        reward = normalized * 50
        if done:
            reward += 100.0
            
        return float(reward)

class DPAMSAProteinEnv(BaseProteinAlignmentEnv):
    def __init__(self, sequences, true_alignment=None):
        super().__init__(sequences, true_alignment)
        # DPAMSA parameters
        self.match_reward = 2        # Reward for matching amino acids
        self.mismatch_penalty = -1   # Penalty for mismatched amino acids
        self.gap_penalty = -2        # Penalty for gaps
        self.max_sp_per_column = self._calculate_max_sp_per_column()
        
    def _calculate_max_sp_per_column(self):
        """Calculate maximum possible SP score for one column (all matches)"""
        n = self.n_seqs
        return (n * (n - 1)) // 2 * self.match_reward
        
    def _calculate_column_score(self, column):
        """Calculate SP score for a single column"""
        score = 0
        for i in range(len(column)):
            for j in range(i+1, len(column)):
                a, b = column[i], column[j]
                if a == '-' or b == '-':
                    score += self.gap_penalty
                elif a == b:
                    score += self.match_reward
                else:
                    score += self.mismatch_penalty
        return score
        
    def _calculate_reward(self, current_alignment, done):
        """Standardized DPAMSA reward"""
        if not current_alignment:
            return 0.0
            
        # Score only last column
        last_column = [seq[-1] for seq in current_alignment]
        column_score = self._calculate_column_score(last_column)
        
        # Normalize and scale
        normalized = column_score / self.max_sp_per_column if self.max_sp_per_column > 0 else 0
        reward = normalized * 50
        
        # Standard completion bonus
        if done:
            reward += 100.0
            
        return float(reward)

class IntelliAlignProteinEnv(BaseProteinAlignmentEnv):
    def __init__(self, sequences, true_alignment=None):
        super().__init__(sequences, true_alignment)
        # IntelliAlign parameters
        self.reward_values = {
            'match': 2,       # LL: Letter with same letter
            'mismatch': -2,   # LDL: Letter with different letter
            'gap_letter': -1, # GL: Gap with letter
            'gap_gap': 0      # GG: Gap with gap
        }
        self.gap_costs = {
            'open': -10,     # GOC: Gap opening cost
            'extend': -1      # GEC: Gap extension cost
        }
        self.weights = {
            'sp': 0.6,       # Weight for Sum-of-Pairs
            'tc': 0.4        # Weight for Total Column
        }

    def _calculate_sp_score(self, alignment):
        """Calculate Sum-of-Pairs score with IntelliAlign parameters"""
        total = 0.0
        n_cols = len(alignment[0]) if alignment else 0
        n_seqs = len(alignment)
        
        for col in range(n_cols):
            column = [seq[col] for seq in alignment]
            
            for i in range(n_seqs):
                for j in range(i+1, n_seqs):
                    a, b = column[i], column[j]
                    
                    if a == '-' and b == '-':
                        total += self.reward_values['gap_gap']
                    elif a == '-' or b == '-':
                        total += self.reward_values['gap_letter']
                    elif a == b:
                        total += self.reward_values['match']
                    else:
                        total += self.reward_values['mismatch']
        
        # Normalize by number of columns and pairs
        if n_cols > 0:
            total /= (n_cols * (n_seqs * (n_seqs - 1)) / 2)
        
        return total

    def _calculate_tc_score(self, alignment):
        """Calculate Total Column score (fully conserved columns)"""
        if not alignment:
            return 0.0
            
        n_cols = len(alignment[0])
        tc = 0.0
        
        for col in range(n_cols):
            column = [seq[col] for seq in alignment]
            first = column[0]
            if all(c == first for c in column) and first != '-':
                tc += 1.0
        
        # Normalize by number of columns
        return tc / n_cols

    def _calculate_affine_gap_penalty(self, alignment):
        """Calculate affine gap penalty for the entire alignment"""
        if not alignment:
            return 0.0
            
        total_penalty = 0.0
        n_seqs = len(alignment)
        
        for seq in alignment:
            in_gap = False
            for char in seq:
                if char == '-':
                    if not in_gap:
                        # Gap opening
                        total_penalty += self.gap_costs['open']
                        in_gap = True
                    else:
                        # Gap extension
                        total_penalty += self.gap_costs['extend']
                else:
                    in_gap = False
        
        # Normalize by number of sequences and columns
        n_cols = len(alignment[0])
        return total_penalty / (n_seqs * n_cols)

    def _calculate_reward(self, current_alignment, done):
        """Standardized IntelliAlign reward"""
        if not current_alignment:
            return 0.0
            
        # Calculate components
        sp_score = self._calculate_sp_score(current_alignment)  # Already normalized
        tc_score = self._calculate_tc_score(current_alignment)  # 0-1
        gap_penalty = self._calculate_affine_gap_penalty(current_alignment)
        
        # Weighted combination (SP and TC already comparable)
        reward = (0.6 * sp_score + 0.4 * tc_score + gap_penalty) * 50
        
        # Standard completion bonus
        if done:
            reward += 100.0
            
        return float(reward)

### Results

Production-grade alignment systems often require weeks of training on GPU clusters. This is not a system for production-grade alignments, this is a piece of research investigating how different reward structures affect alignments. Specifically, wether or not SP scores should be the de facto scoring funciton used within the reward structure. Interesting things to explore here include how creating more diverse datasets affects the model.

In [8]:
# Easy Curriculum Parameters
Config.CURRICULUM_LEVELS = 7
Config.MIN_SEQ_LENGTH = 8
Config.MAX_SEQ_LENGTH = 20
Config.MIN_SEQUENCES = 2
Config.MAX_SEQUENCES = 4
Config.NUM_OF_SEQUENCES = 10  # Replicates per level

Config.INDEL_MODEL_PARAMS = {
        'pow_shape': 1.7,    # Power law shape parameter
        'pow_scale': 100,    # Power law scale parameter
        'min_indel_rate': 0.03,  # Starting indel rate
        'max_indel_rate': 0.03   # Maximum indel rate
    }

easy_curriculum = generate_synthetic_curriculum()

Generated level 1 with 10 alignment blocks
Generated level 2 with 10 alignment blocks
Generated level 3 with 10 alignment blocks
Generated level 4 with 10 alignment blocks
Generated level 5 with 10 alignment blocks
Generated level 6 with 10 alignment blocks
Generated level 7 with 10 alignment blocks


In [9]:
all_results_on_easy_curriculum = {}
def save_results(results, filename):
            """Save training results (without agent objects)"""
            serializable_results = {}
            for name, data in results.items():
                serializable_results[name] = {
                    'level': data['level'],
                    'scoring_method': data['scoring_method'],
                    'training_rewards': data['training_rewards'],
                    'losses': data['losses'],
                    'eval_results': data['eval_results'],
                }
            
            with open(filename, 'w') as f:
                json.dump(serializable_results, f)

for scoring_method in ['blosum', 'true', 'rlalign', 'msadrl', 'edgealign', 'dpamsa', 'intellialign']:
    try:
        agent, results = train_agent(easy_curriculum, scoring_method)

        all_results_on_easy_curriculum[scoring_method] = {
            'agent': agent,
            'results': results
        }
        
        save_results(results, f"{scoring_method}_model.json")
        
        if agent is not None:
            torch.save(agent.policy_net.state_dict(), f"{scoring_method}_model.pth")
        
        
    except Exception as e:
        print(f"Error processing {scoring_method}: {str(e)}")
        continue


=== Starting Level 1-0 (BLOSUM scoring) ===
Episode 0, Avg Reward: -522.47, Avg Loss: 0.0000, Epsilon: 0.846
Episode 50, Avg Reward: -334.80, Avg Loss: 19.5044, Epsilon: 0.687
Episode 100, Avg Reward: 23.42, Avg Loss: 15.7560, Epsilon: 0.558
Episode 150, Avg Reward: 99.32, Avg Loss: 12.5231, Epsilon: 0.453
Episode 200, Avg Reward: 111.59, Avg Loss: 11.9338, Epsilon: 0.367
Episode 250, Avg Reward: 113.40, Avg Loss: 11.3861, Epsilon: 0.298
Episode 300, Avg Reward: 117.76, Avg Loss: 10.9151, Epsilon: 0.242
Episode 350, Avg Reward: 116.37, Avg Loss: 10.7731, Epsilon: 0.196
Early stopping at episode 358 due to no improvement

Example produced alignments:
  ['QPNLSFHFG-', 'GP-QLSFHFG', 'QPQMSFSFG-']
  ['QPNLSFHFG', 'GPQLSFHFG', 'QPQMSFSFG']
  ['QPNLSFHFG-', '-GPQLSFHFG', 'QPQMSFSFG-']

Evaluation for Level 1-0:
  Average Reward: 122.72 ± 4.02
  Average SP Score: 2.43 ± 2.09
  Perfect Match Rate: 0.65
  Column Accuracy: 0.66
  Avg Alignment Length: 9.3

=== Starting Level 1-1 (BLOSUM scoring

### Plotting

In [None]:
import os
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from sklearn.decomposition import PCA
import pandas as pd

# Create plots directory
os.makedirs("plots", exist_ok=True)

def save_selected_plots(all_results):
    """Save the 5 selected plots for the paper"""
    
    # 1. Key Metrics Comparison
    plt.figure(figsize=(12, 8))
    metrics_to_show = ['avg_reward', 'avg_sp_score', 'perfect_match_rate']
    methods = list(all_results.keys())
    colors = plt.cm.viridis(np.linspace(0, 1, len(methods)))
    
    x = np.arange(len(metrics_to_show))
    width = 0.8 / len(methods)
    
    for i, method in enumerate(methods):
        avg_values = []
        for metric in metrics_to_show:
            values = [data['eval_results'][metric] 
                     for data in all_results[method]['results'].values()]
            avg_values.append(np.mean(values))
        
        plt.bar(x + i*width, avg_values, width, color=colors[i], label=method)
    
    plt.title("Key Metrics Comparison Across Scoring Methods")
    plt.xticks(x + width*(len(methods)-1)/2, 
              ['Avg Reward', 'Avg SP Score', 'Perfect Match %'])
    plt.ylabel("Score")
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig("plots/key_metrics_comparison.png")
    plt.close()
    
    # 2. Reward vs SP Score Correlation
    plt.figure(figsize=(10, 8))
    for i, method in enumerate(methods):
        rewards = []
        sp_scores = []
        for data in all_results[method]['results'].values():
            rewards.append(data['eval_results']['avg_reward'])
            sp_scores.append(data['eval_results']['avg_sp_score'])
        
        plt.scatter(rewards, sp_scores, color=colors[i], s=100, alpha=0.7, label=method)
        z = np.polyfit(rewards, sp_scores, 1)
        p = np.poly1d(z)
        plt.plot(rewards, p(rewards), color=colors[i], linestyle='--', alpha=0.5)
    
    plt.title("Reward vs SP Score Correlation")
    plt.xlabel("Average Reward")
    plt.ylabel("Average SP Score")
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig("plots/reward_sp_correlation.png")
    plt.close()
    
    # 3. Perfect Match Rate by Level
    plt.figure(figsize=(12, 8))
    for i, method in enumerate(methods):
        levels = sorted(set(data['level'] for data in all_results[method]['results'].values()))
        perfect_rates = []
        for level in levels:
            level_data = [data['eval_results']['perfect_match_rate'] 
                         for data in all_results[method]['results'].values() 
                         if data['level'] == level]
            perfect_rates.append(np.mean(level_data))
        
        plt.plot(levels, perfect_rates, 'o-', color=colors[i], label=method, linewidth=2.5)
    
    plt.title("Perfect Match Rate by Curriculum Level")
    plt.xlabel("Curriculum Level")
    plt.ylabel("Perfect Match Rate")
    plt.xticks(range(1, 8))
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig("plots/perfect_match_by_level.png")
    plt.close()
    
    # 4. SP Score vs True Alignment Match Rate
    plt.figure(figsize=(10, 8))
    true_rates = {data['level']: data['eval_results']['perfect_match_rate'] 
                 for data in all_results['true']['results'].values()}
    
    methods_rq2 = [m for m in methods if m != 'true']
    for i, method in enumerate(methods_rq2):
        sp_scores = []
        match_rates = []
        for data in all_results[method]['results'].values():
            level = data['level']
            sp_scores.append(data['eval_results']['avg_sp_score'])
            match_rates.append(true_rates.get(level, 0))
        
        corr = np.corrcoef(sp_scores, match_rates)[0,1]
        sc = plt.scatter(sp_scores, match_rates, color=colors[i], 
                        label=f"{method} (r={corr:.2f})", alpha=0.7)
        z = np.polyfit(sp_scores, match_rates, 1)
        p = np.poly1d(z)
        plt.plot(sp_scores, p(sp_scores), color=colors[i], linestyle='--')
    
    plt.title("SP Score vs True Alignment Match Rate")
    plt.xlabel("SP Score")
    plt.ylabel("True Alignment Match Rate")
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig("plots/sp_vs_true_match.png")
    plt.close()
    
    # 5. Alignment Quality Divergence Heatmap
    plt.figure(figsize=(12, 8))
    true_alignments = {data['level']: data['eval_results']['perfect_match_rate'] 
                      for data in all_results['true']['results'].values()}
    
    divergence = []
    for method in methods:
        method_rates = []
        for level in range(1, 8):
            level_data = [data['eval_results']['perfect_match_rate'] 
                         for data in all_results[method]['results'].values() 
                         if data['level'] == level]
            method_rates.append(np.mean(level_data) if level_data else 0)
        divergence.append([abs(m - true_alignments.get(level, 0)) for level, m in zip(range(1,8), method_rates)])
    
    sns.heatmap(divergence, cmap="YlOrRd", 
                xticklabels=range(1,8), yticklabels=methods,
                annot=True, fmt=".2f")
    plt.title("Alignment Quality Divergence from True Alignment")
    plt.xlabel("Curriculum Level")
    plt.ylabel("Scoring Method")
    plt.tight_layout()
    plt.savefig("plots/divergence_heatmap.png")
    plt.close()

    # 6. SP Score Distribution by Method
    plt.figure(figsize=(12, 8))    
    sp_data = []
    for method in methods:
        for data in all_results[method]['results'].values():
            sp_data.append({
                'Method': method,
                'SP Score': data['eval_results']['avg_sp_score']
            })
    
    df = pd.DataFrame(sp_data)
    sns.violinplot(data=df, x='Method', y='SP Score', palette=colors)
    plt.title("SP Score Distribution by Method")
    plt.xticks(rotation=45)
    
    plt.tight_layout()
    plt.savefig("plots/sp_distribution.png")
    plt.close()

# Generate and save plots
save_selected_plots(all_results_on_easy_curriculum)


Passing `palette` without assigning `hue` is deprecated and will be removed in v0.14.0. Assign the `x` variable to `hue` and set `legend=False` for the same effect.

  sns.violinplot(data=df, x='Method', y='SP Score', palette=colors)
  sns.violinplot(data=df, x='Method', y='SP Score', palette=colors)


In [12]:
def plot_divergence_heatmap(all_results):
    plt.figure(figsize=(12, 8))
    
    # Get true alignment rates by level
    true_rates = {
        data['level']: data['eval_results']['perfect_match_rate'] 
        for data in all_results['true']['results'].values()
    }
    
    # Calculate divergence for each method
    methods = [m for m in all_results.keys() if m != 'true']  # Exclude 'true' itself
    levels = sorted(set(data['level'] for data in all_results['true']['results'].values()))
    
    divergence = []
    for method in methods:
        method_rates = []
        for level in levels:
            level_data = [
                data['eval_results']['perfect_match_rate'] 
                for data in all_results[method]['results'].values() 
                if data['level'] == level
            ]
            avg_rate = np.mean(level_data) if level_data else 0
            true_rate = true_rates.get(level, 0)
            method_rates.append(abs(avg_rate - true_rate))
        divergence.append(method_rates)
    
    # Plot heatmap
    sns.heatmap(
        divergence, 
        cmap="YlOrRd", 
        xticklabels=levels, 
        yticklabels=methods,
        annot=True, 
        fmt=".2f",
        linewidths=0.5
    )
    
    plt.title("Divergence from True Alignment by Level", pad=20)
    plt.xlabel("Curriculum Level")
    plt.ylabel("Alignment Method")
    plt.tight_layout()
    plt.savefig("plots/divergence_heatmap.png")
    plt.close()

plot_divergence_heatmap(all_results_on_easy_curriculum)

In [11]:
def plot_state_sp_correlation(all_results):
    """Plot correlation between state features and SP scores"""
    plt.figure(figsize=(15, 6))
    
    # Get feature importance from evaluation
    methods = list(all_results.keys())
    feature_importances = []
    
    for method in methods:
        # Collect all feature importance arrays for this method
        imp_arrays = []
        for data in all_results[method]['results'].values():
            if 'state_feature_importance' in data['eval_results']:
                imp = data['eval_results']['state_feature_importance']
                if imp is not None and len(imp) > 0:
                    # Convert to numpy array and replace NaN with 0
                    arr = np.nan_to_num(np.array(imp), nan=0.0)
                    imp_arrays.append(arr)
        
        if imp_arrays:
            # Find the maximum length across all arrays
            max_len = max(len(arr) for arr in imp_arrays)
            # Pad shorter arrays with zeros
            padded_arrays = []
            for arr in imp_arrays:
                if len(arr) < max_len:
                    padded = np.zeros(max_len)
                    padded[:len(arr)] = arr
                    padded_arrays.append(padded)
                else:
                    padded_arrays.append(arr)
            
            # Calculate average importance (ignoring zeros from padding)
            avg_imp = np.mean(padded_arrays, axis=0)
            feature_importances.append((method, avg_imp))
    
    # Plot if we have data
    if feature_importances:
        # Get the first method's feature importance to determine groups
        _, first_imp = feature_importances[0]
        n_features = len(first_imp)
        
        # Define feature groups based on your state representation
        # These numbers should match your state vector construction
        aa_features = len(Config.AMINO_ACIDS)  # Number of amino acids
        lookahead_features = 3 * aa_features   # 3 lookahead positions
        progress_features = 2 * 7              # Assuming 7 sequences (2 features per seq)
        gap_features = 4                       # 4 gap context features
        
        # Calculate group boundaries
        group_boundaries = [
            ('AA Composition', 0, aa_features),
            ('Lookahead 1', aa_features, aa_features),
            ('Lookahead 2', 2*aa_features, aa_features),
            ('Lookahead 3', 3*aa_features, aa_features),
            ('Progress', 4*aa_features, progress_features),
            ('Gap Context', 4*aa_features + progress_features, gap_features)
        ]
        
        # Verify we're not exceeding available features
        total_defined = sum(g[2] for g in group_boundaries)
        if total_defined < n_features:
            group_boundaries.append(('Other', total_defined, n_features - total_defined))
        
        # Prepare data for plotting
        plot_data = []
        for method, imp in feature_importances:
            for group, start, length in group_boundaries:
                end = start + length
                if end > len(imp):
                    end = len(imp)
                if start < end:
                    group_mean = np.mean(imp[start:end])
                    plot_data.append({'Method': method, 'Feature Group': group, 'Mean Correlation': group_mean})
        
        # Convert to DataFrame for easier plotting
        df = pd.DataFrame(plot_data)
        
        # Create the plot
        plt.figure(figsize=(14, 6))
        sns.barplot(data=df, x='Feature Group', y='Mean Correlation', hue='Method', 
                   palette='viridis', errorbar=None)
        
        plt.title("State Feature Correlation with SP Scores by Group")
        plt.ylabel("Absolute Correlation Coefficient")
        plt.legend(title="Method")
        plt.grid(True, alpha=0.3)
        
        # Rotate x-axis labels if needed
        if len(group_boundaries) > 5:
            plt.xticks(rotation=45, ha='right')
        
        plt.tight_layout()
        plt.savefig("plots/state_sp_correlation.png")
        plt.close()
    else:
        print("Warning: No state feature importance data found in results")

plot_state_sp_correlation(all_results_on_easy_curriculum)

<Figure size 1500x600 with 0 Axes>

In [13]:
def plot_normalized_metrics_radar(all_results):
    """Plot normalized key metrics using a radar chart"""
    # Metrics to compare (use same as before)
    metrics = ['avg_reward', 'avg_sp_score', 'perfect_match_rate']
    metric_labels = ['Avg Reward', 'Avg SP Score', 'Perfect Match %']
    methods = list(all_results.keys())
    
    # Normalize each metric to [0,1] range across methods
    normalized_data = {}
    for metric in metrics:
        # Get all values for this metric across methods
        values = []
        for method in methods:
            method_values = [data['eval_results'][metric] 
                           for data in all_results[method]['results'].values()]
            values.append(np.mean(method_values))
        
        # Normalize (min-max scaling)
        min_val, max_val = min(values), max(values)
        for i, method in enumerate(methods):
            if max_val - min_val > 0:
                norm_val = (values[i] - min_val) / (max_val - min_val)
            else:
                norm_val = 0.5  # if all values are equal
            normalized_data.setdefault(method, []).append(norm_val)
    
    # Prepare radar plot
    angles = np.linspace(0, 2*np.pi, len(metrics), endpoint=False).tolist()
    angles += angles[:1]  # Close the loop
    
    fig, ax = plt.subplots(figsize=(10, 10), subplot_kw={'polar': True})
    
    # Plot each method
    colors = plt.cm.viridis(np.linspace(0, 1, len(methods)))
    for idx, method in enumerate(methods):
        values = normalized_data[method]
        values += values[:1]  # Close the loop
        ax.plot(angles, values, color=colors[idx], linewidth=2, label=method)
        ax.fill(angles, values, color=colors[idx], alpha=0.25)
    
    # Customize plot
    ax.set_theta_offset(np.pi/2)
    ax.set_theta_direction(-1)
    ax.set_thetagrids(np.degrees(angles[:-1]), metric_labels)
    
    # Set yticks to show original scale reference
    ax.set_rlabel_position(180)
    plt.yticks([0, 0.5, 1], ["Worst", "Medium", "Best"], color="grey", size=8)
    plt.ylim(0, 1)
    
    plt.title("Normalized Performance Comparison", pad=20)
    plt.legend(loc='upper right', bbox_to_anchor=(1.3, 1.1))
    plt.tight_layout()
    plt.savefig("plots/normalized_metrics_radar.png", dpi=300, bbox_inches='tight')
    plt.close()

plot_normalized_metrics_radar(all_results_on_easy_curriculum)

In [14]:
def plot_metrics_small_multiples(all_results):
    """Plot each metric in its own subplot"""
    metrics = {
        'avg_reward': 'Avg Reward',
        'avg_sp_score': 'Avg SP Score', 
        'perfect_match_rate': 'Perfect Match %'
    }
    methods = list(all_results.keys())
    colors = plt.cm.viridis(np.linspace(0, 1, len(methods)))
    
    fig, axes = plt.subplots(1, len(metrics), figsize=(15, 5))
    
    for ax, (metric, label) in zip(axes, metrics.items()):
        # Calculate means for each method
        method_means = []
        for method in methods:
            values = [data['eval_results'][metric] 
                     for data in all_results[method]['results'].values()]
            method_means.append(np.mean(values))
        
        # Plot with separate y-scales
        ax.bar(methods, method_means, color=colors)
        ax.set_title(label)
        ax.tick_params(axis='x', rotation=45)
        
        # Add value labels
        for i, v in enumerate(method_means):
            ax.text(i, v, f"{v:.2f}", ha='center', va='bottom')
    
    plt.suptitle("Key Metrics Comparison", y=1.05)
    plt.tight_layout()
    plt.savefig("plots/metrics_small_multiples.png", dpi=300)
    plt.close()

plot_metrics_small_multiples(all_results_on_easy_curriculum)