In [2]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
sns.set(style='whitegrid', font_scale=1.3, palette='Set2')
%matplotlib inline
from collections import defaultdict
from scipy.stats import entropy
import plotly.plotly as py
import plotly.graph_objs as go
import os
from pprint import pprint

In [3]:
out_path = './'

data = pd.read_csv('./enrichment.csv', index_col=0) # read in our data
metadata_cols = ['virus', 'start', 'end', 'sequence', 'strains'] # nonnumeric columns
sample_cols = [c for c in data.columns.values if c not in metadata_cols] # data columns

print data.head() # peek at the first few rows

                16418_CHIKVDay0_20ug  16418_CHIKVDay35_20ug  \
id                                                            
1                           0.664517               0.537128   
100                         0.925310               0.468594   
1001                        0.734245               0.197732   
1002                        2.405997               1.363379   
1008.1177.1346              0.640381               0.497251   

                24961_ZIKVDay28_20ug  25147_ZIKVDay28_20ug  \
id                                                           
1                           0.310872              0.357264   
100                         0.199741              0.657030   
1001                        0.000000              0.295953   
1002                        3.077846              1.176111   
1008.1177.1346              0.384816              0.503280   

                25421_ZIKVDay28_20ug  26021_ZIKVDay0_20ug  \
id                                                          
1

In [4]:
def find_all_oligos(virus, site, data=data):
    '''
    find all the integer indices 
    in the dataframe that correspond to oligos 
    containing that site
    '''
    indices = np.where((data['virus']==virus) &
                 (data['start'] <= site) & 
                 (site <= data['end']))
    return indices[0]
    
def aggregate_site(virus, site, data):
    '''Return a series of the mean value of each column for all oligos containing the input site'''
    indices = find_all_oligos(virus, site, data)
    entries = data.iloc[indices]
    agg = entries.mean(axis=0)
    return agg
        
def aggregate_virus_sites(virus, data=data, path=None):
    ''' For all sites in the viral genome in the dataset, fetch the aggregated values of all oligos containing that site'''
    if path and os.path.isfile(path):
        return pd.read_csv(path, index_col=0)

    first_site = data.loc[data['virus'] == virus]['start'].min()
    last_site = data.loc[data['virus']==virus]['end'].max()
    
    sites = range(first_site, last_site)
    aggregated_sites = { site: aggregate_site(virus, site, data) for site in sites }

    df = pd.DataFrame.from_dict(aggregated_sites, orient='index')
    df.to_csv(out_path + virus + '_sitewise_enrichment.csv')
    return df

site_maps = {virus: aggregate_virus_sites(virus, path='./%s_sitewise_enrichment.csv'%virus)
             for virus in pd.unique(data['virus'])}

In [None]:
def plot_interactive_binding_footprints(virus):
    values = site_maps[virus]
    traces = []
    
    for serum in sample_cols:
        if serum in ['input', 'beads']:
            continue
            
        vals = values[serum]
        
        trace = go.Scatter(
            x = vals.index.values,
            y = vals.values,
            mode = 'lines',
            name = serum)
        
        traces.append(trace)
    
    layout = dict(title = 'Oligos from %s'%virus,
              xaxis = dict(title = 'Genomic position'),
              yaxis = dict(title = 'Fold enrichment by sera'),
              )
    
    py.iplot(traces, filename='2018-01-08_'+virus)


for virus in flavis:
    plot_interactive_binding_footprints(virus)

In [None]:
# def plot_time_series(virus, serum, ax=None):
    
#     timepoints = [s for s in site_maps[virus].columns.values if serum in s and 'Day' in s]
#     timepoints.sort(key = lambda s: int(s.split('Day')[1].split('_')[0]))
#     timepoint_days = [int(s.split('Day')[1].split('_')[0]) for s in timepoints]
#     timepoint_values = {tp : site_maps[virus][tp] for tp in timepoints}
    
#     if ax == None:
#         fig, ax = plt.subplots(figsize=(12, 6))
    
#     if len(timepoints) != 2: # Currently only looking at *pairs* of timepoints
#         return
    
#     else:
#         baseline = timepoint_values[timepoints[0]]
#         comparison = timepoint_values[timepoints[1]]
        
        
#         difference = comparison.subtract(baseline)
#         label = serum.split('Day')[0] + '@ Day'+ str(timepoint_days[1]) + ' - ' + 'Day '+str(timepoint_days[0])
#         ax.plot(difference.index.values, difference.values, label=label)
#     return ax

# def plot_all_time_series(virus):
#     fig, ax = plt.subplots(figsize=(12,6))
    
#     time_series_samples = set([s.split('Day')[0] for s in sample_cols if 'Day' in s])

#     for serum in time_series_samples:
#         plot_time_series(virus, serum, ax)
#     plt.legend(title='Serum')
#     ax.set_title(virus)
#     ax.set_xlabel('Genome position')
#     ax.set_ylabel('Fold enrichment over input')
# #     ax.set_ylim(0, 150)
#     plt.show()

# flavis = ['DENV1', 'DENV2', 'DENV3', 'DENV4', 'CHIKV']
# for v in flavis:
#     plot_all_time_series(v)

In [None]:
# def calc_entropy(row):
#     total = row[sample_cols].sum()
# #     total = row.sum()
#     if total == 0:
#         return 0.
#     distrib = row[sample_cols].map(lambda x: float(x) / float(total))
# #     distrib = row.map(lambda x: float(x) / float(total))

#     return entropy(distrib.values)

# data['entropy'] = data.apply(calc_entropy, axis=1)
# data['max_enrichment'] = data[sample_cols].max(axis=1)

In [None]:
# def parse_time_course(df, cols):
    
#     def parse_col_header(col):
#         try:
#             sample, virusday, concentration = col.split('_')
#         except ValueError:
#             sample, virusdayconcentration = col.split('Preg')
#             sample = sample+'Preg'
#             virusday, concentration = virusdayconcentration.split('_')
            
#         virus, day = virusday.split('Day')    
#         return {'sample': sample, 'virus': virus, 'day': int(day), 'concentration': concentration}

    
#     values = df[cols]
#     new_header = { c : parse_col_header(c)['day'] for c in cols}
#     values.rename(columns=new_header, inplace=True)
#     return values

# time_courses = defaultdict(list)
# for c in sample_cols:
#     if 'Day' in c:
#         time_courses[c.split('Day')[0]].append(c)

In [None]:
# def plot_binding_footprint(virus, serum, ax=None):
    
#     values = site_maps[virus][serum]
    
#     if ax == None:
#         fig, ax = plt.subplots(figsize=(12, 6))
    
#     ax.plot(values.index.values, values.values, label=serum)
#     return ax

# def plot_all_binding_footprints(virus, serum_subset=sample_cols):
#     values = site_maps[virus]
    
#     fig, ax = plt.subplots(figsize=(9,3))
#     for serum in serum_subset:
#         plot_binding_footprint(virus, serum, ax)
#     plt.legend(title='Serum')
#     ax.set_title(virus)
#     ax.set_xlabel('Genome position')
#     ax.set_ylabel('Fold enrichment over input')
#     ax.set_ylim(0, 150)
#     plt.show()
        