In [136]:
import numpy as np
from hmmlearn import hmm
from enum import Enum   

In [137]:
class states(Enum):
    M = 0
    D = 1
    I = 2
    I_gap = 3

In [138]:
class ProfileHMM(hmm.CategoricalHMM):
    def __init__(self, alphabet, gap_symbol ,alignment, insertion_criteria=None):

        #Transformamos el alfabeto a una lista de caracteres
        alphabet = list(map(str, alphabet))
        self.alphabet = alphabet
        self.gap = gap_symbol

        #Añadimos como posibles emisiones el gap y la emisión silenciosa
        n_features = len(alphabet) + 2

        #Obtener la longitud de las secuencias
        sequence_length = len(alignment[0])

        #Comprobar que todas las secuencias tienen la misma longitud
        if not all(len(seq) == sequence_length for seq in alignment):
            raise ValueError("Todos las secuencias deben tener la misma longitud")

        #Obtener las regiones
        regions = ["".join( [seq[i] for seq in alignment] ) for i in range(sequence_length)]
        
        #Si no se introduce un criterio de inserción, se utiliza el por defecto
        if not callable(insertion_criteria):
            insert_regions = list(map(self.__default_insertion_criteria, regions))
        else:
            insert_regions = list(map(insertion_criteria, regions))
        
        #El número de estados de alineamiento es la suma de las regiones de alineamiento
        self.n_match_states = insert_regions.count(False)

        super().__init__(n_components=3*self.n_match_states+3, n_features=n_features)
        self.startprob_ = np.zeros(self.n_components)
        self.startprob_[0] = 1
        self.transmat_ = np.zeros((self.n_components, self.n_components))
        self.emissionprob_ = np.zeros((self.n_components, self.n_features))

        self.__compute_probabilities(alignment, regions, insert_regions)

    '''Devuelve true si se trata de una región de inserción'''
    def __default_insertion_criteria(self, region):
        return region.count(self.gap) > (len(region))/2 or max([region.count(elem) for elem in self.alphabet])<(len(region))/2
    
    def __detect_state(self, element, prev_state, in_insert_region):
        #Los estados son: 0=M, 1=D, 2=I, 3=I_gap. Éste último sirve para distinguir las inserciones y los gaps de la región de inserción para no acumular erróneamente
        state = prev_state

        if in_insert_region:
            if element!=self.gap:
                state = states.I.value
            elif (prev_state==states.I.value or prev_state==states.I_gap.value) and element==self.gap:
                state = states.I_gap.value
        elif not in_insert_region:
            if element==self.gap:
                state = states.D.value
            else:
                state = states.M.value

        return state
    
    def __compute_probabilities(self, alignment, regions, insert_regions):
        match_emissions = np.zeros((self.n_match_states, len(self.alphabet)), dtype=int)
        insert_emissions = np.zeros((self.n_match_states+1, len(self.alphabet)), dtype=int)
        transition_frenquencies = np.zeros((3, 3, self.n_match_states+1), dtype=int)

        region_index = 0
        n_sequence = len(alignment)
        previous_states = []
        actual_states = []

        for i in range(len(insert_regions)):

            #Caso especial para el primero
            if i==0:
                gaps = regions[i].count(self.gap)

                if not insert_regions[i]:
                    transition_frenquencies[states.M.value, states.M.value, region_index] = n_sequence-gaps
                    transition_frenquencies[states.M.value, states.D.value, region_index] = gaps
                    match_emissions[region_index] = [regions[i].count(elem) for elem in self.alphabet]

                    region_index += 1
                else:
                    insert_emissions[region_index] = np.add(insert_emissions[region_index], [regions[i].count(elem) for elem in self.alphabet])
                    transition_frenquencies[states.M.value, states.I.value, region_index] = n_sequence-gaps

                previous_states = [self.__detect_state(regions[i][j], 0 ,insert_regions[i]) for j in range(n_sequence)]
                
            else:
                actual_states = [self.__detect_state(regions[i][j], previous_states[j] ,insert_regions[i]) for j in range(n_sequence)]

                #Si estamos salimos de región de inserción, cambiamos los I_gap por I
                if not insert_regions[i] and insert_regions[i-1]:
                    previous_states = [states.I.value if elem==states.I_gap.value else elem for elem in previous_states]

                transitions = list(zip(previous_states, actual_states))

                if not insert_regions[i]:
                    match_emissions[region_index] = [regions[i].count(elem) for elem in self.alphabet]

                    for k in range(3):
                        transition_frenquencies[k,states.M.value,region_index] = transitions.count((k,states.M.value))
                        transition_frenquencies[k,states.D.value,region_index] = transitions.count((k,states.D.value))

                    region_index += 1
                else:
                    insert_emissions[region_index] = np.add(insert_emissions[region_index], [regions[i].count(elem) for elem in self.alphabet])

                    for k in range(3):
                        transition_frenquencies[k,states.I.value,region_index] += transitions.count((k,states.I.value))

                    if insert_regions[i-1]:
                        #Tener en cuenta también transiciones de I_gap a I
                        transition_frenquencies[states.I.value, states.I.value,region_index] += transitions.count((states.I_gap.value, states.I.value))

                previous_states = actual_states

        #Calcular probabilidades de transición hasta fin
        if insert_regions[-1]:
            transition_frenquencies[states.I.value,states.M.value,region_index] = actual_states.count(states.I.value)
        else:
            transition_frenquencies[states.M.value,states.M.value,region_index] = actual_states.count(states.M.value)
            transition_frenquencies[states.D.value,states.M.value,region_index] = actual_states.count(states.D.value)

        print(match_emissions)
        print(insert_emissions)
        print(transition_frenquencies)

In [139]:
model = ProfileHMM(alphabet=["A", "C", "G", "T"], gap_symbol='-', alignment=['GCAG', 'G--G', 'G-AG', 'GCTG', 'A-AC', 'G-AC', 'G-GG', 'A-AC'])

[[2 0 6 0]
 [5 0 1 1]
 [0 3 5 0]]
[[0 0 0 0]
 [0 2 0 0]
 [0 0 0 0]
 [0 0 0 0]]
[[[8 5 7 8]
  [0 1 0 0]
  [0 2 0 0]]

 [[0 0 1 0]
  [0 0 0 0]
  [0 0 0 0]]

 [[0 2 0 0]
  [0 0 0 0]
  [0 0 0 0]]]


In [140]:
model = ProfileHMM(alphabet=["A", "C", "G", "T"], gap_symbol='-', alignment=['TA--TC', 'TAG-TC', 'TAGA-C', '-AG-TG'])

[[0 0 0 3]
 [4 0 0 0]
 [0 0 3 0]
 [0 0 0 3]
 [0 3 1 0]]
[[0 0 0 0]
 [0 0 0 0]
 [0 0 0 0]
 [1 0 0 0]
 [0 0 0 0]
 [0 0 0 0]]
[[[3 3 3 2 3 4]
  [1 0 1 0 0 0]
  [0 0 0 1 0 0]]

 [[0 1 0 1 1 0]
  [0 0 0 0 0 0]
  [0 0 0 0 0 0]]

 [[0 0 0 0 0 0]
  [0 0 0 1 0 0]
  [0 0 0 0 0 0]]]


In [141]:
model = ProfileHMM(alphabet=["A", "C", "G", "T"], gap_symbol='-', alignment=['AG---C', 'A-AG-C', 'AG-AA-', '--AAAC', 'AG---C'])

print(model.n_match_states)

[[4 0 0 0]
 [0 0 3 0]
 [0 4 0 0]]
[[0 0 0 0]
 [0 0 0 0]
 [6 0 1 0]
 [0 0 0 0]]
[[[4 3 2 4]
  [1 1 0 0]
  [0 0 1 0]]

 [[0 0 0 1]
  [0 1 0 0]
  [0 0 2 0]]

 [[0 0 2 0]
  [0 0 1 0]
  [0 0 4 0]]]
3
