In [1]:
import pandas as pd
import json
import matplotlib.pyplot as plt
import numpy as np
import datetime

In [2]:
# Not sure all of these are needed
from tokenizers import Tokenizer
from tokenizers.models import BPE
from sklearn.feature_extraction.text import CountVectorizer

### Read SILVA data

In [3]:
# Heavy file, takes a minute
silva_db = pd.read_table('C:/Users/efrat/Downloads/silva_138_release/Exports/SILVA_parsed_V2.tsv', index_col = 0,
                        dtype={'raw_id': str, 
                               "full_taxonomy": str, 
                               "seq_length": int,
                               "seq": str, 
                               "kingdom": str, 
                               "phylum": str, 
                               "class": str, 
                               "order": str, 
                               "family": str,
                               "genus": str,
                               "species": str,
                               "strain": str})
silva_db.head()

Unnamed: 0,raw_id,full_taxonomy,seq_length,seq,kingdom,phylum,class,order,family,genus,species,strain
0,HG531388.1.1375,Bacteria;Proteobacteria;Alphaproteobacteria;Rh...,1375,AGUCGAGCGGGCGCAGCAAUGCGUCAGCGGCAGACGGGUGAGUAAC...,Bacteria,Proteobacteria,Alphaproteobacteria,Rhizobiales,Xanthobacteraceae,Rhodoplanes,Rhodoplanes oryzae,
1,HL281785.3.1301,Bacteria;Bacteroidota;Bacteroidia;Bacteroidale...,1299,AUUCCGGGAUAGCCUUUCGAAAGAAAGAUUAAUACUGGAUAGCAUA...,Bacteria,Bacteroidota,Bacteroidia,Bacteroidales,Bacteroidaceae,Bacteroides,unidentified,
2,AB002644.1.1485,Bacteria;Firmicutes;Bacilli;Bacillales;Bacilla...,1485,GGCUAAUACAUGCAAGUCGAGCGAGUGAACAAACAGAAGCCUUCGG...,Bacteria,Firmicutes,Bacilli,Bacillales,Bacillaceae,Bacillus,low G+C Gram-positive bacterium HTA454,
3,AB002648.1.1383,Bacteria;Firmicutes;Bacilli;Thermoactinomyceta...,1383,AGCGGCGAACGGGUGAGUAACACGNGGGUAACCUGCCCUCAAGACC...,Bacteria,Firmicutes,Bacilli,Thermoactinomycetales,Thermoactinomycetaceae,Thermoflavimicrobium,low G+C Gram-positive bacterium HTA1422,
4,JN049459.1.1443,Bacteria;Actinobacteriota;Actinobacteria;Strep...,1443,GACAUGGCGCCUCUACCAUGCAGUCGACGAUGACCACCUUCGGGGU...,Bacteria,Actinobacteriota,Actinobacteria,Streptomycetales,Streptomycetaceae,Streptomyces,actinobacterium ZXY010,


In [4]:
# Create list of all sequences in our data and convert to lowercase
seqs = silva_db.seq.values
seqs = [s.lower() for s in seqs]
#seqs[0:2]
print(len(seqs),"sequences total in db")

432033 sequences total in db


### Read BPE tokens

In [5]:
# Load the BPE tokenizer
vocab = "C:/Users/efrat/Documents/DNA_BERT_Data/tokens/vocab_15K/vocab.json"
merges = "C:/Users/efrat/Documents/DNA_BERT_Data/tokens/vocab_15K/merges.txt"
tokenizer_15K = Tokenizer(BPE())
tokenizer_15K.model = BPE(vocab, merges)
vocab_dict = json.load(open(vocab))

### Get k-mer tokens

In [6]:
# Helper functions
# Given a sequence and k-length, return the unique set of k-mers
def get_kmers_seq(seq, k):
    # Start with an empty set
    kmers = set()
    # Calculate how many kmers of length k there are
    num_kmers = len(seq) - k + 1
    # Loop over the kmer start positions
    for i in range(num_kmers):
        # Slice the string to get the kmer
        kmer = seq[i:i+k]
        # Add the kmer to the dictionary if it's not there
        kmers.add(kmer)
    # Return the final counts
    return kmers

def get_kmers(seqs, k):
    kmers = set()
    n = len(seqs)
    i = 0
    for seq in seqs:
        i += 1
        if i%1000 == 0:
            print(".", end=' ')
        kmers = kmers.union(get_kmers_seq(seq, k))
    return kmers
    
def encode_kmers(seq, k):
    kmer_encoding = []
    num_kmers = len(seq) - k + 1
    for i in range(num_kmers):
        kmer = seq[i:i+k]
        kmer_encoding.append(kmer)
    return kmer_encoding

In [7]:
# Here we calculate the actual vocab' size per k = 8/10/12
#kmer_vocab_8 = get_kmers(seqs, 8)
#print("8-mer vocab size:", len(kmer_vocab_8)) # --> 259339
#kmer_vocab_10 = get_kmers(seqs, 10)
#print("10-mer vocab size:", len(kmer_vocab_10)) # --> 1611004
#kmer_vocab_6 = get_kmers(seqs, 6)
#print("6-mer vocab size:", len(kmer_vocab_6)) # --> 15621
#kmer_vocab_4 = get_kmers(seqs, 4)
#print("4-mer vocab size:", len(kmer_vocab_4)) # --> 625

### Create a matrix of tokens-in-sequences
The matrix (all_encodings_cv) will be sparse. 

In [8]:
# Use sklearn to batch encode all sequences and convert to a matrix (takes a few minutes)
print(datetime.datetime.now())
cv = CountVectorizer(tokenizer = lambda x: tokenizer_15K.encode(x).tokens, lowercase = True)
all_encodings_cv_bpe = cv.fit_transform(seqs)
token_strings_bpe = cv.get_feature_names()
print(all_encodings_cv_bpe.shape)
print(datetime.datetime.now())

2020-07-19 10:04:55.990637
(432033, 15620)
2020-07-19 10:13:36.145018


Same for kmer-based tokens

In [9]:
print(datetime.datetime.now())
cv_6mer = CountVectorizer(tokenizer = lambda x: encode_kmers(x, 6), lowercase = True)
all_encodings_cv_6mer = cv_6mer.fit_transform(seqs)
token_strings_6mer = cv_6mer.get_feature_names()
print(all_encodings_cv_6mer.shape)
print(datetime.datetime.now())

2020-07-19 10:13:36.152964
(432033, 15621)
2020-07-19 10:20:11.005341


## Tokens analysis

Per token, we collect the following statistics:
* token + token id (identifiers)
* token_length
* total_num_of_appearances

And for each taxonomic level ('phylum','class','order','family','genus','species','strain'):
* token_id
* tax_level 
* tax_name
* num_appearances

These will help us understand which tokens are widely shared (and not short, so not trivial) and which are clade specific.

### Collect statistics

Helper functions

In [20]:
# For each token, get its total count of appearances (a token may appear multiple times in a single sequence)
def get_token_appearances(encodings_cv):
    token_appearances = encodings_cv.sum(0).tolist()
    token_appearances = [i for sublist in token_appearances for i in sublist]
    return token_appearances

# For each token, get the number of unique sequences it appears in 
def get_token_appearances_uniq(encodings_cv):
    token_appearances_uniq = (encodings_cv != 0).sum(0).tolist()
    token_appearances_uniq = [i for sublist in token_appearances_uniq for i in sublist]
    return token_appearances_uniq

# For each token, collect clade-specific stats
# (Currently takes several minutes)
def get_clade_spec_stats(encodings_cv, column_names, calc_stats_flags, silva_db):
    nseqs = encodings_cv.shape[0]
    # Init' empty dataset
    token_taxon_stats = pd.DataFrame(columns=['token','tax_level','tax_name','n_occurrences']) 

    # Iterate over tokens, only if flag calc_stats_flags = True
    for i in range(0, len(column_names)):
        if i%100 == 0: print('.', end = '')
        if i%5000 == 0: print("Completed",i,"/",len(column_names),"tokens")
        if not calc_stats_flags[i]: continue
        current_token = column_names[i]
        seq_appearances_vec = encodings_cv.getcol(i)
        seq_appearances_vec = [j[0] for j in seq_appearances_vec.toarray()] # flatten
        # Get indices of relevant sequences
        seqs_ids = [j for j, val in enumerate(seq_appearances_vec) if val == 1] 
        # Get metadata rows for the sequences in above indices
        db_subset = silva_db.loc[seqs_ids, ]

        # Iterate over taxonomic levels to get level-specific stats
        for tax_level in ['phylum','class','order','family','genus','species','strain']:
            level_counts = db_subset.groupby(tax_level, as_index = False).agg({'raw_id':'count'})
            level_counts = level_counts[level_counts['raw_id'] > 4] # We remove entities with only few hits, to reduce noise
            level_counts.columns = ['tax_name', 'n_occurrences']
            level_counts['tax_level'] = tax_level
            level_counts['token'] = current_token
            #print(level_counts.head())
            token_taxon_stats = token_taxon_stats.append(level_counts)

    return token_taxon_stats

In [11]:
# Analyze bpe tokens (appearances stats)
bpe_token_appearances = get_token_appearances(all_encodings_cv_bpe)
# Sanity (should be same size as vocab size)
print(len(bpe_token_appearances))

bpe_token_appearances_uniq = get_token_appearances_uniq(all_encodings_cv_bpe)
# Sanity (should be same size as vocab size)
print(len(bpe_token_appearances_uniq))

15620
15620


In [12]:
# Analyze k-mer tokens (appearances stats)
kmer_token_appearances = get_token_appearances(all_encodings_cv_6mer)
# Sanity (should be same size as vocab size)
print(len(kmer_token_appearances))

kmer_token_appearances_uniq = get_token_appearances_uniq(all_encodings_cv_6mer)
# Sanity (should be same size as vocab size)
print(len(kmer_token_appearances_uniq))

15621
15621


In [13]:
# Organize stats into a single data frame - BPE
token_ids = [vocab_dict[t] for t in token_strings_bpe]
token_lengths = [len(t) for t in token_strings_bpe]
token_stats_bpe = pd.DataFrame(list(zip(token_strings_bpe, 
                                    token_ids,
                                    token_lengths, 
                                    bpe_token_appearances, 
                                    bpe_token_appearances_uniq)), 
                           columns =['token', 'token_id', 'token_length', 'total_appearances_mult_in_seq', 'total_seqs_in']) 
token_stats_bpe.head(10)

Unnamed: 0,token,token_id,token_length,total_appearances_mult_in_seq,total_seqs_in
0,a,0,1,20740,20626
1,aa,6,2,82693,72322
2,aaaa,326,4,49389,42937
3,aaaaac,4433,6,4242,4165
4,aaaaagc,4996,7,2373,2371
5,aaaaagg,9724,7,1528,1517
6,aaaaagu,11317,7,1253,1222
7,aaaaau,6347,6,2642,2635
8,aaaaaugacgguac,8190,14,1899,1899
9,aaaacagg,8369,8,1846,1843


In [33]:
# Same for k-mer
token_stats_6mer = pd.DataFrame(list(zip(token_strings_6mer, kmer_token_appearances, kmer_token_appearances_uniq)), 
                           columns =['token', 'total_appearances_mult_in_seq', 'total_seqs_in']) 
token_stats_6mer.head(10)

Unnamed: 0,token,total_appearances_mult_in_seq,total_seqs_in
0,aaaaaa,52262,38457
1,aaaaac,95941,89694
2,aaaaag,107978,97615
3,aaaaan,91,91
4,aaaaau,41319,38144
5,aaaaca,29827,27517
6,aaaacc,152322,130079
7,aaaacg,45603,39375
8,aaaacn,75,75
9,aaaacu,358216,289497


In [36]:
# We next want to collect stats about whether each token is clade-specific or not. 
# Optionally, to reduce runtime / focus on main tokens, we can calculate this for tokens with > 10 apearences only
print("For the BPE vocab, clade-specificity stats will be collected for",
      len(token_stats_bpe[token_stats_bpe.total_seqs_in >= 10]),
      "tokens instead of",len(token_stats_bpe))
print("For the k-mer vocab, clade-specificity stats will be collected for",
      len(token_stats_6mer[token_stats_6mer.total_seqs_in >= 10]),
      "tokens instead of",len(token_stats_6mer))

include_token_bpe = list(token_stats_bpe.total_seqs_in >= 10)
include_token_6mer = list(token_stats_6mer.total_seqs_in >= 10)

For the BPE vocab, clade-specificity stats will be collected for 15603 tokens instead of 15620
For the k-mer vocab, clade-specificity stats will be collected for 14325 tokens instead of 15621


In [18]:
# For each token, collect clade-specific stats
token_taxon_stats_bpe = get_clade_spec_stats(all_encodings_cv_bpe, token_strings_bpe, include_token_bpe, silva_db)

.Completed 0 / 15620 tokens
....................................................................................................Completed 10000 / 15620 tokens
........................................................

In [None]:
token_taxon_stats_6mer = get_clade_spec_stats(all_encodings_cv_6mer, token_strings_6mer, include_token_6mer, silva_db)

.Completed 0 / 15621 tokens
..................................................Completed 5000 / 15621 tokens
..................................................Completed 10000 / 15621 tokens
...........................................

In [32]:
token_stats_6mer.shape

(15620, 4)

In [23]:
# Save to file
token_taxon_stats_bpe.to_csv(path_or_buf = "C:/Users/efrat/Documents/DNA_BERT_Data/tokens/tokens_15K_tax_level_specific_stats.tsv", sep='\t')

In [None]:
token_taxon_stats_6mer.to_csv(path_or_buf = "C:/Users/efrat/Documents/DNA_BERT_Data/tokens/tokens_6mers_tax_level_specific_stats.tsv", sep='\t')

To label tokens as clade-specific, we apply the following logic, per taxonomic level (e.g. phyla):
- for each clade (e.g. each phylum) we count the number of sequences in that clade in which the token appeared (once or more)
- if more than 99% of the sequences containing the token belong to a single clade, we call the taxon clade-specific (we allow 3% for other clades to account for the noisy nature of the data, and the expected annotation mistakes).

In [24]:
def get_summarized_token_taxon_stats(token_stats, token_taxon_stats):
    # We'll take the max count per clade
    token_taxon_sum = token_taxon_stats.groupby(['token', 'tax_level'], as_index = False).agg({'n_occurrences':'max'})
    token_taxon_sum.columns = ['token', 'tax_level', 'max_occurrences_specific_clade']

    # Re-organize the table
    token_taxon_sum = token_taxon_sum.pivot_table(index=['token'], 
                                                  columns='tax_level', 
                                                  values='max_occurrences_specific_clade').reset_index()
    token_taxon_sum.columns = [c if c=="token" else "max_in_single_"+c for c in token_taxon_sum.columns]

    # Merge with the other stats
    token_stats2 = pd.merge(token_stats, token_taxon_sum, on='token')

    # And now we compute the max % of appearances in a single clade
    token_stats2['max_perc_in_single_class'] = token_stats2['max_in_single_class'] * 100.0 / token_stats2['total_seqs_in']
    token_stats2['max_perc_in_single_family'] = token_stats2['max_in_single_family'] * 100.0 / token_stats2['total_seqs_in']
    token_stats2['max_perc_in_single_genus'] = token_stats2['max_in_single_genus'] * 100.0 / token_stats2['total_seqs_in']
    token_stats2['max_perc_in_single_species'] = token_stats2['max_in_single_species'] * 100.0 / token_stats2['total_seqs_in']
    token_stats2['max_perc_in_single_phylum'] = token_stats2['max_in_single_phylum'] * 100.0 / token_stats2['total_seqs_in']
    token_stats2['max_perc_in_single_order'] = token_stats2['max_in_single_order'] * 100.0 / token_stats2['total_seqs_in']
    token_stats2['max_perc_in_single_strain'] = token_stats2['max_in_single_strain'] * 100.0 / token_stats2['total_seqs_in']

    # And label each token as clade-specific or not...
    token_stats2['clade_specific_class'] = token_stats2['max_perc_in_single_class'] > 99
    token_stats2['clade_specific_family'] = token_stats2['max_perc_in_single_family'] > 99
    token_stats2['clade_specific_genus'] = token_stats2['max_perc_in_single_genus'] > 99
    token_stats2['clade_specific_species'] = token_stats2['max_perc_in_single_species'] > 99
    token_stats2['clade_specific_phylum'] = token_stats2['max_perc_in_single_phylum'] > 99
    token_stats2['clade_specific_order'] = token_stats2['max_perc_in_single_order'] > 99
    token_stats2['clade_specific_strain'] = token_stats2['max_perc_in_single_strain'] > 99

    token_stats2['clade_specific_any'] = (token_stats2['clade_specific_strain'] | 
                token_stats2['clade_specific_species'] |
                token_stats2['clade_specific_genus'] |
                token_stats2['clade_specific_family'] |
                token_stats2['clade_specific_order'] |
                token_stats2['clade_specific_class'] |
                token_stats2['clade_specific_phylum'])

    # One additional statistic
    token_stats2['avg_occur_in_seq'] = token_stats2['total_appearances_mult_in_seq'] / token_stats2['total_seqs_in']
    
    return(token_stats2)

In [25]:
# Summarize - BPE stats 
token_stats_bpe2 = get_summarized_token_taxon_stats(token_stats_bpe, token_taxon_stats_bpe)

# Save
token_stats_bpe2.to_csv(path_or_buf = "C:/Users/efrat/Documents/DNA_BERT_Data/tokens/tokens_15K_tax_level_summarized_stats.tsv", sep='\t')
token_stats_bpe2.head()

Unnamed: 0,token,token_id,token_length,total_appearances_mult_in_seq,total_seqs_in,max_in_single_class,max_in_single_family,max_in_single_genus,max_in_single_order,max_in_single_phylum,...,max_perc_in_single_strain,clade_specific_class,clade_specific_family,clade_specific_genus,clade_specific_species,clade_specific_phylum,clade_specific_order,clade_specific_strain,clade_specific_any,avg_occur_in_seq
0,a,0,1,20740,20626,3741.0,1399.0,1554.0,1524.0,6348.0,...,,False,False,False,False,False,False,False,False,1.005527
1,aa,6,2,82693,72322,9796.0,3390.0,4973.0,3358.0,18807.0,...,,False,False,False,False,False,False,False,False,1.1434
2,aaaa,326,4,49389,42937,6057.0,2134.0,3164.0,2415.0,13196.0,...,,False,False,False,False,False,False,False,False,1.150267
3,aaaaac,4433,6,4242,4165,1319.0,465.0,305.0,468.0,2370.0,...,,False,False,False,False,False,False,False,False,1.018487
4,aaaaagc,4996,7,2373,2371,404.0,188.0,218.0,205.0,918.0,...,,False,False,False,False,False,False,False,False,1.000844


In [None]:
token_stats_bpe2[token_stats_bpe2['clade_specific_any']]

In [None]:
# Same for 8-mers (TODO)

### Analyze

(1) Token length vs. number of sequences appearing in

In [None]:
plt.figure(figsize=(9,5))
plt.rc('font', size=12)
plt.scatter(x = token_stats2['total_seqs_in'], 
            y = token_stats2['token_length'], 
            c = token_stats2['avg_occur_in_seq'], 
            alpha = 0.2,
            cmap = 'viridis')
plt.colorbar(label='Average occurrences per sequence')
plt.xlabel('Total sequences token is contained in')
plt.ylabel('Token length')
plt.show()

(2) numbers of clade-specific tokens

In [None]:
tmp = pd.melt(token_stats2, id_vars=['token'], 
              value_vars=['max_perc_in_single_class', 'max_perc_in_single_family', 
                          'max_perc_in_single_genus', 'max_perc_in_single_species',
                          'max_perc_in_single_phylum', 'max_perc_in_single_order',
                          'max_perc_in_single_strain'])
tmp['variable2'] = tmp['variable'].replace({'max_perc_in_single_':'Number of clade-specific tokens: '}, regex=True)
tmp = tmp[tmp['value'] > 97]
tmp.groupby('variable2', as_index = False)['value'].count()

In [None]:
fig = plt.figure(1, figsize=(6, 5))
ax = fig.add_subplot(111)
bp = ax.boxplot([token_stats2.loc[token_stats2['clade_specific_any'], ]['token_length'],
                 token_stats2.loc[~token_stats2['clade_specific_any'], ]['token_length']], patch_artist=True)
ax.set_xticklabels(['Clade specific tokens', 'Other tokens'])
ax.set_ylabel('Token length')