# Color scheme for trees

These functions were used to create color scheme for phylogenetic trees that will be further visualized using ggtree package (see Figures_code.R).

In [None]:
import argparse
import os
from colour import Color
import seaborn as sns
import distinctipy
import pandas as pd
import string

'''
Caclulates colors for taxa names so that taxa labels will form gradient in a phylogenetic tree
# taxa_file_name is a text file with ordered taxa names
'''
def create_color_dict(taxa_file_name):

    # reading file with taxa list
    with open(taxa_file_name) as taxa_file:
        list_taxa = [line.strip('\n') for line in taxa_file.readlines()]
    taxa_file.close()

    # create hex codes for gradient colors
    color1 = Color("blue")
    color2 = Color("red")
    grad = list(color2.range_to(color1, len(list_taxa)))
    grad_hex = [x.get_hex() for x in grad]
    grad_hex_cor = []
    for hex_code in grad_hex:
        if len(hex_code) == 4:
            hex_code = '#' + 2 * hex_code[1] + 2 * hex_code[2] + 2 * hex_code[3] 
        grad_hex_cor.append(hex_code)
    # create color dict
    dict_color = {}
    for i in range(len(list_taxa)):
        dict_color[list_taxa[i]] = grad_hex_cor[i]
    return dict_color

'''
Caclulates colors for taxa names so that taxa labels will form gradient in a phylogenetic tree. 
This function assumes that there may be multiple tree clades to be colored in different colors.

taxa_file_names is a text file with file names that correspond to different user defined clades. 
Each file is a text file with ordered taxa names.
'''
def create_color_dict_many(taxa_file_names):
    #colors = ['green', 'purple', 'red', 'blue', 'orange']
    colors = distinctipy.get_colors(len(taxa_file_names) ,pastel_factor=0)
    if len(taxa_file_names) == 1:
        # reading file with taxa list
        create_color_dict(taxa_file_names[0])
    else:
        if len(taxa_file_names) > len(colors):
            print('The number of colors is smaller than the number of colour dictionaries. \
            Please, add more colours to the list')
        else:
            lists_taxa_names = []
            dict_color = {}
            for taxa_file_name in taxa_file_names:
                print(taxa_file_name)
                with open(taxa_file_name) as taxa_file:
                    list_taxa = [line.strip('\n') for line in taxa_file.readlines()]
                lists_taxa_names.append(list_taxa)
                taxa_file.close()
            #print(lists_taxa_names)
            for i in range(len(lists_taxa_names)):
                color1 = Color(rgb=colors[i])
                color2 = Color(rgb=colors[i],luminance=0.8,saturation=0.3)
                grad = list(color2.range_to(color1, len(lists_taxa_names[i])))
                grad_hex = [x.get_hex() for x in grad]
                grad_hex_cor = []
                for hex_code in grad_hex:
                    if len(hex_code) == 4:
                        hex_code = '#' + 2 * hex_code[1] + 2 * hex_code[2] + 2 * hex_code[3] 
                    grad_hex_cor.append(hex_code) 
                for j in range(len(lists_taxa_names[i])):
                    dict_color[lists_taxa_names[i][j]] = grad_hex_cor[j]
    return dict_color

'''
Creates list of strings
'''
def lexstrings(max_length: int, alphabet=string.ascii_lowercase):
    yield ""
    if max_length == 0: return
    for first in alphabet:
        for suffix in lexstrings(max_length - 1, alphabet=alphabet):
            yield first + suffix

In [None]:
# Creates dictionary with colors for taxa names
file_taxa_name = "../data/color_order_16_orf1b.txt"
with open(file_taxa_name) as file_taxa:
    lists_taxa = [x.strip('\n') for x in file_taxa.readlines()]
dict_color = create_color_dict_many(lists_taxa)

In [None]:
# Add colors to metadata
df_wg = pd.read_csv("../data/MAV_wg_metadata.csv")
ids = list(df_wg['GBAC'])

dict_color_sep = {}

for key in dict_color:
    key_new = '/'.join(key.replace('NC_', 'NC.').split('_')).replace('NC.', 'NC_')
    dict_color_sep[key_new] = dict_color[key]
for rec_id in ids:
    if rec_id not in dict_color_sep.keys():
        dict_color_sep[rec_id] = '#000000'


df_color = pd.DataFrame.from_dict(dict_color_sep, orient='index').reset_index().set_axis(['GBAC', 'color_orf1b16'], axis=1)#,columns=['GBAC', 'color_orf1b17'])
df_wg = df_wg.set_index('GBAC').join(df_color.set_index('GBAC'))
df_wg['num'] = list(range(1, len(df_wg)+1))
codes = list(lexstrings(4, alphabet="abcde"))[1:len(df_wg)+1]
df_wg['code'] = codes
df_wg.to_csv("../data/MAV_wg_metadata_upd.csv")