In [1]:
import pandas as pd
import numpy as np
import pysam
from textwrap import wrap

In [2]:
add_seq_primers = False

In [3]:
data_dir = '/lustre/groups/epigenereg01/workspace/projects/vale/mlm/mpra/griesemer_2021/'

In [4]:
human_fasta = pysam.FastaFile(data_dir + '../../fasta/Homo_sapiens_dna_fwd.fa')

In [5]:
utr_variants = pd.read_csv(data_dir + 'preprocessing/GRCh38_UTR_variants.tsv', sep='\t') #all positions are 0-based [start, end)

In [6]:
utr_variants = utr_variants[utr_variants.other_var_in_oligo_window.isna()] #seeding multiple variants into oligo sequence isn't currently supported

In [7]:
len(utr_variants)

18376

In [8]:
utr_variants['vartype'] = utr_variants.apply(lambda x: 'SNP' if len(x.ref)==len(x.alt) else
                                            'DEL' if len(x.ref)>len(x.alt) else 'INS', axis=1)

In [9]:
utr_variants.sort_values(by='oligo_id', inplace=True)

In [16]:
def reverse_complement(seq):
    '''
    Take sequence reverse complement
    '''
    compl_dict = {'A':'T', 'C':'G', 'G':'C', 'T':'A'}
    compl_seq = ''.join([compl_dict.get(x,x) for x in seq])
    rev_seq = compl_seq[::-1]
    return rev_seq

In [17]:
#oligo primers, don't think we need to include them

primer_5_end = 'CGAGCTCGCTAGCCT'
primer_3_end = 'AGATCGGAAGAGCGTCG'

In [18]:
def insert_variant(seq, oligo):
    
    varpos = int(oligo.var_start - oligo.oligo_start) #variant position w.r.t. oligo coordinates
    
    if oligo.vartype!='DEL':
        seq[varpos] = oligo.alt
    else:
        seq[varpos+1:varpos+len(oligo.ref)] = ''
    
    return seq
        
def check_ref(seq, oligo):

    varpos = int(oligo.var_start - oligo.oligo_start) #variant position w.r.t. oligo coordinates

    #detect reference mismatches
    if oligo.vartype != 'DEL' and seq[varpos] != oligo.ref:
        return False
    elif oligo.vartype == 'DEL' and ''.join(seq[varpos:varpos+len(oligo.ref)]) != oligo.ref:
        return False
    
    return True

In [22]:
def extract_utr_seq(oligo, complement_negative=False):

    seq = human_fasta.fetch(oligo.UTR_ID).upper()
    
    #oligo position within UTR region
    oligo_start = int(oligo.oligo_start - oligo.human_UTR_start)
    oligo_end = int(oligo.oligo_end - oligo.human_UTR_start)
    
    seq = seq[oligo_start:oligo_end]#extract oligo seqeunce, no primers here

    seq = list(seq)

    if oligo.mpra_variant_id.endswith('_2'):
        #seeding multiple variants into oligo sequence isn't currently supported
        return (None, None)
        
    if not check_ref(seq, oligo):
        #mismatch with the reference genome
        return (None, None)
    
    #seed variant if alt sequence
    if oligo.tag == 'alt':
        
        seq = insert_variant(seq, oligo)
        
#    if oligo.mpra_variant_id.endswith('_2'):
#        
#        for other_var_id in oligo.other_var_in_oligo_window.split(','):
#            
#            if not other_var_id in utr_variants.variant_id.unique():
#                print(other_var_id,'Error 1')
#                return (None, None)
#                        
#            other_var = utr_variants[utr_variants.variant_id==other_var_id].iloc[0]
#            
#            if not check_ref(seq, other_var) and oligo.tag != 'alt':
#                print(other_var_id,'Error 2')
#                return (None, None)
#            
#            seq = insert_variant(seq, other_var)
            
    seq = ''.join(seq)

    if add_seq_primers:
        seq = primer_5_end + seq + primer_3_end

    #for genes on the negative strand, take reverse complement
    if complement_negative and oligo.strand=='-':
        seq = reverse_complement(seq)

    #FASTA sequence header
    seq_header = f'>{oligo.oligo_id}'
    
    return seq_header, seq

In [23]:
#write FASTA file with variants

def write_fasta(output_fasta, complement_negative=False):
    
    n_mismatches = 0

    with open(output_fasta, 'w') as f:
        for idx, oligo in utr_variants.iterrows():
            seq_header, seq = extract_utr_seq(oligo, complement_negative=complement_negative)
            if seq_header != None:
                f.write(seq_header+'\n')
                for line in wrap(seq, 80): #wrap sequence with standard FASTA width
                    f.write(line+'\n')
            else:
                n_mismatches += 1
                
    print(f'Variants with mismatched reference: {n_mismatches//2}')

In [None]:
write_fasta(data_dir + 'fasta/variants_dna_fwd.fa', complement_negative=False) 

In [34]:
write_fasta(data_dir + 'fasta/variants_rna.fa', complement_negative=True) 

Variants with mismatched reference: 1728


In [35]:
#! ./RNAfold.sh {data_dir}/fasta/variants_rna.fa  {data_dir}/fasta/free_energy.tsv

# Combine mpra_df

In [24]:
expression_df = pd.read_csv(data_dir + 'griesemer_supplementary/Variant_MPRAu_Results.txt', sep='\t')
oligo_info_df = pd.read_csv(data_dir + 'griesemer_supplementary/Oligo_Variant_Info.txt', sep='\t')

In [25]:
#minimal free energy from RNAfold software, recalculate each time when FASTA file changes!
#
#mpra_df['min_free_energy'] = pd.read_csv(data_dir + 
#'fasta/free_energy.tsv', header=None).squeeze() 

In [26]:
#Add seqeunces from FASTA file

fasta_fa = data_dir + 'fasta/variants_rna.fa'

seq = {}

with open(fasta_fa, 'r') as f:
    for line in f:
        if line.startswith('>'):
            oligo_id = line[1:].rstrip()
            seq[oligo_id] = ''
        else:
            seq[oligo_id] += line.rstrip()

seq_df = pd.DataFrame(seq.items(), columns=['oligo_id', 'seq'])

In [27]:
#combine all information together

expression_df = expression_df[['mpra_variant_id']+[x for x in expression_df.columns if x.startswith('log2FoldChange_Ref_') or x.startswith('log2FoldChange_Alt_')]]
oligo_info_df = oligo_info_df[['mpra_variant_id','gene_symbols']]

mpra_df = utr_variants.merge(seq_df).merge(expression_df).merge(oligo_info_df).drop_duplicates()

mpra_df.columns

Index(['mpra_variant_id', 'tag', 'oligo_id', 'variant_id', 'chrom', 'ref',
       'alt', 'other_var_in_oligo_window', 'var_start', 'var_end',
       'oligo_start', 'oligo_end', 'UTR_ID', 'human_UTR_start',
       'human_UTR_end', 'strand', 'vartype', 'seq',
       'log2FoldChange_Ref_HEK293FT', 'log2FoldChange_Alt_HEK293FT',
       'log2FoldChange_Ref_HEPG2', 'log2FoldChange_Alt_HEPG2',
       'log2FoldChange_Ref_HMEC', 'log2FoldChange_Alt_HMEC',
       'log2FoldChange_Ref_K562', 'log2FoldChange_Alt_K562',
       'log2FoldChange_Ref_GM12878', 'log2FoldChange_Alt_GM12878',
       'log2FoldChange_Ref_SKNSH', 'log2FoldChange_Alt_SKNSH', 'gene_symbols'],
      dtype='object')

In [18]:
# Define Groups for Group K-fold based on genes

mpra_df['group'] = mpra_df['gene_symbols']

In [19]:
#some oligos overlap
#The corresponding oligos should have the same group label

new_groups = {}

def get_overlap(a, b):
    '''
    get the overlap length between 2 intervals
    '''
    return max(0, min(a[1], b[1]) - max(a[0], b[0]))
    
df = mpra_df.sort_values(by=['chrom','oligo_start','oligo_end']).drop_duplicates(subset=['mpra_variant_id'])

row_idx = 0
while row_idx<len(df)-1:
    cur_var = df.iloc[row_idx]
    new_groups[cur_var.group] = [cur_var.mpra_variant_id]
    for next_row_idx in range(row_idx+1,len(df)):
        next_var = df.iloc[next_row_idx]
        if (cur_var.chrom==next_var.chrom 
            and get_overlap(
                        cur_var[['oligo_start','oligo_end']].values,
                        next_var[['oligo_start','oligo_end']].values)
           ):                
            new_groups[cur_var.group].append(next_var.mpra_variant_id)
        else:
            break
    row_idx = next_row_idx

In [103]:
new_groups = pd.DataFrame(list(new_groups.items()),columns=['group','mpra_variant_id']).explode('mpra_variant_id')
mpra_df = mpra_df.drop(columns='group').merge(new_groups)

In [104]:
mpra_df.to_csv(data_dir + 'preprocessing/mpra_rna.tsv', index=None, sep='\t')