In [1]:
import time
import numpy as np
import torch

In [2]:
seq_filename = "../sequence_sets/cmx_aligned_blank_90.fasta"

In [3]:
from Bio import SeqIO
seqs = [list(str(seq.seq.upper())) for seq in SeqIO.parse(seq_filename, "fasta")]

In [4]:
msa = np.array(seqs, dtype="|S1")
print(msa.shape)
N, L = msa.shape # set the number of sequences and the length of the protein
msa

(14441, 559)


array([[b'K', b'N', b'A', ..., b'Q', b'A', b'D'],
       [b'E', b'D', b'A', ..., b'Q', b'A', b'D'],
       [b'K', b'D', b'A', ..., b'Q', b'A', b'D'],
       ...,
       [b'-', b'-', b'-', ..., b'-', b'-', b'-'],
       [b'-', b'-', b'-', ..., b'-', b'-', b'-'],
       [b'-', b'-', b'A', ..., b'-', b'-', b'-']], dtype='|S1')

In [5]:
#msa = msa[:10000, :] # for testing
msa.shape

(14441, 559)

In [6]:
torch.set_grad_enabled(False)

msa_int = torch.ByteTensor(msa.view(np.uint8))
print(msa_int.shape)
msa_int

torch.Size([14441, 559])


tensor([[75, 78, 65,  ..., 81, 65, 68],
        [69, 68, 65,  ..., 81, 65, 68],
        [75, 68, 65,  ..., 81, 65, 68],
        ...,
        [45, 45, 45,  ..., 45, 45, 45],
        [45, 45, 45,  ..., 45, 45, 45],
        [45, 45, 65,  ..., 45, 45, 45]], dtype=torch.uint8)

In [7]:
distance_to_threshold = 0.8 # Hamming distance >= 80%

epsilon = 1e-6 # to avoid rounding issues in the unlikely event there are any
distance_from_threshold_int = int((1 - distance_to_threshold) * L + epsilon)
distance_from_threshold_int

111

In [8]:
# Device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print('Device:', device)

Device: cpu


In [9]:
start_time = time.time()

In [10]:
# Create a torch scalar for the threshold
torch_threshold = torch.ShortTensor(1)
torch_threshold[0] = distance_from_threshold_int
torch_threshold = torch_threshold.to(device)

In [11]:
msa_int = msa_int.to(device) # move to device
torch_seqs = torch.unbind(msa_int, dim=0) # split into separate sequences

In [12]:
def count_neighbors(torch_seq):
    """Count the neighbors of torch_seq in arr_int using torch_threshold as cutoff"""
    dist_from_seq = (torch_seq != msa_int).sum(axis=1, dtype=torch.short)
    threshold_count = (dist_from_seq <= torch_threshold).sum(dtype=torch.short)
    return threshold_count

neighbors_count = torch.stack(tuple(count_neighbors(torch_seq) for 
                                        torch_seq in torch_seqs))

weights = 1 / neighbors_count.float() 
weights_np = weights.data.numpy()
print (weights_np)

[0.0625 0.0625 0.0625 ... 1.     0.5    1.    ]


In [13]:
print('Time elapsed: %.2f min' % ((time.time() - start_time)/60))

Time elapsed: 0.23 min
