# Load modules

In [1]:
# Necessary imports
from skbio.embedding import ProteinEmbedding
from skbio.sequence import Protein
from tqdm import tqdm
import skbio

import torch

from deepblast.utils import load_model
from skbio.alignment import PairAlignPath
from deepblast.dataset.utils import get_sequence, pack_sequences, revstate_f
import matplotlib.pyplot as plt

def condense_cigar(cigar_str):
    """
    Convert full length to condensed CIGAR
    Example: MMMIIII = 3M4I
    """
    condensed_cigar = ''
    current_state = ''
    count = 0
    for i in range(len(cigar_str)):
        if cigar_str[i] == current_state:
            count += 1
        else:
            if count > 0:
                condensed_cigar += str(count) + current_state
            current_state = cigar_str[i]
            count = 1
    condensed_cigar += str(count) + current_state
    return condensed_cigar


def tm_to_cigar(tm_alignment_string, condensed=False):
    """
    Convert TMalign style alignment string to CIGAR string
    """

    cigar = ''

    for state in tm_alignment_string:
        if state == ':':
            cigar += 'M'
        elif state == '1':
            cigar += 'I'
        elif state == '2':
            cigar += 'D'

    return cigar

def align(x, y, model):
    pred_alignment = model.align(str(x), str(y))
    # TODO : need to convert TMalign style string to cigar
    cigar = tm_to_cigar(pred_alignment)
    cigar = condense_cigar(cigar)
    path = PairAlignPath.from_cigar(cigar)
    return path


def predict_aln_matrix(query_seq, target_seq, model):
    x_code = get_sequence(str(query_seq), model.tokenizer)[0].to(model.device)
    y_code = get_sequence(str(target_seq), model.tokenizer)[0].to(model.device)
    seq, order = pack_sequences([x_code], [y_code])
    with torch.no_grad():
        gen = model.aligner.traceback(seq, order)
    _, aln_mat = next(gen)

    return aln_mat.squeeze()

# Align sequences

In [3]:
model_name = "Rostlab/prot_t5_xl_uniref50"
tokenizer_name = "Rostlab/prot_t5_xl_uniref50"
from skbio import Protein

# Parse bagel.fa
sequence_list = skbio.io.read("bagel.fa", format='fasta', constructor=Protein)

In [5]:
x = next(sequence_list)
y = next(sequence_list)

0it [00:00, ?it/s]You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565
1it [00:06,  6.67s/it]

In [None]:
# smith-waterman errored out because of the sting/path length missmatch

model = load_model("/nfs/cds-peta/exports/biol_micro_cds_gr_sunagawa/scratch/vbezshapkin/tm-vec/models/deepblast-v3.ckpt", device="cpu",
                   alignment_mode="smith-waterman"
                   )

path = align(x, y, model)

## Visualize Predicted Alignment Matrix

In [None]:
matrix = predict_aln_matrix(x, y, model)

# visualise matrix with cbar
plt.imshow(matrix, cmap='viridis')
plt.colorbar()
plt.show();