In [None]:
import matplotlib.pyplot as plt
import matplotlib as mpl
from pathlib import Path
import re
import math
import numpy as np
from matplotlib.lines import Line2D
import matplotlib.patches as mpatches
import statsmodels.api as sm
import statsmodels.formula.api as smf
import pandas as pd

In [None]:
def get_attr(string):
    dico = {}
    spt_ = string.split(";")
    
    for x in spt_:
        if x:
            try:
                dico[x.split()[0].strip()] =  x.split()[1].replace('"', "").strip()
            except:
                print(spt_)
                raise
    return dico


def gtf_to_dict(gtf_file, main_ = "gene_id"):
    print(main_)
    dico = {}
    with open(gtf_file) as f_in:
        """atm reads only genes"""
        for line in f_in:
            line = line.strip()
            if not line:
                continue
            if line.startswith("#"):
                continue
            spt = line.strip().split("\t")
            try:
                chr_ = spt[0]
                start = int(spt[3]) #- 1 # gtf are 1 based
                end = int(spt[4]) #-1 
                strand =  spt[6]
                attr = get_attr(spt[-1])
                """                 if main_ not in attr:
                                    print(attr)
                continue """
                gene_id = attr[main_]
                gene_symbol = attr.get("gene_symbol", "None")
                type_ = spt[2]
            except:
                print(attr)
                print(line)
                print(spt)
                raise
            
            if gene_id not in dico:
                dico[gene_id] = {
                    "chr": chr_,
                    "symbol" :  gene_symbol,
                    "strand" : strand,
                    "transcript": {}
                }
                
            if type_ == "gene":
    
                dico[gene_id]["start"] = start
                dico[gene_id]["end"] = end
            
            else:
                try:
                    transcript_id  = attr.get("transcript_id", "None")
                    transcript_symbol = attr.get("transcript_symbol", "None")

                    # gene_id  = attr["gene_id"]
                    # gene_symbol = attr.get("gene_symbol", "None")
                    # strand =  spt[6]
                    
                    if transcript_id not in dico[gene_id]["transcript"]:
                        dico[gene_id]["transcript"][transcript_id] = {
                        "transcript_symbol" : transcript_symbol,
                        "transcript_id" : transcript_id
                        }
                    
                    if type_ not in dico[gene_id]["transcript"][transcript_id]: 
                        dico[gene_id]["transcript"][transcript_id][type_] = []
                        
                    dico[gene_id]["transcript"][transcript_id][type_].append({
                        "chr": chr_,
                        "start": start, 
                        "end" : end,
                        "strand" : strand,
                    })
                except: 
                    continue

In [None]:
gtf_file = "<path to annotated Y genes gtf>"
gtf_dict = gtf_to_dict(gtf_file, "gene_id")

In [None]:
omnisplice_run = Path("<Path to directory containing OmniSplice output>")

In [None]:

def data_to_df_for_glm(control, treatment):
    df = pd.DataFrame(control + treatment)
    df["failures"] = df.apply(lambda x: sum(x[1:]), axis=1)
    df["successes"] = df[0]
    df["group"] = ["control"] * len(control) + ["treatment"] * len(treatment)
    return df

def gln_binomial_test(control, treatment):
    df = data_to_df_for_glm(control, treatment)
    mod = smf.glm('successes + failures ~ group', family=sm.families.Binomial(), data=df).fit()
    # convert intercept to probability
    odds_i = math.exp(mod.params[0])
    control_p = odds_i / (1 + odds_i)

    # convert stable="stable" to probability
    odds_p = math.exp( mod.params[0]) * math.exp(mod.params[1])
    treatment_p = (odds_p / (1 + odds_p))
    return (mod.pvalues[1], control_p , treatment_p)


def collapse(values):
    res = []
    for i, v in enumerate(values):
        if i % 2 == 0:
            res.append([x+values[i+1][ii] for ii, x in enumerate(v)])
    return res

In [None]:
reg = re.compile("(.*)_\d+_S\d")

In [None]:
dico_result = {}
# merged
for file in sorted(omnisplice_run.glob("*table")):
    print(file)
    geno = reg.search(file.stem).group(1) 

    
    with open(file) as fi:
        header = fi.readline().strip()
        header = dict((v, i) for i,v in enumerate(header.split()))
        for l in fi:
            spt = l.strip().split('\t')
            value = []
            value = list(map(int, spt[header["spliced"]: header["exon_intron"] + 1]))
            value[-1] += int(spt[header["exon_other"]])

            gene = spt[1]
            if gene not in dico_result:
                dico_result[gene] = {}
            if spt[header["exon_type"]] == "Acceptor":
                continue
            exon_n = spt[3]
            if exon_n not in dico_result[gene] :
                dico_result[gene][exon_n] = {}
            
            if geno not in dico_result[gene][exon_n]:
                dico_result[gene][exon_n][geno] = []
            dico_result[gene][exon_n][geno].append(value)

In [None]:
out_dir = Path("<path to output dir>")

In [None]:
color = ["#D3D3D385", "#214d4e",  "#cc655b", '#41bbc5', "#c6dbae", '#069668',  '#b3e61c',  '#d55e00', '#cc78bc',
         '#ca9161', '#fbafe4',  '#029e73']
mpl.rcParams['axes.linewidth'] = 1.5
mpl.rcParams["axes.spines.right"] = False
mpl.rcParams["axes.spines.top"] = False

In [None]:

for gene in dico_result:
    try:
        gene_name = dico_gtf[gene]["symbol"]
    except:
        print(gene_name)
        gene_name = 'PRY'
    if gene_name == "PRY_alt":
        gene_name = "PRY"
        
    len_exon = len(dico_result[gene])
    z = 0
    exon_order = list(sorted(dico_result[gene].keys(), key=lambda x: x.split('_')[0], reverse=True))[:-1]
    for i_exon, exon in enumerate(exon_order):
           
        pure = dico_result[gene][exon]["mau"]
        control = pure
        
        pure = np.sum(np.array(pure), axis=0)
        s_pure = sum(pure) 
        pure = pure / s_pure
        hyb = dico_result[gene][exon]["sim_mau"]
        treatment = hyb
        hyb = np.sum(np.array(hyb), axis=0)
        s_hyb = sum(hyb) 
        hyb = hyb / s_hyb
        this = exon_order[i_exon]
        this = this.split('_')
        this[1] = str(int(this[1]))
        exon_order[i_exon] = " ".join(this)
        print(pure, hyb)
        try:
            pval, control_p, treat_p = gln_binomial_test(control, treatment)
            if pval < 0.001:
                
                exon_order[i_exon] += " *"
            print(i_exon, pval)    
        except:
            print(control, treatment)
            exon_order[i_exon] += "_na"
        
        bottom = 0
        for i,e in enumerate(pure):
            plt.barh(y=z, width=e,left=bottom, color = color[i])
            bottom += e
        bottom += 0.05
        for i,e in enumerate(hyb[::-1]):
            plt.barh(y=z, width=e,left=bottom, color = color[3-i])
            bottom += e
    
        z += 1

    plt.yticks(ticks=range(0,z), labels=exon_order)
    handles, labels = plt.gca().get_legend_handles_labels()
    line1 = Line2D([], [], label="Spliced",   color=color[0], linewidth=6)
    line2 = Line2D([], [], label="Unspliced",  color=color[1], linewidth=6)
    line3 = Line2D([], [], label="Clipped",   color=color[2], linewidth=6)
    line4 = Line2D([], [], label="Exon_Intron",   color=color[3], linewidth=6)
    
    handles.extend([ line1, line2, line3, line4])#, line5])
    plt.gcf().legend(handles=handles, loc='outside right upper', bbox_to_anchor=(1.25, 0.8))
    plt.gca().set_xticks([])
    plt.title(gene_name,     fontsize=24)
    plt.xlim(-0.1, 2.2)

    plt.tick_params(axis='y', labelsize=12)

    plt.savefig(out_dir / "{}.pdf".format(gene_name), bbox_inches="tight")
    plt.show();
    