In [4]:
import os
import pandas as pd
import numpy as np 


from Bio import Entrez
from Bio import SeqIO


from scipy.cluster import hierarchy
from scipy.cluster.hierarchy import fcluster

from sklearn.cluster import AgglomerativeClustering


import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.cm as cm


pd.set_option('display.max_columns', None)
pd.set_option('display.max_rows', None)
pd.set_option('display.max_colwidth', None)

In [5]:
def find_mash_dist_file(folder):
    folder = os.path.abspath(folder)
    files = os.listdir(folder)
    mash_file = [file for file in files if 'mash.distances.tab' in file][0]
    mash_file_path = os.path.join(folder, mash_file)
    return mash_file_path
    

In [67]:
mash_file = find_mash_dist_file('/processing_Data/antibioticos/mperezv/ANALYSIS/Polimixinas_OTA/plasmidID/NO_GROUP/8c/mapping')

In [68]:
mash_file

'/processing_Data/antibioticos/mperezv/ANALYSIS/Polimixinas_OTA/plasmidID/NO_GROUP/8c/mapping/8c.coverage_adapted_filtered_80_term.mash.distances.tab'

In [6]:
def mash_dist_to_pairwise(distance_file, distance_type='hash_distance'):
    df = pd.read_csv(distance_file, sep='\t', names=['reference_ID', 'query_ID', 'distance', 'p_value', 'shared_hashes'])
    df[['hash_1', 'hash_2']] = df['shared_hashes'].str.split('/', expand=True)
    df.hash_1 = df.hash_1.astype(float)
    df.hash_2 = df.hash_2.astype(float)
    df['hash_distance'] = 1 - (df.hash_1 / df.hash_2)
    dfpair = df[['reference_ID', 'query_ID', distance_type]]
    
    return dfpair

In [7]:
def pairwise_to_matrix(pwdf):
    dist_matrix = pwdf.groupby(['reference_ID', 'query_ID']).mean().unstack()
    dist_matrix = dist_matrix.droplevel(0, axis=1)
    return dist_matrix

In [29]:
pw = mash_dist_to_pairwise(mash_file)

In [30]:
pw.head()

Unnamed: 0,reference_ID,query_ID,hash_distance
0,NZ_CP036436.1,NZ_CP036436.1,0.0
1,NC_032101.1,NZ_CP036436.1,1.0
2,NZ_CP018963.1,NZ_CP036436.1,0.833
3,NZ_CP032203.1,NZ_CP036436.1,1.0
4,NZ_CP035364.1,NZ_CP036436.1,1.0


In [32]:
columns = pw.columns.tolist()
pw[(pw[columns[0]] != pw[columns[1]]) & (pw[columns[2]] <= 0.5)].sort_values(by=[columns[2]]).shape

(64, 3)

In [28]:
def big_pairwise_to_cluster(pw,threshold = 0.5):
    
    def rename_dict_clusters(cluster_dict):
        reordered_dict = {}
        for i, k in enumerate(list(cluster_dict)):
            reordered_dict[i] = cluster_dict[k]
        return reordered_dict
    
    def regroup_clusters(list_keys, groups_dict, both_samples_list):
        #sum previous clusters
        list_keys.sort()
        new_cluster = sum([groups_dict[key] for key in list_keys], [])
        #add new cluster
        cluster_asign = list(set(new_cluster + both_samples_list))
        #Remove duped cluster
        first_cluster = list_keys[0]
        groups_dict[first_cluster] = cluster_asign
        rest_cluster = list_keys[1:]
        for key in rest_cluster:
            del groups_dict[key]
        groups_dict = rename_dict_clusters(groups_dict)
        return groups_dict
    
    groups = {}
    
    with open(pw, "r") as f:
        for line in f:
            line_split = line.split('\t')
            sample_1 = line_split[0]
            sample_2 = line_split[1]
            hash1, hash2 = line_split[4].split('/')
            hash_distance = 1 - (int(hash1) / int(hash2))

            if hash_distance <= threshold:
                group_number = len(groups)

                both_samples_list = [sample_1,sample_2]

                if group_number == 0:
                    groups[group_number] = both_samples_list

                all_samples_dict = sum(groups.values(), [])

                if sample_1 in all_samples_dict or sample_2 in all_samples_dict:
                    #extract cluster which have the new samples
                    key_with_sample = {key for (key,value) in groups.items() if (sample_1 in value or sample_2 in value)}

                    cluster_with_sample = list(key_with_sample)
                    cluster_with_sample_name = cluster_with_sample[0]
                    number_of_shared_clusters = len(key_with_sample)
                    if number_of_shared_clusters > 1:
                        groups = regroup_clusters(cluster_with_sample, groups, both_samples_list)
                    else:
                        groups[cluster_with_sample_name] = list(set(groups[cluster_with_sample_name] + both_samples_list))
                else:
                    groups[group_number] = both_samples_list
            else:
                if sample_1 not in all_samples_dict:
                    group_number = len(groups)
                    groups[group_number] = [sample_1]

                if sample_2 not in all_samples_dict:
                    group_number = len(groups)
                    groups[group_number] = [sample_2]
            
    cluster_df = pd.DataFrame(groups.values(),index=list(groups))
    
    cluster_df_return = cluster_df.stack().droplevel(1).reset_index().rename(columns={'index': 'group', 0: 'id'})
            
    return cluster_df_return

In [52]:
clusters = big_pairwise_to_cluster('/processing_Data/antibioticos/mperezv/ANALYSIS/Polimixinas_OTA/plasmidID/NO_GROUP/8c/mapping/8c.coverage_adapted_filtered_80_term.mash.distances.tab',threshold = 0.5)

In [53]:
clusters.groupby('group')['id'].apply(list).reset_index(name='accs')

Unnamed: 0,group,accs
0,0,"[NZ_CP036441.1, NZ_CP020075.1, NZ_CP036436.1]"
1,1,"[NZ_KX154765.1, NZ_CP041525.1, NZ_CP032203.1, NZ_CP018963.1, NC_032101.1]"
2,2,[NZ_CP035364.1]
3,3,[NZ_CP035347.1]
4,4,[NZ_LS998787.1]
5,5,[NZ_LS998788.1]
6,6,[NZ_CM017254.1]
7,7,[NZ_CP011334.1]
8,8,[NZ_CP010576.1]
9,9,"[NZ_CP033774.1, NZ_CP020499.1, NZ_CP041100.1, CP025817.1, NZ_LR130542.1, NZ_CP011577.1, NZ_CP021856.1, NZ_CP032209.1, NZ_CP026139.1]"


In [54]:
def calculate_seqlen(fasta_file):
    len_df = pd.DataFrame(columns=['id','length'])
    index = 0
    for seq_record in SeqIO.parse(fasta_file, "fasta"):
        len_df.loc[index, 'id'] = seq_record.id
        len_df.loc[index, 'length'] = len(seq_record)
        index = index + 1
    return len_df
    

In [55]:
fasta_file = '/processing_Data/antibioticos/mperezv/ANALYSIS/Polimixinas_OTA/plasmidID/NO_GROUP/8c/mapping/8c.coverage_adapted_filtered_80_term.fasta'

In [56]:
len_df = calculate_seqlen(fasta_file)

In [57]:
len_df.head()

Unnamed: 0,id,length
0,NZ_CP036436.1,19305
1,NC_032101.1,58120
2,NZ_CP018963.1,54644
3,NZ_CP032203.1,41865
4,NZ_CP035364.1,13142


In [58]:
"""def acc_to_len_id(row):
    accession_number = row.id
    Entrez.email = "A.N.Other@example.com"
    try:
        handle = Entrez.efetch(db="nucleotide", id=accession_number, rettype="fasta", retmode="text")
        record = SeqIO.read(handle, "fasta")
        handle.close()
        #print("Downloaded: " + record.description)
        #print("Downloaded: " + str(len(record)))
        #return record
        description = ' '.join(record.description.split()[3:])
        species = ' '.join(record.description.split()[1:3])
        return len(record), species, description
    except:
        print(record.id + " No record in NCBI, using local data")
        return len(record), 'Unknown', 'Unknown'
        
    #SeqIO.write(record, output_handle, "fasta")
"""

'def acc_to_len_id(row):\n    accession_number = row.id\n    Entrez.email = "A.N.Other@example.com"\n    try:\n        handle = Entrez.efetch(db="nucleotide", id=accession_number, rettype="fasta", retmode="text")\n        record = SeqIO.read(handle, "fasta")\n        handle.close()\n        #print("Downloaded: " + record.description)\n        #print("Downloaded: " + str(len(record)))\n        #return record\n        description = \' \'.join(record.description.split()[3:])\n        species = \' \'.join(record.description.split()[1:3])\n        return len(record), species, description\n    except:\n        print(record.id + " No record in NCBI, using local data")\n        return len(record), \'Unknown\', \'Unknown\'\n        \n    #SeqIO.write(record, output_handle, "fasta")\n'

In [59]:
#clusters[['length', 'species', 'description']] = clusters.apply(acc_to_len_id, axis=1, result_type="expand")

In [60]:
final_cluster = clusters.merge(len_df, on='id', how='left')

In [61]:
final_cluster.head()

Unnamed: 0,group,id,length
0,0,NZ_CP036441.1,19305
1,0,NZ_CP020075.1,22062
2,0,NZ_CP036436.1,19305
3,1,NZ_KX154765.1,50800
4,1,NZ_CP041525.1,31500


In [62]:
final_cluster.sort_values(by=['group', 'length'], ascending=[True, False]).groupby('group').head(1)

Unnamed: 0,group,id,length
1,0,NZ_CP020075.1,22062
7,1,NC_032101.1,58120
8,2,NZ_CP035364.1,13142
9,3,NZ_CP035347.1,10695
10,4,NZ_LS998787.1,27017
11,5,NZ_LS998788.1,18395
12,6,NZ_CM017254.1,40463
13,7,NZ_CP011334.1,2954
14,8,NZ_CP010576.1,35843
19,9,NZ_LR130542.1,202379


In [74]:
def extract_representative(row):
    row.clustered.remove(row.id)

In [75]:
def extract_length(row, final_cluster):
    lengths = [final_cluster['length'][final_cluster.id == idclust].tolist()[0] for idclust in row.clustered]
    return lengths

In [76]:
def extract_distance(reprensetative, list_clustered, mash_file):
    distances = []
    for idclust in list_clustered:
        with open(mash_file, "r") as f:
            for line in f:
                line_split = line.split('\t')
                sample_1 = line_split[0]
                sample_2 = line_split[1]
                if sample_1 == reprensetative and sample_2 == idclust:
                    hash1, hash2 = line_split[4].split('/')
                    hash_distance = 1 - (int(hash1) / int(hash2)) 
                    distances.append(hash_distance)
    return distances

In [82]:
def retrieve_fasta_cluster(fasta_file, final_cluster, output_dir, mash_file, kmerdist, quantity_id=1, save_clustered=False):
    input_file = os.path.abspath(fasta_file)
    file_prefix = input_file.split('/')[-1]
    prefix = ('.').join(file_prefix.split('.')[0:-1])
    
    output_representative = os.path.join(output_dir, prefix + '.' + str(kmerdist) + '.representative.fasta')
    output_clustered = os.path.join(output_dir, prefix +  '.' + str(kmerdist) + '.clustered.fasta')
    output_summary = os.path.join(output_dir, prefix + '.' + str(kmerdist) + '.clusters.tab')
    
    representative_id = final_cluster.sort_values(by=['group', 'length'], ascending=[True, False]).groupby('group').head(quantity_id)
    summary_id_grouped = final_cluster.groupby('group')['id'].apply(list).reset_index(name='clustered')
    representative_list = representative_id.id.tolist()
    representative_and_sumary = representative_id.merge(summary_id_grouped, on='group', how='left')
    #Use function extract_representative to remove the repr. from column
    representative_and_sumary.apply(extract_representative, axis=1)
    representative_and_sumary['lengths_clustered'] = representative_and_sumary.apply(lambda x: extract_length(x, final_cluster), axis=1)
    representative_and_sumary['distance_clustered'] = representative_and_sumary.apply(lambda x: extract_distance(x.id, x.clustered, mash_file), axis=1)
    #read the fasta and retrieve representative sequences
    representative_records = []
    clustered_records = []
    for seq_record in SeqIO.parse(fasta_file, "fasta"):
        if seq_record.id in representative_list:
            representative_records.append(seq_record)
        else:
            clustered_records.append(seq_record)
        
    SeqIO.write(representative_records, output_representative, "fasta")
    
    if not save_clustered == False:
        SeqIO.write(clustered_records, output_clustered, "fasta")
        
    representative_and_sumary.to_csv(output_summary, sep='\t', index=False)

    return representative_and_sumary

In [83]:
retrieve_fasta_cluster(fasta_file, final_cluster, '/processing_Data/antibioticos/mperezv/ANALYSIS/Polimixinas_OTA/plasmidID/NO_GROUP/8c/mapping', mash_file , 0.5, quantity_id=1, save_clustered=True)

Unnamed: 0,group,id,length,clustered,lengths_clustered,distance_clustered
0,0,NZ_CP020075.1,22062,"[NZ_CP036441.1, NZ_CP036436.1]","[19305, 19305]","[0.16400000000000003, 0.16500000000000004]"
1,1,NC_032101.1,58120,"[NZ_KX154765.1, NZ_CP041525.1, NZ_CP032203.1, NZ_CP018963.1]","[50800, 31500, 41865, 54644]","[0.562, 0.635, 0.474, 0.531]"
2,2,NZ_CP035364.1,13142,[],[],[]
3,3,NZ_CP035347.1,10695,[],[],[]
4,4,NZ_LS998787.1,27017,[],[],[]
5,5,NZ_LS998788.1,18395,[],[],[]
6,6,NZ_CM017254.1,40463,[],[],[]
7,7,NZ_CP011334.1,2954,[],[],[]
8,8,NZ_CP010576.1,35843,[],[],[]
9,9,NZ_LR130542.1,202379,"[NZ_CP033774.1, NZ_CP020499.1, NZ_CP041100.1, CP025817.1, NZ_CP011577.1, NZ_CP021856.1, NZ_CP032209.1, NZ_CP026139.1]","[184336, 178294, 194739, 101557, 130719, 146168, 161618, 86594]","[0.20699999999999996, 0.247, 0.247, 0.5349999999999999, 0.479, 0.41400000000000003, 0.33299999999999996, 0.712]"
