<a href="https://colab.research.google.com/github/sierra-hunt/github-and-kaggle-ML-work/blob/main/RL_for_DNA_Sequence_Alignment_and_Phylogenetic_Tree_Construction.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Phylogenetic tree building whilst aligning muptiple sequences through Reinforcement Learning

This research prototype implements a reinforcement learning framework for simultaneously optimizing DNA sequence alignment and phylogenetic tree construction. The system uses DNABERT-2 transformer embeddings to represent genetic sequences, then applies Monte Carlo Tree Search (MCTS) to explore alignment modifications and tree adjustments. The approach considers multiple RL strategies (Q-Learning, Policy Gradient, PPO) for this dual-optimization problem, balancing alignment quality with phylogenetic consistency. The implementation includes parallel simulation, Newick tree parsing, and a complete RL environment for bioinformatics optimization.

**This wasn't for my university at all, and is just something I was working on in my spare time as a possible extension of the work I did for my Masters. This project demonstrates a *designed but not fully implemented* reinforcement learning system for phylogenetic optimization.**


Rough look into the reasoning for my approach here:

Q-Learning / DQN
- Discrete action space, simple alignment modifications
- Efficient, works well with small to medium-sized datasets
- May struggle with large sequence states; requires dimensionality reduction

Policy Gradient Methods
- Continuous or complex action space (e.g., fine-tuning alignments)
- Can model complex actions, better for interdependent decisions
- Sample inefficient, requires fine-tuning for stability

Proximal Policy Optimization (PPO)
- Complex, continuous, interdependent actions, stability-focused
- Stable, sample-efficient, robust learning
- Computationally expensive, requires careful tuning

MCTS
- State representation is an explicit search tree representation of actions (alignment changes or tree modifications)
- Does not require a predefined reward function or policy, making it easier to adapt to unknown or complex problem domains.
-  Stable as it relies on incremental backpropagation of rewards
- The effectiveness of MCTS depends heavily on the quality of the simulations (playouts).
- Computationally expensive

In [None]:
import torch
from transformers import AutoTokenizer, AutoModel
from transformers.models.bert.configuration_bert import BertConfig
import json
import os
import csv  # Ensure to import csv for reading the file

# Load the DNABERT-2 tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("zhihan1996/DNABERT-2-117M", trust_remote_code=True)
config = BertConfig.from_pretrained("zhihan1996/DNABERT-2-117M")
model = AutoModel.from_pretrained("zhihan1996/DNABERT-2-117M", trust_remote_code=True, config=config)

# Define file paths
csv_file = r"/content/drive/MyDrive/2024-25 RL/Prototype/simulations_data_unaligned_MSA_and_trees.csv"
embedding_dir = r"/content/drive/MyDrive/2024-25 RL/Prototype/embeddings"
os.makedirs(embedding_dir, exist_ok=True)

def tokenize_and_embed(sequence, kmer_size=6):
    """Tokenizes and embeds a DNA sequence with both mean and max pooling."""
    # Split sequence into kmers
    kmer_tokens = [sequence[i:i+kmer_size] for i in range(len(sequence) - kmer_size + 1)]
    # Tokenize the kmers
    encoded_input = tokenizer(" ".join(kmer_tokens), return_tensors="pt", truncation=True)

    # Get model output and compute the embeddings
    with torch.no_grad():
        outputs = model(**encoded_input)

        # If the model returns a tuple, the first element contains the hidden states
        hidden_states = outputs[0]  # Shape: [1, sequence_length, 768]

        # Mean pooling: Average hidden states across the sequence length dimension
        embedding_mean = torch.mean(hidden_states[0], dim=0).squeeze().numpy()

        # Max pooling: Take the max hidden state across the sequence length dimension
        embedding_max = torch.max(hidden_states[0], dim=0)[0].squeeze().numpy()

    # Convert NumPy arrays to lists for JSON serialization
    return embedding_mean.tolist(), embedding_max.tolist()

# Open and read the CSV file
with open(csv_file, 'r', encoding='utf-8') as file:
    reader = csv.DictReader(file)
    for row in reader:
        sim_id = row["Simulation"]
        sim_data = json.loads(row["Data (JSON)"])
        unaligned_sequences = sim_data["unaligned_sequences"]

        embeddings_mean = {}
        embeddings_max = {}

        # Process each sequence and compute embeddings
        for seq_id, seq in unaligned_sequences.items():
            embedding_mean, embedding_max = tokenize_and_embed(seq)
            embeddings_mean[seq_id] = embedding_mean
            embeddings_max[seq_id] = embedding_max

        # Save embeddings to JSON files
        output_mean_file = os.path.join(embedding_dir, f"{sim_id}_mean_embeddings.json")
        output_max_file = os.path.join(embedding_dir, f"{sim_id}_max_embeddings.json")

        # Save embeddings for mean pooling
        with open(output_mean_file, 'w') as emb_file_mean:
            json.dump(embeddings_mean, emb_file_mean)

        # Save embeddings for max pooling
        with open(output_max_file, 'w') as emb_file_max:
            json.dump(embeddings_max, emb_file_max)

print(f"Embeddings saved in: {embedding_dir}")


Some weights of BertModel were not initialized from the model checkpoint at zhihan1996/DNABERT-2-117M and are newly initialized: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


In [None]:
# Representing the Phylogenetic Tree as an Adjacency Matrix

import numpy as np
import json

def parse_newick_tree(newick):
    """Parses a Newick string into a tree dictionary."""
    from ete3 import Tree
    t = Tree(newick, format=1)
    nodes = list(t.traverse())
    node_map = {node.name: idx for idx, node in enumerate(nodes)}
    adjacency_matrix = np.zeros((len(nodes), len(nodes)))

    for node in nodes:
        if not node.is_leaf():
            for child in node.children:
                adjacency_matrix[node_map[node.name], node_map[child.name]] = child.dist
                adjacency_matrix[node_map[child.name], node_map[node.name]] = child.dist

    return adjacency_matrix, node_map



In [None]:
from concurrent.futures import ThreadPoolExecutor
import numpy as np
import json


class MCTSNode:
    def __init__(self, state, parent=None, action=None):
        self.state = state
        self.parent = parent
        self.children = []
        self.visits = 0
        self.value = 0.0
        self.action = action  # Store the action that led to this node

    def is_fully_expanded(self):
        """Check if the node has children (fully expanded)."""
        return len(self.children) > 0

    def best_child(self, exploration_weight=1.0):
        """Select the best child based on value and exploration."""
        scores = [
            (child.value / (child.visits + 1e-6)) +
            exploration_weight * np.sqrt(np.log(self.visits + 1) / (child.visits + 1e-6))
            for child in self.children
        ]
        return self.children[np.argmax(scores)]


class MCTSState:
    def __init__(self, embeddings, tree_adj_matrix, tree_node_mapping, alignment_score=0.0, tree_score=0.0):
        self.embeddings = embeddings
        self.tree_adj_matrix = tree_adj_matrix
        self.tree_node_mapping = tree_node_mapping
        self.alignment_score = alignment_score
        self.tree_score = tree_score

    def get_combined_score(self, alpha=0.5):
        """Combines alignment and tree scores."""
        return alpha * self.alignment_score + (1 - alpha) * self.tree_score


def initialize_state(simulation_id, embedding_dir, newick_tree):
    """Initialize MCTS state from embedding and tree data."""
    embedding_file = f"{embedding_dir}/{simulation_id}_embeddings.json"
    with open(embedding_file, 'r') as file:
        embeddings = json.load(file)

    tree_adj_matrix, tree_node_mapping = parse_newick_tree(newick_tree)

    alignment_score = compute_alignment_score(embeddings)  # Replace with real scoring function
    tree_score = compute_tree_score(tree_adj_matrix)       # Replace with real scoring function

    return MCTSState(
        embeddings=embeddings,
        tree_adj_matrix=tree_adj_matrix,
        tree_node_mapping=tree_node_mapping,
        alignment_score=alignment_score,
        tree_score=tree_score
    )


def simulate(state):
    """Placeholder for a domain-specific simulation."""
    alignment_score = np.random.rand()  # Replace with real scoring function
    tree_score = np.random.rand()      # Replace with real scoring function
    return 0.5 * alignment_score + 0.5 * tree_score


def expand(node):
    """Expand a node by generating children based on alignment and tree actions."""
    state = node.state
    new_states = []

    # Generate alignment actions
    alignment_actions = generate_alignment_actions(state.embeddings.values())

    # For each alignment action, generate corresponding tree actions
    for align_action in alignment_actions:
        updated_sequences = apply_alignment_action(state.embeddings.values(), align_action)

        # Generate tree actions based on the updated alignment
        tree_actions = generate_tree_actions(state.tree_adj_matrix, updated_sequences)

        for tree_action in tree_actions:
            compound_action = {"alignment": align_action, "tree": tree_action}
            new_state = apply_compound_action(state, compound_action)
            new_states.append(MCTSNode(state=new_state, parent=node, action=compound_action))

    return new_states


def apply_compound_action(state, action):
    """Apply both alignment and tree actions to generate a new state."""
    # Apply alignment action
    alignment_action = action["alignment"]
    new_sequences = apply_alignment_action(state.embeddings.values(), alignment_action)

    # Update tree based on new alignment
    tree_action = action["tree"]
    new_tree_adj_matrix = modify_tree_based_on_action(state.tree_adj_matrix, tree_action, new_sequences)

    return MCTSState(
        embeddings=compute_embeddings(new_sequences),
        tree_adj_matrix=new_tree_adj_matrix,
        tree_node_mapping=state.tree_node_mapping,
        alignment_score=compute_alignment_score(new_sequences),
        tree_score=compute_tree_score(new_tree_adj_matrix)
    )


def insert_gap(sequence, position):
    """Inserts a gap at a specific position in the sequence."""
    return sequence[:position] + "-" + sequence[position:]


def shift_sequence(sequence, shift):
    """Shifts the sequence by adding gaps to the left or right."""
    if shift > 0:
        return "-" * shift + sequence[:-shift]
    elif shift < 0:
        return sequence[-shift:] + "-" * abs(shift)
    return sequence


def refine_alignment(sequences):
    """Refines the alignment by minimizing mismatches (placeholder logic)."""
    # Implement a refinement strategy like progressive alignment or a scoring heuristic
    return sequences


def compute_reward(state, alpha=0.5):
    """Computes a reward for the state considering alignment and tree scores."""
    alignment_score = compute_alignment_score(state.embeddings.values())
    tree_score = compute_tree_score(state.tree_adj_matrix)
    inconsistency_penalty = compute_inconsistency_penalty(state.embeddings, state.tree_adj_matrix)
    return alpha * alignment_score + (1 - alpha) * tree_score - inconsistency_penalty


def compute_alignment_score(sequences):
    """Computes a score for the alignment."""
    score = 0
    for i, seq1 in enumerate(sequences):
        for j, seq2 in enumerate(sequences):
            if i < j:
                match_score = sum(1 for a, b in zip(seq1, seq2) if a == b)
                gap_penalty = sum(1 for a, b in zip(seq1, seq2) if a == "-" or b == "-")
                score += match_score - 0.5 * gap_penalty  # Example weights
    return score / len(sequences)


def compute_tree_score(adj_matrix):
    """Computes a score for the tree structure."""
    sparsity = np.sum(adj_matrix > 0) / adj_matrix.size  # Fraction of non-zero connections
    return 1 - sparsity  # Prefer sparse trees


def backpropagate(node, reward):
    """Propagates the reward up the tree."""
    while node is not None:
        node.value += reward
        node.visits += 1
        node = node.parent


def apply_action(state, action):
    """Apply a specific action to modify the state."""
    if action["type"] == "align":
        new_sequences = refine_alignment(state.embeddings.values(), **action["params"])
        new_tree_adj_matrix = update_tree_from_alignment(new_sequences)
    elif action["type"] == "tree":
        new_tree_adj_matrix = modify_tree(state.tree_adj_matrix, **action["params"])
        new_sequences = refine_alignment_based_on_tree(state.embeddings.values(), new_tree_adj_matrix)
    else:
        raise ValueError("Unsupported action type")

    return MCTSState(
        embeddings=compute_embeddings(new_sequences),  # Recompute embeddings
        tree_adj_matrix=new_tree_adj_matrix,
        tree_node_mapping=state.tree_node_mapping,
        alignment_score=compute_alignment_score(new_sequences),
        tree_score=compute_tree_score(new_tree_adj_matrix)
    )


def parallel_mcts(root, simulations=100, parallel_workers=4):
    """Runs parallel MCTS simulations."""
    with ThreadPoolExecutor(max_workers=parallel_workers) as executor:
        for _ in range(simulations):
            node = root
            # Selection
            while node.is_fully_expanded():
                node = node.best_child()

            # Expansion
            if not node.is_fully_expanded():
                expand(node)

            # Parallel Rollouts
            rewards = list(executor.map(simulate, [child.state for child in node.children]))

            # Backpropagation
            for child, reward in zip(node.children, rewards):
                backpropagate(child, reward)
