# Comparative genomics of DIR

How conserved are the proteins near iodate reductase (IriA?) across the phylogeny of IriA and IriA/AioA-like proteins?



#### Dependencies

The following must be installed and in your path:
- ncbi-genome-download: https://github.com/kblin/ncbi-genome-download
- cd-hit
- FastTree
- muscle

#### Methods

1. Download genomes
  - Download genomes from a list of RefSeq/GenBank accessions
2. Define phylogeny
  - Identify HMM hit in each genome
  - Generate phylogenetic tree of HMM hits
  - Define groups within the tree ("clades")
3. Define gene neighborhood composition
  - Obtain genes within +/- 10 positions of HMM hits ("gene neigborhoods")
  - Group proteins from gene neighborhoods by sequence similarity ("subfamilies")

#### Analysis

- Which subfamilies are conserved which clades?
- What the functions of those subfamilies?

In [None]:
import subprocess as sp
import numpy as np
import pandas as pd
import os
import csv
import re

os.chdir('/path/to/your/directory')

import matplotlib.pyplot as plt
plt.rcParams['svg.fonttype'] = 'none' # Editable SVG text
%matplotlib inline

# Download assemblies from list of RefSeq/GenBank accessions

RefSeq accessions start with "GCF_", while GenBank accessions start with "GCA_". If both accessions are available for a given genome, use the RefSeq accession.


Only a subset of available files are downloaded

In [None]:
download = False 
path_to_genomes = './data/genomes/'
if download == True:

    # Download
    for db in ['refseq', 'genbank']:
        accessions_to_download = path_to_genomes + db + '-accessions.txt'
        print(accessions_to_download)
        !ncbi-genome-download -s $db -F 'protein-fasta,features,assembly-report' \
        -A $accessions_to_download -m ./data/genomes/metadata.csv -o ./data/genomes all
    
    # Unzip any gzipped files
    path_to_downloads = path_to_genomes + '/*/*/*/*gz'
    print(path_to_downloads)
    !gunzip $path_to_downloads


# Define phylogeny

### Identify HMM hit in each genome

In [None]:
# Define input files

hmm = './data/hmm/combined_iriA_aioA.hmm'
threshold = 640


#######

import os

# Search every genome and record sequence and ID of hits

path_to_hmm = './data/hmm/'
tab = path_to_hmm + 'iriA-hits.out'
txt = path_to_hmm + 'iriA-hits.txt'
hits = path_to_hmm + 'iriA-hits.faa'

if 'iriA-hits.faa' in os.listdir(path_to_hmm):
    os.remove(hits)

hmmhits = {} # genomeid : geneid #longer list
path_to_hits = {} # accession : path #shorter list


# For each genome in each database
for db in ['refseq/', 'genbank/']:
    for domain in os.listdir(path_to_genomes + db): # 'bacteria/', 'archaea/'
        for accession in os.listdir(path_to_genomes + db + domain):
            # Search all protein-fasta files
            directory = path_to_genomes + db + domain + '/' + accession + '/'
            #path_to_hits[accession] = directory
            for protein_fasta in [directory + x for x in os.listdir(directory) if '.faa' in x]:
                
                # Hmmsearch
                hmmsearch = ' '.join(['hmmsearch', '--noali', '--tblout', tab, '-T', str(threshold), hmm, protein_fasta])
                sp.call(hmmsearch, shell=True)
                #print(hmmsearch)

                # Record hits
                hit_geneids = []
                with open(tab, 'r') as fh:
                    for line in fh.readlines():
                        if '#' not in line:
                            geneid = line.strip().split(' ')[0]
                            hit_geneids.append(geneid)

                # Record 
                hmmhits[accession] = []

                for geneid in hit_geneids:

                    # Append hits to .faa file
                    with open(txt, 'w') as fh:
                        fh.write(geneid)

                    hits_to_faa = ' '.join(["perl -ne 'if(/^>(\S+)/){$c=$i{$1}}$c?print:chomp;$i{$_}=1 if @ARGV'", txt, protein_fasta, ">>", hits])
                    sp.call(hits_to_faa, shell=True)

                    # Record path for future use
                    path_to_hits[accession] = directory
                    

                # Record Genome ID for each Gene ID
                hmmhits[accession].append(geneid)

In [None]:
### Harmonizes dictionaries so that no false directories are called later
for k in list(hmmhits.keys()):
    if k not in path_to_hits:
        del hmmhits[k]
with open('iodate_reducing_genomes.txt', 'w') as fh:
    for item in hmmhits: 
        fh.write(item + '\n')
print(len(hmmhits))
print(len(path_to_hits))
print('Both printed numbers should be equal, if they are not, the code will break!')

### Generate phylogenetic tree of HMM hits

In [None]:
faa_ingroup = hits
faa_outgroup = './data/tree/aioA_bigtree.faa' # Outgroup to idrA proteins
temp = './data/tree/temp.faa'
faa = './data/tree/iriA-all.faa' 
#If redrawing the tree--don't forget to erase both the temp and all files, otherwise you get duplicates

# Add outgroup
concat = ' '.join(['cat', faa_outgroup, faa_ingroup, '>>', faa])
sp.call(concat, shell=True)

# Format headers to be compatible with ete3 package
format_headers = ' '.join(["sed 's|(||g'", faa, "| sed 's|)||g' | sed 's|:|-|g' | sed 's|\*||' >", temp])
sp.call(format_headers, shell=True)

# Remove empty entries and duplicate entries from FASTA
drop_empties_and_duplicates =' '.join(['sh scripts/clean_fasta_entries.sh', temp, faa])
sp.call(drop_empties_and_duplicates, shell=True)

# Align
aln = './data/tree/iriA.aln'
align = ' '.join(['muscle', '-in', faa, '-out', aln])
sp.call(align, shell=True)

# Trim the alignment to, say, where >80% of positions aren't gaps
filt = 80
f_aln = './data/tree/iriA-'+ str(filt) + '.aln'
filter_gaps = ' '.join(['python3', './scripts/remove-gapped-positions.py', '-p', str(filt), '-i', aln, '-o', f_aln])
sp.call(filter_gaps, shell=True)

# Tree
tree = './data/tree/iriA.nwk' 
fasttree = ' '.join(['fasttree','-boot 10000', aln, ">", tree])
sp.call(fasttree, shell=True)

In [None]:
### Creates a dictionary with accession and names so the tree can be fully annotated later###
names_list = []
accession_list = []
with open('./data/tree/iriA-all.faa', 'r') as fh:
    for line in fh.readlines():
        if '>' in line:
            line = line.split('[')[1]
            line = line.split(']')[0]
            names_list.append(line)
with open('./data/tree/iriA-all.faa', 'r') as fh:    
    for accession in fh.readlines():
        if '>' in accession:
            accession = accession.split('>')[1]
            accession = accession.split(' ')[0]
            accession_list.append(accession)
     
    
##This is for the pruned tree...remove if needed
asc_names = dict(zip(accession_list, names_list))

### Plot phylogenetic tree

In [None]:
# Create Circular rooted tree
from ete3 import Tree, NodeStyle, TreeStyle, TextFace, PhyloTree

t = Tree(tree, format=0)
#t = PhyloTree(tree, format=0)

def style_tree(t, branch_support_text=False, branch_support_dots=True, dots_on_leaves=False):

    # Tree style
    ts = TreeStyle()
    ts.show_leaf_name = True
    ts.show_branch_support = branch_support_text
    
    
    # Node styles
    
    simple = NodeStyle()
    simple["shape"] = "circle"
    simple["size"] = 0
    simple["fgcolor"] = "black"
    
    h_support = NodeStyle()
    h_support["shape"] = "circle"
    h_support["size"] = 14
    h_support["fgcolor"] = 'black'
    
    m_support = NodeStyle()
    m_support["shape"] = "circle"
    m_support["size"] = 14
    m_support["fgcolor"] = '#939598' #  grey
    
    l_support = NodeStyle()
    l_support["shape"] = "circle"
    l_support["size"] = 14
    l_support["fgcolor"] = '#E6E7E8' # light grey    

    if branch_support_dots == True:
        
        # Circles colored by support values
        for node in t.traverse():
            
            support_val = node.support
            
            if support_val >= 0.99:
                node.set_style(h_support)
                
            elif support_val >= 0.90:
                node.set_style(m_support)
                
            elif support_val >= 0.80:
                node.set_style(l_support)
                
            else:
                node.set_style(simple)
                
            if dots_on_leaves == False:
                if node.is_leaf() == True:
                    node.set_style(simple)
    else:
        
        # No circles indicating support
        for node in t.traverse():
            node.set_style(simple)
        
    return ts

# Set outgroup from FAA
# http://etetoolkit.org/docs/latest/tutorial/tutorial_trees.html#tree-rooting
with open(faa_outgroup) as fh:
    # ID of outgroups used for tree construction
    outgroup = [line.strip().split(' ')[0].split('>')[1] for line in fh.readlines() if ">" in line]
ancestor = t.get_common_ancestor(outgroup[0],outgroup[2])
t.set_outgroup(ancestor)

### Get three or so outgroups from the aioA tree. HMM should be more inclusive outgroup further out ###


# Ladderize (sort tree)
t.ladderize()

#Check if in dictionary
def checkKey(dict, key): 
    if key in dict.keys():
        return dict[key]

# Record order of node leaves
order = []
for node in t.traverse('postorder'):
    if len(node.name) > 5: # Remove non-name nodes
        ID = node.name.strip()
        ID_new = checkKey(asc_names, ID) # 'KPDAEFOI_00687'
        order.append(ID)
        node.name = ID_new #ID + ' ' + ID_new
dict_order = dict(zip(order,range(len(order))))




# Style and plot
ts = style_tree(t, branch_support_text=False, branch_support_dots=True)
ts.show_leaf_name  = False
ts.mode = "c"
ts.arc_start = 0 # 0 degrees = 3 o'clock
ts.arc_span = 360


label_size = 20
for node in t.traverse():
    if node.is_leaf():

        # Show leaf labels
        name_face = TextFace(node.name, fgcolor="black", fsize=label_size, tight_text=True, ftype='Arial')
        node.add_face(name_face, column=0, position='branch-right')

#Shading
nst1 = NodeStyle()
nst1["bgcolor"] = "lightgreen"
n1 = t.get_common_ancestor("Halobiforma lacisalsi", "Halorubrum kocurii")
n1.set_style(nst1)

nst2 = NodeStyle()
nst2["bgcolor"] = "orchid"
n2 = t.get_common_ancestor("Dehalococcoidia bacterium", "Vibrio")
n2.set_style(nst2)

nst3 = NodeStyle()
nst3["bgcolor"] = 'silver'
n3 = t.get_common_ancestor("Microvirga guangxiensis", "Pyrobaculum ferrireducens")
n3.set_style(nst3)

ts.scale =  250
#ts.title.add_face(TextFace("Phylogeny of Iodate Reductases and Arsenite Oxidases", fsize=20))
t.set_outgroup('Pseudomonas aeruginosa PAO1 NasC')
t.show(tree_style=ts)
something = t.render('aioA_tree_publication.png', dpi=300, tree_style=ts)
"%%inline"

### Define groups within tree based on phylogenetic distance

In [None]:
import numpy as np


def cache_distances(tree):
    
    # Precalculate distances of all nodes to the root 
    # Reference: https://www.biostars.org/p/97409/
    node2rootdist = {tree:0}
    for node in tree.iter_descendants('preorder'):
        node2rootdist[node] = node.dist + node2rootdist[node.up]
    return node2rootdist

def collapse(tree, min_dist):
    

    
    # Collapse nodes by a minimum distance and remove children of collapsed nodes
    # Reference: https://www.biostars.org/p/97409/
    
    dict_collapsing = {} # old node : collapsed node
    
    # cache the tip content of each node to reduce the number of times the tree is traversed
    node2tips = tree.get_cached_content()
    root_distance = cache_distances(tree)

    # Traverse tree and iteratively collapse nodes
    N = -1
    for node in tree.get_descendants('preorder'):
        if not node.is_leaf():
            avg_distance_to_tips = np.mean([root_distance[tip]-root_distance[node]
                                         for tip in node2tips[node]])

            if avg_distance_to_tips < min_dist:
                # do whatever, ete support node annotation, deletion, labeling, etc.

                # rename
                N += 1
                collapsed_node_name = ("n" + str(N) + '_' + str(len(node2tips[node])))
                
                node.name = collapsed_node_name
                # Previous text:
                #+' COLLAPSED avg_d:%g {%s}' %(avg_distance_to_tips, ','.join([tip.name for tip in node2tips[node]]))
                
                #dict_collapsing[node.name] = [tip.name for tip in node2tips[node]]
                # Store collapsed nodes
                for tip in node2tips[node]:
                    if tip.is_leaf():
                        if tip.name in dict_collapsing.keys():
                            pass
                        else:
                            dict_collapsing[tip.name] = collapsed_node_name
                
                # label
                node.add_features(collapsed=True)

                # set drawing attribute so they look collapsed when displayed with tree.show()
                node.img_style['draw_descendants'] = False
                
            #if node in outgroup nodes:
                #Rename, store, label, etc.
        
    # Remove collapsed nodes (labeled) from tree
    for node in tree.search_nodes(collapsed=True):
        for ch in node.get_children():
            ch.detach()
    
    # Add non-collapsed nodes to dict
    for node in order:
        if node not in dict_collapsing.keys():
            dict_collapsing[node] = node

    
    return tree, dict_collapsing
          
    
tcoll = t.copy()                
tcoll, tree_node_groups = collapse(tcoll, 0.4) # Adjust value to desired group size
tcoll.ladderize() # Order by increasing node partition
tcoll = Tree(tcoll.write())

ts = style_tree(tcoll, branch_support_text=False, branch_support_dots=True)
tcoll.render("%%inline")

# Define neighborhood composition

### Obtain genes within +/- 10 positions of HMM hits ("gene neigborhoods")

In [None]:
import os
import subprocess as sp
import numpy as np
import pandas as pd

proximity = 10

def filter_by_relative_proximity(df):

    # Calculate proximity to reference gene for other genes on dffold
    df['Start Coord'] = df['Start Coord'].astype(int)
    df['End Coord'] = df['End Coord'].astype(int)
    startcoord_ref = int(df.at[hit,'Start Coord'])
    endcoord_ref = int(df.at[hit,'End Coord'])

    # Define baseline values for hit
    df.at[hit,'Rel Start Coord'] = 0
    df.at[hit,'Rel Strand'] = '+'

    # Depending on strand (top=+, bottom=-)
    sign = df.at[hit,'Strand']

    if sign == "+": # Subtract all coords from initial start coord
        vector = +1.0
        strand_key = {"+" : "+", "-" : "-"}
        df['Rel Start Coord'] = df['Start Coord'] - startcoord_ref
        df['Rel End Coord'] = (df['End Coord'] - startcoord_ref)
        df['Rel Strand'] = df['Strand'].map(strand_key)

    elif sign == "-": # -1 * Subtract all coords from initial end coord (true start)
        vector = -1.0
        strand_key = {"-" : "+", "+" : "-"} # Flip strands
        df['Rel Start Coord'] = vector*(df['End Coord'] - endcoord_ref)
        df['Rel End Coord'] = vector*(df['Start Coord'] - endcoord_ref)
        df['Rel Strand'] = df['Strand'].map(strand_key)

    # Assign Rel # below or above gene
    df = df.sort_values(by='Rel Start Coord')
    df.loc[(df['Rel Start Coord'] > 0), 'Rel #'] = range(1, 1 + len(df['Rel Start Coord'][(df['Rel Start Coord'] > 0)].tolist()), 1)
    df.loc[(df['Rel Start Coord'] < 0), 'Rel #'] = list(reversed(range(-1, -1 - len(df['Rel Start Coord'][(df['Rel Start Coord'] < 0)].tolist()), -1)))
    df.at[hit,'Rel #'] = 0

    # Filter by proximity
    df = df[(df['Rel #'] <= proximity) & (df['Rel #'] >= -1*proximity)]

    return df


for n, genome in enumerate(list(sorted(hmmhits.keys()))):       
    
    for FILE in os.listdir(path_to_hits[genome]):
        if "_feature_table.txt" in FILE and genome in FILE and "lock" not in FILE:
            feature_table = FILE
        if "_assembly_report.txt" in FILE and genome in FILE  and "lock" not in FILE:
            assembly_report = FILE
        if "protein.faa" in FILE and genome in FILE  and "lock" not in FILE:
            proteins = FILE


    # Extract data from genome*assembly_report
    strain = ''
    genbank = ''
    refseq = ''
    genome_length = 0
    metagenomic = False

    with open(path_to_hits[genome] + assembly_report, 'r+', encoding="utf-8") as fh:

        for line in fh.readlines():

            if '#' in line:
                # Name
                key = '# Organism name:  '
                if key in line:
                    line = line.strip().split(key)[1]
                    if r' (' in line:
                        line = line.split(r' (')[0]
                    genome_name = line

                # Strain
                elif '# Infraspecific name:  strain=' in line:
                    strain = line.strip().split('# Infraspecific name:  strain=')[1]
                elif '# Isolate:  ' in line:
                    strain = line.strip().split('# Isolate:  ')[1]
                elif '# Assembly name:  ' in line:
                    strain = line.strip().split('# Assembly name:  ')[1]

                # Genbank
                elif '# GenBank assembly accession: ' in line:
                    genbank = line.strip().split('# GenBank assembly accession: ')[1]

                # Refseq
                elif '# RefSeq assembly accession: ' in line:
                    refseq = line.strip().split('# RefSeq assembly accession: ')[1]

                else:
                    pass

            # Sequence length
            else:
                val = line.strip().split('\t')[8]
                if val == 'na':
                    val = 0
                contig_length = int(val)
                genome_length += contig_length

    # Extract genome information from genome*_feature_table.txt
    gen = pd.read_csv(path_to_hits[genome] + feature_table, sep='\t')
    gen = gen[(gen['# feature'] == 'CDS') & (gen['class'] != 'without_protein')]
    rename_columns = {'genomic_accession':'Scaffold ID',
                        'start': 'Start Coord',
                        'end' : 'End Coord',
                        'strand' : 'Strand',
                        'product_accession' : 'Gene ID',
                        'locus_tag' : 'Locus Tag',
                        'name' : 'Gene Product Name'}
    gen = gen.rename(columns=rename_columns)
    gen = gen.loc[:,rename_columns.values()]

    num_cds = len(gen.index)
    gen['Genome ID'] = [genome] * num_cds
    gen['Genome Name'] = [genome_name] * num_cds
    gen['Strain'] = [strain] * num_cds
    gen['Genbank Accession'] = [genbank] * num_cds
    gen['Refseq Accession'] = [refseq] * num_cds
    gen['Gene Count'] = [num_cds] * num_cds
    gen['Genome Length (bp)'] = [genome_length] * num_cds

    gen = gen.set_index('Gene ID')
    gen['Gene ID'] = gen.index.tolist()

    # From genome data, keep data from genes with +/- N genes to gene on same contig/scaffold
    for hit in hmmhits[genome]:

        # Create dataframe of only genes on same scaffold as hit
        scaffold = gen.at[hit,'Scaffold ID']

        # If list, choose first instance
        if type(scaffold) == type(np.array([])):
            scaffold = scaffold[0]

        scaf = gen[gen['Scaffold ID'] == scaffold].copy()
        scaf = scaf.drop_duplicates(subset='Gene ID')

        # Filter by relative proximity
        scaf = filter_by_relative_proximity(scaf)

        # Append to dataframe
        if n == 0:
            df = scaf.copy()
        else:
            df = pd.concat([df,scaf])


# Both a GenBank and RefSeq accession are found in the Genome IDs, remove the Genbank accession
accessions = set(df['Genome ID'].tolist())
genbank = df[~df['Refseq Accession'].isnull()].drop_duplicates('Genome ID')['Genbank Accession'].tolist()
refseq = df[~df['Refseq Accession'].isnull()].drop_duplicates('Genome ID')['Refseq Accession'].tolist()
paired_accessions = list(zip(genbank, refseq))

redundant_accessions = []
for pair in paired_accessions:
    gen, ref = pair
    if ref in accessions and gen in accessions:
        redundant_accessions.append(gen)
        
df = df.drop(df[df['Genome ID'].isin(redundant_accessions)].index)
        
# Extract protein sequences from gene neighborhoods to new file
id_file = 'temp.txt'
hit_file = "./data/subfamilies/neighborhoods.faa"
if "neighborhoods.faa" in os.listdir("./data/subfamilies/"):
    os.remove(hit_file)

for genome in list(sorted(hmmhits.keys())):


    genes = df[df['Genome ID'] == genome]['Gene ID'].tolist()
    genes = [str(x) for x in genes]
    #print(genes)
    #if 'GBD43204.1' in genes:
    #    print('yes')

    for FILE in os.listdir(path_to_hits[genome]):  #Go and manually remove the .gz files!
        if "protein.faa" in FILE and genome in FILE:
            proteins = FILE

    with open(id_file,'w') as fh: #makes list of genes
        fh.write('\n'.join(genes))
        #print('\n'.join(genes))

    command = ' '.join(["perl -ne 'if(/^>(\S+)/){$c=$i{$1}}$c?print:chomp;$i{$_}=1 if @ARGV'", #gets sequences from list of genes
                        id_file, path_to_hits[genome] + proteins, ">>", hit_file])


    sp.call(command,shell=True, env=os.environ)



### Save hits onto a searchable spreadsheet

In [None]:
df.to_csv('/path/to/your/directory/hit_neighborhoods.csv')

### Group proteins from gene neighborhoods by sequence similarity ("subfamilies")

Subfamilies numbered by abundance

In [None]:
#Set variables
seqs ="./data/subfamilies/neighborhoods.faa"
mmseq_output = './data/subfamilies/mmseq' 

### Create a readable tab separated file with protein subfamilies

In [None]:
os.chdir('./data/subfamilies/neighborhoods_clean.faa_proteinClustering')

with open('orf2subfamily.tsv') as f:
    with open("orf2subfamily_clean.tsv",'w') as f1:
        next(f) # skip header line
        for line in f:
            f1.write(line)
            
geneid_to_subfamily = {}
with open('orf2subfamily_clean.tsv') as fh:
    reader = csv.reader(fh, delimiter='\t')
    for line in reader:
        subfamily = line[1]
        subfamily = subfamily.split(sep='m')
        subfamily = subfamily[1]
        geneid = line[0]
        geneid_to_subfamily[geneid] = subfamily
        
os.chdir('path/to/your/directory')
            
df['mmseq'] = df['Gene ID'].map(geneid_to_subfamily)
cdhit_sorted = df.groupby('mmseq').count().sort_values('Scaffold ID', ascending=False).index.tolist()
df['Subfamily'] = df['mmseq'].map(dict(zip(cdhit_sorted, range(len(cdhit_sorted))))).astype(str)
df

# Analysis

### What are the functions of these subfamilies?

In [None]:
subfam_prod = df.groupby('Subfamily')['Gene Product Name'].agg(lambda x: pd.Series.mode(x)[0])
subfamily_to_product = dict(zip(subfam_prod.index.tolist(),subfam_prod.tolist()))

df['Gene Count'] = [1] * len(df.index)
df['Tree Node Group'] = df['Gene ID'].map(tree_node_groups)

df['Tree Node Group'] = df['Scaffold ID'].map(dict(zip(df[~df['Tree Node Group'].isnull()]['Scaffold ID'].tolist(),
                                                       df[~df['Tree Node Group'].isnull()]['Tree Node Group'].tolist())))

subfams = pd.DataFrame(df.groupby(['Subfamily']).count()['Gene Count'])

subfams['Mode Gene Product Name'] = subfams.index.map(subfamily_to_product)
subfams = subfams.sort_values(by='Gene Count', ascending=False)

subfams = subfams[subfams.index != 'nan']
os.chdir('./data/')
subfams.to_excel('subfamily_analysis.xlsx')

### Which subfamilies are conserved in each clade?

In [None]:
import numpy as np
plt.rcParams['svg.fonttype'] = 'none' # Editable SVG text

def add_polygon(coordinates):
    
    # Plots a gene as an arrow with offsets from the end of translation
    # Reference: https://nickcharlton.net/posts/drawing-animating-shapes-matplotlib.html

    x1,x2,y,h,color,strand = coordinates
    
    edge_color='black'
    edge_width=1
    alpha=0.75
    
    # Polygon for ORF
    points = [[x1, y-h/2],   # left bottom
              [x1, y],       # left center
              [x1, y + h/2], # left top
              [x2, y + h/2], # right top
              [x2, y],       # right center
              [x2, y - h/2]] # right bottom
    
    # Short genes require different offset length
    x_offset = h/3
    if x_offset > abs(x1-x2):
        x_offset = abs(x1-x2)
    
    # "Point" the polygon to indicate direction
    if strand == '+':
        # Arrow on right
        points[3][0] = x2 - x_offset 
        points[5][0] = x2 - x_offset 
    else:
        # Arrow on left
        points[0][0] = x1 + x_offset 
        points[2][0] = x1 + x_offset 
    
    # Plot parameters
    polygon = plt.Polygon(points, fc=color, edgecolor=edge_color, linewidth=edge_width, alpha=alpha)
    
    return plt.gca().add_patch(polygon)

def set_relative_coordinates(df, central_gene):
    
    # Calculate proximity to reference gene for other genes on scaffold
    df['Start Coord'] = df['Start Coord'].astype(int)
    df['End Coord'] = df['End Coord'].astype(int)
    startcoord_ref = int(df.at[central_gene,'Start Coord'])
    endcoord_ref = int(df.at[central_gene,'End Coord'])
    
    # Define baseline values for hit
    df.at[central_gene,'Rel Start Coord'] = 0
    df.at[central_gene,'Rel Strand'] = '+'

    # Depending on strand (top=+, bottom=-)
    sign = df.at[central_gene,'Strand']

    if sign == "+": # Subtract all coords from initial start coord
        vector = +1.0
        strand_key = {"+" : "+", "-" : "-"}
        df['Rel Start Coord'] = df['Start Coord'] - startcoord_ref
        df['Rel End Coord'] = (df['End Coord'] - startcoord_ref)
        df['Rel Strand'] = df['Strand'].map(strand_key)

    elif sign == "-": # -1 * Subtract all coords from initial end coord (true start)
        vector = -1.0
        strand_key = {"-" : "+", "+" : "-"} # Flip strands
        df['Rel Start Coord'] = vector*(df['End Coord'] - endcoord_ref)
        df['Rel End Coord'] = vector*(df['Start Coord'] - endcoord_ref)
        df['Rel Strand'] = df['Strand'].map(strand_key)

    # Assign Rel # below or above gene
    df = df.sort_values(by='Rel Start Coord')
    df.loc[(df['Rel Start Coord'] > 0), 'Rel #'] = range(1, 1 + len(df['Rel Start Coord'][(df['Rel Start Coord'] > 0)].tolist()), 1)
    df.loc[(df['Rel Start Coord'] < 0), 'Rel #'] = list(reversed(range(-1, -1 - len(df['Rel Start Coord'][(df['Rel Start Coord'] < 0)].tolist()), -1)))
    df.at[central_gene,'Rel #'] = 0
    
    # Filter by proximity
    proximity = 10
    df = df[(df['Rel Start Coord'] <= proximity) & (df['Rel End Coord'] >= -1*proximity)]
    
    return df

def plot_gene_clusters(df, ordered_scaffolds, label_subfams=False, saveas='svg'):

    # Spacing of plot
    y = 0 # initial y-coordinate
    h = 700 # height of genes
    spacing_vertical = 3.25 # aspect ratio between scaffolds

    # Plot genes in Cluster N (0 is likely key genes) arranged by coordinates
    plt.figure(figsize=(20,len(ordered_scaffolds)))
    
    for scaffold in ordered_scaffolds:

        y = y - h * spacing_vertical # assign new y-coordinate
                
        # Set relative coordinates for subset
        scaf = df[df['Scaffold ID'] == scaffold].copy()
        scaf = scaf.drop_duplicates('Gene ID')

        # Plot genome text
        genome = str(scaf['Genome Name'][0][0:40])
        plt.text(15000, y, genome, size=20, verticalalignment='center')
        
        # Plot ea. gene in scaffold
        genes_in_scaffold = scaf.index.tolist()
        for GENE in genes_in_scaffold:

            # Gene
            x1 = scaf.at[GENE,'Rel Start Coord'] / 2
            x2 = scaf.at[GENE,'Rel End Coord'] / 2
            strand = scaf.at[GENE,'Rel Strand']
            color = scaf.at[GENE, 'Color']
            coordinates = [x1, x2, y, h, color, strand]

            if label_subfams == True:
                # Label
                label_x = x1 + 0.1 * abs(x1 - x2)
                label_y = y + h * spacing_vertical / 2
                if scaf.at[GENE,'Subfamily'] != 'nan':
                    label_text = scaf.at[GENE,'Subfamily']
                #if scaf.at[GENE,'Subfamily'] < 4:
                #    label_text = scaf.at[GENE,'Subfamily']
                else:
                    label_text = ' '
                plt.text(label_x, label_y, label_text, size=12, rotation=45, verticalalignment='center')
            add_polygon(coordinates)

    plt.axis('scaled')
    plt.axis('off')
    mat = np.linspace(0, 1, 256)
    mat = np.vstack((mat, mat))
    plt.imshow(mat, aspect='auto', cmap='Purples')
    plt.colorbar(orientation='horizontal', pad = -0.03).ax.tick_params(labelsize=22)
    plt.text(s='Gene Frequency in Selected Genomes', x=0, y=-1)
    
        
    return plt.show()

# Color subfamilies by presence in genomes
from matplotlib import cm
from matplotlib.colors import ListedColormap, LinearSegmentedColormap

total_genomes = len(set(df['Genome ID'].tolist()))
cmap = cm.get_cmap('Purples', total_genomes)
norm = pd.DataFrame(subfams['Gene Count'] / subfams['Gene Count'].max())

norm.columns = ['Count']
norm['Color'] = norm['Count'].apply(lambda x : cmap(x))
subfam_to_color = dict(zip(norm.index.tolist(), norm['Color'].tolist()))
subfam_to_color['nan'] = 'white'

df['Tree Order'] = df['Gene ID'].map(dict_order)



#### Assign color for nan
df['Color'] = df['Subfamily'].map(subfam_to_color)
os.chdir('./data/subfamilies')
plot_gene_clusters(df=df, 
                   ordered_scaffolds=df.sort_values(by='Tree Order').loc[:,'Scaffold ID'].drop_duplicates().tolist(),
                   label_subfams=True, saveas='svg')

#plt.savefig("subfamily_clusters", format='svg')