In [None]:
import os
import numpy
from plotly import offline as plotly
from plotly import graph_objects
from scipy import stats
from statsmodels.stats import multitest

from capblood_seq_poc.dataset import Capblood_Seq_Dataset
from capblood_seq_poc import common as cbs
from capblood_seq_poc import viz as cbs_viz

In [None]:
# Load the dataset. This downloads it if it doesn't exist already, and loads it into memory
dataset = Capblood_Seq_Dataset(data_directory="data", pipeline_name="normalized")
dataset.load()

In [None]:
CELL_TYPES = cbs.CELL_TYPES + [None]

GENE = "EAF2"

# Whether to pool subjects into one t-test (True) or perform a test on each subject
# separately and then combine via Stouffer's method (False)
POOL_SUBJECTS = True

# Whether to normalize means within each subject - recommend doing this
# if POOL_SUBJECTS is True
NORMALIZE_WITHIN_SUBJECT = True

In [None]:
AM_colors = []
PM_colors = []
AM_means = []
PM_means = []
p_values = []
AM_mean_groups = []
PM_mean_groups = []
cell_type_label_list = []

for cell_type_index, cell_type in enumerate(CELL_TYPES):
    
    cell_type_AM_means = []
    cell_type_PM_means = []
    
    if cell_type is None:
        cell_type_label = "All Cells"
    else:
        cell_type_label = cell_type
    
    cell_type_label_list.append(cell_type_label)
    
    if not POOL_SUBJECTS:
        num_samples_per_subject = []
        p_values_per_subject = []
    
    for subject_index, subject_id in enumerate(cbs.SUBJECT_IDS):
        
        subject_AM_means = []
        subject_PM_means = []
        
        for sample in cbs.SAMPLE_NAMES:
            
            transcript_counts = dataset.get_transcript_counts(
                sample,
                cell_type=cell_type,
                subject_id=subject_id,
                normalized=True,
                genes=GENE
            )
            
            if transcript_counts is None:
                continue
            
            transcript_counts = transcript_counts.to_array()
            
            if "AM" in sample:
                subject_AM_means.append(transcript_counts.mean())
                AM_colors.append(cbs_viz.SUBJECT_ID_COLORS[subject_id])
                AM_mean_groups.append(cell_type_index)
            else:
                subject_PM_means.append(transcript_counts.mean())
                PM_colors.append(cbs_viz.SUBJECT_ID_COLORS[subject_id])
                PM_mean_groups.append(cell_type_index)
        
        subject_AM_means = numpy.array(subject_AM_means)
        subject_PM_means = numpy.array(subject_PM_means)
        
        num_samples = len(subject_AM_means) + len(subject_PM_means)
        
        if NORMALIZE_WITHIN_SUBJECT:
            mean_of_means = \
                (subject_AM_means.mean() * len(subject_PM_means) + \
                subject_PM_means.mean() * len(subject_AM_means))/num_samples
            subject_AM_means -= mean_of_means
            subject_PM_means -= mean_of_means
        cell_type_AM_means.extend(subject_AM_means)
        cell_type_PM_means.extend(subject_PM_means)
        
        if not POOL_SUBJECTS:
            z, p_value = stats.ttest_ind(subject_AM_means, subject_PM_means)
            if numpy.isnan(z):
                continue
            num_samples_per_subject.append(num_samples)
            p_values_per_subject.append(p_value)
    
    if not POOL_SUBJECTS:
        _, p_value = stats.combine_pvalues(p_values_per_subject, weights=num_samples_per_subject, method="stouffer")
    else:
        _, p_value = stats.ttest_ind(cell_type_AM_means, cell_type_PM_means)
    p_values.append(p_value)
    AM_means.extend(cell_type_AM_means)
    PM_means.extend(cell_type_PM_means)

y_max = max(numpy.abs(AM_means).max(), numpy.abs(PM_means).max()) * 1.1

AM_box_trace = graph_objects.Box(
    x=AM_mean_groups,
    y=AM_means,
    line={
        "color": cbs_viz.AM_COLOR
    },
    name="AM"
)

AM_mean_groups = numpy.array(AM_mean_groups)
PM_mean_groups = numpy.array(PM_mean_groups)

AM_mean_groups_jittered = AM_mean_groups.astype(numpy.float32).copy()
PM_mean_groups_jittered = PM_mean_groups.astype(numpy.float32).copy()

for cell_type_index in range(len(cell_type_label_list)):
    AM_mean_groups_jittered[AM_mean_groups_jittered == cell_type_index] = cell_type_index -0.35
    PM_mean_groups_jittered[PM_mean_groups_jittered == cell_type_index] = cell_type_index

AM_mean_groups_jittered += numpy.random.rand(len(AM_mean_groups_jittered))/40
PM_mean_groups_jittered += numpy.random.rand(len(PM_mean_groups_jittered))/40

AM_scatter_trace = graph_objects.Scatter(
    x=AM_mean_groups_jittered,
    y=AM_means,
    marker_color=AM_colors,
    mode="markers",
    showlegend=False,
    name="AM"
)

PM_box_trace = graph_objects.Box(
    x=PM_mean_groups,
    y=PM_means,
    line={
        "color": cbs_viz.PM_COLOR
    },
    name="PM"
)

PM_scatter_trace = graph_objects.Scatter(
    x=PM_mean_groups_jittered,
    y=PM_means,
    marker_color=PM_colors,
    mode="markers",
    showlegend=False,
    name="PM"
)

if not NORMALIZE_WITHIN_SUBJECT:
    y_min = 0
else:
    y_min = -y_max

title = "Mean expression"
if NORMALIZE_WITHIN_SUBJECT:
    title += " normalized within subject"

layout = graph_objects.Layout(
    {
        "yaxis": {
            "range": [y_min, y_max],
            "title": title,
            "exponentformat": "power"
        },
        "plot_bgcolor": "rgba(0, 0, 0, 0)",
        "width": 800,
        "title": {
            "text": "%s Time-of-Day Expression" % (GENE),
            "xanchor": "center",
            "xref": "container",
            "x": 0.5
        },
        "boxmode": "group",
        "xaxis": {
            "tickvals": list(range(len(cell_type_label_list))),
            "ticktext": cell_type_label_list
        }
    }
)

figure = graph_objects.Figure(data=[AM_box_trace, AM_scatter_trace, PM_box_trace, PM_scatter_trace], layout=layout)

for cell_type_index, cell_type_label in enumerate(cell_type_label_list):

    significance_line = graph_objects.layout.Shape(
        type="line",
        x0=cell_type_index-0.175,
        x1=cell_type_index+0.175,
        y0=y_max,
        y1=y_max,
        line=dict(
            color="Black",
            width=5
        )
    )

    significance_bracket_left = graph_objects.layout.Shape(
        type="line",
        x0=cell_type_index-0.175,
        x1=cell_type_index-0.175,
        y0=y_max,
        y1=y_max * 0.95,
        line=dict(
            color="Black",
            width=4
        )
    )

    significance_bracket_right = graph_objects.layout.Shape(
        type="line",
        x0=cell_type_index+0.175,
        x1=cell_type_index+0.175,
        y0=y_max,
        y1=y_max * 0.95,
        line=dict(
            color="Black",
            width=4
        )
    )

    figure.add_shape(significance_line)
    figure.add_shape(significance_bracket_left)
    figure.add_shape(significance_bracket_right)

    figure.add_annotation(
        graph_objects.layout.Annotation(
            text="p=%.1e" % p_values[cell_type_index],
            showarrow=False,
            yanchor="bottom",
            yref="y",
            y=y_max,
            x=cell_type_index,
            xref="x",
            xanchor="center"
        )
    )

plotly.iplot(figure)


# figure.write_image("%s_AM_vs_PM_cell_types_vs_all.svg" % (gene))