In [13]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Path,PathPatch
from arnie.utils import *
import sys
from arnie.mfe import mfe
import seaborn as sns 

#Normalizing the shape data
def normalize_shape(shape_reacs):
    shape_reacs = np.array(shape_reacs)

    # Get rid of nan values for now
    nonan_shape_reacs = shape_reacs[~np.isnan(shape_reacs)]

    # Find Filter 1: 1.5 * Inter-Quartile Range
    sorted_shape = np.sort(nonan_shape_reacs)
    q1 = sorted_shape[int(0.25 * len(sorted_shape))]
    q3 = sorted_shape[int(0.75 * len(sorted_shape))]
    iq_range = abs(q3 - q1)
    filter1 = next(x for x, val in \
        enumerate(list(sorted_shape)) if val > 1.5 * iq_range)

    # Find Filter 2: 95% value
    filter2 = int(0.95 * len(sorted_shape))

    # Get maximum filter value and fiter data
    filter_cutoff = sorted_shape[max(filter1, filter2)]
    sorted_shape = sorted_shape[sorted_shape < filter_cutoff]

    # Scalefactor: Mean of top 10th percentile of values
    top90 = sorted_shape[int(0.9 * len(sorted_shape))]
    scalefactor = np.mean(sorted_shape[sorted_shape > top90])
        
    # Scale dataset
    return shape_reacs/scalefactor

# input is text file of any shape data set, output is normalized list of values in a list
def get_normalized_shape_data(shape_filename):

    # write shape text file to list
    shape_file = open("{}".format(shape_filename), "r")
    shape_data = shape_file.read()
    shape_data_list = shape_data.split("\n")
    shape_file.close()
    
    shape_nan_list = []
    for char in shape_data_list:
        if char == '':
            shape_data_list.remove('')
        elif (char == '-999') or (char == 'nan') or (char == "NaN"):
            shape_nan_list.append('nan')
        else: 
            shape_nan_list.append(float(char))

    #convert unknown values to string 'nan'
    
    #convert string 'nan' to np.nan
    shape_reacs = []
    for char in shape_nan_list:
        if char == 'nan':
            shape_reacs.append(np.nan)
        else:
            shape_reacs.append(char)
    
    # normalize shape data
    normalized_shape_data = normalize_shape(shape_reacs).tolist()
    return normalized_shape_data

def draw_contact(i, j, ax,ystart,size_multiple):
    size = size_multiple*(j - i)

    verts = [
       (i, ystart),   # P0
       (i, ystart+size),  # P1
       (j, ystart+size),  # P2
       (j, ystart),  # P3
    ]
    codes = [
        Path.MOVETO,
        Path.CURVE4,
        Path.CURVE4,
        Path.CURVE4,
    ]

    path = Path(verts, codes)

    patch = PathPatch(path, facecolor='none', lw=0.5)
    ax.add_patch(patch)

def get_shape_graphs_for_pks(pk_csv, path_to_shape_data, list_of_shape_sets, output_folder):
    df = pd.read_csv(pk_csv)
    
    all_shape_sets = []
    for name in list_of_shape_sets:
        shape_data = get_normalized_shape_data(path_to_shape_data + '/' + name + '.csv')
        all_shape_sets.append(shape_data)
        
        for idx, char in enumerate(shape_data):
            if isinstance(char, str):
                print('error found a string in shape data')
            elif char == -999:
                print('error found -999')
            elif char < (-600):
                print('error found < -600 ' + str(char) + ' at location ' + str(idx))
                
    starts = df['start'].to_list()
    seqs = df['sequence'].to_list()
    structs = df['structure'].to_list()
    pk_predictor = df['program'].to_list()
    
    # generate a list containing sets of all programs with shape data with only pks of interest
    all_pk_shape_data = []
    for program in all_shape_sets:
        program_specific_pk_shape_data = []
        for idx in starts:
            pk_shape_data = program[idx:idx+120]
            program_specific_pk_shape_data.append(pk_shape_data)
        all_pk_shape_data.append(program_specific_pk_shape_data)
        
    
    #finding the lowest and highest shape values from pks to use as min and max values for scaling
    track_mins = []
    track_maxs = []
    for track in test_pk_shape_data:
        pk_mins = []
        pk_maxs = []
        for pk in track:
            pk_mins.append(min(pk))
            pk_maxs.append(max(pk))
    track_mins.append(min(pk_mins))
    track_maxs.append(max(pk_maxs))
    
    x_min = min(track_mins)
    x_max = max(track_maxs)
    
    
    # for this pseudoknot, we are finding the sequence, the predicted structure
        # and a list with the shape data from all 5 tracks for that window
    for idx in range(0,len(starts)):
        seq = seqs[idx]
        struct = structs[idx]
        reacts = []
        for num in range(len(list_of_shape_sets)):
            # 'program' contains all the shape data for all windows with pks from one track of data
            program = all_pk_shape_data[num]
            # react contains the shape data from one window from one track of data
            react = program[idx]
            reacts.append(react)
        
        react_labels = []
        react_labels = react_labels + list_of_shape_sets
        num_white_space = 25 # increase if need more space at bottom for arc
        figsize_x = 20
        figsize_y = 10
        arc_offset = 0 # increase if arc needs to move more down 
        arc_height = 0.3 # decrease if want arc shorter


        for i in range(num_white_space):
            reacts.append(np.zeros(len(seq)))
            react_labels.append("")
    

        plt.figure(figsize=(figsize_x,figsize_y))
        
        #plt.imshow(reacts,cmap='gist_heat_r',aspect='auto')
        current_cmap = plt.cm.get_cmap('gist_heat_r')
        current_cmap.set_bad(color='gray')
        plt.imshow(reacts, cmap=current_cmap, aspect='auto')
        g=plt.yticks(list(range(len(reacts))),react_labels)
        g=plt.xticks(list(range(len(seq))),seq)
        plt.clim(x_min, x_max)
        plt.gca().spines['bottom'].set_position(('data', len(reacts)-0.5-num_white_space))
        bp_list = convert_dotbracket_to_bp_list(struct,allow_pseudoknots=True)
        for i,j in bp_list:
            draw_contact(i,j,plt.gca(),len(reacts)-num_white_space-arc_offset,arc_height)

        plt.colorbar(orientation='horizontal', label='reactivity')
        plt.savefig(output_folder+'/'+str(starts[idx])+'_'+pk_predictor[idx]+'.png', dpi=150, bbox_inches='tight')