# Load modules

In [50]:
# Necessary imports
from skbio.embedding import ProteinEmbedding
from skbio.sequence import Protein
from tqdm import tqdm
import skbio

import torch

from deepblast.utils import load_model
from skbio.alignment import PairAlignPath
from deepblast.dataset.utils import get_sequence, pack_sequences, revstate_f
import matplotlib.pyplot as plt

def condense_cigar(cigar_str):
    """
    Convert full length to condensed CIGAR
    Example: MMMIIII = 3M4I
    """
    condensed_cigar = ''
    current_state = ''
    count = 0
    for i in range(len(cigar_str)):
        if cigar_str[i] == current_state:
            count += 1
        else:
            if count > 0:
                condensed_cigar += str(count) + current_state
            current_state = cigar_str[i]
            count = 1
    condensed_cigar += str(count) + current_state
    return condensed_cigar


def tm_to_cigar(tm_alignment_string, condensed=False):
    """
    Convert TMalign style alignment string to CIGAR string
    """

    cigar = ''

    for state in tm_alignment_string:
        if state == ':':
            cigar += 'M'
        elif state == '1':
            cigar += 'I'
        elif state == '2':
            cigar += 'D'

    return cigar

def align(x, y, model):
    pred_alignment = model.align(str(x), str(y))
    # TODO : need to convert TMalign style string to cigar
    cigar = tm_to_cigar(pred_alignment)
    cigar = condense_cigar(cigar)
    path = PairAlignPath.from_cigar(cigar)
    return path

# Align sequences

In [24]:
model_name = "Rostlab/prot_t5_xl_uniref50"
tokenizer_name = "Rostlab/prot_t5_xl_uniref50"
from skbio import Protein

# Parse bagel.fa
sequence_list = skbio.io.read("bagel.fa", format='fasta', constructor=Protein)

In [25]:
x = next(sequence_list)
y = next(sequence_list)

In [51]:
# smith-waterman errored out because of the sting/path length missmatch

model = load_model("/nfs/cds-peta/exports/biol_micro_cds_gr_sunagawa/scratch/vbezshapkin/tm-vec/models/deepblast-v3.ckpt", device="cpu",
                   alignment_mode="needleman-wunsch"
                   )

path = align(x, y, model)



In [52]:
path.

<PairAlignPath, shape: Shape(sequence=2, position=50), CIGAR: '3M1I1M3D1M1D1M3D1M1I5D1M4D2M1I10D1M1D4M2D3M'>

In [None]:
from scipy.sparse import coo_matrix

def state_diff_f(X):
    """ Constructs a state transition element.
    Notes
    -----
    There is a bit of a paradox regarding beginning / ending gaps.
    To see this, try to derive an alignment matrix for the
    following alignments
    XXXMMMXXX
    MMYYXXMM
    It turns out it isn't possible to derive traversal rules
    that are consistent between these two alignments
    without explicitly handling start / end states as separate
    end states. The current workaround is to force the start / end
    states to be match states (similar to the needleman-wunsch algorithm).
    """
    a, b = X
    if a == x and b == x:
        # Transition XX, increase tape on X
        return (1, 0)
    if a == x and b == m:
        # Transition XM, increase tape on both X and Y
        return (1, 1)
    if a == m and b == m:
        # Transition MM, increase tape on both X and Y
        return (1, 1)
    if a == m and b == x:
        # Transition MX, increase tape on X
        return (1, 0)
    if a == m and b == y:
        # Transition MY, increase tape on y
        return (0, 1)
    if a == y and b == y:
        # Transition YY, increase tape on y
        return (0, 1)
    if a == y and b == m:
        # Transition YM, increase tape on both X and Y
        return (1, 1)
    if a == x and b == y:
        # Transition XY increase tape on y
        return (0, 1)
    if a == y and b == x:
        # Transition YX increase tape on x
        return (1, 0)
    else:
        raise ValueError(f'`Transition` ({a}, {b}) is not allowed.')

def states2edges(states):
    """ Converts state string to bipartite matching. """
    prev_s, next_s = states[:-1], states[1:]
    transitions = list(zip(prev_s, next_s))
    state_diffs = np.array(list(map(state_diff_f, transitions)))
    coords = np.cumsum(state_diffs, axis=0).tolist()
    coords = [(0, 0)] + list(map(tuple, coords))
    return coords


def states2matrix(states, sparse=False):
    """ Converts state string to alignment matrix.

    Parameters
    ----------
    states : list
       The state string
    """
    coords = states2edges(states)
    data = np.ones(len(coords))
    row, col = list(zip(*coords))
    row, col = np.array(row), np.array(col)
    N, M = max(row) + 1, max(col) + 1
    mat = coo_matrix((data, (row, col)), shape=(N, M))
    if sparse:
        return mat
    else:
        return mat.toarray()