In [15]:
import numpy as np

In [48]:
# inputs: string, set of indices within [0, len(string)), mask dimensions
# string = protein sequence, indices specify binding site coordinates within sequence
# for testing use arbitrary substring from the example in README.md and pretend [ is BOS, ] is EOS
pseq = 'LHVVAAIQARMGSTRLPGKVLVSIAGRPTIQRIAERLAVCQELDAVAVSTSV'
indices = [10, 14, 15, 16, 18, 19, 20]
dim = 64

pseq = '[' + pseq + ']'
indices = [e+1 for e in indices]

print(pseq)
print(indices)
len(pseq)

[LHVVAAIQARMGSTRLPGKVLVSIAGRPTIQRIAERLAVCQELDAVAVSTSV]
[11, 15, 16, 17, 19, 20, 21]


54

In [49]:
# display (fake) binding site
def show_seq_idx(seq, idx):
    for i in range(len(seq)):
        if i in idx:
            print(seq[i], end='')
        else:
            print('-', end='')
            
show_seq_idx(pseq, indices)

-----------M---RLP-KVL--------------------------------

In [50]:
# generate random path from indices to full sequence as list of integers
# assume indices contains at least 1 element
# aiming for correctness, not optimization
path = []
indices_curr = sorted(indices)
while len(indices_curr) < len(pseq):
    # get list of contiguous segments in indices_curr as 2-tuples
    # single-element segments have same value for each tuple field, this is fine
    segments = []
    seg_start = indices_curr[0]
    prev = seg_start
    # iterate over elements past first; hopefully it just skips this and doesn't throw an error if 1 element
    for idx in indices_curr[1:]:
        if idx - prev == 1:
            prev = idx
            continue
        else:
            # complete segment
            segments.append((seg_start, prev))
            # start new segment
            seg_start = idx
            prev = idx
    # get last segment
    segments.append((seg_start, indices_curr[-1]))
    
    # check whether we need to move further past the leftmost and rightmost elements in the sequence
    found_start = (indices_curr[0] == 0)
    found_end = (indices_curr[-1] == len(pseq) - 1)
    
    # for each segment, there are 2 possible steps: move left of leftmost index or move right of rightmost
    # get list of all such steps
    steps = []
    for segment in segments:
        steps.append(segment[0]-1)
        steps.append(segment[1]+1)

    # prune if necessary
    if found_start:
        steps = steps[1:]
    if found_end:
        steps = steps[:-1]
    
    # get unique; if 2 segments are separated by 1 index, that index will be duplicated and its probability will be doubled
    # this is probably not a huge deal but i would rather avoid doing it unintentionally
    steps = np.unique(steps)
    
    # select move, update path and indices
    step = np.random.choice(steps)
    path.append(step)
    indices_curr.append(step)
    indices_curr = sorted(indices_curr)

    # visualize path
    show_seq_idx(pseq, indices_curr)
    print() # need newline

----------RM---RLP-KVL--------------------------------
----------RM--TRLP-KVL--------------------------------
----------RMG-TRLP-KVL--------------------------------
----------RMG-TRLP-KVLV-------------------------------
----------RMG-TRLP-KVLVS------------------------------
----------RMGSTRLP-KVLVS------------------------------
----------RMGSTRLPGKVLVS------------------------------
---------ARMGSTRLPGKVLVS------------------------------
--------QARMGSTRLPGKVLVS------------------------------
-------IQARMGSTRLPGKVLVS------------------------------
------AIQARMGSTRLPGKVLVS------------------------------
-----AAIQARMGSTRLPGKVLVS------------------------------
----VAAIQARMGSTRLPGKVLVS------------------------------
---VVAAIQARMGSTRLPGKVLVS------------------------------
---VVAAIQARMGSTRLPGKVLVSI-----------------------------
--HVVAAIQARMGSTRLPGKVLVSI-----------------------------
--HVVAAIQARMGSTRLPGKVLVSIA----------------------------
--HVVAAIQARMGSTRLPGKVLVSIAG---------------------------
--HVVAAIQA

In [64]:
print(path)

[10, 14, 12, 22, 23, 13, 18, 9, 8, 7, 6, 5, 4, 3, 24, 2, 25, 26, 27, 28, 29, 1, 30, 31, 32, 0, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53]


In [None]:
# create mask
mask = np.zeros((dim, dim))

# for each index in original binding site, unmask the entire column
mask[:len(pseq), indices] = 1

# construct mask
for path_idx, step in enumerate(path):
    # for this step + all later path steps, add on the current step
    populated_idxs = path[path_idx:]
    mask[populated_idxs, step] = 1

# visualize mask; # = masked out, . = unmasked
for row in mask:
    for e in row:
        if e == 0:
            print('#', end='')
        else:
            print('.', end='')
    print()

.................................###############################
#.............................##################################
##.......................#######################################
###.....................########################################
####....................########################################
#####...................########################################
######..................########################################
#######.................########################################
########................########################################
#########...............########################################
##########..###...#...##########################################
###########.###...#...##########################################
##########...#....#...##########################################
##########........#.....########################################
##########..##....#...##########################################
###########.###...#...###

In [72]:
# offset for positional embeddings
binding_site_start = sorted(indices)[0]
binding_site_start

11

In [None]:
# return values: pseq, mask, binding_site_start
# not implemented yet: tokenization