In [1]:
# Sepkoski Maker V6
# If you run into bugs, feel free to email parsonsc@mit.edu

In [2]:
import os

from IPython.display import clear_output

import pandas as pd
import numpy as np

import scipy
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import matplotlib as mpl
mpl.rcParams['figure.dpi']= 300

from scipy.stats import gaussian_kde
from sklearn.linear_model import LinearRegression
from sklearn.linear_model import LogisticRegression
from geopy import distance

In [2]:
# Returns a dictionary of the midpoints of the age ranges
# of all confirmed occurences of a given clade
# By default, resolution is genus-level, but that can also be set
# to "family", "order", "class", or "phylum" (NOT "species" or "kingdom")
#
# The data is stored in the output dictionary as follows:
# {"subclade_1":[occurenceTime_1, occurenceTime_2, ...], "subclade_2":[occurenceTime_1, ...], ...}
#
# This queries the Paleobiology Database (paleobiodb.org). It saves the search 
# results to the current working directory by default, so as to not redundantly 
# pull from the database. This can be disabled by setting save_locally to False
#
# By default, ALL subclades of the desired taxonomic rank are returned, but the optional 
# exclude_modern or recency_cutoff arguments can be used to filter out clades that persist 
# into the modern and therefore can be argued to not be valid data when considering clade 
# longevity (otherwise stated: "How can you measure the longevity of a thing that hasn't 
# yet died?")
#
# Set exclude_modern to True to exclude any subclade with an occurence in the
# "Holocene", "Late Pleistocene", "Pleistocene", "Middle Pleistocene", "Calabrian", 
# "Early Pleistocene", "Quaternary", "Irvingtonian", "Rancholabrean", "Ionian",
# "Castlecliffian", or "Lujanian" (there may be more synonyms for "modern"
# in the database, but these are what I've seen so far)
#
# The recency_cutoff can be set to serve as a more stringent version of exclude_modern,
# removing subclades that have representatives younger than the provided cutoff. (e.g. if 
# recency_cutoff=50, for any subclades that have occurences younger than 50 Ma, the 
# entire subclade will be excluded from the resultant dictionary) 
# 
# If a recency_cutoff value is defined, the value of exclude_modern will be irrelevant

def getSubcladeOccurences(clade, resolution="genus", exclude_modern=False, save_locally=True,
                          recency_cutoff=None, colocalized_threshold=None, exclude_zero_min_age=False,
                          no_simultaneous=False, environment_only=None, plot=False):
    if resolution == None:
        resolution = ["phylum", "class", "order", "family", "genus"]
    elif isinstance(resolution, str):
        resolution = [resolution.lower()]
    else:
        resolution = [x.lower() for x in resolution]
    subcladeDicts = {x:{} for x in resolution}
    
    modern_synonyms = ["Holocene", "Late Pleistocene", "Pleistocene", 
                       "Middle Pleistocene", "Calabrian", "Early Pleistocene",
                       "Quaternary", "Irvingtonian", "Rancholabrean", "Ionian",
                       "Castlecliffian", "Lujanian"]
    if not environment_only:
        localFilename = clade + ".csv"
    else:
        localFilename = clade + "_" + environment_only + ".csv"
    if localFilename in os.listdir():
        print("Local save of requested search found.\n")
        cladeDataframe = pd.read_csv(localFilename, low_memory=False)
    else:
        print("No local save of requested search found. Downloading from PBDB.")
        if not environment_only:
            url_search = "https://paleobiodb.org/data1.2/occs/list.csv?datainfo&rowcount&base_name=CLADE_NAME&show=classext,coords,loc,acconly,ecospace,lith,geo,taphonomy".replace("CLADE_NAME", clade)
            cladeDataframe = pd.read_csv(url_search, skiprows=17, low_memory=False)
        else:
            url_search = "https://paleobiodb.org/data1.2/occs/list.csv?datainfo&rowcount&base_name=CLADE_NAME&envtype=ENV_TYPE&show=classext,coords,loc,acconly,ecospace,lith,geo,taphonomy".replace("CLADE_NAME", clade).replace("ENV_TYPE", environment_only)
            cladeDataframe = pd.read_csv(url_search, skiprows=18, low_memory=False)
        if save_locally:
            cladeDataframe.to_csv(localFilename)
            print("Search results saved to", os.getcwd() + "/" + localFilename + "\n")
    
    r, c = cladeDataframe.shape
    presentInModern = []
    tooRecent = []
    localityDict = {}
    for i in range(r):
        for res in resolution:
            subcladeName = cladeDataframe.at[i, res]
            if not subcladeName or str(subcladeName) == "nan" or "SPECIFIED" in subcladeName:
                continue
            min_age = cladeDataframe.at[i, "min_ma"]
            max_age = cladeDataframe.at[i, "max_ma"]
            mid_age = (float(min_age) + float(max_age)) / 2

            if subcladeName in subcladeDicts[res]:
                if no_simultaneous and mid_age in subcladeDicts[res][subcladeName]:
                    continue

            late_int = cladeDataframe.at[i, "late_interval"]
            early_int = cladeDataframe.at[i, "early_interval"]
            #if late_int in modern_synonyms:
            #    if subcladeName not in presentInModern:
            #        presentInModern.append(subcladeName)
            #elif str(late_int) == "nan":
            #    if early_int in modern_synonyms:
            #        if subcladeName not in presentInModern:
            #            presentInModern.append(subcladeName)
            #if int(min_age) == 0 and exclude_zero_min_age:
            #    continue
            #
            #if colocalized_threshold:
            #    lat = cladeDataframe.at[i, "lat"]
            #    lng = cladeDataframe.at[i, "lng"]
            #    country = cladeDataframe.at[i, "cc"]
            #    if lat == "" or lng == "":
            #        coords = False
            #    else:
            #        coords = True
            #    if subcladeName not in localityDict:
            #        colocal = False
            #        if not coords:
            #            localityDict[subcladeName] = [(mid_age, None, country)]
            #        else:
            #            localityDict[subcladeName] = [(mid_age, (float(lat), float(lng)), country)]
            #    else:
            #        colocal = False
            #        for locality in localityDict[subcladeName]:
            #            if locality[0] < mid_age * 1.05 and locality[0] > mid_age * 0.95:
            #                if coords and locality[1]:
            #                    geodesic_dist_km = distance.distance(locality[1], (lat,lng)).km
            #                    if geodesic_dist_km < 10:
            #                        colocal = True
            #                        #print(geodesic_dist_km, locality, (mid_age, (lat, lng), country))
            #                        break
            #                elif country == locality[2]:
            #                    colocal = True
            #                    #print("no coords", locality, (mid_age, (lat, lng), country))
            #                    break
            #        if not colocal:
            #            if not coords:
            #                localityDict[subcladeName] += [(mid_age, None, country)]
            #            else:
            #                localityDict[subcladeName] += [(mid_age, (float(lat), float(lng)), country)]
            #    if colocal:
            #        continue
            #
            #if recency_cutoff:
            #    if mid_age < recency_cutoff:
            #        tooRecent.append(subcladeName)

            if subcladeName in subcladeDicts[res]:
                subcladeDicts[res][subcladeName] += [mid_age]
            else:
                subcladeDicts[res][subcladeName] = [mid_age]
            
    # Removes clades that have yet to die off        
    #if recency_cutoff:
    #    print("Excluding", len(presentInModern), "subclades with occurences after " + 
    #          str(recency_cutoff) + " Ma\n")
    #    for subclade in set(presentInModern + tooRecent):
    #        subcladeDict.pop(subclade)
    #elif exclude_modern:
    #    print("Excluding", len(presentInModern), "subclades with modern occurences\n")
    #    for subclade in presentInModern:
    #        subcladeDict.pop(subclade)    
    if len(resolution) == 1:
        subcladeDict = subcladeDicts[resolution[0]]
        print(len(subcladeDict), "unique subclades (%s) found" %resolution)
        print(len([subclade for subclade in subcladeDict 
                   if len(subcladeDict[subclade]) > 1]), "subclades with multiple occurences")
        print(len([subclade for subclade in subcladeDict 
                   if len(set(subcladeDict[subclade])) > 1]), "subclades with multiple unique occurences")
        print(len([subclade for subclade in subcladeDict 
                   if len(subcladeDict[subclade]) > 4]), "subclades with at least 5 occurences")
        #print([subclade for subclade in subcladeDict if len(subcladeDict[subclade]) == 1])
        if plot:
            plt.hist([max(subcladeDict[subclade]) - min(subcladeDict[subclade]) for subclade in subcladeDict])
            plt.yscale('log', nonpositive='clip')
            plt.xlabel("Total Duration")
            plt.ylabel("Count")
    return subcladeDicts

def getSupercladeCohorts(clade, superclade_resolution="class", subclade_resoltion="genus", 
                        environment_only=None):
    supercladeCohorts = {}
    
    modern_synonyms = ["Holocene", "Late Pleistocene", "Pleistocene", 
                       "Middle Pleistocene", "Calabrian", "Early Pleistocene",
                       "Quaternary", "Irvingtonian", "Rancholabrean", "Ionian",
                       "Castlecliffian", "Lujanian"]
    if not environment_only:
        localFilename = clade + ".csv"
    else:
        localFilename = clade + "_" + environment_only + ".csv"
    if localFilename in os.listdir():
        print("Local save of requested search found.\n")
        cladeDataframe = pd.read_csv(localFilename, low_memory=False)
    else:
        raise FileNotFoundError
    
    r, c = cladeDataframe.shape
    for i in range(r):
        supercladeName = cladeDataframe.at[i, superclade_resolution]
        subcladeName = cladeDataframe.at[i, subclade_resoltion]
        if (not subcladeName) or (str(subcladeName) == "nan") or ("SPECIFIED" in subcladeName):
            continue
        if (not supercladeName) or (str(supercladeName) == "nan") or ("SPECIFIED" in supercladeName):
            continue

        if supercladeName in supercladeCohorts:
            supercladeCohorts[supercladeName] += [subcladeName]
        else:
            supercladeCohorts[supercladeName] = [subcladeName]
    for x in supercladeCohorts: supercladeCohorts[x]=list(set(supercladeCohorts[x]))
    return supercladeCohorts

def getSuperCladeRichnessMatrix(supercladeCohorts, subcladeStartEnds, n_bins=50, t_0=1000):
    bin_size = t_0 / n_bins
    n_super = len(supercladeCohorts)
    superCladeRichnessMatrix = np.empty((n_super, n_bins))
    labels = [x for x in supercladeCohorts]
    for i in range(n_bins):
        row = 0
        t = t_0 - (i * bin_size)
        for superclade in supercladeCohorts:
            count = 0
            cohort = supercladeCohorts[superclade]
            for subclade in cohort:
                start, end = subcladeStartEnds[subclade]
                #print(start, end)

                if start >= (t - (0.5 * bin_size)) and end < (t + (0.5 * bin_size)):
                    count += 1
            superCladeRichnessMatrix[row, i] = count
            row += 1
    return superCladeRichnessMatrix, labels
                
                

def getTaxonomyDict(subcladeResolution, supercladeResolution, subset="Metazoa", verbose=False):
    localFilename = subset + ".csv"
    if localFilename in os.listdir():
        cladeDataframe = pd.read_csv(localFilename, low_memory=False)
    else:
        return None
    r, c = cladeDataframe.shape
    taxonomyDict = {}
    seen = []
    for i in range(r):
        subcladeName = cladeDataframe.at[i, subcladeResolution]
        if not subcladeName or str(subcladeName) == "nan" or "SPECIFIED" in subcladeName:
            if verbose and subcladeName not in seen:
                seen.append(subcladeName)
                print(subcladeName)
            continue
        supercladeName = cladeDataframe.at[i, supercladeResolution]
        if not supercladeName or str(supercladeName) == "nan" or "SPECIFIED" in supercladeName:
            if verbose:
                if (subcladeName not in seen) or (supercladeName not in seen):
                    seen.append(subcladeName)
                    seen.append(supercladeName)
                    #print(subcladeName, supercladeName, taxonomyDict.get(subcladeName))
            continue
        taxonomyDict[subcladeName] = supercladeName
    return taxonomyDict

def getDurationMatrix(subcladeDict, n_bins=50, t_0=550, max_duration=200, 
                      min_occurrences=1, min_unique=False, return_labels=False,
                      occurrence_processor=None):
    bin_len = t_0/n_bins
    if min_unique:
        not_permissible = [x for x in subcladeDict if len(set(subcladeDict[x])) < min_occurrences]
    else:
        not_permissible = [x for x in subcladeDict if len(subcladeDict[x]) < min_occurrences]
    not_permissible += [x for x in subcladeDict if (max(subcladeDict[x]) - min(subcladeDict[x])) > max_duration]
    num_not_permissible = len(set(not_permissible))
    durationMatrix = np.zeros((len(subcladeDict) - num_not_permissible, n_bins))
    i = 0
    if return_labels:
        labels = []
    for subclade in subcladeDict:
        occurrences = subcladeDict[subclade]
        if min_unique:
            if len(set(occurrences)) < min_occurrences:
                continue
        else:
            if len(occurrences) < min_occurrences:
                continue

        first = max(occurrences)
        last = min(occurrences)
        total_duration = first - last
        if total_duration > max_duration:
            continue
        
        if occurrence_processor:
            occurrences = occurrence_processor(occurrences.copy())
            first = max(occurrences)
            last = min(occurrences)
            total_duration = first - last
        
        start = max(t_0 - first, 0)
        start_bin = int((start // bin_len)) # Previously added 1 here
        end = max(t_0 - last, 0)
        end_bin = int((end // bin_len)) # Previously added 1 here
        #if t_0 - first > 0:
        age = 1
        #else:
        #    age = (first - t_0) // bin_size
        for j in range(start_bin, min(n_bins, end_bin + 1)):
            durationMatrix[i, j] = age
            age += 1
        i += 1
        if return_labels:
            labels.append(subclade)
    return durationMatrix if not return_labels else (durationMatrix, labels)

def occurrence_bootstrapper(occurrence_list, n_iter=100):
    durations = {}
    for i in range(n_iter):
        sample = np.random.choice(occurrence_list, replace=True, size=len(occurrence_list))
        start = max(sample)
        end = min(sample)
        dur = start - end
        if dur in durations:
            durations[dur][1] += 1
        else:
            durations[dur] = [(start,end), 1]
    #print(durations)
    min_dur = min(durations.values(), key=lambda x:(x[0][0]-x[0][1] + 1)**2/x[1])
    most_common = max(durations.values(), key=lambda x:x[1])
    #ratio = (min_dur[0][0] - min_dur[0][1]) / ((most_common[0][0] - most_common[0][1]) if most_common[0][0] != most_common[0][1] else 1)
    #print(ratio, min_dur, most_common)
    return [x for x in occurrence_list.copy() if (x <= min_dur[0][0]) and (x >= min_dur[0][1])]

# Returns survival matrices for a given duration matrix. Can be subdivided
# based on the age or richness of each subclade's superclade by specifying
# a taxDict. Can also return speciation rates.
#
# Max tail: What is the maximum age for which we're calculating survivorship
# min_per_bin: The minimum number of datapoints which can be used to infer 
# a survival rate for a particular age at a particular time_bin. The bin is 
# filled with a -1 in such instances. Not relavant if proportions=False.
# taxDict: A dictionary mapping taxa in the subclade to their appropriate 
# superclade. Requires specification of supercladeDurationMatrix or supercladeRichnessMatrix
# supercladeDurationMatrix: A duration matrix for the superclade you want to split by.
# Does nothing if taxDict not specified
# supercladeRichnessMatrix: A richness matrix per time bin for the superclade you want
# to split by. Does nothing if taxDict not specified.



def getSurvivalMatrix(durationMatrix, max_tail=10, min_per_bin=1, taxDict=None, 
                      supercladeDurationMatrix=None, taxon_labels=None, 
                      super_labels=None, proportions=True, young_cutoff=5, mid_cutoff=9,
                      speciation=False, supercladeRichnessMatrix=None, verbose=True):
    n_taxa, n_bins = durationMatrix.shape
    age_bins = max_tail if (not speciation) else 1
    survivalMatrices = [np.zeros((age_bins, n_bins))]
    if taxDict:
        if not ((supercladeDurationMatrix is None) ^ (supercladeRichnessMatrix is None)):
            raise SyntaxError("Only provide either a duration matrix or richness"
                              "dictionary for the superclade")
        survivalMatrices += [np.zeros((age_bins, n_bins)), np.zeros((age_bins, n_bins))]
    if not proportions and taxDict:
        tot_lists = [[],[],[]]
    elif not proportions:
        tot_lists = [[]]
    young, mid, old = (0, 0, 0)
    assigned, total_checked = (0, 0)
    for t in range(n_bins - 1):
        time_slice = durationMatrix[:, t:t + 2].copy()
        survDicts = [{x:[0,0] for x in range(1, max_tail + 1)}]
        if taxDict:
            survDicts.append({x:[0,0] for x in range(1, max_tail + 1)})
            survDicts.append({x:[0,0] for x in range(1, max_tail + 1)})
        for i in range(n_taxa):
            age = time_slice[i, 0]
            fate = time_slice[i, 1]
            if age > 0 and ((age <= age_bins) or speciation):
                if len(survDicts) == 3:
                    superclade = taxDict.get(taxon_labels[i])
                    total_checked += 1
                    if superclade == None or superclade not in super_labels:
                        continue
                    else:
                        assigned += 1
                    j = super_labels.index(superclade)
                    if not (supercladeDurationMatrix is None):
                        superclade_age = supercladeDurationMatrix[j, t]
                        if superclade_age <= young_cutoff:
                            young += 1
                            survDict = survDicts[0]
                        elif superclade_age > young_cutoff: #<= mid_cutoff:
                            mid += 1
                            survDict = survDicts[1]
                        #else:
                        #    old += 1
                        #    survDict = survDicts[2]
                    elif not (supercladeRichnessMatrix is None):
                        superclade_richness = supercladeRichnessMatrix[j, t]
                        if superclade_richness <= young_cutoff:
                            young += 1
                            survDict = survDicts[0]
                        elif superclade_richness > young_cutoff: #<= mid_cutoff:
                            mid += 1
                            survDict = survDicts[1]
                        #else:
                        #    old += 1
                        #    survDict = survDicts[2]
                else:
                    survDict = survDicts[0]
                if not speciation:
                    if fate > 0:
                        survDict[age][0] += 1
                else:
                    if age == 1:
                        survDict[age][0] += 1
                survDict[age][1] += 1
        for i in range(len(survDicts)):
            survDict = survDicts[i]
            if speciation:
                survivalMatrices[i][0, t] = survDict[1][0]
            else:
                survProps = []
                if not proportions:
                    tots = []
                for age in survDict:
                    surv, tot = survDict[age]
                    if (tot < min_per_bin) and proportions:
                        survProps.append(-1)
                        if not proportions:
                            tots.append(0)
                    elif proportions:
                        survProps.append(surv/(max(tot, 1)))
                    else:
                        survProps.append(surv)
                        tots.append(tot)
                if not proportions:
                    tot_lists[i].append(tots)
                survivalMatrices[i][:, t] = np.array(survProps)
    if speciation:
        survivalMatrices = [x.flatten() for x in survivalMatrices]

    if taxDict and verbose:
        if not (supercladeDurationMatrix is None):
            print(young, "observations of genera in young superclades and", mid, "in old superclades")
        elif not (supercladeRichnessMatrix is None):
            print(young, "observations of genera in poor superclades and ", mid, "in rich superclades")
        print("checked:", total_checked, "assigned to valid superclade:", assigned, "ratio:", assigned/total_checked)
    if not proportions:
        for x in tot_lists: x.append([0] * max_tail)
        if not taxDict:
            return survivalMatrices, tot_lists[0]
        else:
            return survivalMatrices, tot_lists
    return survivalMatrices if len(survivalMatrices) != 1 else survivalMatrices[0]


def getSurvivalsByAge(durationMatrix, max_age=10):
    n_taxa, n_bins = durationMatrix.shape
    fracSurvival = {}
    tots = []
    for i in range(1, max_age + 1):
        surv, tot = (0, 0)
        indices_of_age = np.where(durationMatrix==i)
        coords = list(zip(indices_of_age[0], indices_of_age[1]))
        for r, c in coords:
            if c + 1 < n_bins:
                fate = durationMatrix[r, c+1]
                if fate != 0:
                    surv += 1
                tot += 1
            else:
                continue
                #surv += 1
                #tot += 1
        fracSurvival[i] = surv / tot if tot > 0 else -1
        tots.append(tot)
    return fracSurvival, [tots]
    
# Gets the min and max ages of all confirmed occurences of each subclade
# of interest, given a dictionary generated by getSubcladeOccurences (above)
# 
# Change min_occurence_threshold to only use subclades that at least that number
# of occurences within the database, letting you limit to subclades with
# better fossil records. By default it's set to 2 as that is the minimum number of 
# occurences that can be used to create an age range.
# 
# Stored as {"subclade_1":(old_1, young_1), "subclade_2":(old_2, young_2), ...}
def getSubcladeAgeRanges(subcladeDict, min_occurence_threshold=2, keep_short_lived=True):                
    subcladeStartEnds = {}
    for subclade in subcladeDict:
        occurences = subcladeDict[subclade]
        if len(occurences) < min_occurence_threshold:
            continue
        old = max(occurences)
        young = min(occurences)
        if not keep_short_lived:
            if old == young:
                continue
        subcladeStartEnds[subclade] = (old, young)
    return subcladeStartEnds

# Gets the longevities (first occurence - last occurence) of all subclades
# contained within a dictionary of age ranges previously generated 
# by getSubcladeAgeRanges (above)
#
# Stored as {"subclade_1":longevity_1, "subclade_2":longevity_2, ...}
def getSubcladeLongevities(subcladeStartEnds):
    subcladeAges = {subclade:subcladeStartEnds[subclade][0] - subcladeStartEnds[subclade][1] 
                    for subclade in subcladeStartEnds}
    sortedSubcladeAges = [(k, v) for k, v in sorted(subcladeAges.items(), key=lambda item: item[1])]
    sortedSubcladeAges.reverse()
    return sortedSubcladeAges

def getOccurenceSpacings(subcladeDict):
    subcladeDict = dict(subcladeDict)
    intervals = []
    for subclade in subcladeDict:
        occurences = subcladeDict[subclade]
        occurences.sort()
        occurences.reverse()
        prev = None
        for o in occurences:
            if not prev:
                prev = o
            elif o == prev:
                continue
            else:
                interval = prev - o
                prev = o
                intervals.append(interval)
    return intervals
        
        

In [1]:
# Plots a given age-range dictionary (generated
# via getSubcladeAgeRanges) as a pseudo-Sepkoski curve
# The plot is basically: "Of the subclades in the provided dictionary,
# how many are in existence at each given time-point?"
#
# n_bins is the number of time-points. It probably doesn't 
# need to be messed with
#
# If you want to mark particular major events (e.g. mass extinctions), 
# pass a list of their timings to the event_lines argument
def plotSepkoski(subcladeStartEnds, n_bins=100, event_lines=None):
    fig, ax = plt.subplots()
    oldest = max([subcladeStartEnds[subclade][0] for subclade in subcladeStartEnds])
    youngest = min([subcladeStartEnds[subclade][1] for subclade in subcladeStartEnds])
    xs = np.linspace(oldest, youngest, n_bins)
    ys = np.zeros(n_bins)
    for i in range(n_bins):
        count = 0
        for subclade in subcladeStartEnds:
            old, young = subcladeStartEnds[subclade]
            if xs[i] < old and xs[i] > young:
                count += 1
        ys[i] = count

    ax.set_xlabel("Time (Ma)")
    ax.set_ylabel("Count")
    ax.plot(xs, ys)
    if event_lines:
        y_max = ax.get_ylim()[1]
        ax.vlines(event_lines, 0, ax.get_ylim()[1], colors="tab:red", linestyles=':')
        ax.set_ylim((0, y_max))
        if 0 in event_lines:
            ax.set_xlim(left=0)
    ax.invert_xaxis()

def plotSepkoskiBins(startEnds, n_bins=100, event_lines=None, t_0=None):
    fig, ax = plt.subplots()
    if not isinstance(next(iter(startEnds.values())), dict):
        print(next(iter(startEnds.values())))
        startEnds = {"only":startEnds}
    for condition in startEnds:
        subcladeStartEnds = startEnds[condition]
        if subcladeStartEnds == {}:
            continue
        oldest = max([subcladeStartEnds[subclade][0] for subclade in subcladeStartEnds])
        youngest = min([subcladeStartEnds[subclade][1] for subclade in subcladeStartEnds])

        xs = np.linspace(oldest, youngest, n_bins)
        bin_size = (oldest - youngest) / n_bins
        ys = np.zeros(n_bins)
        for i in range(n_bins):
            count = 0
            for subclade in subcladeStartEnds:
                old, young = subcladeStartEnds[subclade]
                if xs[i] - (bin_size / 2) < old and xs[i] + (bin_size / 2) > young:
                    count += 1
            ys[i] = count

        ax.plot(xs, ys, label=condition)
    if event_lines:
        y_max = ax.get_ylim()[1]
        ax.vlines(event_lines, 0, ax.get_ylim()[1], colors="tab:red", linestyles=':')
        ax.set_ylim((0, y_max))
        if 0 in event_lines:
            ax.set_xlim(left=0)
    if t_0 != None:
        ax.set_xlim(right=t_0)
    if "only" not in startEnds:
        ax.legend()
        
    major_ticks = np.arange(0, t_0 if t_0 != None else 1000, 100)
    minor_ticks = np.arange(0, t_0 if t_0 != None else 1000, 10)
    ax.set_xticks(major_ticks)
    ax.set_xticks(minor_ticks, minor=True)
    ax.grid(which='both', axis="x")

    ax.grid(which='minor', alpha=0.2)
    ax.grid(which='major', alpha=0.5)
        
    ax.set_xlabel("Time (Ma)")
    ax.set_ylabel("Count")
    ax.invert_xaxis()
    
    
    
# Plots a given age-range dictionary (generated
# via getSubcladeAgeRanges) as a stack of bars  
# representing the age ranges
# 
# Change res, fig_dims, and line_width to change the 
# resolution, dimensions of the figure, and width of the
# bars, respectively
def plotStackedBars(subcladeStartEnds, res=200, fig_dims=(6, 10), line_width=0.5):
    fig, ax = plt.subplots(figsize=fig_dims, dpi=res)
    oldest = max([subcladeStartEnds[subclade][0] for subclade in subcladeStartEnds])
    youngest = min([subcladeStartEnds[subclade][1] for subclade in subcladeStartEnds])
    subcladeStartEnds = dict(subcladeStartEnds)
    subcladeStartEnds = [(k, v) for k, v in sorted(subcladeStartEnds.items(), key=lambda item: item[1][0])]
    subcladeStartEnds.reverse()
    ax.set_xlim(youngest, oldest)
    ax.set_ylim(0, len(subcladeStartEnds) + 3)
    count = 1
    for subclade, ages in subcladeStartEnds:
        old, young = ages
        plt.plot([young, old], [count, count], 'k-', lw=line_width)
        count += 1


    ax.invert_xaxis()
    ax.set_xlabel("Time (Ma)")
    ax.set_ylabel("Count")

# Plots a given OCCURENCE TIME dictionary (generated
# via getSubcladeOccurences) as a stack of bars  
# representing the age ranges WITH colorful dots
# denoting the timing of each individual occurence
# 
# Change res, fig_dims, and line_width to change the 
# resolution, dimensions of the figure, and width of the
# bars, respectively
def plotStackedBarsWithDots(subcladeAgeDict, res=200, fig_dims=(4,5), line_width=0.5, 
                            min_occurence_threshold=5, colorcode=False, event_lines=None,
                            text=False):
    subcladeStartEnds = getSubcladeAgeRanges(subcladeAgeDict, min_occurence_threshold=min_occurence_threshold)
    
    fig, ax = plt.subplots(figsize=fig_dims, dpi=res)
    oldest = max([subcladeStartEnds[subclade][0] for subclade in subcladeStartEnds])
    youngest = min([subcladeStartEnds[subclade][1] for subclade in subcladeStartEnds])
    subcladeStartEnds = [(k, v) for k, v in sorted(subcladeStartEnds.items(), key=lambda item: item[1][0])]
    subcladeStartEnds.reverse()
    ax.set_xlim(youngest, oldest)
    ax.set_ylim(0, len(subcladeStartEnds) + 3)
    ax.invert_xaxis()
    count = 1
    if colorcode:
        alt = 0
        prev = "red"
    for subclade, ages in subcladeStartEnds:
        old, young = ages
        plt.plot([young, old], [count, count], 'k-', lw=line_width)
        xs = list(set(subcladeAgeDict[subclade]))
        ys = [count] * len(xs)
        if colorcode:
            if young > 25:
                if prev == "blue":
                    prev = "red"
                    alt = 0
                if alt == 0:
                    color="red"
                else:
                    color="maroon"
            else:
                if prev == "red":
                    prev = "blue"
                    alt = 0
                if alt == 0:
                    color="blue"
                else:
                    color="navy"
            alt = not alt
            plt.scatter(xs, ys, marker='.', color=color)
        else:
            plt.scatter(xs, ys, marker='.')
        if text:
            plt.text(max(xs), count + 0.25, subclade)
        #print([young,old])
        count += 1
    if event_lines:
        y_max = ax.get_ylim()[1]
        ax.vlines(event_lines, 0, ax.get_ylim()[1], colors="tab:red", linestyles=':')
        ax.set_ylim((0, y_max))
        if 0 in event_lines:
            ax.set_xlim(left=0)

    ax.set_xlabel("Time (Ma)")
    ax.set_ylabel("Count")


# Plots a given age-range dictionary (generated
# via getSubcladeAgeRanges) as a histogram of the longevities
# of each of the subclades
# 
# Increase n_bins if you want more bins
def plotAgeHistogram(subcladeStartEnds, n_bins=10):
    longevitiesTop10 = getSubcladeLongevities(subcladeStartEnds)[:10]
    print("Top ten most persistent subclades:")
    for i in range(10):
        subclade, lifespan = longevitiesTop10[i]
        print(str(i+1) + ". " + subclade + ": " + str(round(lifespan, 1)) + " Ma") 
    fig, ax = plt.subplots()
    subcladeAges = [subcladeStartEnds[subclade][0] - subcladeStartEnds[subclade][1] 
                    for subclade in subcladeStartEnds]
    ax.hist(subcladeAges, bins=n_bins)
    ax.set_xlabel("Clade Longevity")
    ax.set_ylabel("Count")

# Plots a given age-range dictionary (generated
# via getSubcladeAgeRanges) as a histogram of the longevities
# of each of the subclades
# 
# Decrease cov_factor to get more resolution
# Increase cov_factor to get a smoother curve
def plotAgeDensityDist(subcladeStartEnds, cov_factor=0.2):
    subcladeAges = np.array([subcladeStartEnds[subclade][0] - subcladeStartEnds[subclade][1] 
                             for subclade in subcladeStartEnds])
    density = gaussian_kde(subcladeAges)
    # Decrease this value to increase resolution:
    density.covariance_factor = lambda : cov_factor
    density._compute_covariance()
    xs = np.linspace(int(min(subcladeAges) * 0.9), int(max(subcladeAges) * 1.1), 100)

    fig, ax = plt.subplots()
    ax.plot(xs, density(xs))
    ax.set_xlabel("Clade Longevity")
    ax.set_ylabel("Density")

In [1]:
def getSubcladeEcospace(clade, environment_only=None, resolution="genus"):
    ecospace_dict = {}
    details = ["motility","life_habit","vision","diet","reproduction",
               "ontogeny", "lithology1", "lithification1","composition", "reinforcement"]
    localFilename = clade + "_" + environment_only + ".csv" if (not (environment_only is None)) else clade + ".csv"
    cladeDataframe = pd.read_csv(localFilename, low_memory=False)
    #details = list(cladeDataframe.columns.values)
    #print(details)
    n, d = cladeDataframe.shape
    for i in range(n):
        name = cladeDataframe.at[i,resolution]
        if pd.isna(name): continue
        if name not in ecospace_dict:
            ecospace_dict[name] = {x:[] for x in details}
        for detail in details:
            datapoint = cladeDataframe.at[i,detail]
            if not pd.isna(datapoint):
                ecospace_dict[name][detail] += str(datapoint).split(",")
    for name in ecospace_dict:
        for detail in ecospace_dict[name]:
            ecospace_dict[name][detail] = list(set(ecospace_dict[name][detail]))
            ecospace_dict[name][detail] = [x.strip() for x in ecospace_dict[name][detail]]
    return ecospace_dict

In [1]:
class LogitRegression(LinearRegression):
    def fit(self, x, p):
        p[p > max_p] = max_p
        p[p < min_p] = min_p
        p = np.asarray(p)
        y = np.log(p / (1 - p))
        return super().fit(x, y)

    def predict(self, x):
        y = super().predict(x)
        return 1 / (np.exp(-y) + 1)

# old/defunct (returns odds, not log-odds)
def getLogOddsLinear(survivalMatrix, max_p = 0.99, min_p = 0.01, limit_ages=False):
    if isinstance(survivalMatrix, dict):
        survivalMatrix = np.array([list(survivalMatrix.values())]).T
    n_ages, n_bins = survivalMatrix.shape
    logOdds = []
    models = []
    for t in range(n_bins):
        if limit_ages:
            x = np.array(list(range(0, limit_ages)))
            y = survivalMatrix[:limit_ages, t]
        else:
            x = np.array(list(range(0, n_ages)))
            y = survivalMatrix[:, t]
        #not_missing_not_zero_one = np.argwhere((y != 0.5) & (y != 0) & (y != 1) & (y != -1))
        not_missing = np.argwhere(y != -1)
        x = x[not_missing]
        x = x.reshape(-1, 1)
        y = y[not_missing]
        if len(y) < 2:
            logOdds.append(np.nan)
            models.append(np.nan)
            continue
        model = LogitRegression()
        model.fit(x, y)
        models.append(model)
        logOdds.append(np.exp(model.coef_)[0][0])
        #print(len(logOdds), len(models))
    return logOdds, models

#maximum likelihood estimator (returns odds, not log-odds)
def getLogOdds(survivalMatrix, n_taxa=None, num_bootstraps=10, limit_ages=False,
               ignore_youngest=False, count_by_bin=None, verbose=False):
    if isinstance(survivalMatrix, dict):
        survivalMatrix = np.array([list(survivalMatrix.values())]).T
    n_ages, n_bins = survivalMatrix.shape
    logOdds = []
    models = []
    all_bootstraps = []
    for t in range(n_bins):
        if verbose:
            clear_output(wait=True)
            print("Time bin:", t, "/", n_bins)
        if limit_ages:
            ages = np.array(list(range(0, limit_ages)))
            survival_rates = survivalMatrix[:limit_ages, t]
        else:
            ages = np.array(list(range(0, n_ages)))
            survival_rates = survivalMatrix[:, t]
        not_missing = np.argwhere(survival_rates != -1)
        ages = ages[not_missing]
        ages = ages.reshape(-1, 1)
        survival_rates = survival_rates[not_missing]
        if len(survival_rates) < 2:
            logOdds.append(np.nan)
            models.append(np.nan)
            all_bootstraps.append([])
            continue
        true_x, true_y = [], []
        for i in range(len(ages)):
            age = int(ages[i])
            if count_by_bin is None or count_by_bin is []:
                num_at_age = n_taxa // n_bins // len(ages)
            else:
                num_at_age = count_by_bin[t][age]
            if ignore_youngest and i == 0:
                continue
                num_surv = round(float(num_at_age * survival_rates[1]))
                num_die = round(float(num_at_age * (1 - survival_rates[1])))
            else:
                num_surv = round(float(num_at_age * survival_rates[i]))
                num_die = round(float(num_at_age * (1 - survival_rates[i])))
                #print(i, num_surv, num_die + num_surv, survival_rates[i])
            true_y += [1] * num_surv
            true_y += [0] * num_die
            true_x += [age] * (num_surv + num_die)
        if (0 not in true_y) or (1 not in true_y):
            logOdds.append(np.nan)
            models.append(np.nan)
            all_bootstraps.append([])
            continue
        true_x, true_y = np.array(true_x), np.array(true_y)
        true_x = true_x.reshape(-1, 1)
        true_model = LogisticRegression()
        true_model.fit(true_x, true_y)
        true_model_inverse = LogitRegression()
        true_model_inverse.coef_ = true_model.coef_
        true_model_inverse.intercept_ = true_model.intercept_
        models.append(true_model_inverse)
        logOdds.append(np.exp(true_model.coef_)[0][0])
        
        bootstraps = []
        if num_bootstraps>0:
            try:
                n_occurrences_at_t = np.sum(count_by_bin[t])
                frac_by_age_bin_at_t = count_by_bin[t] / n_occurrences_at_t
                ages =  np.random.choice(list(range(n_ages)), size=n_occurrences_at_t, p=frac_by_age_bin_at_t)
                for i in range(num_bootstraps):
                    x, y = [], []
                    for i in range(n_occurrences_at_t):
                        #age = np.random.randint(1, len(ages))
                        age = ages[i]
                        rate = float(survivalMatrix[age, t])
                        life = np.random.choice([0, 1], p=[1 - rate, rate])
                        x.append(age)
                        y.append(life)
                    x, y = np.array(x), np.array(y)
                    x = x.reshape(-1, 1)
                    model = LogisticRegression()
                    model.fit(x, y)
                    bootstraps.append(np.exp(model.coef_)[0][0])
                all_bootstraps.append(bootstraps)
            except ValueError:
                all_bootstraps.append([])
    return logOdds, models, all_bootstraps


def get_extinction_rate(durationMatrix, proportions=True):
    n_taxa, n_bins = durationMatrix.shape
    rates = []
    for i in range(n_bins - 1):
        num_ext = 0
        tot = 0
        for j in range(n_taxa):
            if durationMatrix[j, i] > 0:
                tot += 1
                if durationMatrix[j, i + 1] == 0:
                    num_ext += 1
        if tot == 0:
            rates.append(np.nan)
        else:
            if proportions:
                rates.append(num_ext / tot)
            else:
                rates.append(num_ext)
    rates = rates + [np.nan]
    return rates

def get_speciation_rate(durationMatrix):
    new_taxa_by_bin = np.count_nonzero(durationMatrix == 1, axis=0)
    return new_taxa_by_bin

def get_richness(durationMatrix):
    n_taxa_by_bin = np.count_nonzero(durationMatrix > 0, axis=0)
    return n_taxa_by_bin

NameError: name 'LinearRegression' is not defined

In [None]:
def get_column_ratios(surv_dicts, max_ind):
    tot = len(surv_dict)
    ratios = []
    for i in range(len(surv_dicts)):
        ratios.append(surv_dicts[i][max_ind] / surv_dicts[i][max_ind-1])
    return ratios

In [3]:
def monoExp(x, m, t):
    return m * np.exp(-t * x)

def exponential_fit(x, y, m0, t0):
    params, cv = scipy.optimize.curve_fit(monoExp, x, y, (m0, t0))
    m, t = params
    return m, t

    # determine quality of the fit
    #squaredDiffs = np.square(ys - monoExp(xs, m, t, b))
    #squaredDiffsFromMean = np.square(ys - np.mean(ys))
    #rSquared = 1 - np.sum(squaredDiffs) / np.sum(squaredDiffsFromMean)
    #print(f"R² = {rSquared}")

def getExponentialFits(survivalMatrix):
    if isinstance(survivalMatrix, dict):
        survivalMatrix = np.array([list(survivalMatrix.values())]).T
    bins_included, n_bins = survivalMatrix.shape
    params = []
    fig,ax = plt.subplots()
    for i in range(n_bins):
        col = survivalMatrix[:,i]
        ratios = []
        for max_age in range(1, bins_included):
            cur, prev = col[max_age], col[max_age-1]
            if (cur > -1) and (prev > 0):
                ratios.append((cur / prev) - 1)
            else:
                ratios.append(None)
            
        xs = np.array([x+1 for x in range(bins_included-1) if ratios[x] != None])
        if len(xs) < 2:
            params.append((np.nan, np.nan))
            continue
        ratios = [x for x in ratios if x != None]
        try:
            ax.plot(xs, ratios,color="green")
            m, t = exponential_fit(xs, ratios, max(ratios), 1)
            params.append((m, t))
        except:
            ax.plot(xs, ratios,color="red", zorder=100)
            print("Bin", i, "failed")
            params.append((np.nan, np.nan))
    return params


In [None]:
def pluralize_rank(rank):
    plural = rank.lower()
    for s, p in [("species","species"),("genus","genera"),("family","families"),
                 ("order","orders"),("class","classes"),("phylum","phyla"),
                 ("kingdom","kingdoms"),("clade", "clades")]:
        plural = plural.replace(s, p)
    if plural == rank:
        print("Error: Could not pluralize %s. Returning as is."%(rank))
    return plural