# WFA Algorithm

In [3]:
"""
A toy implementation of the WFA algorithm.
Copyright Benedict Paten, Dec 2021
"""
import sys


class WF(object):
    """ Represents a "wavefront", a series of points along the x+y diagonal that represent "furthest points".
    """
    def __init__(self, min_diag, max_diag):
        assert min_diag <= max_diag
        self.min_diag, self.max_diag = min_diag, max_diag  # Min and max diag are the bounds (inclusive) on the diagonal
        self.fpa = [0] * (max_diag - min_diag + 1)

    def get_fp(self, k):
        """ Returns the further point (an x coordinate) on the x - y = k antidiagonal
        """
        if k < self.min_diag or k > self.max_diag:
            return -sys.maxsize  # if the point is not on the anti-diagonal then return a very small number, indicating
        # it is unreachable
        return self.fpa[k - self.min_diag]

    def set_fp(self, k, h):
        """ Returns the further point (an x coordinate) on the x - y = k antidiagonal
        """
        assert self.min_diag <= k <= self.max_diag  # Otherwise we're trying to set a point not on the wavefront
        self.fpa[k - self.min_diag] = h

    def trim_diagonal(self, new_min_diag, new_max_diag):
        """ Trim the wavefront (e.g. to remove outermost points).
        """
        assert self.min_diag <= new_min_diag <= new_max_diag <= self.max_diag  # We can only trim to make the diagonal
        # smaller or the same size
        self.fpa = self.fpa[new_min_diag-self.min_diag:(new_max_diag+1)-self.min_diag]
        self.min_diag, self.max_diag = new_min_diag, new_max_diag


class WFS(object):
    """ Represents a wavefront set, i.e. a series of wavefronts, one for each score.
    """
    def __init__(self):
        self.wfs = {0: WF(0, 0)}  # Start with a wavefront for the (x=0, y=0) point and score=0

    def get_wf(self, s):
        """ Get the wavefront for score s
        """
        return self.wfs[s] if s in self.wfs else None

    def get_fp(self, s, k):
        """ Get the furthest point for a score s and antidiaonal k = x - y
        """
        if s not in self.wfs:
            return -sys.maxsize  # If the furthest point is not defined return a very small number
        return self.wfs[s].get_fp(k)

    def set_fp(self, s, k, h):
        """ Set the furthest point for a score s and antidiaonal k = x - y
        """
        return self.wfs[s].set_fp(k, h)

    def add_wf(self, min_diag, max_diag, s):
        """ Adds a wavefront to the set.
        """
        self.wfs[s] = WF(min_diag, max_diag)
        return self.wfs[s]

    def get_min_diag(self, s):
        """ Get the minimum k=x-y for the wavefront for score s
        """
        return self.wfs[s].min_diag if s in self.wfs else sys.maxsize

    def get_max_diag(self, s):
        """ Get the maximum k=x-y for the wavefront for score s
        """
        return self.wfs[s].max_diag if s in self.wfs else -sys.maxsize


class WFA(object):
    def __init__(self, string1, string2, gap_score=2,
                 mismatch_score=1):
        """ Finds an optimal global alignment of two strings using WFS algorithm.
        The algorithm is as described in https://doi.org/10.1093/bioinformatics/btaa777
        The notation/language somewhat follows the paper, but is otherwise as follows:
        The two input strings string1 (x) and string2 (y)
        In the dp matrix we have (x, y) row,column coordinates, thus string1 is along the rows and string2
        is along the columns.
        The anti-diagonal is k = x-y
        The diagonal is x+y
        The coordinates can be visualized as follows:
             y-1 y+0 y+1
         x-1 k+0 k-1 k-2
         x+0 k+1 k+0 k-1
         x+1 k+2 k+1 k+0
        As in the paper, the further points, "fp", are represented as x coordinates along the anti-diagonal.
        """
        self.string1, self.string2 = string1, string2
        self.gap_score, self.mismatch_score = gap_score, mismatch_score
        self.wfs = WFS()  # The wavefront set
        self.s = 0  # The starting alignment score

        # Run the wavefront dynamic programming process to find the optimal alignment
        while True:
            self._extend()  # Extend the wavefront
            if self._done():  # We're done if we reach the end of the dp matrix
                break
            self._next()  # Set up the next wavefront

    def _extend(self):
        """ Extends each point on the current wavefront by alignment matches.
        """
        # Get the current wavefront, whose points are to be extended
        wf = self.wfs.get_wf(self.s)
        # For each diagonal on the wf extend it by the maximum number of matches from the current furthest point
        for k in range(wf.min_diag, wf.max_diag + 1):
            h = wf.get_fp(k)
            if h >= 0 and h-k >= 0:  # If h = x-y such that x >= 0 and y >= 0
                while h < len(self.string1) and h-k < len(self.string2) and self.string1[h] == self.string2[h-k]:
                    # Extend the furthest point 
                    ## CODE TO COMPLETE
                    h += 1
                    wf.set_fp(k, h)


    def _done(self):
        """ Are we at the end of the dp matrix?
        """
        return self.wfs.get_fp(self.s, len(self.string1) - len(self.string2)) == len(self.string1)

    def _next(self):
        """ Adds the next score wavefront to the set.
        """
        while False:
          print("false")
        while True:  # Get the next score by increasing s until we find s minus mismatch or gap score has a
            # wavefront

            # Increment s
            self.s += 1

            if self.wfs.get_wf(self.s - self.gap_score) is not None or \
            self.wfs.get_wf(self.s - self.mismatch_score) is not None:
                break  # There is a prior wavefront to connect to

        # Update min and max diag
        min_diag = min(self.wfs.get_min_diag(self.s - self.gap_score),
                       self.wfs.get_min_diag(self.s - self.mismatch_score)) - 1
        max_diag = max(self.wfs.get_max_diag(self.s - self.gap_score),
                       self.wfs.get_max_diag(self.s - self.mismatch_score)) + 1

        # Add the next WFS line
        wf = self.wfs.add_wf(min_diag, max_diag, self.s)

        # Do dp calcs
        for k in range(wf.min_diag, wf.max_diag + 1):
            ## CODE TO COMPLETE
            I = self.wfs.get_fp(self.s-self.gap_score,k-1)+1
            D = self.wfs.get_fp(self.s-self.gap_score,k+1)
            wf.set_fp(k, max(I,D,self.wfs.get_fp(self.s-self.mismatch_score,k)+1))

    def get_alignment_score(self):
        """ Return the alignment score
        """
        return self.s

    def get_alignment(self):
        """Returns an alignment of the two string.
        Implements the traceback algorithm.
        """
        t = self.s  # The score of the sub-alignment that we're tracing back
        k = len(self.string1) - len(self.string2)  # The diagonal we're tracing back on
        f = len(self.string1)  # The furthest point
        alignment = []  # The alignment, represented as a sequence of (x, y) pairs
        assert self.wfs.get_fp(t, k) == f  # This is the condition that must be true at the beginning of trace back
        while k != 0 or f != 0:  # While we haven't gotten to the first cell in the dp matrix
            # Do backtrace dp calcs
            a = self.wfs.get_fp(t - self.mismatch_score, k)  # match
            b = self.wfs.get_fp(t - self.gap_score, k - 1)  # insert in string1 (x)
            c = self.wfs.get_fp(t - self.gap_score, k + 1)  # insert in string2 (y)
            #  print("a", a, "b", b, "c", c, "f", f, "k", k)

            while f > max(a, b+1, c, 0):  # The plus one for an insert in string1 is necessary
                # k = x - y, f = x
                x, y = f, -(k - f)
                alignment.append((x-1, y-1))  # subtract one to get seq coordinates
                f -= 1

            if a >= b and a >= c:  # we must take a mis-match
                t -= self.mismatch_score
            elif b >= a and b >= c:  # alignment has insert in string1
                k -= 1
                f -= 1
                t -= self.gap_score
            else:  # alignment has insert in string2
                assert c >= a and c >= b
                k += 1
                t -= self.gap_score

        alignment.reverse()
        return alignment


def main():
    string1 = "ACTGTTCGCGATGG"
    string2 = "AGTGATTCGCGTGG"
    wfa = WFA(string1, string2, 2, 1)
    print("Alignment score is: ", wfa.get_alignment_score())
    print("Alignment: ", wfa.get_alignment())

main()

Alignment score is:  5
Alignment:  [(0, 0), (1, 1), (2, 2), (3, 3), (4, 5), (5, 6), (6, 7), (7, 8), (8, 9), (9, 10), (11, 11), (12, 12), (13, 13)]


# Test Code

In [4]:
import numpy
from random import randint, choice

## Here is a an implementation of NeedlemanWunsch
class NeedlemanWunsch(object):
    def __init__(self, string1, string2, gap_score=-2, match_score=3, mismatch_score=-3):
        """ Finds an optimal global alignment of two strings.
        """

        self.editMatrix = numpy.zeros(shape=[len(string1) + 1, len(string2) + 1],
                                      dtype=int)  # Numpy matrix representing edit matrix
        # Preinitialized to have zero values

        # Initialize prefix indels
        for i in range(1, len(string1) + 1):
            self.editMatrix[i, 0] = gap_score * i

        for j in range(1, len(string2) + 1):
            self.editMatrix[0, j] = gap_score * j

        # Function to compute dynamic programming recursion
        self.getScoreTuple = lambda i, j: (self.editMatrix[i - 1, j - 1] +
                                           (match_score if string1[i - 1] == string2[j - 1] else mismatch_score),
                                           self.editMatrix[i - 1, j] + gap_score,
                                           self.editMatrix[i, j - 1] + gap_score)

        # Fill in remaining edit matrix
        for i in range(1, len(string1) + 1):
            for j in range(1, len(string2) + 1):
                s = max(self.getScoreTuple(i, j))

                self.editMatrix[i, j] = s

    def get_alignment_score(self):
        """ Return the alignment score
        """
        return self.editMatrix[self.editMatrix.shape[0] - 1, self.editMatrix.shape[1] - 1]

    def get_alignment(self):
        """ Returns an optimal global alignment of two strings. Aligned
        is returned as an ordered list of aligned pairs.

        e.g. For the two strings GATTACA and TACA an global alignment is
        is GATTACA
           ---TACA
        This alignment would be returned as:

        [(3, 0), (4, 1), (5, 2), (6, 3)]
        """
        i, j = self.editMatrix.shape[0] - 1, self.editMatrix.shape[1] - 1

        aligned_pairs = []
        while i > 0 and j > 0:
            s = self.getScoreTuple(i, j)
            m = s.index(max(s))
            if m == 0:
                i, j = i - 1, j - 1
                aligned_pairs.append((i, j))
            elif m == 1:
                i -= 1
            else:
                assert m == 2
                j -= 1

        # Put in the right order
        aligned_pairs.reverse()

        return aligned_pairs


# Here we test that the WFA alignment score agrees with Needleman Wunsch for 
# 100 randomly chosen small test examples
for test in range(100):
    def get_random_string(length=-1):
      return "".join([ choice("ACTG") for i in range(length if length >= 0 else randint(0, 10))])

    x, y = get_random_string(), get_random_string()
    mismatch_score, gap_score = randint(1, 10), randint(1, 10)
    nw = NeedlemanWunsch(x, y, match_score=0, mismatch_score=-mismatch_score, gap_score=-gap_score)
    wfa = WFA(x, y, mismatch_score=mismatch_score, gap_score=gap_score)
    print("Score", wfa.get_alignment_score())
    print("x", x, "y", y)
    print("WFA Alignment", wfa.get_alignment())
    print("NW Alignment", nw.get_alignment())
    assert -nw.get_alignment_score() == wfa.get_alignment_score()

Score 34
x GAGATACTG y CATAACA
WFA Alignment [(0, 0), (1, 1), (2, 2), (3, 3), (5, 4), (6, 5), (8, 6)]
NW Alignment [(0, 0), (1, 1), (2, 2), (3, 3), (5, 4), (6, 5), (8, 6)]
Score 12
x GTAGC y CGAC
WFA Alignment [(0, 0), (1, 1), (2, 2), (4, 3)]
NW Alignment [(0, 0), (1, 1), (2, 2), (4, 3)]
Score 10
x CCC y C
WFA Alignment [(0, 0)]
NW Alignment [(2, 0)]
Score 8
x TGG y CCCGAC
WFA Alignment [(0, 0), (1, 1), (2, 3)]
NW Alignment [(0, 2), (1, 3), (2, 5)]
Score 36
x GTTCGTAGC y 
WFA Alignment []
NW Alignment []
Score 72
x CGCGGGTAAA y ATACTGA
WFA Alignment [(0, 0), (1, 1), (2, 3), (3, 4), (4, 5), (7, 6)]
NW Alignment [(2, 0), (3, 1), (4, 2), (5, 3), (6, 4), (8, 5), (9, 6)]
Score 6
x GTA y TT
WFA Alignment [(1, 0), (2, 1)]
NW Alignment [(1, 0), (2, 1)]
Score 36
x GTC y TAATTTAG
WFA Alignment [(1, 0)]
NW Alignment [(0, 7)]
Score 22
x TGG y AACTCA
WFA Alignment [(0, 3), (1, 4), (2, 5)]
NW Alignment [(0, 3), (1, 4), (2, 5)]
Score 4
x C y 
WFA Alignment []
NW Alignment []
Score 6
x GGGCTGA y TCGAG