In [81]:
# Import required libraries
from Bio import AlignIO, Phylo
from Bio.Phylo.TreeConstruction import DistanceCalculator, DistanceTreeConstructor, _DistanceMatrix
import matplotlib.pyplot as plt
from Bio import SeqIO
from Bio.Align import MultipleSeqAlignment
import numpy as np
import math
from enum import Enum
import os
from ete3 import Tree, TreeStyle, NodeStyle, faces, AttrFace

class SequenceType(Enum):
    """Enumeration defining sequence types"""
    NUCLEOTIDE = "nucleotide"
    AMINO_ACID = "amino_acid"

    
class CustomDistanceCalculator:
    """Custom distance calculator class"""
    
    @staticmethod
    def calculate_k2pgap_distance(seq1, seq2):
        """
        Calculate K2P+Gap distance

        Parameters:
        seq1, seq2 (str): Two sequences to compare

        Returns:
        float: K2P+Gap distance. Returns None if error occurs
        """
        try:
            if len(seq1) != len(seq2):
                raise ValueError("Sequences have different lengths")
                                
            # Convert sequences to uppercase
            seq1 = seq1.upper()
            seq2 = seq2.upper()
            
            n = len(seq1)  # Total number of sites

            # Count the state of each site
            S = sum(1 for i in range(n) if seq1[i] == seq2[i] and seq1[i] in {'A','T','G','C'}) / n  # Same base

            # Transitions (TC,CT,AG,GA)
            P = sum(1 for i in range(n) if (seq1[i] == 'T' and seq2[i] == 'C') or 
                                           (seq1[i] == 'C' and seq2[i] == 'T') or
                                           (seq1[i] == 'A' and seq2[i] == 'G') or
                                           (seq1[i] == 'G' and seq2[i] == 'A')) / n

            # Transversions (TA,AT,CG,GC,TG,GT,AC,CA)
            Q = sum(1 for i in range(n) if (seq1[i] == 'T' and seq2[i] == 'A') or 
                                           (seq1[i] == 'A' and seq2[i] == 'T') or
                                           (seq1[i] == 'C' and seq2[i] == 'G') or
                                           (seq1[i] == 'G' and seq2[i] == 'C') or
                                           (seq1[i] == 'T' and seq2[i] == 'G') or
                                           (seq1[i] == 'G' and seq2[i] == 'T') or
                                           (seq1[i] == 'A' and seq2[i] == 'C') or
                                           (seq1[i] == 'C' and seq2[i] == 'A')) / n      
            
            # Base proportion in two sequences
            w = sum(2 if (seq1[i] != '-' and seq2[i] != '-') else 1 if (seq1[i] != '-' or seq2[i] != '-') else 0 for i in range(n)) / (2 * n)

            # Calculate K2P+Gap distance
            K = (3/4) * w * math.log(w) - (w/2) * math.log((S - P) * math.sqrt(S + P - Q))
            
            return max(0, K)  # Prevent negative distance

        except Exception as e:
            print(f"Error occurred in K2P+Gap distance calculation: {str(e)}")
            return None

    @staticmethod
    def calculate_jcgap_distance(seq1, seq2):
        """
        Calculate JC+Gap distance

        Parameters:
        seq1, seq2 (str): Two amino acid sequences to compare

        Returns:
        float: JC+Gap distance. Returns None if error occurs
        """
        try:
            if len(seq1) != len(seq2):
                raise ValueError("Sequences have different lengths")
                
            # Convert sequences to uppercase
            seq1 = seq1.upper()
            seq2 = seq2.upper()

            n = len(seq1)  # Total number of sites

            # Count identical/different amino acids and handle gaps
            S = sum(1 for i in range(n) if seq1[i] == seq2[i] and seq1[i] not in {'-'}) / n  # Same amino acid
            P = sum(1 for i in range(n) if seq1[i] != seq2[i] and seq1[i] not in {'-'} and seq2[i] not in {'-'}) / n
            
            # Amino acid proportion in two sequences
            w = sum(2 if (seq1[i] != '-' and seq2[i] != '-') else 1 if (seq1[i] != '-' or seq2[i] != '-') else 0 for i in range(n)) / (2 * n)

            # Calculate JC+Gap distance
            K = -19/20 * w * math.log((S - P/19) / w)

            return max(0, K)  # Prevent negative distance

        except Exception as e:
            print(f"Error occurred in JC+Gap distance calculation: {str(e)}")
            return None

class CustomCalculator(DistanceCalculator):
    """Custom calculator class inheriting from BioPython's DistanceCalculator"""
    
    def __init__(self, seq_type):
        self.seq_type = seq_type
        
    def get_distance(self, msa):
        """
        Calculate distance matrix from alignment
        """
        names = [seq.id for seq in msa]
        matrix = []
        
        for i in range(len(msa)):
            row = []
            for j in range(i + 1):
                if i == j:
                    row.append(0.0)
                else:
                    seq1 = str(msa[i].seq)
                    seq2 = str(msa[j].seq)
                    
                    if self.seq_type == SequenceType.NUCLEOTIDE:
                        distance = CustomDistanceCalculator.calculate_k2pgap_distance(seq1, seq2)
                    else:  # AMINO_ACID
                        distance = CustomDistanceCalculator.calculate_jcgap_distance(seq1, seq2)
                    
                    row.append(distance)
            matrix.append(row)
        
        return _DistanceMatrix(names, matrix)

def create_phylogenetic_tree(alignment_file, seq_type, format="fasta", output_tree=None):
    try:
        if not isinstance(seq_type, SequenceType):
            raise ValueError("seq_type must be either SequenceType.NUCLEOTIDE or SequenceType.AMINO_ACID")
        
        alignment = AlignIO.read(alignment_file, format)
        calculator = CustomCalculator(seq_type)
        
        # Calculate and save distance matrix
        distance_matrix = calculator.get_distance(alignment)
        base_name = os.path.splitext(alignment_file)[0]
        matrix_file = f"{base_name}_matrix.csv"
        
        # Save distance matrix as CSV
        with open(matrix_file, 'w') as f:
            f.write(',' + ','.join(distance_matrix.names) + '\n')
            for i, row in enumerate(distance_matrix.matrix):
                full_row = row + [distance_matrix[j][i] for j in range(i+1, len(distance_matrix.names))]
                f.write(f"{distance_matrix.names[i]},{','.join(map(str, full_row))}\n")
        
        constructor = DistanceTreeConstructor(calculator, method='nj')
        tree = constructor.build_tree(alignment)
        
        if output_tree is None:
            output_tree = f"{base_name}_tree.nwk"
        
        for clade in tree.get_nonterminals():
            clade.name = ''
        Phylo.write(tree, output_tree, "newick")
        
        return tree, output_tree
        
    except Exception as e:
        print(f"An error occurred: {str(e)}")
        return None, None

def visualize_tree(tree_file, output_image=None, show_length=True, show_support=True):
    """
    Function to visualize phylogenetic tree using ETE3
    
    Parameters:
    tree_file (str): Tree file in Newick format
    output_image (str): Output image filename (optional)
    show_length (bool): Whether to display branch lengths
    show_support (bool): Whether to display support values
    """
    try:
        # Load tree with ETE3 (specify format options)
        t = Tree(tree_file, format=1, quoted_node_names=True)
        
        # TreeStyle settings
        ts = TreeStyle()
        ts.show_branch_length = show_length
        ts.show_branch_support = show_support
        ts.scale = 120  # Tree zoom level
        ts.branch_vertical_margin = 15  # Vertical margin for branches
        ts.show_leaf_name = True
        
        # NodeStyle settings
        nstyle = NodeStyle()
        nstyle["size"] = 10
        nstyle["fgcolor"] = "black"
        
        # Apply style to each node
        for n in t.traverse():
            n.set_style(nstyle)
            
            # Display branch length
            if show_length and n.dist:
                face = faces.TextFace(f" {n.dist:.3f}")
                face.margin_right = 5
                n.add_face(face, column=0, position="branch-top")
        
        # Save as image
        if output_image:
            t.render(output_image, tree_style=ts)
        
        # Display on screen
        t.show(tree_style=ts)
        
    except Exception as e:
        print(f"Error occurred during visualization: {str(e)}")

# Execution

In [83]:
alignment_file = "sample_data/sample_NUC.fasta"  # Specify the aligned sequence file

# Create phylogenetic tree and distance matrix
tree, tree_file = create_phylogenetic_tree(
    alignment_file, 
    seq_type=SequenceType.AMINO_ACID  # Change to NUCLEOTIDE for nucleotide sequences
)

In [79]:
# Visualization with ETE3
visualize_tree(
    tree_file,
    output_image="tree_visualization.png",  # Save as image
    show_length=True,  # Display branch lengths
    show_support=True  # Display support values
)