In [1]:
import tensorflow as tf
import numpy as np
import primo.tools.sequences as seqtools

In [2]:
test_pairs = np.array(
    [
        # pair 0
        [
            # top sequence
            [
                #A  T  C  G
                [0, 1, 0, 0], # pos 0 = T
                [0, 0, 0, 1], # pos 1 = G
                [1, 0, 0, 0], # pos 2 = A
                [0, 0, 1, 0], # pos 3 = C
                [1, 0, 0, 0], # pos 4 = A
            ],
            
            # bottom sequence
            [
                #A  T  C  G
                [0, 0, 1, 0], # pos 0 = C
                [0, 0, 0, 1], # pos 1 = G
                [0, 1, 0, 0], # pos 2 = T
                [0, 0, 1, 0], # pos 3 = C
                [0, 0, 0, 1], # pos 4 = G
            ]
        ],
        # pair 1
        [
            # top sequence
            [
                #A  T  C  G
                [0, 0, 0, 1], # G
                [0, 0, 1, 0], # C
                [1, 0, 0, 0], # A
                [0, 0, 1, 0], # C
                [0, 1, 0, 0], # T
            ],
            
            # bottom sequence
            [
                #A  T  C  G
                [1, 0, 0, 0], # A
                [0, 0, 1, 0], # C
                [1, 0, 0, 0], # A
                [0, 1, 0, 0], # T
                [0, 0, 0, 1], # G
            ]
        ]
        
    ]
)

In [3]:
test_pairs.shape

(2, 2, 5, 4)

In [4]:
[("".join(s1), "".join(s2)) for s1, s2 in seqtools.bases[test_pairs.argmax(-1)]]

[('TGACA', 'CGTCG'), ('GCACT', 'ACATG')]

In [5]:
def match_layer(**lambda_args):
    def match(seq_pairs):
        
        # ensure that sequence pairs have dimension: (batch, 2 sequences, length, 4 channels)
        seq_pairs.shape.assert_is_compatible_with([None, 2, None, 4])
        seq_len = seq_pairs.shape[2]
        
        # separate first and sequences from each pair
        top = seq_pairs[:, 0, :, :]
        bot = seq_pairs[:, 1, :, :]
        
        # computes the outer product of one-hot vectors at each position for each pair
        # for syntax see https://numpy.org/doc/stable/reference/generated/numpy.einsum.html
        matches = tf.einsum('...i,...j->...ij', top, bot)
        
        # flatten match matrix at each position into match vector
        return tf.reshape(matches, [-1, seq_len, 16])
    
    return tf.keras.layers.Lambda(match, **lambda_args)

In [6]:
model = tf.keras.models.Sequential([
    match_layer(input_shape = [2,5,4])
])

In [7]:
result = model(test_pairs)

In [8]:
result

<tf.Tensor: shape=(2, 5, 16), dtype=float32, numpy=
array([[[0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
        [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
        [0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

       [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.]]],
      dtype=float32)>

In [15]:
result.numpy().argmax(-1)

array([[ 6, 15,  1, 10,  3],
       [12, 10,  0,  9,  7]])

In [9]:
substs = np.array(["{}->{}".format(b1, b2) for b1 in "ATCG" for b2 in "ATCG"])

In [10]:
substs

array(['A->A', 'A->T', 'A->C', 'A->G', 'T->A', 'T->T', 'T->C', 'T->G',
       'C->A', 'C->T', 'C->C', 'C->G', 'G->A', 'G->T', 'G->C', 'G->G'],
      dtype='<U4')

In [11]:
substs[result.numpy().argmax(-1)]

array([['T->C', 'G->G', 'A->T', 'C->C', 'A->G'],
       ['G->A', 'C->C', 'A->A', 'C->T', 'T->G']], dtype='<U4')

In [12]:
[("".join(s1), "".join(s2)) for s1, s2 in seqtools.bases[test_pairs.argmax(-1)]]

[('TGACA', 'CGTCG'), ('GCACT', 'ACATG')]

In [None]:
subpen = np.array([-1.7449405080809126, -1.275485084790358, -1.8001827224086722, -1.9323849500279549, -1.6677722398632207, -1.6537370694565101, -1.8981469677400609, -1.0814292717607923, -1.3231152511430453, -0.99840146446464273, -1.2766126030502924, -1.073338813454068, -1.5614374592181826, -1.4737507320504855, -1.298392565410591, -1.0105000195452765, -0.43349702574711524, -0.11665543376814178, -0.17370266801790191, 0.2676084623705467, 0.051835157750172757, 0.08920809165894289, 0.075459598643889569, 0.046975071077932237])
subtrans = np.array([[ 0.        ,  1.16616601,  0.96671383,  0.94917742],       [ 0.94076049,  0.        ,  1.18426595,  0.87129983],       [ 0.58224486,  1.11064886,  0.        ,  1.04707949],       [ 0.9633753 ,  0.98895548,  1.2293125 ,  0.        ]])

In [None]:
finkel_bases = 'ACGT'
shift = np.array([list(seqtools.bases).index(b) for b in finkel_bases])
subtrans_shift = subtrans[shift, :][:, shift]