In [7]:
# libraries
import gurobipy as gp
import time
import json
from collections import defaultdict
from ete3 import Tree
import numpy as np
from pysam import FastaFile,FastxFile
import re
import torch
from torch.utils.data import Dataset
import pickle


In [8]:
# pytorch dataset to save the extant data
class AjMat_Dataset(Dataset):
    def __init__(self,adj_mat,seq_name,seq_binary):
        self.adj_mat = adj_mat
        self.seq_name = seq_name
        self.seq_binary = seq_binary

    def __len__(self):
        return len(self.adj_mat)
    def __getitem__(self, idx):
        return self.adj_mat[idx],self.seq_name[idx],self.seq_binary[idx]


# pytorch dataset to save the extant data
class AjMat_lean_Dataset(Dataset):
    def __init__(self,seq_name,seq_binary):
        self.seq_name = seq_name
        self.seq_binary = seq_binary

    def __len__(self):
        return len(self.seq_name)
    def __getitem__(self, idx):
        return self.seq_name[idx],self.seq_binary[idx]

class IndelsInfo:
    def __init__(self,AjMat_Dataset,fasta_file,nwk_file_path,folder_location,tree_name,AjMat_lean_Dataset):

        self.AjMat_Dataset = AjMat_Dataset
        self.input_file = fasta_file
        self.nwk_file_path = nwk_file_path
        self.ancestor_list = []
        self.tree_neighbor_dict = defaultdict(list)
        self.ancestor_info = []
        self.sequence_length = 0
        self.Extant_AdjMat_dataset = AjMat_Dataset
        self.folder_location = folder_location
        self.tree_name = tree_name
        self.AjMat_lean_Dataset = AjMat_lean_Dataset
        self.extant_dict = {}


    # create node types for each position for each sequences
    def create_node_type(self, seq_fwd_pog,seq_rvs_pog,seq_name):
        node_type_dict = defaultdict(list)
        node_type_dict[(seq_name,'start')] = [0]
        node_type_dict[(seq_name,'end')] = [self.sequence_length - 1]

        for n in range(1,self.sequence_length - 1):
            if n in seq_fwd_pog.keys() and n in seq_rvs_pog.keys(): #if node has forward and backward
                node_type_dict[(seq_name,'fwd_back_pos')] += [n]
            elif n in seq_fwd_pog.keys() and n not in seq_rvs_pog.keys():
                node_type_dict[(seq_name,'fwd_pos')] += [n]
            elif n not in seq_fwd_pog.keys() and n in seq_rvs_pog.keys():
                node_type_dict[(seq_name,'back_pos')] += [n]
            else:
                node_type_dict[(seq_name,'dead_pos')] += [n]
        return node_type_dict

    # function to find next position that is filled
    def next_pos(self,str1,curr_pos,seq_len):
        start_pos = curr_pos + 1

        while(start_pos < len(str1)):
            if str1[start_pos] != '-':
                return start_pos
            else:
                start_pos = start_pos + 1
        return seq_len

    # function to convert a sequence to adj matrix
    def convert_to_adj_mat(self,seq_str):
        seq_len = len(seq_str)
        aj_mat_array = np.zeros((seq_len,seq_len))

        next_filled = []
        ind = 0

        while(ind < seq_len - 1):
            if seq_str[ind] != '-':
                curr_ind = ind
                ind = self.next_pos(seq_str,curr_ind,seq_len - 1) # find the next filled position
                next_filled.append((curr_ind,ind))
                aj_mat_array[curr_ind,ind] = 1
            else:
                ind = ind + 1
        return aj_mat_array

    # convert adj matrix into pog dictionary
    def create_extant_pog(self,adj_mat_t):
        x_summ = np.column_stack(np.where(adj_mat_t))
        seq_fwd_pog_dict = dict(zip(x_summ[:,0], x_summ[:,1]))
        seq_rvs_pog_dict = dict(zip(x_summ[:,1], x_summ[:,0]))
        return seq_fwd_pog_dict,seq_rvs_pog_dict

    # 1 - convert fasta file to adj matrix, pog, node type, seq binary into pytorch dataset
    def get_extant_data(self):
        adj_mat_list      = []
        seq_name_list     = []
        seq_binary_list   = []


        with FastxFile(self.input_file) as fh:
            for entry in fh:
                # add start and end string to the sequence
                seq_name = entry.name
                new_sequence = 'x' + entry.sequence + 'x'
                self.sequence_length = len(new_sequence)

                # convert to adj matrix
                seq_adj_mat  = self.convert_to_adj_mat(new_sequence)

                # binarise sequences
                seq_binary   = ''.join(sum(seq_adj_mat).astype(int).astype(str))
                # make start pos as 1 for start node
                seq_binary = '1' + seq_binary[1:]

                # convert to pog structure
                seq_fwd_pog,seq_rvs_pog = self.create_extant_pog(seq_adj_mat)

                # create node type dict
                node_type = self.create_node_type(seq_fwd_pog,seq_rvs_pog,seq_name)

                # add to the list
                #adj_mat_t = torch.from_numpy(seq_adj_mat)
                adj_mat_list.append(seq_adj_mat)
                seq_name_list.append(seq_name)
                seq_binary_list.append(seq_binary)

        # save it into pytorch dataset
        self.Extant_AdjMat_dataset = self.AjMat_Dataset(adj_mat_list, seq_name_list, seq_binary_list)
        self.AjMat_lean_Dataset = self.AjMat_lean_Dataset(seq_name_list, seq_binary_list)
        self.extant_dict = dict(zip(seq_name_list, seq_binary_list))

        return self.Extant_AdjMat_dataset,self.AjMat_lean_Dataset,self.extant_dict

    # 2 - create neighbour dict using the tree file
    def get_tree_data(self):

        ''' create neighbor dict '''
        tree_file = open(self.nwk_file_path,"r")
        my_tree = tree_file.read() + ";"
        tree = Tree(my_tree, format=1)

        # add node names to the internal branches
        edge = 0
        for n in tree.traverse():
            if not n.is_leaf():
                n.name = "NODE_%d" %edge
                edge += 1
                self.ancestor_list.append(n.name)

        # create neighbourhood object
        for n in tree.traverse():
            if n.is_leaf() == False:
                for c in n.children:
                    self.tree_neighbor_dict[n.name] += [c.name]

        return self.tree_neighbor_dict

    # 3 - ancestor data - all ancestors, aggregated pog, aggregated adj mat
    def get_ancestor_data(self):

        # all ancestors name
        ancestor_branchpoints = self.ancestor_list

        ancestor_fwd_pog = defaultdict(list)
        ancestor_rvs_pog = defaultdict(list)

        # ancestor adj mat
        ancestor_adj_mat = np.where(sum(self.Extant_AdjMat_dataset[:][0]))

        # ancestor foward and backward pog
        row_col_sum = np.column_stack(np.where(sum(self.Extant_AdjMat_dataset[:][0])))
        for r in row_col_sum:
            pos = r[0]
            next_pos = r[1]
            ancestor_fwd_pog[pos] += [next_pos]
            ancestor_rvs_pog[next_pos] += [pos]

        # create node type dict
        ancestor_node_type = self.create_node_type(ancestor_fwd_pog,ancestor_rvs_pog,'ANCESTOR')
        self.ancestor_info = [ancestor_branchpoints,ancestor_fwd_pog,ancestor_rvs_pog,ancestor_node_type]
        return self.ancestor_info

    # save Dataset
    def save_data(self):
        # neighbor dict
        with open(self.folder_location + self.tree_name + '/neighbor_dict.pkl','wb') as f:
            pickle.dump(self.tree_neighbor_dict,f)
        # ancestor Info
        with open(self.folder_location + self.tree_name + '/ancestor_info.pkl','wb') as f:
            pickle.dump(self.ancestor_info,f)
        # extant info
        print(self.extant_dict)
        with open(self.folder_location + self.tree_name + '/extant_data.pkl','wb') as f:
            pickle.dump(self.extant_dict,f)


def main():
    folder_location         = '/Users/sanjanatule/Documents/uq/Projects/MIPIndel/data/'
    #folder_location         = '/media/WorkingSpace/Share/mipindel/data/'

    ## Sample tree 1
    tree_name               = 'st1'
    nwk_file_path           = folder_location + tree_name + '/input_tree.nwk'
    extant_sequence_file    = folder_location + tree_name + '/input_extants.fasta'

    ## CYP2U - 165
    #tree_name = 'CYP2U_165'
    #nwk_file_path           = folder_location + tree_name + '/CYP2U_165.nwk'
    #extant_sequence_file    = folder_location + tree_name + '/CYP2U_165.aln'

    # ## CYP2U - 359
    # tree_name = 'CYP2U_359'
    # nwk_file_path           = folder_location + tree_name + '/CYP2U_359.nwk'
    # extant_sequence_file    = folder_location + tree_name + '/CYP2U_359.aln'

    # ## DHAD - 1612
    # tree_name = 'DHAD_1612'
    # nwk_file_path           = folder_location + tree_name + '/DHAD_1612.nwk'
    # extant_sequence_file    = folder_location + tree_name + '/DHAD_1612.aln'

    ## MBL
    # tree_name = 'MBL'
    # nwk_file_path           = folder_location + tree_name + '/nuclease_filt_i10.aln.treefile.nwk'
    # extant_sequence_file    = folder_location + tree_name + '/nuclease_filt_i10.aln'

    # CYPU - Anthony
    # tree_name = 'anthony'
    # nwk_file_path           = folder_location + tree_name + '/CYP19_Putative_6_DASH.nwk'
    # extant_sequence_file    = folder_location + tree_name + '/CYP19_Putative_6_DASH.fasta'

    # prepare input Dataset
    print("Processing Input Files")
    MIPIndel      = IndelsInfo(AjMat_Dataset,extant_sequence_file,nwk_file_path,folder_location,tree_name,AjMat_lean_Dataset) # class
    print("1 - Processing Extant Data")
    extant_data,extant_info_lean,extant_dict = MIPIndel.get_extant_data() # pog, adj matrix, binary, node type, sequence name
    print("2 - Preparing Tree Data")
    neighbor_dict = MIPIndel.get_tree_data() # neighbor info
    print("3 - Preparing Ancestor Data")
    ancestor_data = MIPIndel.get_ancestor_data() # ancestor list, ancestor pog, node type
    print("4 - Saving Data")
    MIPIndel.save_data() # save data
    print("Done")

    # Info about the data
    total_sequences = len(ancestor_data[0]) + 1
    print("TOTAL EXTANT SEQUENCES",total_sequences)
    print("SEQUENCE LENGTH",len(extant_info_lean[0][1]))


In [9]:
main()

Processing Input Files
1 - Processing Extant Data
2 - Preparing Tree Data
3 - Preparing Ancestor Data
4 - Saving Data
{'XP_005457042.1': '110000000000100000000000000001111000011111111111011001000011100011100111110000011100111111111111111111111111111111111000000000000000000000000000001110011111111111111111111111111111111111111111111111111111111111111011111111111111111111111111111111111111111111111111111011101111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111100111111110111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111001110000000000001', 'XP_004549474.1': '110000000000100000000000000001111000011111111111011001000011100011100111110000011100111111111111111111111111111111111000000000000000000000000000001110011111111111111111111111111111111111111111111111111111111111111