In [1]:
import ete3 as et
import numpy as np
import pandas as pd
import polars as pl

import time

from multiprocessing import Pool

In [2]:
def cotr_corgias(df, cores, gpu=False, num_blocks=0):
    ogs = ((i, row) for i, row in df.T.iterrows())
    with Pool(processes=cores) as process:
        count = process.starmap_async(count_transition, ogs).get()
                
    num_genomes = df.shape[0] - 1
    result =run_transition(count, gpu, num_blocks, num_genomes)

    return result

def count_transition(og, row):
    row = np.array(row)
    shifted = row.copy()[1:]
    transition = shifted - row[:-1]
    for i in range(len(transition) -1):
        if transition[i] + transition[i+1] == 0:
            transition[i+1] = 0

    num_transition = np.count_nonzero(transition)

    return og, transition, num_transition

def run_transition(count, gpu, num_blocks, N):
    og_names = [ sublist[0] for sublist in count ]
    t_matrix = np.vstack([ sublist[1] for sublist in count ])
    num_transition = np.vstack([ sublist[2] for sublist in count ])
    k = calculate_k(t_matrix, gpu, num_blocks)
    result = transition_count2df(k, num_transition, og_names, N)

    return result

def calculate_k(t_matrix, gpu=False, num_blocks=0):
    if gpu:
        if num_blocks == 0:

            df = cp.asarray(t_matrix, dtype=cp.int16)
            df_T = cp.asarray(t_matrix.transpose(), dtype=cp.int16)
            k = cp.asnumpy(cp.dot(df, df_T))
        else:
            block_size = t_matrix.shape[0] // num_blocks
            k = block_dot(t_matrix, t_matrix.transpose(), block_size)
    else:
        k = np.dot(t_matrix, t_matrix.transpose())

    return k

def transition_count2df(k, num_transition, og_names, N):
    indices = flatten_indices(k)
    og_names = pl.DataFrame(og_names).with_row_index('index').select(
               pl.col('index').cast(pl.Int64), pl.col('column_0').alias('OG'))
    num_transition = pl.DataFrame(num_transition).with_row_index('index').select(
                 pl.col('index').cast(pl.Int64), pl.col('column_0').alias('num_transition'))

    indices = indices.join(
             og_names, left_on='column_0', right_on='index'
             ).rename({'OG':'OG1'}).join(
             og_names, left_on='column_1', right_on='index'
             ).rename({'OG':'OG2'}).join(
             num_transition, left_on='column_0', right_on='index'
             ).rename({'num_transition':'num_change1'}).join(
             num_transition, left_on='column_1', right_on='index'
             ).rename({'num_transition':'num_change2'}).select(
             ['OG1', 'OG2', 'num_change1', 'num_change2']
             )
    k = uppermatrix2vector(k)
    k = pl.DataFrame({'k':k})
    result = pl.concat([indices, k], how='horizontal')
    df = result.with_columns(N).rename({'literal':'N'})

    return df

def uppermatrix2vector(matrix):
    rows, _ = matrix.shape
    upper_triangle_indices = np.triu_indices(rows, k=1)
    upper_triangle = matrix[upper_triangle_indices]
    return upper_triangle


def flatten_indices(df):
    rows, _ = df.shape
    upper_indices = np.triu_indices(rows, k = 1)
    indices = pl.DataFrame(np.vstack(upper_indices).T)

    return indices

In [3]:
# modified from https://github.com/lab83bio/Cotransitions/blob/master/cotr_transitions.py

def cotr_original(csv, count_consecutive, outfile):

    ngenes, norgs = len(csv.index), len(csv.columns)

    tr = csv.diff(axis=1) 
    tr = tr.applymap(lambda x: 1 if x>1 else -1 if x<-1 else x) 

    tr_l = tr.values.tolist()

    if not count_consecutive: #count only once consecutive state transitions (-1,1)
    	for r in tr_l:
    		for i in range(len(r)-1):
    			if (r[i]+r[i+1]==0):
    				r[i+1]=0

    # sets for fast comparison
    t01 = [set(np.nonzero(row > 0)[0]) for row in np.array(tr_l)] # 0->1 transitons
    t10 = [set(np.nonzero(row < 0)[0]) for row in np.array(tr_l)] # 1->0 transitons
    tt = [len(a | b) for a,b in zip(t01,t10)] #total transitions

    # sys.stderr.write("done transitions\n")

    #    print('Orthogroup1','Orthogroup2','orgs','t1','t2','c','d','k', sep="\t")
    data = []
    for i in range(ngenes-1):
        for j in range(i+1,ngenes):
            concordant = len(t01[i] & t01[j]) + len(t10[i] & t10[j])
            discordant = len(t10[i] & t01[j]) + len(t01[i] & t10[j])
            k = concordant - discordant
#            if abs(k) >= args.min_transitions:
#                print(tr.index[i],tr.index[j], norgs, tt[i], tt[j], concordant, discordant, k, sep='\t')
            data.append([tr.index[i],tr.index[j], norgs, tt[i], tt[j], concordant, discordant, k])
    df = pl.DataFrame(
        data, schema=["gene1", "gene2", "norgs", "t1", "t2", "concordant", "discordant", "k"],
        orient='row'
    )
    df.write_csv(outfile, separator="\t")

In [4]:
# run original cotr
tree = et.Tree('simulation/tree2K.tre')
order = [ leaf.name for leaf in tree.get_leaves() ]
csv = pd.read_csv('simulation/df4tree2K.csv')
csv = csv.rename(columns={'Unnamed: 0':'org'}).set_index('org').T
csv = csv.loc[:, order]

start = time.time()
cotr_original(csv, False, 'original_cotr.csv')
end = time.time()
print(end-start)

247.6375172138214


In [5]:
# run CORGIAS cotr with 1 CPU 
df = pl.read_csv('simulation/df4tree2K.csv').to_pandas()
index = df.columns[0]
df.set_index(index, inplace=True)
df = df.loc[order]

start = time.time()
cotr = cotr_corgias(df, 1)
cotr.write_csv('corgias_cotr.csv')
end = time.time()
print(end-start)

32.336382150650024


In [6]:
corgias_result = pl.read_csv('corgias_cotr.csv')
original_result = pl.read_csv('original_cotr.csv', separator='\t')

In [7]:
corgias_result.head()

OG1,OG2,num_change1,num_change2,k,N
str,str,i64,i64,i64,i64
"""c1""","""c2""",554,542,188,1999
"""c1""","""c3""",554,406,8,1999
"""c1""","""c4""",554,518,-9,1999
"""c1""","""c5""",554,569,-5,1999
"""c1""","""c6""",554,559,-24,1999


In [8]:
original_result.head()

gene1,gene2,norgs,t1,t2,concordant,discordant,k
str,str,i64,i64,i64,i64,i64,i64
"""c1""","""c2""",2000,554,542,210,22,188
"""c1""","""c3""",2000,554,406,67,59,8
"""c1""","""c4""",2000,554,518,79,88,-9
"""c1""","""c5""",2000,554,569,81,86,-5
"""c1""","""c6""",2000,554,559,68,92,-24


In [9]:
# check equivalence
check_t1 = (corgias_result['num_change1'] == original_result['t1']).all()
check_t2 = (corgias_result['num_change2'] == original_result['t2']).all()
check_k = (corgias_result['k'] == original_result['k']).all()
print(check_t1, check_t2, check_k)

True True True
