In [None]:
import matplotlib.pyplot as plt
import matplotlib
import numpy as np
import pandas as pd
import seaborn as sns
import copy

In [None]:
%config Completer.use_jedi = False
%matplotlib inline

In [None]:
pd.set_option('display.max_colwidth', 500)
pd.set_option('display.max_columns', 7)

In [None]:
df_meta = pd.read_csv('metadata_plants.csv')

In [None]:
df_meta.head()
# Alyrata  Csativus  Macuminata  Ppatens  Taestivum  Vcarteri
# remove test genomes
df_plants = df_meta[~df_meta['species'].isin(['Alyrata','Csativus','Macuminata','Ppatens',
                                              'Taestivum','Vcarteri'])].reset_index(drop=True)



In [None]:
np.random.seed(11345)
x = np.repeat(range(6), 9)
np.random.shuffle(x)
x

In [None]:
df_plants.loc[:,"random_split"] = x

In [None]:
df_plants.loc[:,"length_split"] = -1

In [None]:
df_plants.head()

In [None]:
for i, split in zip(np.argsort(df_plants.loc[:, "total_len"]), np.repeat(range(6), 9)):
    df_plants.loc[i,"length_split"] = split

In [None]:
df_plants.loc[:,"gc_split"] = -1

In [None]:
x2 = df_plants.loc[:,"C"] / (df_plants.loc[:,"A"] + df_plants.loc[:,"C"])

In [None]:
for i, split in zip(np.argsort(x2), np.repeat(range(6), 9)):
    df_plants.loc[i,"gc_split"] = split

In [None]:
df_plants.loc[df_plants.loc[:, "random_split"] == 1, "species"]

In [None]:
df_plants.loc[df_plants.loc[:, "length_split"] == 1, "species"]

In [None]:
# and now for the phylogenetic split
remaining = list(df_plants['species'])
green_algae = ['Dsalina', 'Creinhardtii', 'Czofingiensis', 'MpusillaCCMP1545', 'MpusillaRCC299',
               'Olucimarinus', 'CsubellipsoideaC169']
remaining = [x for x in remaining if x not in green_algae]
monocots = ['Hvulgare', 'Bdistachyon', 'Osativa', 'Sbicolor', 'Zmays', 'Sitalica', 'Othomaeum', 'Acomosus',
            'Aofficinalis', 'Zmarina', 'Spolyrhiza']
remaining = [x for x in remaining if x not in monocots]

asterids = ['Hannuus', 'Lsativa', 'Dcarota', 'Mguttatus', 'Oeuropaea', 'Stuberosum',
            'Slycopersicum']
remaining = [x for x in remaining if x not in asterids]
fabids = ['Mesculenta', 'Rcommunis', 'Lusitatissimum', 'Ptrichocarpa', 'Mdomestica', 'Ppersica',
          'Fvesca', 'Mtruncatula', 'Carietinum', 'Gmax']
remaining = [x for x in remaining if x not in fabids]
malvids = ['Athaliana', 'Crubella', 'Cgrandiflora', 'Esalsugineum', 'Cpapaya', 'Graimondii', 'Csinensis',
           'Cclementina', 'Egrandis', 'Tcacao', 'Boleraceacapitata']
remaining = [x for x in remaining if x not in malvids]

remaining
df_plants.loc[:, 'phylo_split'] = 5

In [None]:
df_plants.loc[[x in green_algae for x in df_plants.loc[:, "species"]], 'phylo_split'] = 0
df_plants.loc[[x in monocots for x in df_plants.loc[:, "species"]], 'phylo_split'] = 1
df_plants.loc[[x in asterids for x in df_plants.loc[:, "species"]], 'phylo_split'] = 2
df_plants.loc[[x in fabids for x in df_plants.loc[:, "species"]], 'phylo_split'] = 3
df_plants.loc[[x in malvids for x in df_plants.loc[:, "species"]], 'phylo_split'] = 4




In [None]:
def mk_bash_denbi(outdir, species):
    template = """
pfx=/mnt/share/ubuntu/data/plants/single_genomes/
outdir={}

mkdir -p $outdir
python /mnt/share/ubuntu/repos/github/weberlab-hhu/helixer_scratch/data_scripts/merge-files.py \\
        --input-files {} \\
        --output-file $outdir/training_data.h5
cd $outdir
ln -s ../../eight_genomes_nosplit_phase/validation_data.h5
cd ..
"""
    
    one_path = "$pfx/{}/test_data.h5"
    sp_paths = [one_path.format(sp) for sp in species]
    return template.format(outdir, ' '.join(sp_paths))

In [None]:
print(mk_bash_denbi('gc3', df_plants.loc[df_plants.loc[:, "gc_split"] == 3, "species"]))

In [None]:
print(mk_bash_denbi('length5', df_plants.loc[df_plants.loc[:, "length_split"] == 5, "species"]))

In [None]:
print(mk_bash_denbi('phylo5', df_plants.loc[df_plants.loc[:, "phylo_split"] == 5, "species"]))

In [None]:
df_plants.loc[:, ['species', 'random_split',
       'length_split', 'gc_split', 'phylo_split']].to_csv('rabbit_splits.csv', index=False)

In [None]:
def mergy_plotty(key, sort_by):
    # mergy
    gc_res = pd.read_csv(f'data_splits/f1_{key}.csv', header=None)
    gc_res.columns = ['model', 'species', 'genic_f1', 'sub_genic_f1']
    gc_preview = gc_res.pivot(columns='model', index='species', values='genic_f1')
    indis = gc_preview.index.to_list()
    indis = [x if (x != 'MspRCC299') else "MpusillaRCC299" for x in indis]
    gc_preview.index = indis
    x = df_plants.merge(gc_preview, left_on='species', right_index=True, how='left')
    indexes = np.argsort(x[sort_by])
    trainers = np.eye(6)[x[f'{key}_split']]
    
    # stat
    print(np.median(x.iloc[indexes, 57:63]))
    
    # plotty
    fig, (ax, bx) = plt.subplots(1,2, figsize=(4,14))
    fig.suptitle(f'{key}, sort {sort_by}', fontsize=20, y=0.93)
    y_label_list = x.loc[:,'species'][indexes]

    ax.imshow(trainers[indexes], aspect='auto', 
               interpolation='none')
    ax.set_yticks(range(54))
    ax.set_yticklabels(y_label_list)
    ax.set_xticks(range(6))
    ax.set_xlabel('split')

    img = bx.imshow(x.iloc[indexes, 57:63], aspect='auto', 
               interpolation='none')
    bx.set_xticks(range(6))
    bx.set_xlabel('split')
    bx.set_yticks(range(54))
    bx.set_yticklabels('')

    fig.colorbar(img, fraction=0.04)


    ax.set_title("trainers")
    bx.set_title("genic_f1")



In [None]:
def import2pivot(key):
    gc_res = pd.read_csv(f'data_splits/f1_{key}.csv', header=None)
    gc_res.columns = ['model', 'species', 'genic_f1', 'sub_genic_f1']
    gc_preview = gc_res.pivot(columns='model', index='species', values='genic_f1')
    #gc_preview = gc_res.pivot(columns='model', index='species', values='sub_genic_f1')
    indis = gc_preview.index.to_list()
    indis = [x if (x != 'MspRCC299') else "MpusillaRCC299" for x in indis]
    gc_preview.index = indis
    return gc_preview



def plotty_flex(key, indexes, f1_array, trainers, sort_by=None):
    f1_array = np.array(f1_array)
    # stat
    print(np.median(f1_array))
    print(np.median(f1_array, axis=0))
    
    # plotty
    fig, (ax, bx) = plt.subplots(1,2, figsize=(4,14))
    fig.suptitle(f'{key}, sort {sort_by}', fontsize=20, y=0.93)
    y_label_list = x.loc[:,'species'][indexes]

    ax.imshow(trainers[indexes], aspect='auto', 
               interpolation='none')
    ax.set_yticks(range(54))
    ax.set_yticklabels(y_label_list)
    ax.set_xticks(range(6))
    ax.set_xlabel('split')

    img = bx.imshow(f1_array[indexes], aspect='auto', 
               interpolation='none')
    bx.set_xticks(range(6))
    bx.set_xlabel('split')
    bx.set_yticks(range(54))
    bx.set_yticklabels('')

    fig.colorbar(img, fraction=0.04)


    ax.set_title("trainers")
    bx.set_title("genic_f1")
    return fig, (ax, bx)
    

    

In [None]:
mergy_plotty(key='gc', sort_by='gc_content')

In [None]:
mergy_plotty(key='length', sort_by='total_len')

In [None]:
#mergy_plotty(key='lengthlong', sort_by='total_len')
# adding outgroups / validation
sort_by = 'total_len' # 'phylo_split' 
hundredk_preview = import2pivot('lengthlong')
x = df_plants.merge(hundredk_preview, left_on='species', right_index=True, how='left')

length_preview = import2pivot('length')
xlength = df_plants.merge(length_preview, left_on='species', right_index=True, how='left')

In [None]:
f1_array = x.iloc[:, 57:]
print(f1_array.shape)
indexes = np.argsort(x[sort_by])
trainers = np.eye(6)[x['length_split']]
fig, (ax, bx) = plotty_flex('lengthlong', indexes, f1_array, trainers, sort_by)
print(np.median(np.array(f1_array)[indexes]))
print(np.median(np.array(f1_array)[indexes], axis=0))

In [None]:
lendiffs = x.iloc[:,57:] - xlength.iloc[:,57:]

In [None]:
fig, (zx, ax, bx, cx) = plt.subplots(1,4, figsize=(10,14))
y_label_list = x.loc[:,'species'][indexes]
zx.imshow(trainers[indexes], aspect='auto', interpolation='none', norm=None)
img = ax.imshow(np.array(x.iloc[:,57:])[indexes], aspect='auto', interpolation='none', norm=None, vmax=1, vmin=0)
bx.imshow(np.array(xlength.iloc[:,57:])[indexes], aspect='auto', interpolation='none', norm=None, vmax=1, vmin=0)
fig.colorbar(img, ax=bx, fraction=0.04)

cmap=matplotlib.cm.RdBu_r
img2 = cx.imshow(np.array(lendiffs)[indexes], aspect='auto', interpolation='none', cmap=cmap, vmax=0.12, vmin=-.12)
fig.colorbar(img2, fraction=0.04)
for axis in [zx, ax, bx, cx]:
    axis.set_yticklabels('')
    axis.set_yticks(range(54))
    axis.set_xticks(range(6))
    axis.set_xlabel('split')

zx.set_title("trainers")
ax.set_title("~100k genic F1")
bx.set_title("20k genic F1")
cx.set_title("difference gF1")


y_label_list = x.loc[:,'species'][indexes]
zx.set_yticks(range(54))
zx.set_yticklabels(y_label_list)
fig.tight_layout()

In [None]:
mergy_plotty(key='phylo', sort_by='phylo_split')

In [None]:
mergy_plotty(key='random', sort_by='random_split')

In [None]:
mergy_plotty(key='random', sort_by='gc_content')

In [None]:
mergy_plotty(key='random', sort_by='total_len')

In [None]:
mergy_plotty(key='random', sort_by='phylo_split')

In [None]:
# adding outgroups / validation
sort_by = 'gc_content' # 'phylo_split' 
og_preview = import2pivot('og')
random_preview = import2pivot('random')
x = df_plants.merge(og_preview, left_on='species', right_index=True, how='left')
x = x.merge(random_preview, left_on='species', right_index=True, how='left')
f1_array = x.iloc[:, [64,57,58,59,60,61]]
indexes = np.argsort(x[sort_by])
trainers = np.zeros(shape=f1_array.shape)  #np.eye(6)[x['random_split']]
trainers[x['random_split'] == 2, :] = 1
trainers[42,[1,5]] = 1
fig, (ax, bx) = plotty_flex('og', indexes, f1_array, trainers, sort_by)
xticklabs = ['random 2', '+ Pubmilicalis', '+ D melanogaster', '+ M musculus', '+_S cerevisiae', '+ all 4']
ax.set_xticklabels(xticklabs, rotation=90)
ax.set_xlabel('')
bx.set_xticklabels(xticklabs, rotation=90)
bx.set_xlabel('')

In [None]:
sort_by = 'five_prime_UTR'# 'phylo_split' 


In [None]:
sort_by = 'phylo_split'


In [None]:
# single genomes!
# adding outgroups / validation
sg_preview = import2pivot('single_genomes')
x = df_plants.merge(sg_preview, left_on='species', right_index=True, how='left')
f1_array = np.array(x.iloc[:, 57:])
indexes = np.argsort(x[sort_by])


fig, bx = plt.subplots(1,1, figsize=(14,14))
fig.suptitle(f'singles, sort {sort_by}', fontsize=20, y=0.93)
pre_y_label_list = x.loc[:,'species'][indexes]
no_utr_sp = x["species"][x["five_prime_UTR"]==0].to_list()
low_utr_sp = x["species"][5000 > x["five_prime_UTR"]].to_list()
y_label_list = []
for item in pre_y_label_list:
    pfx = ''
    if item in low_utr_sp:
        pfx = '*'
    if item in no_utr_sp:
        pfx = '**'
    y_label_list.append(pfx + item)
#

img = bx.imshow(f1_array[indexes, :][:, indexes], aspect='auto', 
           interpolation='none')
bx.set_xticks(range(54))
bx.set_xticklabels(y_label_list, rotation=90)
bx.set_xlabel('trained on')
bx.set_yticks(range(54))
bx.set_yticklabels(y_label_list)
bx.set_ylabel('validated on')

fig.colorbar(img, fraction=0.04)


bx.set_title("genic_f1")


# Info for eight
# validation sampled from (to 7.176GB)
'Acomosus, Ahypochondriacus, Aofficinalis, Atrichopoda, Boleraceacapitata, Carietinum, Cclementina, Cgrandiflora, Cpapaya, Cquinoa, Crubella, Csinensis, CsubellipsoideaC169, Czofingiensis, Dcarota, Egrandis, Esalsugineum, Fvesca, Graimondii, Hannuus, Hvulgare, Kfedtschenkoi, Lsativa, Lusitatissimum, Mdomestica, Mesculenta, MpusillaCCMP1545, MspRCC299, Mtruncatula, Oeuropaea, Olucimarinus, Osativa, Othomaeum, Ppersica, Pumbilicalis, Rcommunis, Sbicolor, Slycopersicum, Smoellendorffii, Spolyrhiza, Stuberosum, Tcacao, Vvinifera, Zmarina, Zmays'
# training all of
'Athaliana, Bdistachyon, Creinhardtii, Gmax, Mguttatus, Mpolymorpha, Ptrichocarpa, Sitalica'

In [None]:
# adding outgroups / validation, but with intentionally selected instead of random trainers
sort_by = 'gc_content' # 'phylo_split' 
og_preview = import2pivot('eightplusog')
x = df_plants.merge(og_preview, left_on='species', right_index=True, how='left')
f1_array = x.iloc[:, 57:]
indexes = np.argsort(x[sort_by])
trainers = np.zeros(shape=f1_array.shape)  #np.eye(6)[x['random_split']]
trainsp = ['Athaliana', 'Bdistachyon', 'Creinhardtii', 'Gmax', 'Mguttatus', 'Mpolymorpha', 'Ptrichocarpa', 'Sitalica']
trainers[[w in trainsp for w in x.iloc[:,0] ], :] = 1
trainers[42,[1,5]] = 1
fig, (ax, bx) = plotty_flex('eightplusog', indexes, f1_array, trainers, sort_by)
xticklabs = ['eight', '+ Pubmilicalis', '+ D melanogaster', '+ M musculus', '+_S cerevisiae', '+ all 4']
ax.set_xticklabels(xticklabs, rotation=90)
ax.set_xlabel('')
bx.set_xticklabels(xticklabs, rotation=90)
bx.set_xlabel('')