# Improving the prediction of eciency of CRISPR/Cas9 guides using method combinations

General instructions
====================

The following sections should be run from start to finish, since they form the backend of this notebook:
- Prerequisites
- Utils
- Comparison tools
- Architecture

The Pipeline section is where you configure and train a the model. Choose between creating a new experiment or loading an existing one.

Next follows the Comparisons section. It provides various plots and tables for evaluating the model and comparing it to existing tools.

Finally, the Feature importance section produces SHAP charts for the current model.

# Prerequisites

In [None]:
#@title Setup { form-width: "150px" }

# Some warning and errors will be displayed, however they do not affect the
# functionality of the notebook.
# After running this cell, the runtime should be restarted.

!pip freeze | grep == | sed 's/==/>=/' > constraints.txt

!pip install -c constraints.txt spacecutter
!pip install -c constraints.txt -U skorch
!pip install -c constraints.txt https://github.com/ceshine/shap/archive/master.zip


In [None]:
#@title Constants { form-width: "150px" }

import os

# Files are loaded into the notebook through Google Drive.
GDRIVE_PATH = "/content/drive"
# Remember to adjust this path according to the location of the main work
# directory inside the drive.
WORK_DIR = os.path.join(GDRIVE_PATH, "My Drive/some_path/work_dir")
DATA_DIR = os.path.join(WORK_DIR, "data")

# Remember to adjust these paths to match the files according to their location
# inside the main work directory. If the default paths were used, these should
# work without modification.
CHARI_DATA_PATH = os.path.join(DATA_DIR, "chari1_features.csv")
GENOME_DATA_PATH = os.path.join(DATA_DIR, "genome1_features.csv")
XU_DATA_PATH = os.path.join(DATA_DIR, "xu1_features.csv")       # Ribosomal
NR_XU_DATA_PATH = os.path.join(DATA_DIR, "xu2_features.csv")    # Nonribosomal
DOENCH_DATA_PATH = os.path.join(DATA_DIR, "doench1_features.csv")
MIXTURE_DATA_PATH = os.path.join(DATA_DIR, "mixture", "mixture.pkl")
# Only required for performing DeepCRISPR comparisons.
# If these are required, follow the instructions in the README.txt inside the
# data/deepcrispr directory to generate the necessary output files.
DEEPCRISPR_PATH = os.path.join(DATA_DIR, "deepcrispr")

# Common training, validation and testing dataset sizes
SIZES_6K = (6000, 1000, 1000)
SIZES_12K = (12000, 1500, 1500)

# The number of features in the feature representation. Remember to change this
# if the feature representation was changed from the original.
NUM_FEATURES = 18

os.environ["WORK_DIR"] = WORK_DIR



In [None]:
#@title Features { form-width: "150px" }

FEATURES = [
    "CHOPCHOP: GC content",                     # 0
    "CHOPCHOP: self\ncomplementarity",          # 1
    "CHOPCHOP: Xu 2015",                        # 2
    "CHOPCHOP: Doench 2014",                    # 3
    "CHOPCHOP: Moreno-\nMateos 2015",           # 4
    "CHOPCHOP: G20",                            # 5
    "FlashFry: Doench 2014",                    # 6
    "FlashFry: Moreno-\nMateos 2015",           # 7
    "mm10db: AT content",                       # 8
    "mm10db: multiple\nmatches",                # 9
    "mm10db: secondary\nstrcuture or energy",   # 10
    "mm10db: reverse primer",                   # 11
    "mm10db: TTTT",                             # 12
    "mm10db: off-target",                       # 13
    "mm10db: accepted",                         # 14
    "PhytoCRISP-Ex",                            # 15
    "sgRNA Scorer 2.0",                         # 16
    "SSC",                                      # 17
]

COL_TO_SCORE_TOOL = {i: FEATURES[i] for i in [2, 3, 4, 6, 7, 16, 17]}
COL_TO_DECISION_TOOL = {i: FEATURES[i] for i in [5, 14, 15]}


In [None]:
#@title Mount Google Drive { form-width: "150px" }

from google.colab import drive
drive.mount(GDRIVE_PATH)


In [None]:
#@title Imports { form-width: "150px" }

import gc
import sys
import time
import math
import copy
import shap
import numpy as np
import pickle
import pandas as pd
import random
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.colors as clr

from random import randint
from sklearn import metrics
from datetime import datetime
from itertools import product
from IPython.display import display, HTML
from spacecutter.losses import CumulativeLinkLoss
from spacecutter.models import OrdinalLogisticModel
from spacecutter.callbacks import AscensionCallback

import torch
import torch.nn.functional as F
from torch import nn
from torch.nn import BCEWithLogitsLoss
from torch.autograd import Variable

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# Utils

## Data utils

In [None]:
#@title DataPoint class { form-width: "150px" }

class DataPoint(object):
    """Represents a single datapoint.

    Attributes:
        target: A string of the target sequence.
        label: The GenomeCRISPR effect label of the target.
        features: The feature representation of the target.
    """
    
    def __init__(self, target, label, features):
        self.target = target
        self.label = label
        self.features = features


In [None]:
#@title Read data { form-width: "150px" }

def from_csv_line(line):
    """Extracts a list of values from the string of a CSV line"""
    return line.strip().split(',')

def get_data(path, req_features=[]):
    """Reads data from a dataset CSV file into a list of DataPoints.

    Args:
        path: The full path of the dataset.
        req_features: A list of indices of features to include in the output.
            If the list is empty (which is the default) - all features are
            included.
    
    Returns:
        A list of DataPoint instances, each matching a single datapoint in the
        dataset.
    """
    datapoints = []
    with open(path, 'r') as fd:
        for line in fd:
            values = from_csv_line(line)
            target = values[0]
            label  = int(values[1])
            features = [float(score) for score in values[2:]]
            if req_features:
                features = [features[i] for i in req_features]
            datapoints.append(DataPoint(target, label, features))
    return datapoints


In [None]:
#@title Parse data { form-width: "150px" }

"""Utilities for extracting information from a list of DataPoint instances."""

def get_targets(datapoints):
    return [dp.target for dp in datapoints]

def get_labels(datapoints):
    return np.array(
        [dp.label for dp in datapoints], dtype=np.float32
    ).reshape(-1, 1)

def get_features(datapoints, num_features):
    return np.array(
        [dp.features for dp in datapoints], dtype=np.float32
    ).reshape(-1, num_features)

def labels_to_indices(labels):
    """Shifts the values of the labels such that the lowest one is 0."""
    min_label = min(labels)
    return [label - min_label for label in labels]

def get_labels_as_indices(datapoints):
    """Extracts the labels and shifts them such that the lowest one is 0."""
    labels = [dp.label for dp in datapoints]
    label_indices = labels_to_indices(labels)
    return np.array(label_indices, dtype=np.int)


In [None]:
#@title DeepCRISPR configuration { form-width: "150px" }

# Versions
REGRESSIONS = "regression"
CLASSIFICATION = "classification"
# Available variants
VARIANTS = {
    REGRESSIONS: ["seq_only", "epigenetics"],
    CLASSIFICATION: ["epigenetics"],
}

def get_deepcrispr_path(version, variant, data_name):
    """Returns the path of the desired DeepCRISPR output.
    
    Args:
        version: Either REGRESSION or CLASSIFICATION.
        variant: One of the variants listed in VARIANTS for the chosen version.
        data_name: The name of the dataset required. One of: "doench",
            "mixture", "xu", "nr_xu".
    """
    return os.path.join(DEEPCRISPR_PATH, version, variant,
                        f"{data_name}_results.pkl")


In [None]:
#@title Mixture dataset { form-width: "150px" }

# If the feature representation is different from the original one presented in
# the paper, the Mixture.pkl needs to be updated to include the new
# representations.

def get_mixture_targets(data_dir):
    targets_path = os.path.join(data_dir, "mixture", "mixture_targets.txt")
    with open(targets_path, 'r') as mixture_targets:
        return [target.strip() for target in mixture_targets.readlines()]

def get_datapoints(dataset, targets):
    out = []
    for dp in dataset:
        if not dp.target in targets: continue
        out.append(dp)
        targets.remove(dp.target)
    return out

def get_mixture_dataset(datasets):
    targets = get_mixture_targets(DATA_DIR)
    targets_set = set(targets)
    mixture = []
    for dataset in datasets:
        mixture += get_datapoints(dataset, targets_set)

    ordered_mixture = []
    for target in targets:
        for dp in mixture:
            if dp.target != target: continue
            ordered_mixture.append(dp)
            break
    return ordered_mixture

def create_mixture_dataset(chari_data, xu_data, doench_data, genome_data):
    datasets = [chari_data, xu_data, doench_data, genome_data]
    mixture_dataset = get_mixture_dataset(datasets)
    with open(MIXTURE_DATA_PATH, 'wb') as fd:
            pickle.dump(mixture_dataset, fd)
    

## Progress tracking utils

In [None]:
#@title Training progress { form-width: "150px" }

def to_mins_secs(time_in_secs):
    """Converts a time preiod in seconds to a minutes and seconds string"""
    mins = int(time_in_secs) // 60
    secs = int(time_in_secs) % 60
    return f"{mins}m {secs}s"


def time_estimate(start_time, completed_fraction):
    """Computes the elapsed time and the remaining time for the task.
    
    Args:
        start_time: A number represting the start time of the task.
        completed_fraction: The fraction of the task already completed.
    
    Returns:
        A string with the elapsed and remaining time.
    """
    now = time.time()
    elapsed = now - start_time
    total_estiamte = elapsed / completed_fraction
    remainig = total_estiamte - elapsed
    return f"{to_mins_secs(elapsed)}\t({to_mins_secs(remainig)} remaining)"


def report_progress(experiment, start_time, epoch, epochs,
                    tot_loss, cur_loss, print_every):
    """Prints a training progress report string.

    Args:
        experiment: An Experiment instance to produce the report for.
        start_time: The start time of the training.
        epoch: The number of the latest epoch completed.
        epochs: The total number of epochs in the training session.
        tot_loss: The toal loss since the last report.
        cur_loss: The loss incurred in the latest epoch.
        print_every: The number of epochs between reports.
    """
    avg_loss = tot_loss / print_every

    experiment.training.update_progress(
        experiment.model.epoch_counter, cur_loss)
    validation_loss, test_loss = experiment.test()

    elapsed_time = time_estimate(start_time, epoch/epochs)
    progress = (epoch/epochs)*100
    progress_string = f"({epoch} {progress}%)".ljust(15)
    print(f"{elapsed_time}\t{progress_string} {avg_loss:.4f}")


In [None]:
#@title DataFrame utils { form-width: "150px" }

TOOL_COL = "Tool"


def pretty_print(df):
    """Prints a DataFrame as a table."""
    return display(HTML(df.to_html(justify="left").replace("\\n","<br>")))


def _create_report(stats, tools, format_str=None):
    """Produces a tool-comparison report DataFrame.
    
    Args:
        stats: A dictionary mapping dataset name to a list of the results of all
            the tools, ordered to match the ordering of tools.
        tools: A list of the names of the compared tools.
        format_str: A formatting string used to format the results. If not
            provided, no formatting is applied.
    
    Returns:
        A dataframe where in the first column are the names of the tools, and
        each of the following columns corresponds to one of the datasets, and
        lists the results of the tools for that dataset.
    """
    df = pd.DataFrame({TOOL_COL: tools})
    for data_name, results in stats.items():
        if format_str:
            results = [format_str % r for r in results]
        df = pd.concat([df, pd.DataFrame({data_name: results})], axis=1)
    return df
    

In [None]:
#@title Plot losses { form-width: "150px" }

def to_colour(c1, c2, c3):
    """Converts an RGB in 256-base to a 0-1 colour tuple."""
    return (c1/255.0, c2/255.0, c3/255.0)


def prep_plot():
    plt.clf()
    plt.style.use('seaborn-colorblind')


def plot_losses(data_dict, ylabel="Loss", save_name=""):
    """Plots loss-per-epoch

    Args:
        data_dict: A dictionary mapping dataset name to a corresponding Data
            instance.
        ylable: The title of the y axis ("Loss" by default).
        save_name: If provided, the plot is saved under this name in the
            WORK_DIR.
    """
    prep_plot()
    fig = plt.figure()
    for name, data in data_dict.items():
        plt.plot(data.epochs, data.losses, label=f"{name} dataset",
                 marker='o', alpha=0.5)
    plt.legend(loc='best')
    plt.xlabel("Epochs")
    plt.ylabel(ylabel)
    plt.show()

    if save_name:
        save_path = os.path.join(WORK_DIR, save_name)
        fig.savefig(save_path, bbox_inches="tight", dpi=100)


def plot_standard_losses(experiment, train=True, validation=True, test=False):
    """Plots loss-per-epoch for the training, validation and tes sets.

    Args:
        experiment: The Experiment instance to plot for.
        train: A boolean indicating whether to plot for the training set.
        validation: A boolean indicating whether to plot for the validation set.
        test: A boolean indicating whether to plot for the test set.
    """
    datasets = {}
    if train: datasets["Training"] = experiment.training
    if validation: datasets["Validation"] = experiment.validation
    if test: datasets["Test"] = experiment.testing
    plot_losses(datasets)



## Feature importance utils

In [None]:
#@title Colours { form-width: "150px" }

SHAP_CMAP = clr.ListedColormap(sns.color_palette("RdYlBu", 256))
SHAP_COLOUR = sns.color_palette("colorblind", 10)[0]


In [None]:
#@title Summary plot { form-width: "150px" }

"""Adapted from the shap library by CeShine Lee.
https://github.com/ceshine/shap

Main changes: the addition of the cmap option, and removing the feature names
from  the bar plot.
"""

from scipy.stats import gaussian_kde
from shap.plots import labels
from shap.plots import colors

pl = plt

def summary_plot(shap_values, features=None, feature_names=None,
                 max_display=None, plot_type="dot", color=None,
                 axis_color="#333333", title=None, alpha=1, show=True,
                 sort=True, color_bar=True, auto_size_plot=True,
                 layered_violin_max_num_bins=20, class_names=None,
                 cmap=colors.red_blue):
    """Create a SHAP summary plot, colored by feature values when they are provided.
    Parameters
    ----------
    shap_values : numpy.array
        Matrix of SHAP values (# samples x # features)
    features : numpy.array or pandas.DataFrame or list
        Matrix of feature values (# samples x # features) or a feature_names list as shorthand
    feature_names : list
        Names of the features (length # features)
    max_display : int
        How many top features to include in the plot (default is 20, or 7 for interaction plots)
    plot_type : "dot" (default) or "violin"
        What type of summary plot to produce
    cmap : matplotlib.colors.Colormap
        A colourmap for the colour bar.
    """

    fig = pl.figure()

    multi_class = False
    if isinstance(shap_values, list):
        multi_class = True
        plot_type = "bar" # only type supported for now
    else:
        assert len(shap_values.shape) != 1, "Summary plots need a matrix of shap_values, not a vector."

    # default color:
    if color is None:
        color = "coolwarm" if plot_type == 'layered_violin' else "#1E88E5" #"#ff0052"

    # convert from a DataFrame or other types
    if str(type(features)) == "<class 'pandas.core.frame.DataFrame'>":
        if feature_names is None:
            feature_names = features.columns
        features = features.values
    elif isinstance(features, list):
        if feature_names is None:
            feature_names = features
        features = None
    elif (features is not None) and len(features.shape) == 1 and feature_names is None:
        feature_names = features
        features = None

    num_features = (shap_values[0].shape[1] if multi_class else shap_values.shape[1])

    if feature_names is None:
        feature_names = np.array([labels['FEATURE'] % str(i) for i in range(num_features)])

    # plotting SHAP interaction values
    if not multi_class and len(shap_values.shape) == 3:
        if max_display is None:
            max_display = 7
        else:
            max_display = min(len(feature_names), max_display)

        sort_inds = np.argsort(-np.abs(shap_values.sum(1)).sum(0))

        # get plotting limits
        delta = 1.0 / (shap_values.shape[1] ** 2)
        slow = np.nanpercentile(shap_values, delta)
        shigh = np.nanpercentile(shap_values, 100 - delta)
        v = max(abs(slow), abs(shigh))
        slow = -v
        shigh = v

        pl.figure(figsize=(1.5 * max_display + 1, 0.8 * max_display + 1))
        pl.subplot(1, max_display, 1)
        proj_shap_values = shap_values[:, sort_inds[0], sort_inds]
        proj_shap_values[:, 1:] *= 2  # because off diag effects are split in half
        summary_plot(
            proj_shap_values, features[:, sort_inds] if features is not None else None,
            feature_names=feature_names[sort_inds],
            sort=False, show=False, color_bar=False,
            auto_size_plot=False,
            max_display=max_display
        )
        pl.xlim((slow, shigh))
        pl.xlabel("")
        title_length_limit = 11
        pl.title(shorten_text(feature_names[sort_inds[0]], title_length_limit))
        for i in range(1, min(len(sort_inds), max_display)):
            ind = sort_inds[i]
            pl.subplot(1, max_display, i + 1)
            proj_shap_values = shap_values[:, ind, sort_inds]
            proj_shap_values *= 2
            proj_shap_values[:, i] /= 2  # because only off diag effects are split in half
            summary_plot(
                proj_shap_values, features[:, sort_inds] if features is not None else None,
                sort=False,
                feature_names=["" for i in range(len(feature_names))],
                show=False,
                color_bar=False,
                auto_size_plot=False,
                max_display=max_display
            )
            pl.xlim((slow, shigh))
            pl.xlabel("")
            if i == min(len(sort_inds), max_display) // 2:
                pl.xlabel(labels['INTERACTION_VALUE'])
            pl.title(shorten_text(feature_names[ind], title_length_limit))
        pl.tight_layout(pad=0, w_pad=0, h_pad=0.0)
        pl.subplots_adjust(hspace=0, wspace=0.1)
        if show:
            pl.show()
        return

    if max_display is None:
        max_display = 20

    if sort:
        # order features by the sum of their effect magnitudes
        if multi_class:
            feature_order = np.argsort(np.sum(np.mean(np.abs(shap_values), axis=0), axis=0))
        else:
            feature_order = np.argsort(np.sum(np.abs(shap_values), axis=0))
        feature_order = feature_order[-min(max_display, len(feature_order)):]
    else:
        feature_order = np.flip(np.arange(min(max_display, num_features)), 0)

    row_height = 0.4
    if auto_size_plot:
        pl.gcf().set_size_inches(8, len(feature_order) * row_height + 5.5)
    pl.axvline(x=0, color="#999999", zorder=-1)

    if plot_type == "dot":
        for pos, i in enumerate(feature_order):
            pl.axhline(y=pos, color="#cccccc", lw=0.5, dashes=(1, 5), zorder=-1)
            shaps = shap_values[:, i]
            values = None if features is None else features[:, i]
            inds = np.arange(len(shaps))
            np.random.shuffle(inds)
            if values is not None:
                values = values[inds]
            shaps = shaps[inds]
            colored_feature = True
            try:
                values = np.array(values, dtype=np.float64)  # make sure this can be numeric
            except:
                colored_feature = False
            N = len(shaps)
            # hspacing = (np.max(shaps) - np.min(shaps)) / 200
            # curr_bin = []
            nbins = 100
            quant = np.round(nbins * (shaps - np.min(shaps)) / (np.max(shaps) - np.min(shaps) + 1e-8))
            inds = np.argsort(quant + np.random.randn(N) * 1e-6)
            layer = 0
            last_bin = -1
            ys = np.zeros(N)
            for ind in inds:
                if quant[ind] != last_bin:
                    layer = 0
                ys[ind] = np.ceil(layer / 2) * ((layer % 2) * 2 - 1)
                layer += 1
                last_bin = quant[ind]
            ys *= 0.9 * (row_height / np.max(ys + 1))

            if features is not None and colored_feature:
                # trim the color range, but prevent the color range from collapsing
                vmin = np.nanpercentile(values, 5)
                vmax = np.nanpercentile(values, 95)
                if vmin == vmax:
                    vmin = np.nanpercentile(values, 1)
                    vmax = np.nanpercentile(values, 99)
                    if vmin == vmax:
                        vmin = np.min(values)
                        vmax = np.max(values)

                assert features.shape[0] == len(shaps), "Feature and SHAP matrices must have the same number of rows!"
                nan_mask = np.isnan(values)
                pl.scatter(shaps[nan_mask], pos + ys[nan_mask], color="#777777", vmin=vmin,
                           vmax=vmax, s=16, alpha=alpha, linewidth=0,
                           zorder=3, rasterized=len(shaps) > 500)
                pl.scatter(shaps[np.invert(nan_mask)], pos + ys[np.invert(nan_mask)],
                           cmap=pl.get_cmap(cmap), vmin=vmin, vmax=vmax, s=16,
                           c=values[np.invert(nan_mask)], alpha=alpha, linewidth=0,
                           zorder=3, rasterized=len(shaps) > 500)
            else:

                pl.scatter(shaps, pos + ys, s=16, alpha=alpha, linewidth=0, zorder=3,
                           color=color if colored_feature else "#777777", rasterized=len(shaps) > 500)

    elif plot_type == "violin":
        for pos, i in enumerate(feature_order):
            pl.axhline(y=pos, color="#cccccc", lw=0.5, dashes=(1, 5), zorder=-1)

        if features is not None:
            global_low = np.nanpercentile(shap_values[:, :len(feature_names)].flatten(), 1)
            global_high = np.nanpercentile(shap_values[:, :len(feature_names)].flatten(), 99)
            for pos, i in enumerate(feature_order):
                shaps = shap_values[:, i]
                shap_min, shap_max = np.min(shaps), np.max(shaps)
                rng = shap_max - shap_min
                xs = np.linspace(np.min(shaps) - rng * 0.2, np.max(shaps) + rng * 0.2, 100)
                if np.std(shaps) < (global_high - global_low) / 100:
                    ds = gaussian_kde(shaps + np.random.randn(len(shaps)) * (global_high - global_low) / 100)(xs)
                else:
                    ds = gaussian_kde(shaps)(xs)
                ds /= np.max(ds) * 3

                values = features[:, i]
                window_size = max(10, len(values) // 20)
                smooth_values = np.zeros(len(xs) - 1)
                sort_inds = np.argsort(shaps)
                trailing_pos = 0
                leading_pos = 0
                running_sum = 0
                back_fill = 0
                for j in range(len(xs) - 1):

                    while leading_pos < len(shaps) and xs[j] >= shaps[sort_inds[leading_pos]]:
                        running_sum += values[sort_inds[leading_pos]]
                        leading_pos += 1
                        if leading_pos - trailing_pos > 20:
                            running_sum -= values[sort_inds[trailing_pos]]
                            trailing_pos += 1
                    if leading_pos - trailing_pos > 0:
                        smooth_values[j] = running_sum / (leading_pos - trailing_pos)
                        for k in range(back_fill):
                            smooth_values[j - k - 1] = smooth_values[j]
                    else:
                        back_fill += 1

                vmin = np.nanpercentile(values, 5)
                vmax = np.nanpercentile(values, 95)
                if vmin == vmax:
                    vmin = np.nanpercentile(values, 1)
                    vmax = np.nanpercentile(values, 99)
                    if vmin == vmax:
                        vmin = np.min(values)
                        vmax = np.max(values)
                pl.scatter(shaps, np.ones(shap_values.shape[0]) * pos, s=9, cmap=colors.red_blue_solid, vmin=vmin, vmax=vmax,
                           c=values, alpha=alpha, linewidth=0, zorder=1)
                # smooth_values -= nxp.nanpercentile(smooth_values, 5)
                # smooth_values /= np.nanpercentile(smooth_values, 95)
                smooth_values -= vmin
                if vmax - vmin > 0:
                    smooth_values /= vmax - vmin
                for i in range(len(xs) - 1):
                    if ds[i] > 0.05 or ds[i + 1] > 0.05:
                        pl.fill_between([xs[i], xs[i + 1]], [pos + ds[i], pos + ds[i + 1]],
                                        [pos - ds[i], pos - ds[i + 1]], color=colors.red_blue_solid(smooth_values[i]),
                                        zorder=2)

        else:
            parts = pl.violinplot(shap_values[:, feature_order], range(len(feature_order)), points=200, vert=False,
                                  widths=0.7,
                                  showmeans=False, showextrema=False, showmedians=False)

            for pc in parts['bodies']:
                pc.set_facecolor(color)
                pc.set_edgecolor('none')
                pc.set_alpha(alpha)

    elif plot_type == "layered_violin":  # courtesy of @kodonnell
        num_x_points = 200
        bins = np.linspace(0, features.shape[0], layered_violin_max_num_bins + 1).round(0).astype(
            'int')  # the indices of the feature data corresponding to each bin
        shap_min, shap_max = np.min(shap_values), np.max(shap_values)
        x_points = np.linspace(shap_min, shap_max, num_x_points)

        # loop through each feature and plot:
        for pos, ind in enumerate(feature_order):
            # decide how to handle: if #unique < layered_violin_max_num_bins then split by unique value, otherwise use bins/percentiles.
            # to keep simpler code, in the case of uniques, we just adjust the bins to align with the unique counts.
            feature = features[:, ind]
            unique, counts = np.unique(feature, return_counts=True)
            if unique.shape[0] <= layered_violin_max_num_bins:
                order = np.argsort(unique)
                thesebins = np.cumsum(counts[order])
                thesebins = np.insert(thesebins, 0, 0)
            else:
                thesebins = bins
            nbins = thesebins.shape[0] - 1
            # order the feature data so we can apply percentiling
            order = np.argsort(feature)
            # x axis is located at y0 = pos, with pos being there for offset
            y0 = np.ones(num_x_points) * pos
            # calculate kdes:
            ys = np.zeros((nbins, num_x_points))
            for i in range(nbins):
                # get shap values in this bin:
                shaps = shap_values[order[thesebins[i]:thesebins[i + 1]], ind]
                # if there's only one element, then we can't
                if shaps.shape[0] == 1:
                    warnings.warn(
                        "not enough data in bin #%d for feature %s, so it'll be ignored. Try increasing the number of records to plot."
                        % (i, feature_names[ind]))
                    # to ignore it, just set it to the previous y-values (so the area between them will be zero). Not ys is already 0, so there's
                    # nothing to do if i == 0
                    if i > 0:
                        ys[i, :] = ys[i - 1, :]
                    continue
                # save kde of them: note that we add a tiny bit of gaussian noise to avoid singular matrix errors
                ys[i, :] = gaussian_kde(shaps + np.random.normal(loc=0, scale=0.001, size=shaps.shape[0]))(x_points)
                # scale it up so that the 'size' of each y represents the size of the bin. For continuous data this will
                # do nothing, but when we've gone with the unqique option, this will matter - e.g. if 99% are male and 1%
                # female, we want the 1% to appear a lot smaller.
                size = thesebins[i + 1] - thesebins[i]
                bin_size_if_even = features.shape[0] / nbins
                relative_bin_size = size / bin_size_if_even
                ys[i, :] *= relative_bin_size
            # now plot 'em. We don't plot the individual strips, as this can leave whitespace between them.
            # instead, we plot the full kde, then remove outer strip and plot over it, etc., to ensure no
            # whitespace
            ys = np.cumsum(ys, axis=0)
            width = 0.8
            scale = ys.max() * 2 / width  # 2 is here as we plot both sides of x axis
            for i in range(nbins - 1, -1, -1):
                y = ys[i, :] / scale
                c = pl.get_cmap(color)(i / (
                        nbins - 1)) if color in pl.cm.datad else color  # if color is a cmap, use it, otherwise use a color
                pl.fill_between(x_points, pos - y, pos + y, facecolor=c)
        pl.xlim(shap_min, shap_max)

    elif not multi_class and plot_type == "bar":
        feature_inds = feature_order[:max_display]
        y_pos = np.arange(len(feature_inds))
        global_shap_values = np.abs(shap_values).mean(0)
        pl.barh(y_pos, global_shap_values[feature_inds], 0.7, align='center', color=color)
        pl.yticks(y_pos, fontsize=13)
        pl.gca().set_yticklabels([feature_names[i] for i in feature_inds])

    elif multi_class and plot_type == "bar":
        if class_names is None:
            class_names = ["Class "+str(i) for i in range(len(shap_values))]
        feature_inds = feature_order[:max_display]
        y_pos = np.arange(len(feature_inds))
        left_pos = np.zeros(len(feature_inds))

        class_inds = np.argsort([-np.abs(shap_values[i]).mean() for i in range(len(shap_values))])
        for i,ind in enumerate(class_inds):
            global_shap_values = np.abs(shap_values[ind]).mean(0)
            pl.barh(
                y_pos, global_shap_values[feature_inds], 0.7, left=left_pos, align='center',
                color=colors.default_blue_colors[min(i, len(colors.default_blue_colors)-1)], label=class_names[ind]
            )
            left_pos += global_shap_values[feature_inds]
        pl.yticks(y_pos, fontsize=13)
        pl.gca().set_yticklabels([feature_names[i] for i in feature_inds])
        pl.legend(frameon=False, fontsize=12)

    # draw the color bar
    if color_bar and features is not None and plot_type != "bar" and \
            (plot_type != "layered_violin" or color in pl.cm.datad):
        import matplotlib.cm as cm
        m = cm.ScalarMappable(cmap=cmap if plot_type != "layered_violin" else pl.get_cmap(color))
        m.set_array([0, 1])
        cb = pl.colorbar(m, ticks=[0, 1], aspect=1000)
        cb.set_ticklabels([labels['FEATURE_VALUE_LOW'], labels['FEATURE_VALUE_HIGH']])
        cb.set_label(labels['FEATURE_VALUE'], size=12, labelpad=0)
        cb.ax.tick_params(labelsize=11, length=0)
        cb.set_alpha(1)
        cb.outline.set_visible(False)
        bbox = cb.ax.get_window_extent().transformed(pl.gcf().dpi_scale_trans.inverted())
        cb.ax.set_aspect((bbox.height - 0.9) * 20)
        # cb.draw_all()

    pl.gca().xaxis.set_ticks_position('bottom')
    pl.gca().yaxis.set_ticks_position('none')
    pl.gca().spines['right'].set_visible(False)
    pl.gca().spines['top'].set_visible(False)
    pl.gca().spines['left'].set_visible(False)
    pl.gca().tick_params(color=axis_color, labelcolor=axis_color)
    pl.yticks(range(len(feature_order)), [feature_names[i] for i in feature_order], fontsize=13)
    if plot_type != "bar":
        pl.gca().tick_params('y', length=20, width=0.5, which='major')
    pl.gca().tick_params('x', labelsize=11)
    pl.ylim(-1, len(feature_order))
    if plot_type == "bar":
        pl.gca().axes.get_yaxis().set_visible(False)
        pl.xlabel("Mean |SHAP| value", fontsize=13)
    else:
        pl.xlabel("SHAP value", fontsize=13)
    if show:
        pl.show()
    return fig

def shorten_text(text, length_limit):
    if len(text) > length_limit:
        return text[:length_limit - 3] + "..."
    else:
        return text

# Comparison tools

## General comparison utils

In [None]:
#@title Functions { form-width: "150px" }

def to_percent(fraction):
    """Converts a fraction to a percentage string."""
    return f"{100*fraction:.2f}%"

    
def get_true_ranking(data):
    """Produces a pairwise ranking dictionary from the data.

    Args:
        data: The Data instance to rank.
    
    Returns:
        A dictionary mapping pairs of indices matching to position of two guides
        in the dataset, to booleans, where the boolean indicates if the first
        guide in the pair is better than the second, according to the labels in
        the data. Guides that have the same label are not paired.
        
        For example:
        {
            (1, 4): True
            (2, 5): False
        }
        means guide number 1 is better than guide number 4, and that guide
        number 5 is better than guide number 2.
    """
    pairs = {}
    for i1 in range(len(data.labels)):
        for i2 in range(i1+1, len(data.labels)):
            label1 = data.labels[i1].item()
            label2 = data.labels[i2].item()
            if label1 == label2: continue
            # A lower label means better efficiency in negative screenings.
            is_first_better = label1 < label2
            pairs[(i1, i2)] = is_first_better
    return pairs


In [None]:
#@title Comparison container { form-width: "150px" }

class Comparison(object):
    """Holds the information required for making comparisons.

    Attributes:
        datasets: A dictionary mapping dataset names to Data instances.
        threshold: Guides with labels below the threshold are considered to be
            efficient, the rest are inefficient.
        true_rankings: A dictionary mapping corresponding dataset names to their
            true pairwise rankings, in the format described for the return value
            of get_true_ranking.
        efficients: A dictionary mapping corresponding dataset names to their
            list of indices of efficient guides.
    """
    def __init__(self, datasets, threshold=-1):
        """Initialises a Comparison container.

        Args:
            datasets: A dictionary mapping dataset names to Data instances.
            threshold: Guides with labels below the threshold are considered
                to be efficient, the rest are inefficient. (by defulat, the
                threshold is -1).
        """
        self.datasets = datasets
        self.threshold = threshold

        self.true_rankings = {}
        self.efficients = {}
        for data_name, data in datasets.items():
            self.true_rankings[data_name] = get_true_ranking(data)
            self.efficients[data_name] = data.get_efficient(threshold)


## Compare pairwise ranking

In [None]:
#@title Functions { form-width: "150px" }


def ranking_from_scores(data, score_col, true_ranking):
    """Produces the ranking precision of a specific tool.

    Args:
        data: The Data instance to rank.
        score_col: The index of the tool's score within the feature
            representaion of the guides.
        true_ranking: The true pairwise ranking, in the format described for the
            return value of get_true_ranking.
    
    Returns:
        The percentage of pairs the tool got right out of all the pairs in the
        true_ranking (as a string).
    """
    correct = 0

    for pair, is_first_better in true_ranking.items():
        i1, i2 = pair
        score1 = data.features[i1][score_col]
        score2 = data.features[i2][score_col]
        if (score1 > score2) == is_first_better:
            correct += 1
    
    return to_percent(float(correct) / len(true_ranking))


def ranking_from_predictions(data, model, true_ranking, predictions=None):
    """Produce the ranking precision of a model.

    Args:
        data: The Data instance to rank.
        model: A model which inhertis from BaseModel.
        true_ranking: The true pairwise ranking, in the format described for the
            return value of get_true_ranking.
        predictions: If provided, these are considered to be the scores
            predicted by the model for the guides in the data. Otherwise, the
            predictions are produced using the model provided.

    Returns:
        The percentage of pairs the model got right out of all the pairs in the
        true_ranking (as a string).
    """
    correct = 0
    if predictions is None:
        predictions = model.get_processed_predictions(data)

    for pair, is_first_better in true_ranking.items():
        i1, i2 = pair
        label1 = predictions[i1].item()
        label2 = predictions[i2].item()
        # The models are trained to give lower scores to better guides.
        predicted_first_better = label1 < label2
        if predicted_first_better == is_first_better:
            correct += 1
    
    return to_percent(float(correct) / len(true_ranking))


def ranking_from_majority(data, true_ranking):
    """Produces the ranking precision of the Majority Vote method.

    Args:
        data: The Data instance to rank.
        true_ranking: The true pairwise ranking, in the format described for the
            return value of get_true_ranking.

    Returns:
        The percentage of pairs the Majority Vote got right out of all the pairs
        in the true_ranking (as a string).
        The Majority Vote is composed of all the scoring and decision tools.
    """
    correct = 0
    score_cols = list(COL_TO_SCORE_TOOL.keys())
    decision_cols = list(COL_TO_DECISION_TOOL.keys())
    decision_cols = list(set(decision_cols) - set(score_cols))

    decisions = []
    for col in decision_cols:
        decision_function = COL_TO_DECISION_FUNCTION[col]
        decisions.append(decision_function(data, col))

    for pair, is_first_better in true_ranking.items():
        i1, i2 = pair
        votes1, votes2 = 0, 0
        for col in score_cols:
            score1 = data.features[i1][col]
            score2 = data.features[i2][col]
            if score1 >= score2: votes1 += 1
            if score2 >= score1: votes2 += 1
        for decision in decisions:
            score1 = decision[i1]
            score2 = decision[i2]
            if score1 >= score2: votes1 += 1
            if score2 >= score1: votes2 += 1

        predicted_first_better =  votes1 > votes2
        if predicted_first_better == is_first_better:
            correct += 1

    return to_percent(float(correct) / len(true_ranking))


def ranking_from_deepcrispr(data, true_ranking, results_path):
    """Produces the ranking precision of DeepCRISPR.

    Args:
        data: The Data instance to rank.
        true_ranking: The true pairwise ranking, in the format described for the
            return value of get_true_ranking.
        results_path: The path of the DeepCRISPR results file.

    Returns:
        The percentage of pairs DeepCRISPR got right out of all the pairs in
        the true_ranking (as a string).
    """
    with open(results_path, 'rb') as fd:
        deepcrispr_cores = pickle.load(fd)

    deepcrispr_data = Data(num_features=1)
    deepcrispr_data.targets = data.targets
    deepcrispr_data.labels = data.labels
    for target in data.targets:
        score = deepcrispr_cores[target]
        deepcrispr_data.features.append([score])

    return ranking_from_scores(deepcrispr_data, 0, true_ranking)


In [None]:
#@title Report { form-width: "150px" }

def tools_ranking_report(comparison):
    """Produces the ranking precision report of all the scoring tools.

    Args:
        comparison: A Comparison instance for the desired datasets.

    Returns:
        A DataFrame with the precision report.
    """
    rankings = {}
    for data_name, data in comparison.datasets.items():
        rankings[data_name] = []
        for col in COL_TO_SCORE_TOOL.keys():
            precision = ranking_from_scores(
                data, col, comparison.true_rankings[data_name])
            rankings[data_name].append(precision)

    return _create_report(rankings, list(COL_TO_SCORE_TOOL.values()))


def majority_ranking_report(comparison):
    """Produces the ranking precision report of the Majority Vote.

    Args:
        comparison: A Comparison instance for the desired datasets.

    Returns:
        A DataFrame with the precision report.
    """
    rankings = {}
    for data_name, data in comparison.datasets.items():
        rankings[data_name] = [
            ranking_from_majority(data, comparison.true_rankings[data_name])]

    return _create_report(rankings, ["Majority vote"])


def model_ranking_report(comparison, model, name="Model"):
    """Produce the ranking precision report of a model.

    Args:
        comparison: A Comparison instance for the desired datasets.
        model: A model which inherits from BaseModel, to produce the report for.
        name: The name to give the model in the report ("Model" by default).

    Returns:
        A DataFrame with the precision report.
    """
    rankings = {}
    for data_name, data in comparison.datasets.items():
        precision = ranking_from_predictions(
            data, model, comparison.true_rankings[data_name])
        rankings[data_name] = [precision]

    return _create_report(rankings, [name])


def get_ranking_report(comparison, model):
    """Produce the full ranking precision report.

    Args:
        comparison: A Comparison instance for the desired datasets.
        model: A model which inherits from BaseModel, to include in the report.

    Returns:
        A DataFrame with the precision report for all the scoring tools, the
        model and the Majority Vote method.
    """
    tools = tools_ranking_report(comparison)
    majority = majority_ranking_report(comparison)
    model = model_ranking_report(comparison, model)
    return tools.append(majority, ignore_index=True).\
                 append(model, ignore_index=True)


def deepcrispr_ranking_report(comparison, variants_dict=VARIANTS):
    """Produces the ranking precision report of DeepCRISPR.

    Args:
        comparison: A Comparison instance for the desired datasets.
        variants_dict: A dictionary mapping DeepCRISPR versions to their
        variants (VARIANTS by default).

    Returns:
        A DataFrame with the precision report.
    """
    rankings = {}
    models = [f"DeepCRISPR ({version} {variant})"\
              for version, variants in variants_dict.items()\
              for variant in variants]
    for data_name, data in comparison.datasets.items():
        rankings[data_name] = []
        for version, variants in variants_dict.items():
            for variant in variants:
                results_path = get_deepcrispr_path(version, variant, data_name)
                precision = ranking_from_deepcrispr(
                    data, comparison.true_rankings[data_name], results_path)
                rankings[data_name].append(precision)
    
    return _create_report(rankings, models)



## Compare binary decisions

In [None]:
#@title Decision functions { form-width: "150px" }

"""Utilities for extracting binary decisions from the dataset"""

def _get_scores(data, col, score_type=int):
    """Extracts the raw scores from the data.

    Args:
        data: A Data instance to extract from.
        col: The index of the score within the feature representaion of a guide.
        score_type: The type to cast the score into (int by default).
    
    Returns:
        A list with the score of each guide in the data, with the order
        preserved.
    """
    return [score_type(features[col]) for features in data.features]

def get_phyto_decisions(data, col):
    """Translates the PhytoCRISP-Ex scores to 0 (reject) / 1 (accept)
    
    Args:
        data: A Data instance to extract from.
        col: The index of the tool's score within the feature representaion of a
            guide.
    
    Returns:
        A list of 0s and 1s, where 1s replace the scores of accepted guides and 0s
        replace the scored of rejected guides.
    """
    raw_scores = _get_scores(data, col)
    return [(score+3)//4 for score in raw_scores]

def get_sgrna_decisions(data, col):
    """Translates the sgRNA Scorer 2.0 scores to 0 (reject) / 1 (accept)
    
    Args:
        data: A Data instance to extract from.
        col: The index of the tool's score within the feature representaion of a
            guide.
    
    Returns:
        A list of 0s and 1s, where 1s replace the scores of accepted guides and 0s
        replace the scored of rejected guides.
    """
    raw_scores =_get_scores(data, col, score_type=float)
    return [1 if score > 0  else 0 for score in raw_scores]

def get_model_decisions(scores, threshold):
    """Translates the model scores to 0 (reject) / 1 (accept)
    
    Args:
        scores: A list of the scores given by the model.
        threshold: Scores below threshold are considered to belong to accepted
            guides, the rest are rejected.
    
    Returns:
        A list of 0s and 1s, where 1s replace the scores of accepted guides and
        0s replace the scored of rejected guides.
    """
    return [1 if score < threshold else 0 for score in scores]


# A mapping between a tool's column in the feature representaion and a function
# which translates its scores to a 0/1 decisions list.
COL_TO_DECISION_FUNCTION = {
    5: _get_scores,             # CHOPCHOP: G20
    14: _get_scores,            # mm10db: accepted
    15: get_phyto_decisions,    # phytoCRISP-Ex
}


In [None]:
#@title Stats functions { form-width: "150px" }


# A mapping for constructing the output of get_majority_stats, for internal use.
_IDX_TO_MAJORITY_TYPE = {
    0: "Any",
    1: "Majority",
    2: "Big majority",
}


def get_tool_stats(data, decision_col, efficient, stats_func,
                   col_decision_map=COL_TO_DECISION_FUNCTION):
    """Computes some statistic for a tool for a dataset.
    
    Args:
        data: A Data instance to use the tool on.
        decision_col: The index of the decision tool's score in the feature
            representation of a guide.
        efficient: A list of the indices of the guides in the data considered to
            be efficient.
        stats_func: A function which takes efficient and a list of binary
            decisions and returns some statistic.
        col_decision_map: A dictionary mapping indices in the feature 
            representation to their corresponding decision tools
            (COL_TO_DECISION_FUNCTION by default).
    
    Returns:
        The desired statistic.
    """
    decision_func = col_decision_map[decision_col]
    decisions = decision_func(data, decision_col)
    return stats_func(efficient, decisions)


def get_model_stats(data, model, efficient, threshold, stats_func,
                    predictions=None):
    """Computes some statistic for a model for a dataset.
    
    Args:
        data: A Data instance to use the tool on.
        model: A model which inherits from BaseModel.
        efficient: A list of the indices of the guides in the data considered to
            be efficient.
        threshold: Scores below threshold are considered to belong to accepted
            guides, the rest are rejected.
        stats_func: A function which takes efficient and a list of binary
            decisions and returns some statistic.
        predictions: If provided, these are considered to be the scores 
            predicted by the model for the guides in the data. Otherwise, the
            predictions are produced using the model provided.
    
    Returns:
        The desired statistic.
    """
    if predictions is None:
        predictions = model.get_processed_predictions(data).flatten()
    decisions = get_model_decisions(predictions, threshold)
    return stats_func(efficient, decisions)


def get_majority_stats(data, efficient, stats_func,
                       col_decision_map=COL_TO_DECISION_FUNCTION):
    """Computes some statistic for decision Majority Vote methods.
    
    Args:
        data: A Data instance to use the Majority Vote on.
        efficient: A list of the indices of the guides in the data considered to
            be efficient.
        stats_func: A function which takes efficient and a list of binary
            decisions and returns some statistic.
        col_decision_map: A dictionary mapping indices in the feature
            representation to their corresponding decision tools
            (COL_TO_DECISION_FUNCTION by default).
    
    Returns:
        A dictionary mapping from the name of the Majority Vote method to its
        statistic. The following methods are implemented:
            Any: Not actually a majority vote. Accepts all guides which were
                accepted by at least one decision tool.
            Majority: Accepts all guides which were accepted by at least half of 
                the decision tools.
            Big majority: Only included when the number of decision tools
                considered is even. Accepts all guides which were accepted by
                more than half of the decision tools.
    """
    correct = 0
    decision_cols = list(col_decision_map.keys())
    even = len(decision_cols) % 2 == 0
    half_votes = np.ceil(len(decision_cols) / 2.0)
    
    # tool_decisions[i][j] will be the decision made by tool number i regarding
    # target number j.
    tool_decisions = []
    for col in decision_cols:
        decision_function = col_decision_map[col]
        tool_decisions.append(decision_function(data, col))

    # Threshold votes for the different Majority Vote methods.
    min_votes = [1, half_votes]
    # If the number of tools is even, include the Big Majority method.
    if even: min_votes.append(half_votes + 1)
    vote_decisions = {threshold: [0] * data.num for threshold in min_votes}

    for i in range(data.num):
        votes = sum([decision[i] for decision in tool_decisions])
        for threshold, decisions in vote_decisions.items():
            if votes >= threshold: decisions[i] = 1
    
    stats = [stats_func(efficient, decisions)\
             for decisions in vote_decisions.values()]
    return {
        _IDX_TO_MAJORITY_TYPE[i]: stats[i] for i in range(len(stats))
    }
    

def get_deepcrispr_stats(data, efficient, results_path, stats_func):
    """Computes some statistic for DeepCRISPR for a dataset.
    
    Args:
        data: A Data instance to evaluate DeppCRISPR on.
        efficient: A list of the indices of the guides in the data considered to
            be efficient.
        results_path: The path of the DeepCRISPR results file.
        stats_func: A function which takes efficient and a list of binary
            decisions and returns some statistic.
    
    Returns:
        The desired statistic.
    """
    with open(results_path, 'rb') as fd:
        deepcrispr_cores = pickle.load(fd)

    decisions = []
    for target in data.targets:
        score = deepcrispr_cores[target]
        decision = 1 if score > 0.5 else 0
        decisions.append(decision)
    
    return stats_func(efficient, decisions)



In [None]:
#@title Precision functions { form-width: "150px" }

def get_precision(efficient, decisions): 
    """Returns the decision precision.

    Args:
        efficient: A list of the indices of the guides in the data considered to
            be efficient.
        decisions: A list of binary decisions for each of the guides.
    
    Returns:
        The percentage of efficient guides out of the accepted guides.
    """
    correct, accepted = 0, 0
    for i, decision in enumerate(decisions):
        if not decision: continue
        accepted += 1
        if i in efficient: correct += 1
    if not accepted:
        return 0
    return 100*float(correct)/accepted

def get_tool_precision(data, decision_col, efficient):
    """Returns the decision precision for a tool.

    Args:
        data: A Data instance to use the tool on.
        decision_col: The index of the decision tool's score in the feature
            representation of a guide.
        efficient: A list of the indices of the guides in the data considered to
            be efficient.
    
    Returns:
        The percentage of efficient guides out of the accepted guides.
    """
    return get_tool_stats(data, decision_col, efficient, get_precision)

def get_model_precision(data, model, efficient, threshold, predictions=None):
    """Returns the decision precision for a model.

    Args:
        data: A Data instance to use the tool on.
        model: A model which inherits from BaseModel.
        efficient: A list of the indices of the guides in the data considered to
            be efficient.
        threshold: Scores below threshold are considered to belong to accepted
            guides, the rest are rejected.
        predictions: If provided, these are considered to be the scores
            predicted by the model for the guides in the data. Otherwise, the
            predictions are produced using the model provided.
    
    Returns:
        The percentage of efficient guides out of the accepted guides.
    """
    return get_model_stats(
        data, model, efficient, threshold, get_precision, predictions)

def get_majority_precision(data, efficient,
                           col_decision_map=COL_TO_DECISION_FUNCTION):
    """Returnד the decision precision for the Majority Vote methods.

    Args:
        data: A Data instance to use the tool on.
        efficient: A list of the indices of the guides in the data considered to
            be efficient.
        col_decision_map: A dictionary mapping indices in the feature
            representation to their corresponding decision tools
            (COL_TO_DECISION_FUNCTION by default). These are the tools which
            will be used to implement the votes.
    
    Returns:
        The percentage of efficient guides out of the accepted guides for each
        of the Majority Vote methods, according to the format defined for the
        output of get_majority_stats.
    """
    return get_majority_stats(data, efficient, get_precision, col_decision_map)

def get_deepcrispr_precision(data, efficient, results_path):
    """Returns the decision precision for DeepCRISPR as a decision tool.
    
    Args:
        data: A Data instance to use the tool on.
        efficient: A list of the indices of the guides in the data considered to
            be efficient.
        results_path: The path of the DeepCRISPR results file.
    
    Returns:
        The percentage of efficient guides out of the accepted guides.
    """
    return get_deepcrispr_stats(data, efficient, results_path, get_precision)


In [None]:
#@title Coverage functions { form-width: "150px" }

def get_coverage(efficient, decisions):
    """Returns the decision coverage.

    Args:
        efficient: A list of the indices of the guides in the data considered to
            be efficient.
        decisions: A list of binary decisions for each of the guides.
    
    Returns:
        The percentage of efficient guides that were accepted.
    """
    covered = 0
    
    for i in efficient:
        if decisions[i]:
            covered += 1
    return 100*float(covered)/len(efficient)

def get_tool_coverage(data, decision_col, efficient):
    """Returns the decision coverage for a tool.

    Args:
        data: A Data instance to use the tool on.
        decision_col: The index of the decision tool's score in the feature
            representation of a guide.
        efficient: A list of the indices of the guides in the data considered to
            be efficient.
    
    Returns:
        The percentage of efficient guides that were accepted.
    """
    return get_tool_stats(data, decision_col, efficient, get_coverage)

def get_model_coverage(data, model, efficient, threshold, predictions=None):
    """Return the decision coverage for a model.

    Args:
        data: A Data instance to use the tool on.
        model: A model which inherits from BaseModel.
        efficient: A list of the indices of the guides in the data considered to
            be efficient.
        threshold: Scores below threshold are considered to belong to accepted
            guides, the rest are rejected.
        predictions: If provided, these are considered to be the scores
            predicted by the model for the guides in the data. Otherwise, the
            predictions are produced using the model provided.
    
    Returns:
        The percentage of efficient guides that were accepted.
    """
    return get_model_stats(
        data, model, efficient, threshold, get_coverage, predictions)

def get_majority_coverage(data, efficient,
                          col_decision_map=COL_TO_DECISION_FUNCTION):
    """Returns the decision coverage for the Majority Vote methods.

    Args:
        data: A Data instance to use the Majority Vote on.
        efficient: A list of the indices of the guides in the data considered to
            be efficient.
        col_decision_map: A dictionary mapping indices in the feature
            representation to their corresponding decision tools
            (COL_TO_DECISION_FUNCTION by default). These are the tools which
            will be used to implement the votes.
    
    Returns:
        The percentage of efficient guides that were accepted for each of the
        Majority Vote methods, according to the format defined for the output of
        get_majority_stats.
    """
    return get_majority_stats(data, efficient, get_coverage, col_decision_map)

def get_deepcrispr_coverage(data, efficient, results_path):
    """Return the decision covreage for DeepCRISPR as a decision tool.
    
    Args:
        data: A Data instance to use the tool on.
        efficient: A list of the indices of the guides in the data considered to
            be efficient.
        results_path: The path of the DeepCRISPR results file.
    
    Returns:
        The percentage of efficient guides that were accepted.
    """
    return get_deepcrispr_stats(data, efficient, results_path, get_coverage)
 

### Reporting utils

In [None]:
#@title Precision and coverage plot { form-width: "150px" }

def plot_precision_coverage(data, model, num_points, label_cutoff=-1,
                            save_name=""):
    """Plots the precision and coverage as a function of threshold.

    Args:
        data: A Data instance to use the model on.
        model: A model which inherits from BaseModel.
        num_points: Number of (evenly spaced) thresholds to use.
        label_cutoff: The label below which guides are considered efficient.
        save_name: If provided, the plot is saved under this name in WORK_DIR.
    
    Returns:
        A list of the thresholds used.
    """
    scores = model.get_processed_predictions(data).flatten().double()
    efficient = data.get_efficient(label_cutoff)

    low_bound = min(scores)
    up_bound = max(scores)
    interval = (up_bound-low_bound)/(num_points-1)
    
    thresholds = [low_bound + n*interval for n in range(num_points+1)]
    precisions = [get_model_precision(data, model, efficient, t, scores)\
                  for t in thresholds]
    coverages = [get_model_coverage(data, model, efficient, t, scores)\
                 for t in thresholds]

    prep_plot()
    fig = plt.figure()

    plt.plot(thresholds, precisions, label='Precision', marker='o', alpha=0.5)
    plt.plot(thresholds, coverages, label='Coverage', marker='o', alpha=0.5)

    plt.axhline(50, linestyle='--', color='gray')
    plt.legend(loc='best')
    plt.xlabel("Threshold")
    plt.show()

    if save_name:
        fig.savefig(os.path.join(WORK_DIR, save_name),
                    bbox_inches="tight", dpi=100)
    return thresholds


In [None]:
#@title Precision and coverage report { form-width: "150px" }

STATS_STR = "precision: %.02f\ncoverage:  %.02f"

def tool_decision_report(comparison):
    """Produces a precision/coverage report for the decision tools.

    Args:
        comparison: A Comparison instance for the desired datasets.

    Returns:
        A DataFrame with the decision precision and coverage report.
    """
    stats = {}

    for data_name, data in comparison.datasets.items():
        stats[data_name] = []
        for col in COL_TO_DECISION_TOOL.keys():
            precision = get_tool_precision(
                data, col, comparison.efficients[data_name])
            coverage = get_tool_coverage(
                data, col, comparison.efficients[data_name])
            stats[data_name].append(STATS_STR % (precision, coverage))

    return _create_report(stats, list(COL_TO_DECISION_TOOL.values()))

def majority_decision_report(comparison,
                             col_decision_map=COL_TO_DECISION_FUNCTION):
    """Produces a precision/coverage report for the Majority Vote methods.

    Args:
        comparison: A Comparison instance for the desired datasets.
        col_decision_map: A dictionary mapping indices in the feature 
            representation to their corresponding decision tools
            (COL_TO_DECISION_FUNCTION by default). These are the tools which
            will be used to implement the votes.

    Returns:
        A DataFrame with the decision precision and coverage report.
    """
    stats = {}
    vote_tools = []
    for data_name, data in comparison.datasets.items():
        precisions = get_majority_precision(
            data, comparison.efficients[data_name], col_decision_map)
        coverages = get_majority_coverage(
            data, comparison.efficients[data_name], col_decision_map)

        if not vote_tools:
            vote_tools = [f"{maj_type} vote" for maj_type in precisions.keys()]

        stats[data_name] = [
            STATS_STR % (precisions[name], coverages[name])\
            for name in precisions.keys()
        ]

    return _create_report(stats, vote_tools)

def model_decision_report(comparison, model, threshold, name="Model"):
    """Produce a precision/coverage report for a model.

    Args:
        comparison: A Comparison instance for the desired datasets.
        model: A model which inherits from BaseModel.
        threshold: Scores below threshold are considered to belong to accepted
            guides, the rest are rejected.
        name: The name to give the model in the report ("Model" by default).

    Returns:
        A DataFrame with the decision precision and coverage report.
    """
    stats = {}
    for data_name, data in comparison.datasets.items():
        precision = get_model_precision(
            data, model, comparison.efficients[data_name], threshold)
        coverage = get_model_coverage(
            data, model, comparison.efficients[data_name], threshold)
        stats[data_name] = [STATS_STR % (precision, coverage)]

    return _create_report(stats, [name])

def get_decision_report(comparison, model, threshold):
    """Produce a complete precision/coverage report.

    Args:
        comparison: A Comparison instance for the desired datasets.
        model: A model which inherits from BaseModel.
        threshold: Model scores below threshold are considered to belong to
            accepted guides, the rest are rejected.

    Returns:
        A DataFrame with the decision precision and coverage report.
    """
    tools = tool_decision_report(comparison)
    majority = majority_decision_report(comparison)
    model = model_decision_report(comparison, model, threshold)
    return tools.append(majority, ignore_index=True).\
                 append(model, ignore_index=True)


def deepcrispr_decision_report(comparison, variants_dict):
    """Produceד a precision/coverage report for DeepCRISPR as a decision tool.
    
    Args:
        comparison: A Comparison instance for the desired datasets.
        variants_dict: A dictionary mapping DeepCRISPR versions to its variants
        (VARIANTS by default).
    
    Returns:
        The percentage of efficient guides that were accepted.
    """
    stats = {}

    for data_name, data in comparison.datasets.items():
        stats[data_name] = []
        for variant in variants_dict:
            results_path = get_deepcrispr_path(
                CLASSIFICATION, variant, data_name)
            efficient = comparison.efficients[data_name]
            precision = get_deepcrispr_precision(data, efficient, results_path)
            coverage = get_deepcrispr_coverage(data, efficient, results_path)
            stats[data_name].append(STATS_STR % (precision, coverage))

    variants_names = [f"DeepCRISPR ({variant})" for variant in variants_dict]
    return _create_report(stats, variants_names)


## Compare ROC-AUC

In [None]:
#@title Functions { form-width: "150px" }

def get_model_rocauc(data, model, binary_labels):
    """Returns the ROC-AUC of a model.

    Args:
        data: A Data instance to evaluate the model on.
        model: A model which inherits from BaseModel.
        binary_labels: A list of 0s and 1s, where a 1 at index i indicates that
            guide i in the data is efficient, and a 0 indicates it is
            inefficient.
    
    Returns:
        The ROC-AUC.
    """
    predictions = model.get_processed_predictions(data)
    # The model's scores are inverted (lower means better)
    predictions = [-p for p in predictions]
    return metrics.roc_auc_score(binary_labels, predictions)

def get_deepcrispr_rocauc(data, results_path, binary_labels):
    """Returns the ROC-AUC of DeepCRISPR as a decision tool.

    Args:
        data: A Data instance to use the tool on.
        results_path: The path of the DeepCRISPR results file.
        binary_labels: A list of 0s and 1s, where a 1 at index i indicates that
            guide i in the data is efficient, and a 0 indicates it is
            inefficient.
    
    Returns:
        The ROC-AUC.
    """
    with open(results_path, 'rb') as fd:
        results = pickle.load(fd)

    scores = []
    for target in data.targets:
        scores.append(results[target])
    
    return metrics.roc_auc_score(binary_labels, scores)



In [None]:
#@title Report { form-width: "150px" }


def compare_auc(comparison, model):
    """Produces a ROC-AUC report.

    Args:
        comparison: A Comparison instance for the desired datasets.
        model: A model which inherits from BaseModel.

    Returns:
        A DataFrame with the ROC-AUC report.
    """
    aucs = {}
    for data_name, data in comparison.datasets.items():
        efficient = comparison.efficients[data_name]
        binary_labels = [1 if i in efficient else 0\
                         for i in range(data.num)]
        
        model_rocauc = get_model_rocauc(data, model, binary_labels)
        results_path = get_deepcrispr_path(REGRESSIONS, "seq_only", data_name)
        deepcrispr_rocauc = get_deepcrispr_rocauc(
            data, results_path, binary_labels)
        
        aucs[data_name] = [model_rocauc, deepcrispr_rocauc]
    
    return _create_report(aucs, ["Model", "DeepCRISPR"], "%.2f")


# Architecture

In [None]:
#@title Loss functions { form-width: "150px" }

"""Adapted from the allRank library by Przemek Pobrotyn.
https://github.com/allegro/allRank

Main changes: the addition of stochastic rankNet, which does not consider all
    the possible pairs, rather it samples a limited number of datapoints twice,
    and considers all the possible pairs between them, thus avoiding filling up
    the RAM with a huge list of pairs. The size of the sample is controlled by
    the new argument sample_size. The wrapper function stochasticRankNet takes
    advantage of this argument.
"""

DEFAULT_EPS = 1e-10
PADDED_Y_VALUE = -100
SAMPLE_SIZE = int(6000)


def rankNet(y_pred, y_true, padded_value_indicator=PADDED_Y_VALUE, weight_by_diff=False, weight_by_diff_powed=False, sample_size=0):
    """
    RankNet loss introduced in "Learning to Rank using Gradient Descent".
    :param y_pred: predictions from the model, shape [batch_size, slate_length]
    :param y_true: ground truth labels, shape [batch_size, slate_length]
    :param weight_by_diff: flag indicating whether to weight the score differences by ground truth differences.
    :param weight_by_diff_powed: flag indicating whether to weight the score differences by the squared ground truth differences.
    :return: loss value, a torch.Tensor
    """
    y_pred = y_pred.clone().unsqueeze(0).squeeze(2)
    y_true = y_true.clone().unsqueeze(0).squeeze(2)

    mask = y_true == padded_value_indicator
    y_pred[mask] = float('-inf')
    y_true[mask] = float('-inf')

    if sample_size and sample_size < y_true.shape[1]:
        # sample sample_size twice, and generate every pair of indices from the samples
        sample1 = random.sample(range(y_true.shape[1]), sample_size)
        sample2 = random.sample(range(y_true.shape[1]), sample_size)
        document_pairs_candidates = list(product(sample1, sample2))
    else:
        # here we generate every pair of indices from the range of document length in the batch
        document_pairs_candidates = list(product(range(y_true.shape[1]), repeat=2))

    document_pairs_candidates = list(product(range(y_true.shape[1]), repeat=2))

    pairs_true = y_true[:, document_pairs_candidates]
    selected_pred = y_pred[:, document_pairs_candidates]

    # here we calculate the relative true relevance of every candidate pair
    true_diffs = pairs_true[:, :, 0] - pairs_true[:, :, 1]
    pred_diffs = selected_pred[:, :, 0] - selected_pred[:, :, 1]
    

    # here we filter just the pairs that are 'positive' and did not involve a padded instance
    # we can do that since in the candidate pairs we had symetric pairs so we can stick with
    # positive ones for a simpler loss function formulation
    the_mask = (true_diffs > 0) & (~torch.isinf(true_diffs))

    pred_diffs = pred_diffs[the_mask]

    weight = None
    if weight_by_diff:
        abs_diff = torch.abs(true_diffs)
        weight = abs_diff[the_mask]
    elif weight_by_diff_powed:
        true_pow_diffs = torch.pow(pairs_true[:, :, 0], 2) - torch.pow(pairs_true[:, :, 1], 2)
        abs_diff = torch.abs(true_pow_diffs)
        weight = abs_diff[the_mask]

    # here we 'binarize' true relevancy diffs since for a pairwise loss we just need to know
    # whether one document is better than the other and not about the actual difference in
    # their relevancy levels
    true_diffs = (true_diffs > 0).type(torch.float32)
    true_diffs = true_diffs[the_mask]
   
    return BCEWithLogitsLoss(weight=weight)(pred_diffs, true_diffs)


def stochasticRankNet(y_pred, y_true, padded_value_indicator=PADDED_Y_VALUE,
                      weight_by_diff=False, weight_by_diff_powed=False):
    return RankNet(y_pred, y_true, padded_value_indicator, weight_by_diff,
                   weight_by_diff_powed, SAMPLE_SIZE)


def lambdaLoss(y_pred, y_true, eps=DEFAULT_EPS, padded_value_indicator=PADDED_Y_VALUE, weighing_scheme=None, k=None, sigma=1., mu=10.,
               reduction="mean", reduction_log="binary"):
    """
    LambdaLoss framework for LTR losses implementations, introduced in "The LambdaLoss Framework for Ranking Metric Optimization".
    Contains implementations of different weighing schemes corresponding to e.g. LambdaRank or RankNet.
    :param y_pred: predictions from the model, shape [batch_size, slate_length]
    :param y_true: ground truth labels, shape [batch_size, slate_length]
    :param eps: epsilon value, used for numerical stability
    :param padded_value_indicator: an indicator of the y_true index containing a padded item, e.g. -1
    :param weighing_scheme: a string corresponding to a name of one of the weighing schemes
    :param k: rank at which the loss is truncated
    :param sigma: score difference weight used in the sigmoid function
    :param mu: optional weight used in NDCGLoss2++ weighing scheme
    :param reduction: losses reduction method, could be either a sum or a mean
    :param reduction_log: logarithm variant used prior to masking and loss reduction, either binary or natural
    :return: loss value, a torch.Tensor
    """
    device = y_pred.device
    y_pred = y_pred.clone().unsqueeze(0).squeeze(2)
    y_true = y_true.clone().unsqueeze(0).squeeze(2)

    padded_mask = y_true == padded_value_indicator
    y_pred[padded_mask] = float("-inf")
    y_true[padded_mask] = float("-inf")

    # Here we sort the true and predicted relevancy scores.
    y_pred_sorted, indices_pred = y_pred.sort(descending=True, dim=-1)
    y_true_sorted, _ = y_true.sort(descending=True, dim=-1)

    # After sorting, we can mask out the pairs of indices (i, j) containing index of a padded element.
    true_sorted_by_preds = torch.gather(y_true, dim=1, index=indices_pred)
    true_diffs = true_sorted_by_preds[:, :, None] - true_sorted_by_preds[:, None, :]
    padded_pairs_mask = torch.isfinite(true_diffs)

    if weighing_scheme != "ndcgLoss1_scheme":
        padded_pairs_mask = padded_pairs_mask & (true_diffs > 0)

    ndcg_at_k_mask = torch.zeros((y_pred.shape[1], y_pred.shape[1]), dtype=torch.bool, device=device)
    ndcg_at_k_mask[:k, :k] = 1

    # Here we clamp the -infs to get correct gains and ideal DCGs (maxDCGs)
    true_sorted_by_preds.clamp_(min=0.)
    y_true_sorted.clamp_(min=0.)

    # Here we find the gains, discounts and ideal DCGs per slate.
    pos_idxs = torch.arange(1, y_pred.shape[1] + 1).to(device)
    D = torch.log2(1. + pos_idxs.float())[None, :]
    maxDCGs = torch.sum(((torch.pow(2, y_true_sorted) - 1) / D)[:, :k], dim=-1).clamp(min=eps)
    G = (torch.pow(2, true_sorted_by_preds) - 1) / maxDCGs[:, None]

    # Here we apply appropriate weighing scheme - ndcgLoss1, ndcgLoss2, ndcgLoss2++ or no weights (=1.0)
    if weighing_scheme is None:
        weights = 1.
    else:
        weights = globals()[weighing_scheme](G, D, mu, true_sorted_by_preds)  # type: ignore

    # We are clamping the array entries to maintain correct backprop (log(0) and division by 0)
    scores_diffs = (y_pred_sorted[:, :, None] - y_pred_sorted[:, None, :]).clamp(min=-1e8, max=1e8)
    scores_diffs[torch.isnan(scores_diffs)] = 0.
    weighted_probas = (torch.sigmoid(sigma * scores_diffs).clamp(min=eps) ** weights).clamp(min=eps)
    if reduction_log == "natural":
        losses = torch.log(weighted_probas)
    elif reduction_log == "binary":
        losses = torch.log2(weighted_probas)
    else:
        raise ValueError("Reduction logarithm base can be either natural or binary")

    masked_losses = losses[padded_pairs_mask & ndcg_at_k_mask]
    if reduction == "sum":
        loss = -torch.sum(masked_losses)
    elif reduction == "mean":
        loss = -torch.mean(masked_losses)
    else:
        raise ValueError("Reduction method can be either sum or mean")

    return loss


In [None]:
#@title Data class { form-width: "150px" }

class Data(object):
    """Represents a dataset.

    Attributes:
        targets: A list of the target sequences in the dataset.
        features: A list of feature representations of the targets (with
            matching order).
        labels: A list of the labels of the targets (with matching order).
        num: The number of targets in the dataset.
        num_features: The length of the feature representations.
        epoch_counter: The number of training epochs performed (for tracking
            purposes).
        epochs: A list of epochs after which the loss for this dataset was
            sampled.
        losses: A list of losses measured at the epochs in the apochs list.
    """
    def __init__(self, datapoints=None, num_features=0,
                 labels_extractor=get_labels):
        """Initialises a dataset.

        Args:
            datapoints: A list of DataPoint instances.
            num_features: The length of the feature representations of the
                targets.
            labels_extractor: A function which extracts labels from the
                datapoints.
        """
        if datapoints:
            self.targets = get_targets(datapoints)
            self.features = get_features(datapoints, num_features)
            self.labels = labels_extractor(datapoints)
        else:
            self.targets = []
            self.features = []
            self.labels = []
        
        self.num = len(self.labels)
        self.num_features = num_features
        self.epoch_counter = 0
        self.epochs = []
        self.losses = []
    
    def shuffle(self):
        """Returns a shuffled copy of the dataset."""
        permutation = torch.randperm(self.num)
        shuffled = Data(num_features=self.num_features)
        shuffled.features = self.features[permutation]
        shuffled.labels = self.labels[permutation]
        shuffled.num = self.num
        return shuffled
    
    def get_efficient(self, threshold=-1):
        """Returns a list of the indices of the efficient guides in the dataset.

        Args:
            threshold: Guides with a label below threshold are considered
                efficient.
        
        Returns:
            The list of indices of efficient guides.
        """
        efficient = []
        for i,label in enumerate(self.labels):
            if label < threshold:
                efficient.append(i)
        return efficient
    
    def update_progress(self, epoch_counter, loss):
        """Updates the training progress of the dataset.
        
        Args:
            epoch_counter: The number of elapsed epochs.
            loss: The latest loss registered for the dataset.
        """
        self.epoch_counter = epoch_counter
        self.epochs.append(epoch_counter)
        self.losses.append(loss)


In [None]:
#@title BaseModel { form-width: "150px" }

class BaseModel(nn.Module):
    """A basic model all model should inherit from

    Attributes:
        input_size: The dimension of the input, should correspond to the size of
            the feature representation of a guide.
        output_size: The number of outputs.
        hidden_size: The number of units in a hidden layer.
        epoch_counter: The number of epochs the model has been trained for.
    """

    def __init__(self, input_size, output_size, hidden_size=0):
        super(BaseModel, self).__init__()
        self.input_size = input_size
        self.output_size = output_size
        self.hidden_size = hidden_size
        self.epoch_counter = 0
    
    def get_name(self):
        """Returns the name of the model (which is the class name)"""
        return self.__class__.__name__

    def get_predictions(self, data, is_train=True):
        """Returns the raw predictions of the model for the data.
        
        Args:
            data: A Data instance.
            is_train: A boolean indicating if the model is being trained on the
                data. If True, gradients will be computed.
        
        Returns:
            The predictions of the model for the data.
        """
        with torch.set_grad_enabled(is_train):
            x_features = Variable(torch.from_numpy(data.features)).to(device)
            return self(x_features)
    
    def get_processed_predictions(self, data):
        """Returns the post-processed predictions of the model for the data.
        
        Args:
            data: A Data instance.
        
        Returns:
            The processedpredictions of the model for the data.
        """
        # In this basic model, no post-processing is needed, so jsut return the
        # predictions.
        return self.get_predictions(data, False)

    def count_params(self):
        """Returns the number of learnt parameters of the model."""
        return sum(p.numel() for p in self.parameters() if p.requires_grad)


In [None]:
#@title Experiment class { form-width: "150px" }

class Experiment(object):
    """A class to manage an experiment.
    
    Attributes:
        model: The experimental model (which inherits from BaseModel).
        optimizer: A PyTorch Optimizer (initialised).
        criterion: A loss function.
        num_features: The size of the feature representaion.
        labels_extractor: A method for extracting labels from datapoints.
        lr: The learning rate for the optimizer.
        name: A name representing the experiment.
        training: A Data instance for the training set.
        validation: A Data instance for the validation set.
        testing: A Data instance for the test set.
        tracked_datasets: A dictionary mapping dataset names to Data instances
            for these datasets. These will be updated with the model's loss for
            them throughout the training.
    """

    def __init__(self, model, optimizer, criterion,
                 num_features, labels_extractor=get_labels,
                 lr=0.01, weight_decay=0.0001, suffix=""):
        """Initialise an experiment.

        Args:
            model: The experimentatl model (which inherits from BaseModel).
            optimizer: A PyTorch Optimizer (uninitialised).
            criterion: A loss function.
            num_features: The size of the feature representaion.
            labels_extractor: A method for extracting labels from datapoints.
            lr: The learning rate for the optimizer.
            weight_decay: The L2 regularisation parameter.
            suffix: A suffix for the name of the experiment.
        """
        self.model = model
        self.optimizer = optimizer(self.model.parameters(), lr=lr,
                                   weight_decay=weight_decay)
        self.criterion = criterion
        self.num_features = num_features
        self.labels_extractor = labels_extractor
        self.lr = lr
        
        date_str = datetime.now().strftime("%y%m%d_%H%M%S")
        self.name = "%s_%s" % (model.get_name(), date_str)
        if suffix:
            self.name += f"_{suffix}"

        print("Starting experiment with model %s with %s params" %
              (self.model.get_name(), self.model.count_params()))
        
        self.training = None
        self.validation = None
        self.testing = None
        self.tracked_datasets = {}
        
    def set_data(self, train_data, validation_data, test_data):
        """Initialises basic datasets for the experiment.

        Args:
            train_data: A list of DataPoint instances for training.
            validation_data: A list of DataPoint instances for validation.
            test_data: A list of DataPoint instances for testing.
        """
        self.training = Data(
            train_data, self.num_features, self.labels_extractor)
        self.validation = Data(
            validation_data, self.num_features, self.labels_extractor)
        self.testing = Data(
            test_data, self.num_features, self.labels_extractor)

        self.tracked_datasets["validation"] = self.validation
        self.tracked_datasets["test"] = self.testing
    
    def add_dataset(self, raw_data, name):
        """Adds a dataset to the tracked datasets.
        
        Args:
            raw_data: The list of DataPoint instances of the dataset.
            name: A name to represent the dataset.
        """
        self.tracked_datasets[name] = Data(
            raw_data, self.num_features, self.labels_extractor)

    def prep_and_set_data(self, datasets, genome_data, sizes):
        """Initialises all the required datasets.

        Args:
            datasets: A dictionary mapping dataset names to their list of
                DataPoint instances. These will be used when composing the work
                data (that is, the training, validation and test sets) along
                with the genome data. They will also be added to the tracked
                datasets.
            genome_data: A list of DataPoint instances which will be sampled
                from to complete the work data.
            sizes: A tuple with the sizes of the three standard datasets:
                training, validation and test (in that order).
        """
        # Concatenates the datasets
        work_data = sum([data for data in datasets.values()], [])
        training_size, validation_size, test_size = sizes
        total_size = sum(sizes)
        
        genome_padding_len = total_size - len(work_data)
        genome_padding = random.sample(genome_data, genome_padding_len)
        work_data += genome_padding
        np.random.shuffle(work_data)

        work_datasets = []
        start = 0
        for i in range(len(sizes)):
            end = start + sizes[i]
            work_datasets.append(work_data[start:end])
            start = end
    
        self.set_data(*work_datasets)
        for name, dataset in datasets.items():
            self.add_dataset(dataset, name)

    def set_lr(self, lr):
        """Changes the learning rate."""
        self.lr = lr
        for g in self.optimizer.param_groups:
            g['lr'] = lr

    def train(self, epochs=int(1e5), report_every=0):
        """Trains the model.

        The general structure of this training procedure was adapted from the
        PyTorch tutorial:
        NLP From Scratch: Translation with a Sequence to Sequence Network and
        Attention.
        https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html

        Args:
            epochs: Number of epochs to train for.
            report_every: The number of epochs after which to report progress
                and update the datasets with the current loss.
        """
        if not report_every:
            report_every = epochs // 10
        start_time = time.time()
        data = self.training.shuffle()
        print_loss = 0

        for epoch in range(1, epochs+1):
            self.model.train()
            self.model.epoch_counter += 1
            self.optimizer.zero_grad()

            y = Variable(torch.from_numpy(data.labels)).to(device)
            outputs = self.model.get_predictions(data, True)
            loss = self.criterion(outputs, y)
            print_loss += loss
            loss.backward()
            self.optimizer.step()

            if epoch % report_every == 0:
                report_progress(self, start_time, epoch, epochs,
                                print_loss, loss, report_every)
                print_loss = 0
    
    def test(self):
        """Tests the model and updates progress for tracked datasets.

        Returns:
            A list with two items (in this order): the current loos for the
            validation set and the current loss for the test set.
        """
        self.model.eval()

        for data in self.tracked_datasets.values():
            y = Variable(torch.from_numpy(data.labels)).to(device)
            outputs = self.model.get_predictions(data, False)
            loss = self.criterion(outputs, y)
            data.update_progress(self.model.epoch_counter, loss)

        losses = [
            self.tracked_datasets["validation"].losses[-1],
            self.tracked_datasets["test"].losses[-1]
        ]
        return losses
    
    def get_experiment_dir(self):
        """Returns the path of the experiment's directory."""
        return self._get_experiment_dir(self.name)
    
    def get_model_params_path(self):
        """Returns the path of the model's state file."""
        return self._get_model_params_path(self.name)
    
    def get_experiment_path(self):
        """Returns the path of the experiment's Pickle file."""
        return self._get_experiment_path(self.name)

    def save(self):
        """Saves the experiment and the model."""
        if not os.path.exists(self.get_experiment_dir()):
            os.makedirs(self.get_experiment_dir())

        torch.save(self.model.state_dict(), self.get_model_params_path())
        with open(self.get_experiment_path(), 'wb') as fd:
            pickle.dump(self, fd)

    @classmethod
    def _get_experiment_dir(cls, name):
        """Returns the path of the experiment's directory with this name."""
        return os.path.join(WORK_DIR, "models", name)
    
    @classmethod
    def _get_model_params_path(cls, name):
        """Returns the path of the model's state given the experiment name."""
        return os.path.join(cls._get_experiment_dir(name), "model_params")

    @classmethod
    def _get_experiment_path(cls, name):
        """Returns the path of the experiment's  Pickle file."""
        return os.path.join(cls._get_experiment_dir(name), "experiment.pkl")

    @classmethod
    def load(cls, name):
        """Loads the experiment with the given name."""
        path = cls._get_experiment_path(name)
        with open(path, 'rb') as fd:
            experiment = pickle.load(fd)
        experiment.name = name
        experiment.model.load_state_dict(
            torch.load(cls._get_model_params_path(name), map_location=device))
        experiment.model = experiment.model.to(device)
        return experiment
        

In [None]:
#@title Fully connnected experiemnt { form-width: "150px" }

class FullyConnected(BaseModel):
    """A fully connected feedforward neural network.
    
    Attributes (changes from BaseModel):
        num_layers: The number of neural layers (this includes the hidden layers
            and the output layer, but not the input layer).
        nn: A PyTorch Sequential model, which holds all the layers.
    """

    def __init__(self, input_size, output_size, hidden_size, num_layers):
        """Initialises the model
        
        Args:
            input_size: The dimension of the input, should correspond to the
                size of the feature representation of a guide.
            output_size: The number of outputs.
            hidden_size: The number of units in a hidden layer.
            num_layers: The number of neural layers (this includes the hidden
                layers and the output layer, but not the input layer).
        """
        super(FullyConnected, self).__init__(
            input_size, output_size, hidden_size)
        self.num_layers = num_layers
        self.nn = nn.Sequential(*self.get_layers())
    
    def get_layers(self):
        """Initialises the layers of the network."""
        if self.num_layers == 1:
            return [nn.Linear(self.input_size, self.output_size)]
        
        layers = [
            nn.Linear(self.input_size, self.hidden_size),
            nn.Tanh(),
        ]
        for _ in range(self.num_layers - 2):
            layers.append(nn.Linear(self.hidden_size, self.hidden_size))
            layers.append(nn.Tanh())
        layers.append(nn.Linear(self.hidden_size, self.output_size))
        return layers

    def forward(self, x):
        return self.nn(x)


class FullyConnectedExperiment(Experiment):
    """An experiment with a fully connected feedforward neural network."""
    
    def __init__(self, num_features, hidden_size, num_layers, optimizer,
                 criterion, lr=0.01, weight_decay=0.0001, suffix=""):
        """Initialises the experiment.
        
        Args:
            num_features: The size of the feature representation of guides.
            hidden_size: The number of units in a hidden layer.
            num_layers: The number of neural layers (this includes the hidden
                layers and the output layer, but not the input layer).
            optimizer: A PyTorch Optimizer (uninitialised).
            criterion: A loss function.
            lr: The learning rate for the optimizer.
            weight_decay: The L2 regularisation parameter.
            suffix: A suffix for the name of the experiment.
        """
        model = FullyConnected(
            input_size=num_features,
            output_size=1,
            hidden_size=hidden_size,
            num_layers=num_layers,
        )
        super(FullyConnectedExperiment, self).__init__(
            model, optimizer, criterion, num_features, get_labels,
            lr, weight_decay, suffix
        )


In [None]:
#@title Ordinal classification experiemnt { form-width: "150px" }

class OrdinalClassifier(BaseModel):
    """An ordinal classification model.

    Follows the architecture proposed by:
    Frank, E. and Hall, M. (2001, September)
    A simple approach to ordinal classification.
    In European Conference on Machine Learning (pp. 145-156). Springer, Berlin,
    Heidelberg.
    https://link.springer.com/chapter/10.1007/3-540-44795-4_13

    Attributes (changes from BaseModel):
        num_labels: The number of different labels or classes.
        num_layers: The number of neural layers (this includes the hidden
            layers and the output layer, but not the input layer) in the
            classifiers.
        classifiers: A PyTorch ModuleList of the classifiers.
        datasets: The copies of the training dataset, after it has been
            transformed to a classification dataset for each of the classifiers.
    """

    def __init__(self, input_size, num_labels, hidden_size, num_layers):
        """Initialises the model.
        
        Args:
            input_size: The dimension of the input, should correspond to the
                size of the feature representation of a guide.
            num_labels: The number of different labels or classes.
            hidden_size: The number of units in a hidden layer.
            num_layers: The number of neural layers (this includes the hidden
                layers and the output layer, but not the input layer) in the
                classifiers.
        """
        super(OrdinalClassifier, self).__init__(
            input_size, num_labels, hidden_size)
        self.num_labels = num_labels
        self.num_layers = num_layers
        self.classifiers = nn.ModuleList([
            FullyConnected(input_size, 1, hidden_size, num_layers)\
            for _ in range(num_labels-1)
        ])
        self.datasets = []
        
    def forward(self, x, label):
        """Produces predictions for x using the classifier of the label."""
        # self.classifiers[label] is the classifier for > label
        return self.classifiers[label](x)
    
    def classify(self, data, choose=False):
        """Produces predictions for the data.

        Args:
            data: A Data instance to classify.
            choose: If False, produces probability distributions over all the
                possible classes. If True, chooses the most probable class for
                each item in the dataset.
        
        Returns:
            Either a list with the probability distributions or a single
            predicted class, depending on the value of choose.
        """
        x = Variable(torch.from_numpy(data.features)).to(device)

        # prob_more_than[i][j] is the probability that datapoint j has a label
        # higher than i.
        prob_more_than = [0] * (self.num_labels - 1)
        with torch.no_grad():
            for label, classifier in enumerate(self.classifiers):
                prob_more_than[label] = torch.sigmoid(classifier(x))
        
        # prob_equals[i][j] is the probability that datapoint i has label j.
        prob_equals = [[0]*self.num_labels for _ in range(data.num)]
        for i, probs in enumerate(prob_equals):
            probs[0] = 1 - prob_more_than[0][i]
            top_label = self.num_labels - 1
            for label in range(1, top_label):
                probs[label] = prob_more_than[label-1][i] -\
                    prob_more_than[label][i]
            probs[top_label] = prob_more_than[top_label - 1][i]
        if choose:
            classes = [np.argmax(probs) for probs in prob_equals]
            return torch.tensor(classes).unsqueeze(1)
        
        return torch.tensor(prob_equals)
    
    def split_data(self, data):
        """Splits the dataset into the required classification datasets."""
        datasets = [copy.deepcopy(data) for _ in range(self.num_labels - 1)]
        for min_label, dataset in enumerate(datasets):
            for i, label in enumerate(dataset.labels):
                if label > min_label:
                    dataset.labels[i] = 1
                else:
                    dataset.labels[i] = 0
        return datasets
    
    def get_processed_predictions(self, data):
        """Returns the predicted classes for the data."""
        return self.classify(data, choose=True)
    

class OrdinalClassificationExperiment(Experiment):
    """An ordinal classification experiment.

    Attributes (changes from Experiment):
        optimizers: PyTorch Optimizers (initialised) for all the classifiers.
        weight_decay: The L2 regularisation parameter.
    """

    def __init__(self, num_features, num_labels, hidden_size, num_layers,
                 optimizer, criterion, lr=0.01, weight_decay=0.0001, suffix=""):
        """Initialises the experiment.
        
        Args:
            num_features: The size of the feature representation of guides.
            num_labels: The number of different labels or classes.
            hidden_size: The number of units in a hidden layer.
            num_layers: The number of neural layers (this includes the hidden
                layers and the output layer, but not the input layer) in the
                classifiers.
            optimizer: A PyTorch Optimizer (uninitialised).
            criterion: A loss function.
            lr: The learning rate for the optimizer.
            weight_decay: The L2 regularisation parameter.
            suffix: A suffix for the name of the experiment.
        """
        model = OrdinalClassifier(
            input_size=num_features,
            num_labels=num_labels,
            hidden_size=hidden_size,
            num_layers=num_layers,
        )
        super(OrdinalClassificationExperiment, self).__init__(
            model, optimizer, criterion, num_features,
            get_labels_as_indices, lr, weight_decay, suffix
        )
        self.weight_decay = weight_decay
        self.optimizers = [
            self.optimizer(
                classifier.parameters(), lr=self.lr, weight_decay=0.0001)\
            for classifier in self.model.classifiers
        ]
        
    def train(self, epochs=int(1e5), report_every=0):
        """Trains the model.

        Args:
            epochs: Number of epochs to train for.
            report_every: The number of epochs after which to report progress
                and update the datasets with the current loss.
        """
        if not report_every:
            report_every = epochs // 10
        start_time = time.time()
        print_loss = 0

        criterion = nn.BCEWithLogitsLoss()
        datasets = experiment.model.split_data(experiment.training.shuffle())

        for epoch in range(1, epochs+1):
            experiment.model.train()
            experiment.model.epoch_counter += 1
            for label, classifier in enumerate(experiment.model.classifiers):
                optimizer = self.optimizers[label]
                optimizer.zero_grad()
                data = datasets[label]

                y = Variable(torch.from_numpy(data.labels).float()).\
                    to(device).unsqueeze(1)
                outputs = classifier.get_predictions(data, True)
                loss = criterion(outputs, y)
                print_loss += loss
                loss.backward()
                optimizer.step()

            if epoch % report_every == 0:
                report_progress(experiment, start_time, epoch, epochs,
                                print_loss, loss, report_every)
                print_loss = 0
    
    def test(self):
        """Tests the model and update progress for tracked datasets.

        Returns:
            A list with two items (in this order): the current loos for the
            validation set and the current loss for the test set.
        """
        self.model.eval()
        criterion = nn.CrossEntropyLoss()

        for data in self.tracked_datasets.values():
            y = Variable(torch.from_numpy(data.labels)).to(device)
            outputs = self.model.classify(data, choose=False)
            loss = criterion(outputs, y)
            data.update_progress(self.model.epoch_counter, loss)
        
        losses = [
            self.tracked_datasets["validation"].losses[-1],
            self.tracked_datasets["test"].losses[-1]
        ]
        return losses
    
    def set_lr(self, lr):
        self.lr = lr
        for optimizer in self.optimizers:
            for g in optimizer.param_groups:
                g['lr'] = lr
        

In [None]:
#@title Ordinal regression experiemnt { form-width: "150px" }

class OrdinalRegression(BaseModel):
    """An ordinal regression model.

    Follows and utilises the architecture proposed by:
    Rosenthal, E. (2018)
    spacecutter: Ordinal Regression Models in PyTorch.
    https://www.ethanrosenthal.com/2018/12/06/spacecutter-ordinal-regression/

    Attributes (changes from BaseModel):
        ascension: Responsible for clipping the cutpoints to preserve ascenfing
            order.
    """

    def __init__(self, input_size, num_labels, hidden_size, num_layers):
        """Initialises the model.
        
        Args:
            input_size: The dimension of the input, should correspond to the
                size of the feature representation of a guide.
            num_labels: The number of different labels or classes.
            hidden_size: The number of units in a hidden layer.
            num_layers: The number of neural layers (this includes the hidden
                layers and the output layer, but not the input layer).
        """
        super(OrdinalRegression, self).__init__(
            input_size, num_labels, hidden_size)
        self.num_labels = num_labels
        self.num_layers = num_layers
        self.regression = OrdinalLogisticModel(
            FullyConnected(input_size, 1, hidden_size, num_layers),
            num_labels
        )
        self.ascension = AscensionCallback()
    
    def get_processed_predictions(self, data):
        """Returns the predicted classes for the data."""
        outputs = self.get_predictions(data, False)
        classes = torch.argmax(outputs, dim=1)
        return classes
        
    def forward(self, x):
        return self.regression(x)


class OrdinalRegressionExperiment(Experiment):
    """An ordinal regression experiment."""

    def __init__(self, num_features, num_labels, hidden_size, num_layers,
                 optimizer, criterion, lr=0.01, weight_decay=0.0001, suffix=""):
        """Initialises the experiment.
        
        Args:
            num_features: The size of the feature representation of guides.
            num_labels: The number of different labels or classes.
            hidden_size: The number of units in a hidden layer.
            num_layers: The number of neural layers (this includes the hidden
                layers and the output layer, but not the input layer).
            optimizer: A PyTorch Optimizer (uninitialised).
            criterion: A loss function.
            lr: The learning rate for the optimizer.
            weight_decay: The L2 regularisation parameter.
            suffix: A suffix for the name of the experiment.
        """
        model = OrdinalRegression(
            input_size=num_features,
            num_labels=num_labels,
            hidden_size=hidden_size,
            num_layers=num_layers,
        )
        super(OrdinalRegressionExperiment, self).__init__(
            model, optimizer, criterion, num_features, get_labels_as_indices,
            lr, weight_decay, suffix
        )
    
    def train(self, epochs=int(1e5), report_every=0):
        """Trains the model.

        Args:
            epochs: Number of epochs to train for.
            report_every: The number of epochs after which to report progress
                and update the datasets with the current loss.
        """
        if not report_every:
            report_every = epochs // 10
        start_time = time.time()
        data = self.training.shuffle()
        print_loss = 0

        for epoch in range(1, epochs+1):
            self.model.train()
            self.model.epoch_counter += 1
            self.optimizer.zero_grad()

            y = Variable(torch.from_numpy(data.labels)).\
                to(device).reshape(-1, 1)
            outputs = self.model.get_predictions(data, True)

            loss = self.criterion(outputs, y)
            print_loss += loss
            loss.backward()
            self.optimizer.step()

            self.model.ascension.clip(self.criterion)

            if epoch % report_every == 0:
                report_progress(self, start_time, epoch, epochs,
                                print_loss, loss, report_every)
                print_loss = 0
    
    def test(self):
        """Tests the model and update progress for tracked datasets.

        Returns:
            A list with two items (in this order): the current loos for the
            validation set and the current loss for the test set.
        """
        self.model.eval()

        for name, data in self.tracked_datasets.items():
            y = Variable(torch.from_numpy(data.labels)).\
                to(device).reshape(-1, 1)
            outputs = self.model.get_predictions(data)
            loss = self.criterion(outputs, y)
            data.epoch_counter = self.model.epoch_counter
            data.epochs.append(data.epoch_counter)
            data.losses.append(loss)

        losses = [
            self.tracked_datasets["validation"].losses[-1],
            self.tracked_datasets["test"].losses[-1]
        ]
        return losses


# Pipeline

In [None]:
#@title Get data { form-width: "150px" }

chari_data = get_data(CHARI_DATA_PATH)
genome_data = get_data(GENOME_DATA_PATH)
xu_data = get_data(XU_DATA_PATH)
nr_xu_data = get_data(NR_XU_DATA_PATH)
doench_data = get_data(DOENCH_DATA_PATH)

num_features = NUM_FEATURES
num_labels = NUM_FEATURES


## Initialise experiment

Following are examples of how to initialised each of the three available types of an experiment.
Choose one of these three, or load an existing model instead.

1. Fully connected experiment:

        experiment = FullyConnectedExperiment(
            num_features=num_features,
            hidden_size=13,
            num_layers=4,
            optimizer=torch.optim.SGD,
            criterion=rankNet,
            lr=0.1,
            weight_decay=0.0001,
            suffix="description")

2. Ordinal classification experiment:

        experiment = OrdinalClassificationExperiment(
            num_features=num_features,
            num_labels=num_labels,
            hidden_size=10,
            num_layers=4,
            optimizer=torch.optim.SGD,
            criterion=nn.BCEWithLogitsLoss(),
            lr=0.01,
            weight_decay=0.0001,
            suffix="description")

3. Ordinal regression experiment:

        experiment = OrdinalRegressionExperiment(
            num_features=num_features,
            num_labels=num_labels,
            hidden_size=15,
            num_layers=3,
            optimizer=torch.optim.SGD,
            criterion=CumulativeLinkLoss(),
            lr=0.1,
            weight_decay=0.0001,
            suffix="description")



In [None]:
#@title New experiment { form-width: "150px" }

experiment = FullyConnectedExperiment(
     num_features=num_features,
     hidden_size=18,
     num_layers=4,
     optimizer=torch.optim.SGD,
     criterion=lambdaLoss,
     lr=0.1,
     weight_decay=0.0001,
     suffix="description")
datasets = {
    "xu": xu_data,
    "doench": doench_data,
    "chari": chari_data,
}
experiment.prep_and_set_data(datasets, genome_data, SIZES_6K)
print(experiment.name)

In [None]:
#@title Load experiment { form-width: "150px" }

experiment = FullyConnectedExperiment.load("experiment_name")


## Train

In [None]:
for _ in range(1):
    experiment.train(100)
    experiment.save()

## Evaluate progress

In [None]:
plot_standard_losses(experiment, train=False, test=True)

In [None]:
to_plot = ["xu", "validation"]
plot_losses(
    {key.capitalize(): experiment.tracked_datasets[key] for key in to_plot},
    ylabel="RankNet Loss")

# Comparisons

In [None]:
#@title Initialise comparisons { form-width: "150px" }

with open(MIXTURE_DATA_PATH, 'rb') as fd:
    mixture = pickle.load(fd)

standard_cmp = Comparison(experiment.tracked_datasets)
final_cmp = Comparison({
    "mixture": Data(mixture, num_features),
    "nr_xu": Data(nr_xu_data, num_features),
    "xu": Data(xu_data, num_features),
    "doench": Data(doench_data, num_features),
})


## Against scoring tools

In [None]:
#@title Short summary { form-width: "150px" }

for data_name, data in final_cmp.datasets.items():
    true_ranking = final_cmp.true_rankings[data_name]
    success_rate = ranking_from_predictions(data, experiment.model, true_ranking)
    print(f"\t{data_name.ljust(18)} {success_rate}")


In [None]:
#@title Full report { form-width: "150px" }

pretty_print(get_ranking_report(final_cmp, experiment.model))


## Against decision tools

In [None]:
#@title Precision-coverage plot { form-width: "150px" }

thresholds = plot_precision_coverage(
    experiment.validation, experiment.model, 100)


In [None]:
#@title Full report { form-width: "150px" }

df = get_decision_report(final_cmp, experiment.model, 0.12)
pretty_print(df)


## Against similar tools

In [None]:
#@title Crackling configuration { form-width: "150px" }

CRACKLING_TOOLS = {
    5: COL_TO_DECISION_FUNCTION[5],
    14: COL_TO_DECISION_FUNCTION[14],
    16: get_sgrna_decisions,
}


In [None]:
#@title Ranking comparison { form-width: "150px" }

pretty_print(deepcrispr_ranking_report(final_cmp))


In [None]:
#@title Decision comparison { form-width: "150px" }

deepcrispr_df = deepcrispr_decision_report(
    final_cmp, VARIANTS[CLASSIFICATION])
crackling_df = majority_decision_report(
    final_cmp, CRACKLING_TOOLS).\
    drop(0).replace("Majority vote", "Crackling")
model_df = model_decision_report(final_cmp, experiment.model, 0.12)

pretty_print(
    deepcrispr_df.\
    append(crackling_df, ignore_index=True).\
    append(model_df, ignore_index=True)
)



In [None]:
#@title ROC-AUC comparison { form-width: "150px" }

pretty_print(compare_auc(final_cmp, experiment.model))

# Feature Importance

In [None]:
#@title Prepare explainer { form-width: "150px" }

SAMPLE_SIZE = 1000

x_train = experiment.training.features

e = shap.DeepExplainer(
    experiment.model, 
    torch.from_numpy(
        x_train[np.random.choice(
            np.arange(len(x_train)), SAMPLE_SIZE, replace=False)]
    ).to(device))


x_samples = x_train[np.random.choice(
    np.arange(len(x_train)), SAMPLE_SIZE, replace=False)]


shap_values = e.shap_values(
    torch.from_numpy(x_samples).to(device)
)


In [None]:
#@title Plot dot { form-width: "150px" }

shapfig_dot = summary_plot(
    shap_values,
    features=x_samples, feature_names=FEATURES,
    axis_color="black", cmap=SHAP_CMAP)

shapfig_dot.savefig(os.path.join(WORK_DIR, "shap_dot.png"), bbox_inches="tight", dpi=100)

In [None]:
#@title Plot bar { form-width: "150px" }

shapfig_bar = summary_plot(
    shap_values,
    features=x_samples, feature_names=FEATURES,
    plot_type="bar",
    axis_color="black", color=SHAP_COLOUR)

shapfig_bar.savefig(os.path.join(WORK_DIR, "shap_bar.png"), bbox_inches="tight", dpi=100)