In [24]:
from pathlib import Path
import numpy as np
import pysam
from tqdm import tqdm
from itertools import repeat
import polars as pl
import matplotlib.pyplot as plt
import pandas as pd
import plotly.graph_objects as go
import plotly
import plotly.io as pio
import argparse
from random import seed,uniform


def view_modifications(bamfile_path:str, IVT_path:str, reference_path:str, literature_mod_df_path:str, condition:str):
    literature_mod_df = pd.read_csv(literature_mod_df_path ,sep="\t",header=None,index_col=None) #"/home/stefan/wf-nanoribolyzer/references/rRNA_modifications_conv.bed"
    literature_mod_df.columns = ["reference","start","end","modification","A","B","C"]

    psu_mod_df = literature_mod_df.loc[literature_mod_df["modification"] == "psu"]
    Um_mod_df = literature_mod_df.loc[literature_mod_df["modification"] == "Um"]
    
    A_mod_df = literature_mod_df.loc[literature_mod_df["modification"].isin(["Am","m62A","m6A"])]


    fasta_file = pysam.FastaFile(reference_path)
    reference = fasta_file.references[0]
    reference_sequence = str(fasta_file.fetch(reference))

    reads_aligning = [0 for i in range(len(reference_sequence))]
    mod_positions_m6a = [0 for i in range(len(reference_sequence))]
    mod_positions_pseU = [0 for i in range(len(reference_sequence))]
    number_of_basecalled_C = [0 for i in range(len(reference_sequence))]

    bamfile = pysam.AlignmentFile(bamfile_path, mode="rb")
    for i in tqdm(bamfile.fetch(until_eof=True)):
        if i.is_supplementary:
            continue
        start = i.reference_start
        end = i.reference_end
        for index in range(start,end):
            reads_aligning[index] += 1
        mod_obj = i.modified_bases
        if mod_obj != None:
            # try:
            #     mod_m6a = list(mod_obj[('A', 0, 'a')])
            # except KeyError:
            #     mod_m6a = None
            try:
                mod_pseU = list(mod_obj[('T', 0, 17802)])
            except KeyError:
                mod_pseU = None
            try:
                aligned_pairs = i.get_aligned_pairs(with_seq=True)
                alignment_dict = {}
                for pair_element in aligned_pairs:
                    if None not in pair_element:
                        alignment_dict[str(pair_element[0])] = {"index_query":pair_element[0],"index_reference":pair_element[1],"base_query": str(i.get_forward_sequence())[pair_element[0]],"base_reference": reference_sequence[pair_element[1]]}
                #if mod_m6a != None:
                #    for mod_base in mod_m6a:
                #        p = ((mod_base[1] + 1)/256)
                #        if p >= 0.95 and str(mod_base[0]) in alignment_dict:
                #            mod_positions_m6a[alignment_dict[str(mod_base[0])]["index_reference"]] += 1

                if mod_pseU != None:
                    for mod_base2 in mod_pseU:
                        p = ((mod_base2[1] + 1)/256)
                        if p >= 0.95 and str(mod_base2[0]) in alignment_dict:
                            mod_positions_pseU[alignment_dict[str(mod_base2[0])]["index_reference"]] += 1
                            
                for pair_element in aligned_pairs:
                    if None not in pair_element:
                        if alignment_dict[str(pair_element[0])]["base_query"] == "C" and alignment_dict[str(pair_element[0])]["base_reference"] == "T":
                            number_of_basecalled_C[pair_element[1]] += 1
            except TypeError:
                print("A type error occured")
    
    reads_aligning = np.array(reads_aligning)
    reads_aligning[reads_aligning == 0] = 1


    positions = [i+1 for i in range(len(reads_aligning))]

    mod_positions_m6a = np.array(mod_positions_m6a)
    rel_mod_positions_m6a = mod_positions_m6a/reads_aligning

    mod_positions_pseU = np.array(mod_positions_pseU)
    rel_mod_positions_pseU = mod_positions_pseU/reads_aligning

    number_of_basecalled_C = np.array(number_of_basecalled_C)
    rel_number_of_basecalled_C = number_of_basecalled_C/reads_aligning

    dataset = {
        "position":positions,
        "reads_aligning":reads_aligning,
        "n_pseU":mod_positions_pseU,
        "n_m6a":number_of_basecalled_C,
        "n_C":number_of_basecalled_C,
        "rel_n_pseU":rel_mod_positions_pseU,
        "rel_n_m6a":rel_mod_positions_m6a,
        "rel_n_C":rel_number_of_basecalled_C
    }

    modification_df = pd.DataFrame(dataset)
    modification_df.to_csv(f"{condition}_modification_quantification.csv",sep=";",header=True,index=None)
    
    IVT_reads_aligning = [0 for i in range(len(reference_sequence))]
    IVT_mod_positions_m6a = [0 for i in range(len(reference_sequence))]
    IVT_mod_positions_pseU = [0 for i in range(len(reference_sequence))]
    IVT_number_of_basecalled_C = [0 for i in range(len(reference_sequence))]
    

    IVT_bamfile = pysam.AlignmentFile(IVT_path, mode="rb")
    counter = 0
    for i in tqdm(IVT_bamfile.fetch(until_eof=True)):
        if counter >= 20000:
            break
        if i.is_supplementary:
            continue
        start = i.reference_start
        end = i.reference_end
        for index in range(start,end):
            IVT_reads_aligning[index] += 1
        mod_obj = i.modified_bases
        if mod_obj != None:
            # try:
            #     mod_m6a = list(mod_obj[('A', 0, 'a')])
            # except KeyError:
            #     mod_m6a = None
            try:
                mod_pseU = list(mod_obj[('T', 0, 17802)])
            except KeyError:
                mod_pseU = None
            aligned_pairs = i.get_aligned_pairs(with_seq=True)
            alignment_dict = {}
            for pair_element in aligned_pairs:
                if None not in pair_element:
                    alignment_dict[str(pair_element[0])] = {"index_query":pair_element[0],"index_reference":pair_element[1],"base_query": str(i.get_forward_sequence())[pair_element[0]],"base_reference": reference_sequence[pair_element[1]]}
            #if mod_m6a != None:
            #    for mod_base in mod_m6a:
            #        p = ((mod_base[1] + 1)/256)
            #        if p >= 0.95 and str(mod_base[0]) in alignment_dict:
            #            IVT_mod_positions_m6a[alignment_dict[str(mod_base[0])]["index_reference"]] += 1

            if mod_pseU != None:
                for mod_base2 in mod_pseU:
                    p = ((mod_base2[1] + 1)/256)
                    if p >= 0.95 and str(mod_base2[0]) in alignment_dict:
                        IVT_mod_positions_pseU[alignment_dict[str(mod_base2[0])]["index_reference"]] += 1
                        
            for pair_element in aligned_pairs:
                if None not in pair_element:
                    if alignment_dict[str(pair_element[0])]["base_query"] == "C" and alignment_dict[str(pair_element[0])]["base_reference"] == "T":
                        IVT_number_of_basecalled_C[pair_element[1]] += 1
        counter += 1

    IVT_reads_aligning = np.array(IVT_reads_aligning)
    IVT_reads_aligning[IVT_reads_aligning == 0] = 1


    IVT_positions = [i+1 for i in range(len(IVT_reads_aligning))]

    IVT_mod_positions_m6a = np.array(IVT_mod_positions_m6a)
    IVT_rel_mod_positions_m6a = IVT_mod_positions_m6a/IVT_reads_aligning

    IVT_mod_positions_pseU = np.array(IVT_mod_positions_pseU)
    IVT_rel_mod_positions_pseU = IVT_mod_positions_pseU/IVT_reads_aligning

    IVT_number_of_basecalled_C = np.array(IVT_number_of_basecalled_C)
    IVT_rel_number_of_basecalled_C = IVT_number_of_basecalled_C/IVT_reads_aligning
    
    IVT_dataset = {
        "position":IVT_positions,
        "reads_aligning":IVT_reads_aligning,
        "n_pseU":IVT_mod_positions_pseU,
        "n_m6a":IVT_number_of_basecalled_C,
        "n_C":IVT_number_of_basecalled_C,
        "rel_n_pseU":IVT_rel_mod_positions_pseU,
        "rel_n_m6a":IVT_rel_mod_positions_m6a,
        "rel_n_C":IVT_rel_number_of_basecalled_C
    }
    
    IVT_modification_df = pd.DataFrame(IVT_dataset)
    IVT_modification_df.to_csv(f"IVT_modification_quantification.csv",sep=";",header=True,index=None)

    

    layout = go.Layout(height = 800)
    fig = go.Figure(layout=layout)

    

    


    fig.add_trace(
            go.Scatter(
                x=[index for index in range(len(rel_number_of_basecalled_C))],
                y=[rel_mod_positions_pseU[index]+value for index,value in enumerate(rel_number_of_basecalled_C)],  # Use the same x position and y at the top of the line
                line_color="rgba(153,0,0,0.4)",
                showlegend=True,
                name="C/U + pseU freq."
            )
        )
    
    fig.add_trace(
        go.Scatter(
            x=[index for index in range(len(rel_mod_positions_pseU))],
            y=[value for index,value in enumerate(rel_mod_positions_pseU)],  # Use the same x position and y at the top of the line
            line_color="rgba(153,0,0,1)",
            showlegend=True,
            name="pseU freq."
        )
    )
    
    
    
    fig.add_trace(
            go.Scatter(
                x=[index for index in range(len(IVT_rel_number_of_basecalled_C))],
                y=[-IVT_rel_mod_positions_pseU[index]-value for index,value in enumerate(IVT_rel_number_of_basecalled_C)],  # Use the same x position and y at the top of the line
                line_color="rgba(0, 181, 204, 1)",
                showlegend=True,
                name="IVT 18S C/U + pseU freq."
            )
        )
    
    fig.add_trace(
            go.Scatter(
                x=[index for index in range(len(IVT_rel_mod_positions_pseU))],
                y=[-value for index,value in enumerate(IVT_rel_mod_positions_pseU)],   # Use the same x position and y at the top of the line
                line_color="rgba(30, 81, 123, 1)",
                showlegend=True,
                name="IVT 18S pseU freq."
            )
        )

    fig.add_trace(
        go.Scatter(
            x = [i for i in range(0,len(reference_sequence))],
            y = [0 for i in range(0,len(reference_sequence))],
            name= "",
            showlegend=False,
            line_color = "white"
            )
        )

    for index,T_position in psu_mod_df.iterrows():
        fig.add_shape(
            x0=T_position["end"]-1,
            x1=T_position["end"]-1,
            y0=0,
            y1=1,
            line=dict(
                color="rgba(90,34,139,0.5)",
                width=0.3,
                dash="dash"
                )
            )
    
    for index,T_position in Um_mod_df.iterrows():
        fig.add_shape(
            x0=T_position["end"]-1,
            x1=T_position["end"]-1,
            y0=0,
            y1=1,
            line=dict(
                color="rgba(4,59,92,0.5)",
                width=0.3,
                dash="dash"
                )
            )
    
    fig.add_trace(go.Scatter(
        x=[None], y=[None],  # Invisible point, used only for legend entry
        mode='lines',
        line=dict(
            color="rgba(255,106,106,0.5)",
            width=1,
            dash="dash"
        ),
        showlegend=True,
        name="known pseU"  # Legend entry name
    ))
    
    fig.add_trace(go.Scatter(
        x=[None], y=[None],  # Invisible point, used only for legend entry
        mode='lines',
        line=dict(
            color="rgba(4,59,92,0.7)",
            width=1,
            dash="dash"
        ),
        showlegend=True,
        name="known Um"  # Legend entry name
    ))
    
    

    fig.update_layout(
        title=f"Modification basecalling",
        xaxis=dict(title="Position on reference",gridcolor = "white",tickformat="d"),
        yaxis=dict(title="Modification frequency",gridcolor = "white"),
        plot_bgcolor='rgba(0,0,0,0)'
    )

    fig.write_image(f"{condition}_general_modification_ratio_with_Um.svg",format="svg")
    fig.show()
    
    #5s
    fig.update_layout(
        xaxis=dict(range=[6550,6800])
    )
    fig.write_image(f"{condition}_5-8S_modification_ratio_with_Um.svg",format="svg")
    fig.show()
    
    
    #18S
    fig.update_layout(
        xaxis=dict(range=[3655,5620])
    )
    fig.write_image(f"{condition}_18S_modification_ratio_with_Um.svg",format="svg")
    fig.show()
    
    #28S
    fig.update_layout(
        xaxis=dict(range=[7910,13000])
    )
    fig.write_image(f"{condition}_28S_modification_ratio_with_Um.svg",format="svg")
    fig.show()
    

In [None]:
view_modifications(bamfile_path="/home/stefan/Synology/Data_nano_ribolyzer/directRNA_004/20231114_RNA004_NP_Cyt/filtered_pod5/filtered_pod5_rebasecalled_psU_m6A_aligned.bam",
                   IVT_path = "/home/stefan/Synology/Data_nano_ribolyzer/directRNA_004/IVT_18S/filtered_pod5/filtered_pod5_basecalled.bam",
                   reference_path="/home/stefan/wf-nanoribolyzer/references/RNA45SN1.fasta",
                   literature_mod_df_path="/home/stefan/wf-nanoribolyzer/references/rRNA_modifications_conv.bed",
                   condition = "NP_Cytoplasm"
                   )

In [None]:
view_modifications(bamfile_path="/home/stefan/Synology/Data_nano_ribolyzer/directRNA_004/20231114_RNA004_NP_Nuc/filtered_pod5/filtered_pod5_rebasecalled_psU_m6A_aligned.bam",
                   IVT_path = "/home/stefan/Synology/Data_nano_ribolyzer/directRNA_004/IVT_18S/filtered_pod5/filtered_pod5_basecalled.bam",
                   reference_path="/home/stefan/wf-nanoribolyzer/references/RNA45SN1.fasta",
                   literature_mod_df_path="/home/stefan/wf-nanoribolyzer/references/rRNA_modifications_conv.bed",
                   condition = "NP_Nuc"
                   )

In [None]:
view_modifications(bamfile_path="/home/stefan/Synology/Data_nano_ribolyzer/directRNA_004/20231114_RNA004_IVPA_Nuc/filtered_pod5/filtered_pod5_rebasecalled_psU_m6A_aligned.bam",
                   IVT_path = "/home/stefan/Synology/Data_nano_ribolyzer/directRNA_004/IVT_18S/filtered_pod5/filtered_pod5_basecalled.bam",
                   reference_path="/home/stefan/wf-nanoribolyzer/references/RNA45SN1.fasta",
                   literature_mod_df_path="/home/stefan/wf-nanoribolyzer/references/rRNA_modifications_conv.bed",
                   condition = "IVPA_Nuc"
                   )

In [None]:
view_modifications(bamfile_path="/home/stefan/Synology/Data_nano_ribolyzer/directRNA_004/20231114_RNA004_IVPA_Cyt/filtered_pod5/filtered_pod5_rebasecalled_psU_m6A_aligned.bam",
                   IVT_path = "/home/stefan/Synology/Data_nano_ribolyzer/directRNA_004/IVT_18S/filtered_pod5/filtered_pod5_basecalled.bam",
                   reference_path="/home/stefan/wf-nanoribolyzer/references/RNA45SN1.fasta",
                   literature_mod_df_path="/home/stefan/wf-nanoribolyzer/references/rRNA_modifications_conv.bed",
                   condition = "IVPA_Cyt"
                   )