In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import seaborn as sns
import gpplot as gpp
import csv, requests, warnings, os, matplotlib
warnings.filterwarnings("ignore")
matplotlib.rc("pdf", fonttype=42)

gpp.set_aesthetics(palette='Set2')

In [1]:
# Finds max length of list of lists
def FindMaxLength(lst): 
    maxList = max(lst, key = lambda i: len(i)) 
    maxLength = len(maxList) 
    return maxLength

# Returns reverse complement of a sequence
def revcom(s):
    basecomp = {'A': 'T', 'C': 'G', 'G': 'C', 'T': 'A','N':'N','K':'M','M':'K','R':'Y','Y':'R','S':'S','W':'W','B':'V','V':'B','H':'D','D':'H','-':'-'}
    letters = list(s[::-1])
    letters = [basecomp[base] for base in letters]
    return ''.join(letters)

#Codon to amino acid letter code
def get_codon_map():
    codon_map = {'TTT':'F', 'TTC':'F', 'TTA':'L', 'TTG':'L', 'CTT':'L', 'CTC':'L', 'CTA':'L', 'CTG':'L', 'ATT':'I', 'ATC':'I',
             'ATA':'I', 'ATG':'M', 'GTT':'V', 'GTC':'V', 'GTA':'V', 'GTG':'V', 'TCT':'S', 'TCC':'S', 'TCA':'S', 'TCG':'S',
             'CCT':'P', 'CCC':'P', 'CCA':'P', 'CCG':'P', 'ACT':'T', 'ACC':'T', 'ACA':'T', 'ACG':'T', 'GCT':'A', 'GCC':'A',
             'GCA':'A', 'GCG':'A', 'TAT':'Y', 'TAC':'Y', 'TAA':'*', 'TAG':'*', 'CAT':'H', 'CAC':'H', 'CAA':'Q', 'CAG':'Q',
             'AAT':'N', 'AAC':'N', 'AAA':'K', 'AAG':'K', 'GAT':'D', 'GAC':'D', 'GAA':'E', 'GAG':'E', 'TGT':'C', 'TGC':'C',
             'TGA':'*', 'TGG':'W', 'CGT':'R', 'CGC':'R', 'CGA':'R', 'CGG':'R', 'AGT':'S', 'AGC':'S', 'AGA':'R', 'AGG':'R',
             'GGT':'G', 'GGC':'G', 'GGA':'G', 'GGG':'G'}
    return codon_map

# Get information about all plates associated with specified sgrna.
# Input: sgRNA ID, like sg16
def get_plate_info(sg):
    df = pd.read_excel('metafiles/Plate_guide_info_v1.xlsx')
    plate_info = df[df['sgrna_name'] == sg]
    return plate_info
    
# Get sgrna frame, reverse complement status and relevant condition IDs
# Input: sgRNA ID, like sg16
def get_sgrna_details(sgrna):
    input_df = pd.read_csv('metafiles/BEV_allele_freq_input_v1.csv')
    sg_df = input_df[input_df.sg == sgrna].values[0]
    sgrna_frame = sg_df[5]
    sgrna_rev_status = sg_df[6]
    sgrna_rel_cond = sg_df[-1]
    return sgrna_frame, sgrna_rev_status, sgrna_rel_cond

# Gets wells that pass number of aligned reads in a plate
# Input: Plate number, folder with batch files
def get_pass_wells(plate, folder):
    plate_folder = folder + 'CRISPRessoBatch_on_CRISPResso_batch_file_'+plate+'/'
    files = os.listdir(plate_folder)
    map_stats_file = [x for x in files if "mapping_statistics" in x][0]
    map_stats_df = pd.read_csv(plate_folder+'/'+map_stats_file, sep='\t')
    pass_wells = map_stats_df[map_stats_df['READS ALIGNED'] >= 10000]
    return pass_wells, files, plate_folder

# Get relevant allele file from validation and assign colors 
# Input: sgRNA ID, relavant condition 
def get_sg_bev_file(s, rel_cond):
    pal = sns.color_palette('Set2').as_hex()
    sgrna_file = [x for x in os.listdir('BEV_sg_files/') if s in x][0]
    mcl1_bev = pd.read_csv('BEV_sg_files/'+sgrna_file)
    col = '_'.join(rel_cond.split(';'))
    return mcl1_bev

# Get alleles per well with %Reads >= 10
# Input: row of passed wells, batch files, CRISPResso output folder for specific plate 
def get_alleles(p, files, plate_folder):
    w_folder = [x for x in files if ('CRISPResso_on_'+p.Batch in x)&('.html' not in x)][0]
    w_files = os.listdir(plate_folder+'/'+w_folder)
    well = w_folder.split('_')[2]
    allele_table_file = [x for x in w_files if ("Alleles_frequency_table_around_sgRNA" in x)&('.png' not in x)&('.pdf' not in x)][0]
    allele_table = pd.read_csv(plate_folder+'/'+w_folder+'/'+allele_table_file, sep='\t')
    allele_table = allele_table.sort_values('%Reads', ascending=False)
    allele_table_fil = allele_table.loc[allele_table['%Reads']>=10.0]
    return allele_table_fil, well

# Aligns WT translated sequence to allele translated sequence
def align(wt_trans, al_trans):
    new_al_trans = ''
    for i,n in enumerate(al_trans):
        if n == wt_trans[i]:
            new_al_trans+='-'
        else:
            new_al_trans+=n
    return new_al_trans

# Get colors and categories for all alleles 
def get_colors(allele_table, sg_bev, well, sg_frame, sg_rev_status, w, sg, plate_no, condition):
    wt_seq = allele_table.loc[:,'Reference_Sequence'].values[0]
    if sg_rev_status == True:
        wt_seq = revcom(wt_seq)
    for k,r in allele_table.iterrows():
        cat = ''
        color = ''
        seq = r.Aligned_Sequence
        if sg_rev_status == True:
            seq = revcom(r.Aligned_Sequence)
        matched_row = sg_bev[sg_bev.Aligned_Sequence == seq]
        if len(matched_row) == 1:
            cat = matched_row.Category.values[0]
            color = matched_row.Colors.values[0]
            wt_trans = translate(wt_seq, sg_frame)
            al_trans = translate(seq, sg_frame)
            if len(wt_trans) == len(al_trans):
                al_trans = align(wt_trans, al_trans)
        else:
            if r.n_deleted != 0 or r.n_inserted != 0:
                if sg_rev_status == True:
                    indel_wt_seq = revcom(r.Reference_Sequence)
                    indel_al_seq = revcom(r.Aligned_Sequence)
                else:
                    indel_wt_seq = r.Reference_Sequence
                    indel_al_seq = r.Aligned_Sequence
                wt_trans = translate(indel_wt_seq, sg_frame)
                al_trans = translate(indel_al_seq, sg_frame)
                cat = 'Indel'
                color = '#ffd92f'     
            else:
                wt_trans, al_trans, cat,color = classify_allele(r.Aligned_Sequence,r.Reference_Sequence, sg_frame, sg_rev_status)
                matched_row_trans = sg_bev[sg_bev.Translated == al_trans]
                if len(matched_row_trans) == 1:
                    cat = matched_row_trans.Category.values[0]
                    color = matched_row_trans.Colors.values[0]
                    if len(wt_trans) == len(al_trans):
                        al_trans = align(wt_trans, al_trans)
                else:
                    if len(wt_trans) == len(al_trans):
                        al_trans = align(wt_trans, al_trans)                        
        row = list(r)
        row.extend([wt_trans, al_trans, cat,color,well,plate_no,condition])
        w.writerow(row)
    return

#Classifies "Other" alleles
def classify_allele(aligned_seq, ref_seq, sg_frame, sg_rev_status):
    colors = {'Silent':'#8da0cb',
              'Nonsense':'#ffd92f'}
    wt_allele = ref_seq
    if sg_rev_status == True:
        wt_allele = revcom(wt_allele)
    wt_trans_seq = translate(wt_allele, sg_frame)
    cat = ''
    color = ''
    allele = aligned_seq
    if sg_rev_status == True:
        allele = revcom(allele)
    trans_seq = translate(allele, sg_frame)
    if trans_seq == wt_allele:
        cat = 'Silent'
        color = colors[cat]
    return wt_trans_seq, trans_seq, cat, color


#Get colors and categories for all control alleles
def get_control_colors(allele_table, well, w, plate_no, cond):
    wt_seq = allele_table.loc[:,'Reference_Sequence'].values[0]
    for k,r in allele_table.iterrows():
        cat = ''
        color = ''
        seq = r.Aligned_Sequence
        if seq == wt_seq:
            cat = 'WT'
            color = '#e78ac3'
        else:
            cat = 'Other'
            color = '#b3b3b3'
            flag = 1
        row = list(r)
        row.extend(['','', cat, color, well, plate_no, cond])
        w.writerow(row)
    return

# This function returns the tranlation of a given sequence and frame
def translate(seq, frame):
    codon_map = get_codon_map()
    aa = ''
    i = frame - 1
    while i < len(seq):
        substring = ''
        while len(substring) < 3:
            if i < len(seq):                
                if seq[i] != '-':
                    substring += seq[i]
                i += 1
            else:
                break
        if len(substring) == 3:
            if 'N' in substring:
                aa = aa + '-'
            else:
                aa = aa + codon_map[substring]
    return aa

#Transposes a dataframe and reassigns index and columns
def transpose_df(df):
    df2=df.T
    header=df2.iloc[0]
    df2=df2[1:]
    df2.columns=header
    return df2

#Get colors for specific categories
def get_color_hash(color_df, cat_df):
    color_hash = {}
    cols = list(cat_df.columns)[1:]
    for i,c in enumerate(cols):
        cat_list = list(cat_df[c])
        col_list = list(color_df[c])
        for j,ca in enumerate(cat_list):
            if ca not in color_hash.keys():
                color_hash[ca] = col_list[j]
            else:
                continue 
    return color_hash

#Generate legend for bar plots
def get_legend(color_df, cat_df):
    color_hash = get_color_hash(color_df, cat_df)
    if '' in color_hash.keys():
        del color_hash['']
    handles = []
    labels = {}
    for k,v in color_hash.items():
        if v in labels.keys():
            labels[v] = labels[v]+'/'+k
        else:
            labels[v] = k
    for k,v in labels.items():
        handles.append(mpatches.Patch(color=k, label=v))
    return(handles)

In [None]:
#Generates aggregate alleles file for each sgRNA
sg = ['control','sg14','sg15','sg16','sg17']
#Specify output folder path containing CRISPResso outputs
outputfolder = 'CRISPResso_plate_outputs/'
for i,s in enumerate(sg):
    print(s)
    plate_info = get_plate_info(s)
    if s != 'control':
        sg_frame, sg_rev_status, sg_rel_cond = get_sgrna_details(s)
        sg_bev = get_sg_bev_file(s,sg_rel_cond)
    outputfile = s+'_agg_file.txt'
    with open(outputfile, 'w') as o:
        w = csv.writer(o, delimiter='\t')
        w.writerow(['Aligned_Sequence','Reference_Sequence','Unedited','n_deleted','n_inserted','n_mutated','#Reads','%Reads','Reference_translation','Aligned_translation','Category','Colors','Well','Plate','Condition'])
        for j,p in plate_info.iterrows():
            pass_wells, files, plate_folder = get_pass_wells(p['Plate #'], outputfolder)
            for j,pw in pass_wells.iterrows():
                allele_table_fil, well = get_alleles(pw, files, plate_folder)
                if s == 'control':
                    get_control_colors(allele_table_fil, well, w, p['Plate #'], p.condition)
                else:
                    get_colors(allele_table_fil, sg_bev, well, sg_frame, sg_rev_status, w, s, p['Plate #'],p.condition)       


In [2]:
# Generates stacked bar plots for treated and untreated samples of each sgRNA
sg = ['control','sg14','sg15','sg16','sg17']
for i,s in enumerate(sg):
    print(s)
    agg_df = pd.read_csv(s+'_agg_file.txt', sep='\t')
    agg_df['Labels'] = agg_df.Plate+'_'+agg_df.Well
    cond = list(set(agg_df.Condition))
    for co in cond:
        cond_df = agg_df[agg_df.Condition == co]
        wells = list(set(cond_df.Labels))
        per_reads = []
        colors = []
        cate = []
        for j,w in enumerate(wells):
            w_df = cond_df[cond_df.Labels == w]
            reads_list = [w]
            reads_list.extend(list(w_df['%Reads']))
            per_reads.append(reads_list)
            colors_list = [w]
            colors_list.extend(list(w_df.Colors))
            colors.append(colors_list)
            cate_list = [w]
            cate_list.extend(list(w_df.Category))
            cate.append(cate_list)
        max_len = FindMaxLength(per_reads)
        cols = ['Well']
        cols.extend(['Allele_'+str(x) for x in range(1,max_len)])
        bar_df = pd.DataFrame(columns=cols)
        bar_color_df = pd.DataFrame(columns=cols)
        bar_category = pd.DataFrame(columns=cols)
        for i,pr in enumerate(per_reads):
            c = colors[i]
            cc = cate[i]
            pr.extend([0]*(max_len-len(pr)))
            c.extend(['#FFFFFF']*(max_len-len(c)))
            cc.extend(['']*(max_len-len(cc)))
            bar_df.loc[i] = pr
            bar_color_df.loc[i] = c
            bar_category.loc[i] = cc
        bar_df = bar_df.sort_values('Well', ascending=True)
        bar_color_df = bar_color_df.sort_values('Well', ascending=True)        
        bar_category = bar_category.sort_values('Well', ascending=True)
        data = np.array(bar_df.iloc[:,1:].T)
        X = np.arange(data.shape[1])
        color_data = np.array(bar_color_df.iloc[:,1:].T)
        fig, ax = plt.subplots(figsize=(3,2))
        for axis in ['bottom','left']:
              ax.spines[axis].set_linewidth(0.5)
        for i in range(data.shape[0]):
            plt.bar(X, data[i],
                bottom = np.sum(data[:i], axis = 0),
                color = color_data[i % len(color_data)],width=1.0,edgecolor='white')
        ticklabels = [x.split('_')[1] for x in list(bar_color_df.Well)]
        plt.xticks(ticks=X, labels=ticklabels, rotation=90, fontsize=5)
        plt.yticks(fontsize=6)
        plt.tick_params(axis='both', width=0.5, length=4)
        handles = get_legend(bar_color_df, bar_category)
        plt.legend(handles=handles, bbox_to_anchor=(1.05,1), fontsize=6)
        t = plt.ylabel('% of reads', fontsize=7)
        t = plt.title(s+'_'+co, fontsize=7)
        sns.despine()  
        plt.savefig('Figures/'+s+'/'+s+'_'+co+'_stacked_bar_plot_v1.pdf', dpi=1000,bbox_inches='tight')