In [1]:
import os
import re
import numpy as np
import torch
import torch.nn.functional as F
import pandas as pd

# FASTA parser requires Biopython
try:
    from Bio import SeqIO
except:
    !pip install biopython
    from Bio import SeqIO
    
# Retrieve protein alignment file
if not os.path.exists('BLAT_ECOLX_1_b0.5_labeled.fasta'):
    !wget https://sid.erda.dk/share_redirect/a5PTfl88w0/BLAT_ECOLX_1_b0.5_labeled.fasta
        
# Retrieve file with experimental measurements
if not os.path.exists('BLAT_ECOLX_Ranganathan2015.csv'):
    !wget https://sid.erda.dk/share_redirect/a5PTfl88w0/BLAT_ECOLX_Ranganathan2015.csv
        
# Options
batch_size = 16

In [2]:
# Mapping from amino acids to integers
aa1_to_index = {'A': 0, 'C': 1, 'D': 2, 'E': 3, 'F': 4, 'G': 5, 'H': 6,
                'I': 7, 'K': 8, 'L': 9, 'M': 10, 'N': 11, 'P': 12,
                'Q': 13, 'R': 14, 'S': 15, 'T': 16, 'V': 17, 'W': 18,
                'Y': 19, 'X':20, 'Z': 21, '-': 22}
aa1 = "ACDEFGHIKLMNPQRSTVWYXZ-"

phyla = ['Acidobacteria', 'Actinobacteria', 'Bacteroidetes',
         'Chloroflexi', 'Cyanobacteria', 'Deinococcus-Thermus',
         'Firmicutes', 'Fusobacteria', 'Proteobacteria', 'Other']

def get_data(data_filename, calc_weights=False, weights_similarity_threshold=0.8):
    '''Create dataset from FASTA filename'''
    ids = []
    labels = []
    seqs = []
    label_re = re.compile(r'\[([^\]]*)\]')
    for record in SeqIO.parse(data_filename, "fasta"):
        ids.append(record.id)       
        seqs.append(np.array([aa1_to_index[aa] for aa in str(record.seq).upper().replace('.', '-')]))
        
        label = label_re.search(record.description).group(1)
        # Only use most common classes
        if label not in phyla:
            label = 'Other'
        labels.append(label)
                
    seqs = torch.from_numpy(np.vstack(seqs))
    labels = np.array(labels)
    
    phyla_lookup_table, phyla_idx = np.unique(labels, return_inverse=True)

    dataset = torch.utils.data.TensorDataset(*[seqs, torch.from_numpy(phyla_idx)])
    
    
    weights = None
    if calc_weights is not False:

        # Experiencing memory issues on colab for this code because pytorch doesn't
        # allow one_hot directly to bool. Splitting in two and then merging.
        # one_hot = F.one_hot(seqs.long()).to('cuda' if torch.cuda.is_available() else 'cpu')
        one_hot1 = F.one_hot(seqs[:len(seqs)//2].long()).bool()
        one_hot2 = F.one_hot(seqs[len(seqs)//2:].long()).bool()
        one_hot = torch.cat([one_hot1, one_hot2]).to('cuda' if torch.cuda.is_available() else 'cpu')
        assert(len(seqs) == len(one_hot))
        del one_hot1
        del one_hot2
        one_hot[seqs>19] = 0
        flat_one_hot = one_hot.flatten(1)

        weights = []
        weight_batch_size = 1000
        flat_one_hot = flat_one_hot.float()
        for i in range(seqs.size(0) // weight_batch_size + 1):
            x = flat_one_hot[i * weight_batch_size : (i + 1) * weight_batch_size]
            similarities = torch.mm(x, flat_one_hot.T)
            lengths = (seqs[i * weight_batch_size : (i + 1) * weight_batch_size] <=19).sum(1).unsqueeze(-1).to('cuda' if torch.cuda.is_available() else 'cpu')
            w = 1.0 / (similarities / lengths).gt(weights_similarity_threshold).sum(1).float()
            weights.append(w)
            
        weights = torch.cat(weights)
        neff = weights.sum()

    return dataset, weights


dataset, weights = get_data('BLAT_ECOLX_1_b0.5_labeled.fasta', calc_weights=False)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [3]:
# Example of iteration over dataset

for x,y in dataloader:

    # For illustrative purposes, make sure we can see the entire tensor
    torch.set_printoptions(threshold=np.inf)
    print("data: ", x)
    print("labels: ", y)
        
    # interrupt after first batch
    break

data:  tensor([[22, 22, 22, 22, 22, 22,  3,  7,  8, 14,  9,  3, 15,  3, 15,  5,  5, 14,
          9,  5, 17,  1, 17,  9,  3, 16,  0, 16,  5, 16, 14,  6, 22, 17,  6, 14,
          5,  2,  3, 14,  4, 12, 10,  1, 15, 16,  4,  8,  0,  9,  0,  0,  0,  0,
          7,  9,  0, 14, 17,  2,  0,  5, 11,  3, 13,  9, 16, 14, 14,  7, 16,  4,
          2,  0, 15,  0,  9, 17, 17, 11, 15, 12, 17, 16,  3,  8, 14, 17,  5,  5,
          2, 10, 16, 17,  0,  3,  7,  1,  2,  0,  0, 17, 16, 14, 15,  2, 11, 16,
          0,  5, 11,  9,  9,  9,  0,  5,  7,  5,  5, 12, 15,  5,  9, 16,  0,  4,
          0, 14, 15,  9,  5,  2,  3, 17, 16, 14,  9,  2, 14,  2,  3, 12, 15,  9,
         11,  3,  0,  9, 12,  5,  2, 12, 14,  2, 16, 16, 16, 12, 11,  0, 10,  0,
         15, 11,  9, 13,  0,  9,  7,  9,  5,  8,  0,  9, 15,  0,  0, 15, 14,  3,
         13,  9, 16,  0, 18,  9,  7,  0, 11,  8, 16,  5,  2, 16, 14,  9, 14,  0,
          5,  4,  0,  8,  5, 18, 14, 17,  5,  2,  8, 16,  5, 16,  5,  5, 14,  5,
         16, 11, 11, 

In [4]:
def read_experimental_data(filename, alignment_data, measurement_col_name = '2500', sequence_offset=0):
    '''Read experimental data from csv file, and check that amino acid match those 
       in the first sequence of the alignment.
       
       measurement_col_name specifies which column in the csv file contains the experimental 
       observation. In our case, this is the one called 2500.
       
       sequence_offset is used in case there is an overall offset between the
       indices in the two files.
       '''
    
    measurement_df = pd.read_csv(filename, delimiter=',', usecols=['mutant', measurement_col_name])
    
    wt_sequence, wt_label = alignment_data[0]
    
    zero_index = None
    
    experimental_data = {}
    for idx, entry in measurement_df.iterrows():
        mutant_from, position, mutant_to = entry['mutant'][:1],int(entry['mutant'][1:-1]),entry['mutant'][-1:]  
        
        # Use index of first entry as offset (keep track of this in case 
        # there are index gaps in experimental data)
        if zero_index is None:
            zero_index = position
            
        # Corresponding position in our alignment
        seq_position = position-zero_index+sequence_offset
            
        # Make sure that two two inputs agree on the indices: the 
        # amino acids in the first entry of the alignment should be 
        # identical to those in the experimental file.
        assert mutant_from == aa1[wt_sequence[seq_position]]  
        
        if seq_position not in experimental_data:
            experimental_data[seq_position] = {}
        
        # Check that there is only a single experimental value for mutant
        assert mutant_to not in experimental_data[seq_position]
        
        experimental_data[seq_position]['pos'] = seq_position
        experimental_data[seq_position]['WT'] = mutant_from
        experimental_data[seq_position][mutant_to] = entry[measurement_col_name]
    
    experimental_data = pd.DataFrame(experimental_data).transpose().set_index(['pos', 'WT'])
    return experimental_data
        
        
experimental_data = read_experimental_data("BLAT_ECOLX_Ranganathan2015.csv", dataset)
experimental_data

Unnamed: 0_level_0,Unnamed: 1_level_0,A,C,E,D,G,F,I,K,M,L,N,Q,P,S,R,T,W,V,Y,H
pos,WT,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
0,H,-0.00978356,-0.41826,-0.279024,-0.181607,-0.0602417,-0.818487,-0.359191,0.0144696,-0.224781,-0.480347,-0.0430932,-0.135568,-1.01085,0.0361661,-0.00252626,-0.0671875,-1.34759,-0.026874,-0.885025,
1,P,-1.6426,-0.364138,0.143258,-0.0284025,-0.969268,-0.199804,-0.0735238,0.13559,-0.0283657,-0.211869,0.0206405,-0.0282136,,-0.16013,0.054154,-0.0911999,-0.109139,0.045913,0.00174467,0.0457846
2,E,0.0109131,-0.158233,,-0.0757852,0.0813101,-0.232106,-0.153907,0.0871198,-0.036441,-0.0581804,-0.0064688,0.0496907,-0.387232,-0.0395849,-0.220003,-0.135909,-0.44234,-0.0645674,-0.245436,0.0209168
3,T,-1.45459,-2.41902,-2.41446,-2.29488,-2.35671,-2.60457,-0.280446,-1.42789,-1.8431,-0.765521,-2.48572,-1.6671,-1.79017,-1.39248,-2.37509,,-2.8417,0.0341893,-2.78913,-1.84954
4,L,-0.202228,-1.95959,-1.72164,-2.71077,-1.4842,-0.720047,0.0173958,0.0695957,-0.0480697,,-1.42071,-0.222812,-2.19535,-1.2641,-0.0649357,-0.313656,-0.299738,0.0502655,-0.2186,-0.889277
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
258,L,-0.103001,-0.617685,-2.95432,-2.52951,-2.41565,-0.663794,0.0786722,-2.57103,0.0126656,,-2.32138,-1.72228,-2.42792,-1.65506,-2.70359,-0.231658,-2.78162,-0.0205734,-2.60799,-3.01127
259,I,-0.537631,-0.657012,-2.6017,-2.78886,-2.24781,-0.123746,,-2.24993,-0.0885522,-0.190349,-2.48529,-1.95522,-2.61388,-1.48797,-1.94825,-0.300995,-1.71315,-0.030749,-0.0821665,-2.07812
260,K,-0.0389621,-0.91867,-0.0274808,-0.0614761,-0.0924616,-0.286851,-0.514835,,-0.107796,-0.0442595,-0.0745288,-0.0715558,-2.59327,-0.0202318,-0.04172,-0.090441,-0.449885,-0.261297,-0.111472,-0.00521766
261,H,-0.465274,-0.251095,-0.338762,-1.53546,-0.509937,-0.149276,-2.72312,-0.271887,-0.226841,-0.553996,0.0240883,-0.141955,-2.72487,-0.160349,-0.157191,-0.113391,-0.0762315,-0.825731,-0.0946859,


In [5]:
# For each of the entries in the dataframe above, you should calculate
# the corresponding difference in ELBO from your VAE, and then finally
# calculate a Spearman correlation between the two.

# You can iterate over all experimental values like this:
for (position, mutant_from), row in experimental_data.iterrows():
    print(position, mutant_from)   # mutant from is the wild type (wt)
    for mutant_to, exp_value in row.iteritems():
        print("\t", mutant_to, exp_value) 

0 H
	 A -0.009783565
	 C -0.418259925
	 E -0.279024169
	 D -0.18160661
	 G -0.060241672999999996
	 F -0.8184870959999999
	 I -0.359191243
	 K 0.014469576999999999
	 M -0.22478129300000002
	 L -0.480346576
	 N -0.043093163
	 Q -0.135568334
	 P -1.010848138
	 S 0.036166097
	 R -0.002526258
	 T -0.067187546
	 W -1.3475859609999998
	 V -0.026874047
	 Y -0.88502515
	 H nan
1 P
	 A -1.6425950930000002
	 C -0.364138405
	 E 0.143257861
	 D -0.028402505
	 G -0.969267861
	 F -0.19980403100000002
	 I -0.073523836
	 K 0.135590499
	 M -0.02836574
	 L -0.211868922
	 N 0.020640486
	 Q -0.028213622
	 P nan
	 S -0.160130436
	 R 0.054153993
	 T -0.091199948
	 W -0.109138595
	 V 0.04591299099999999
	 Y 0.0017446679999999998
	 H 0.045784618
2 E
	 A 0.010913108000000001
	 C -0.158232612
	 E nan
	 D -0.07578524
	 G 0.08131007700000001
	 F -0.232106116
	 I -0.153907493
	 K 0.087119809
	 M -0.036440986
	 L -0.05818041
	 N -0.006468796999999999
	 Q 0.049690653
	 P -0.387232365
	 S -0.039584928
	 R -0.220002694

	 K -0.8061271990000001
	 M -0.067702575
	 L nan
	 N -1.5260329819999998
	 Q -1.22558591
	 P -2.6049618430000003
	 S -2.0217628519999997
	 R -0.520211329
	 T -0.997385307
	 W -1.28202634
	 V -0.049845220999999995
	 Y -0.120999362
	 H -0.720452442
32 E
	 A -0.087059359
	 C -1.5553789180000002
	 E nan
	 D -0.058707709000000004
	 G -0.051919235
	 F -2.4408351880000003
	 I -2.18407722
	 K -0.101755554
	 M -1.303131735
	 L -1.3142741740000001
	 N -0.101191244
	 Q -0.13106595099999999
	 P -2.66658801
	 S -0.092556475
	 R -0.221984209
	 T -0.103973384
	 W -2.094695901
	 V -1.128714536
	 Y -1.791599132
	 H -1.177880662
33 S
	 A 0.025721015
	 C -1.504226134
	 E 0.005882297
	 D -0.976962278
	 G -0.036135834
	 F -0.258372422
	 I -1.225007314
	 K -0.040180196
	 M -0.128876766
	 L -0.638674441
	 N -0.340335844
	 Q -0.014098500000000002
	 P -2.61056601
	 S nan
	 R -0.129307108
	 T -0.065115516
	 W -0.495673855
	 V -0.09729119
	 Y -0.39567102299999996
	 H -0.200224732
34 F
	 A -0.096847901
	 C -1.092

	 S -1.501669116
	 R -2.347803569
	 T -0.836128657
	 W -1.558419284
	 V -0.121958186
	 Y -1.6328803809999999
	 H -1.85153629
64 Q
	 A -0.00012030000000000001
	 C -0.273911584
	 E 0.029576106
	 D 0.077216808
	 G -0.08186091200000001
	 F -0.058579433
	 I -0.077210232
	 K 0.094749771
	 M -0.014628486000000001
	 L -0.004543879
	 N 0.027462082000000002
	 Q nan
	 P -1.9268623630000001
	 S 0.023680338999999998
	 R 0.016367382
	 T 0.023069693999999998
	 W 0.052154614
	 V -0.02038913
	 Y 0.087859613
	 H 0.015802767
65 L
	 A -1.51750046
	 C -2.106335815
	 E -2.122418487
	 D -2.198119647
	 G -1.746458061
	 F -1.065589855
	 I -0.292830677
	 K -1.9043932769999998
	 M -0.068121536
	 L nan
	 N -1.80739076
	 Q -1.814728086
	 P -1.452205367
	 S -1.6795782819999998
	 R -2.009935743
	 T -1.712095064
	 W -1.733268579
	 V -1.52692744
	 Y -1.435450717
	 H -1.515653905
66 G
	 A -0.011699145
	 C -0.975286612
	 E 0.041201227
	 D -0.005478861999999999
	 G nan
	 F -0.131499981
	 I -0.13522483300000002
	 K -0.108

	 T 0.015904985
	 W 0.068141592
	 V -0.175151046
	 Y 0.022056821
	 H 0.046338758
95 E
	 A -0.286135695
	 C -2.1288413530000003
	 E nan
	 D 0.050682024000000006
	 G -0.211259948
	 F -2.810730425
	 I -2.332929806
	 K -1.414133549
	 M -1.63992501
	 L -2.748624492
	 N -0.608434947
	 Q -0.055457407
	 P -2.8951336789999997
	 S -0.36466954100000004
	 R -1.7349685469999998
	 T -1.904197339
	 W -2.695865472
	 V -2.23644673
	 Y -2.12499127
	 H -0.909007435
96 L
	 A -1.656683027
	 C -1.956852199
	 E -2.6890916860000003
	 D -2.635226081
	 G -2.8045577510000004
	 F -1.8853995069999998
	 I -1.593683033
	 K -2.7327143830000002
	 M -0.31533879
	 L nan
	 N -2.926474207
	 Q -2.673516497
	 P -2.779291773
	 S -2.488623444
	 R -2.676454953
	 T -2.41338899
	 W -2.643231062
	 V -1.783373826
	 Y -2.901109553
	 H -2.436765693
97 C
	 A -1.3671009619999999
	 C nan
	 E -3.29997929
	 D -2.63960149
	 G -1.7945430709999999
	 F -2.977504548
	 I -1.655882633
	 K -3.1473116689999996
	 M -1.306918993
	 L -1.621818619000

	 C -0.17192907899999998
	 E -0.022203867000000002
	 D -0.055729847
	 G -0.020137481999999998
	 F -0.074061544
	 I 0.047126188
	 K nan
	 M -0.076435888
	 L -0.083766349
	 N -0.004878996
	 Q -0.012246845
	 P -0.127841549
	 S -0.054032555999999995
	 R -0.08119536299999999
	 T -0.074526532
	 W -0.261537109
	 V -0.046394863
	 Y -0.17401648399999997
	 H -0.009511754
121 E
	 A -0.027764072999999997
	 C -0.054622107
	 E nan
	 D -0.053888557999999996
	 G -0.009259429
	 F -0.010103721
	 I -0.19679012899999998
	 K -0.034367585
	 M 0.002160508
	 L -0.185923179
	 N -0.082234238
	 Q -0.051717017000000004
	 P -0.36504550799999996
	 S -0.019853986
	 R -0.086826265
	 T -0.100024051
	 W -0.182909331
	 V -0.058114107000000005
	 Y -0.131728189
	 H -0.07193353799999999
122 L
	 A -1.532918971
	 C -1.6835959569999999
	 E -2.67514971
	 D -2.8228956239999996
	 G -2.258073028
	 F -0.11858342699999999
	 I -0.366678145
	 K -2.628974197
	 M -0.136688725
	 L nan
	 N -2.571101493
	 Q -1.8801027080000001
	 P -3.3791

	 I -2.562807203
	 K -0.250014334
	 M -0.688151548
	 L -0.760679487
	 N -0.755379623
	 Q -0.097072605
	 P -2.9011935230000003
	 S -0.660139414
	 R nan
	 T -2.210755806
	 W -1.4253905690000002
	 V -2.112091666
	 Y -1.119486404
	 H -0.693165748
153 D
	 A -2.777766007
	 C -3.2068331110000003
	 E -2.8871319730000002
	 D nan
	 G -3.096966536
	 F -2.8454228230000003
	 I -2.9784870839999997
	 K -2.9624642110000003
	 M -2.7488243939999997
	 L -3.0365156489999996
	 N -3.039174838
	 Q -2.999547418
	 P -2.955764358
	 S -3.044305848
	 R -3.254302042
	 T -3.24570132
	 W -3.193444419
	 V -3.264237045
	 Y -2.3811209509999998
	 H -3.226945425
154 T
	 A -2.022396221
	 C -1.6754330919999998
	 E -2.343746274
	 D -2.8637175839999998
	 G -2.549241462
	 F -2.86912198
	 I -2.0425981159999997
	 K -2.986548055
	 M -2.8319181189999996
	 L -2.776831022
	 N -2.2611833740000002
	 Q -2.094211605
	 P -2.8639030869999997
	 S -0.131550732
	 R -3.014521567
	 T nan
	 W -3.2135716089999997
	 V -2.349552164
	 Y -2.9465325

	 N -2.7640742719999998
	 Q -2.7367553489999996
	 P -2.968101119
	 S -3.040917185
	 R -3.2162640689999997
	 T -2.364069586
	 W -2.016528177
	 V -1.66022784
	 Y -2.834558335
	 H -2.728694618
182 I
	 A -0.09558502199999999
	 C -0.120373889
	 E -0.46814960299999997
	 D -2.292964037
	 G -0.507712402
	 F -0.115699014
	 I nan
	 K -0.166908908
	 M -0.031823082
	 L -0.062484205999999994
	 N -0.545276258
	 Q -0.043686221
	 P -2.7935813310000004
	 S -0.109283774
	 R -0.09609221300000001
	 T -0.092170431
	 W -0.059908443
	 V -0.052187525
	 Y -0.10040778800000001
	 H -0.490274334
183 D
	 A -0.024022197000000002
	 C -0.08019127
	 E 0.004128037
	 D nan
	 G -0.016034947
	 F -0.184123466
	 I -0.183733321
	 K -0.068487126
	 M -0.025048737
	 L -0.064457659
	 N -0.010855562
	 Q -0.072254373
	 P -2.74751331
	 S 0.010865086999999999
	 R -0.013780093
	 T -0.051384057000000004
	 W -0.10318134699999999
	 V -0.15558870000000002
	 Y -0.089946259
	 H -0.02671686
184 W
	 A -1.916460922
	 C -2.212615247
	 E -2.293

	 T -2.7482269919999998
	 W -2.753353303
	 V -2.719604968
	 Y -2.4095241659999997
	 H -2.819360412
209 S
	 A -1.023786368
	 C -1.777677999
	 E -2.908854686
	 D -2.815419103
	 G -0.302187877
	 F -2.679054189
	 I -2.5831744519999997
	 K -2.429145416
	 M -2.495404731
	 L -2.653834687
	 N -2.837440603
	 Q -2.395784855
	 P -1.8622693369999999
	 S nan
	 R -3.0796995989999996
	 T -0.036772807000000005
	 W -2.514866361
	 V -2.8191033360000004
	 Y -2.7167121219999997
	 H -2.540929603
210 G
	 A -0.451187719
	 C -1.561572545
	 E -2.321554784
	 D -2.715611985
	 G nan
	 F -2.606707788
	 I -2.648282175
	 K -3.099599421
	 M -2.433627533
	 L -2.699437057
	 N -2.9832858589999995
	 Q -2.478982746
	 P -3.176723071
	 S -1.605763109
	 R -2.901957822
	 T -3.3294693669999997
	 W -2.90549311
	 V -2.74767668
	 Y -2.784599127
	 H -3.010996295
211 A
	 A nan
	 C -2.08141988
	 E -2.072320436
	 D -2.115127014
	 G -0.19590627800000002
	 F -2.7314868010000004
	 I -1.6415324690000002
	 K -2.4592008780000003
	 M -1.508

	 A 0.007296149
	 C -0.045596657
	 E 0.02715064
	 D 0.018464877
	 G 0.021784772999999997
	 F -0.072124116
	 I 0.013407068000000001
	 K 0.031578232000000005
	 M 0.012483781999999999
	 L 0.009050001
	 N 0.057900198
	 Q nan
	 P -0.002765732
	 S 0.025386272
	 R 0.011668405
	 T 0.016477116
	 W -0.033193385
	 V -0.015321831999999999
	 Y 0.036114718
	 H -0.011547412
242 A
	 A nan
	 C 0.001832707
	 E 0.035025403
	 D -0.064810058
	 G -0.009499068000000001
	 F -0.0305753
	 I -0.24631791600000003
	 K -0.193697368
	 M 0.062910106
	 L -0.008888088
	 N 0.08711741
	 Q -0.066046697
	 P -1.653566087
	 S -0.018780831
	 R -0.0033143129999999997
	 T 0.018746703
	 W -0.663916399
	 V -0.017698685
	 Y -0.05505783400000001
	 H -0.00282678
243 T
	 A -0.036951636
	 C -0.12178815400000001
	 E -0.06588998900000001
	 D -0.036252419
	 G -0.024747957
	 F -0.90864056
	 I -0.28823761800000003
	 K -0.121983019
	 M -0.172746861
	 L -0.154958829
	 N 0.07182240200000001
	 Q -0.078434605
	 P -0.002425347
	 S 0.006278958
	 