In [28]:
import os
import pandas as pd
import numpy as np 


from Bio import Entrez
from Bio import SeqIO


from scipy.cluster import hierarchy
from scipy.cluster.hierarchy import fcluster

from sklearn.cluster import AgglomerativeClustering


import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.cm as cm


pd.set_option('display.max_columns', None)
pd.set_option('display.max_rows', None)
pd.set_option('display.max_colwidth', None)

In [29]:
def find_mash_dist_file(folder):
    folder = os.path.abspath(folder)
    files = os.listdir(folder)
    mash_file = [file for file in files if 'mash.distances.tab' in file][0]
    mash_file_path = os.path.join(folder, mash_file)
    return mash_file_path
    

In [34]:
mash_file = find_mash_dist_file('/home/pjsola/TMP/mashclust_test/')

In [35]:
mash_file

'/home/pjsola/TMP/mashclust_test/K10339_conda_52.coverage_adapted_filtered_80_term.mash.distances.tab'

In [42]:
def mash_dist_to_pairwise(distance_file, distance_type='hash_distance'):
    df = pd.read_csv(distance_file, sep='\t', names=['reference_ID', 'query_ID', 'distance', 'p_value', 'shared_hashes'])
    df[['hash_1', 'hash_2']] = df['shared_hashes'].str.split('/', expand=True)
    df.hash_1 = df.hash_1.astype(float)
    df.hash_2 = df.hash_2.astype(float)
    df['hash_distance'] = 1 - (df.hash_1 / df.hash_2)
    dfpair = df[['reference_ID', 'query_ID', distance_type]]
    
    return dfpair

In [86]:
def pairwise_to_matrix(pwdf):
    dist_matrix = pwdf.groupby(['reference_ID', 'query_ID']).mean().unstack()
    dist_matrix = dist_matrix.droplevel(0, axis=1)
    return dist_matrix

In [46]:
pairwise_distance = mash_dist_to_pairwise(mash_file)

In [47]:
pairwise_distance.head()

Unnamed: 0,reference_ID,query_ID,hash_distance
0,NZ_CP026852.1,NZ_CP026852.1,0.0
1,NZ_CM004622.1,NZ_CP026852.1,0.996
2,NZ_CP035364.1,NZ_CP026852.1,1.0
3,NZ_CP035347.1,NZ_CP026852.1,1.0
4,NZ_LR025100.1,NZ_CP026852.1,0.997


In [87]:
distamce_matrix = pairwise_to_matrix(pairwise_distance)

In [88]:
distamce_matrix.head()

query_ID,CP029217.1,LR134132.1,NZ_CM004622.1,NZ_CM017034.1,NZ_CM017091.1,NZ_CM017179.1,NZ_CP011334.1,NZ_CP015133.1,NZ_CP015501.1,NZ_CP018443.1,NZ_CP020850.1,NZ_CP023947.1,NZ_CP024516.1,NZ_CP026852.1,NZ_CP027151.1,NZ_CP027154.1,NZ_CP027156.1,NZ_CP027163.1,NZ_CP028818.1,NZ_CP029222.1,NZ_CP031580.1,NZ_CP034282.1,NZ_CP035347.1,NZ_CP035364.1,NZ_CP036325.1,NZ_CP041935.1,NZ_CP042869.1,NZ_CP044032.1,NZ_CP045282.1,NZ_LR025100.1,NZ_LT904874.1,WMHT01000002.1
reference_ID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1,Unnamed: 26_level_1,Unnamed: 27_level_1,Unnamed: 28_level_1,Unnamed: 29_level_1,Unnamed: 30_level_1,Unnamed: 31_level_1,Unnamed: 32_level_1
CP029217.1,0.0,0.996,1.0,1.0,1.0,1.0,1.0,1.0,0.999,1.0,1.0,0.999,0.999,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,0.999,1.0,1.0,1.0,0.999,1.0,1.0,1.0,1.0,1.0,0.999
LR134132.1,0.996,0.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0
NZ_CM004622.1,1.0,1.0,0.0,1.0,0.563,0.476,0.674,1.0,1.0,0.998,0.989,0.973,0.981,0.996,1.0,1.0,0.995,1.0,0.998,1.0,1.0,0.978,1.0,0.999,1.0,1.0,1.0,0.998,0.997,0.998,1.0,1.0
NZ_CM017034.1,1.0,1.0,1.0,0.0,1.0,1.0,0.913,1.0,1.0,1.0,1.0,0.987,0.989,1.0,1.0,1.0,0.997,1.0,1.0,1.0,1.0,0.989,1.0,1.0,1.0,0.994,1.0,1.0,0.999,1.0,1.0,1.0
NZ_CM017091.1,1.0,1.0,0.563,1.0,0.0,0.207,0.882,1.0,1.0,1.0,1.0,0.973,0.981,1.0,1.0,1.0,0.989,1.0,1.0,1.0,1.0,0.978,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0


In [52]:
pairwise_distance.columns[2]

'hash_distance'

In [66]:
def pairwise_to_cluster(pw,threshold = 0.5):
    groups = {}
    sorted_df = pw[(pw[pw.columns[0]] != pw[pw.columns[1]]) & (pw[pw.columns[2]] <= threshold)].sort_values(by=[pw.columns[2]])
    
    def rename_dict_clusters(cluster_dict):
        reordered_dict = {}
        for i, k in enumerate(list(cluster_dict)):
            reordered_dict[i] = cluster_dict[k]
        return reordered_dict
    
    def regroup_clusters(list_keys, groups_dict, both_samples_list):
        #sum previous clusters
        list_keys.sort()
        new_cluster = sum([groups_dict[key] for key in list_keys], [])
        #add new cluster
        cluster_asign = list(set(new_cluster + both_samples_list))
        #Remove duped cluster
        first_cluster = list_keys[0]
        groups_dict[first_cluster] = cluster_asign
        rest_cluster = list_keys[1:]
        for key in rest_cluster:
            del groups_dict[key]
        groups_dict = rename_dict_clusters(groups_dict)
        return groups_dict
        
    for index, row in sorted_df.iterrows():
        group_number = len(groups)
        cluster_name = 'cluster_' + str(group_number)

        sample_1 = str(row[0])
        sample_2 = str(row[1])
        both_samples_list = row[0:2].tolist()
                
        if group_number == 0:
            groups[group_number] = both_samples_list
        
        all_samples_dict = sum(groups.values(), [])
                
        if sample_1 in all_samples_dict or sample_2 in all_samples_dict:
            #extract cluster which have the new samples
            key_with_sample = {key for (key,value) in groups.items() if (sample_1 in value or sample_2 in value)}
            
            cluster_with_sample = list(key_with_sample)
            cluster_with_sample_name = cluster_with_sample[0]
            number_of_shared_clusters = len(key_with_sample)
            if number_of_shared_clusters > 1:
                groups = regroup_clusters(cluster_with_sample, groups, both_samples_list)
            else:
                groups[cluster_with_sample_name] = list(set(groups[cluster_with_sample_name] + both_samples_list))
        else:
            groups[group_number] = both_samples_list
            
    for index, row in pw[(pw[pw.columns[0]] != pw[pw.columns[1]]) & (pw[pw.columns[2]] > threshold)].iterrows():
        sample_1 = str(row[0])
        sample_2 = str(row[1])
        all_samples_dict = sum(groups.values(), [])
        if sample_1 not in all_samples_dict:
            group_number = len(groups)
            groups[group_number] = [sample_1]
        
        if sample_2 not in all_samples_dict:
            group_number = len(groups)
            groups[group_number] = [sample_2]
            
    cluster_df = pd.DataFrame(groups.values(),index=list(groups))
    
    cluster_df_return = cluster_df.stack().droplevel(1).reset_index().rename(columns={'index': 'group', 0: 'id'})
            
    return cluster_df_return

In [96]:
clusters = pairwise_to_cluster(pairwise_distance,threshold = 0.5)

In [97]:
clusters.head()

Unnamed: 0,group,id
0,0,NZ_CP028818.1
1,0,NZ_LR025100.1
2,0,NZ_CP018443.1
3,0,NZ_CP044032.1
4,1,NZ_CP023947.1


In [98]:
clusters.groupby('group')['id'].apply(list).reset_index(name='accs')

Unnamed: 0,group,accs
0,0,"[NZ_CP028818.1, NZ_LR025100.1, NZ_CP018443.1, NZ_CP044032.1]"
1,1,"[NZ_CP023947.1, WMHT01000002.1, NZ_CP024516.1, NZ_CP034282.1]"
2,2,"[NZ_CM017091.1, NZ_CM004622.1, NZ_CM017179.1]"
3,3,"[NZ_CP015501.1, NZ_CP041935.1]"
4,4,[NZ_CP026852.1]
5,5,[NZ_CP035364.1]
6,6,[NZ_CP035347.1]
7,7,[NZ_CP042869.1]
8,8,[NZ_CP045282.1]
9,9,[NZ_CP011334.1]


In [92]:
seqnames = ['NZ_CP023947.1', 'WMHT01000002.1','NZ_CP024516.1', 'NZ_CP034282.1']
distamce_matrix.filter(seqnames, axis=1).filter(seqnames, axis=0)

query_ID,NZ_CP023947.1,WMHT01000002.1,NZ_CP024516.1,NZ_CP034282.1
reference_ID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
NZ_CP023947.1,0.0,0.551,0.55,0.455
WMHT01000002.1,0.551,0.0,0.15,0.393
NZ_CP024516.1,0.55,0.15,0.0,0.333
NZ_CP034282.1,0.455,0.393,0.333,0.0


In [None]:
def acc_to_len_id(row):
    accession_number = row.aacc
    Entrez.email = "A.N.Other@example.com"
    try:
        handle = Entrez.efetch(db="nucleotide", id=accession_number, rettype="fasta", retmode="text")
        record = SeqIO.read(handle, "fasta")
        handle.close()
        #print("Downloaded: " + record.description)
        #print("Downloaded: " + str(len(record)))
        #return record
        description = ' '.join(record.description.split()[3:])
        species = ' '.join(record.description.split()[1:3])
        return len(record), species, description
    except:
        print(record.id + " failed to download")
        sys.exit(1)
    
    '''    
    #SeqIO.write(record, output_handle, "fasta")
    '''