In [None]:
import dabest
import seaborn as sns
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt


def sample_bootstrap(bootstrap, m, n, reverse_neg, abs_rank, chop_tail):
    bootstrap.sort() 
    chop_tail_int = int(np.ceil(len(bootstrap)*chop_tail/100))
    bootstrap = bootstrap[chop_tail_int : len(bootstrap) - chop_tail_int]
    # print(len(bootstrap))
    ranks_to_look = np.linspace(0,len(bootstrap), m * n, dtype = int)   
    ranks_to_look[0] = 1
    if np.sum(bootstrap>0)<len(bootstrap)/2:
        if reverse_neg == True:
            bootstrap = bootstrap[::-1]
    if abs_rank == True:
        bootstrap = bootstrap[np.argsort(np.abs(bootstrap))]
    long_ranks = [bootstrap[r - 1] for r in ranks_to_look]
    return long_ranks
    
def spiralize(fill, m, n):
    i = 0
    j = 0
    k = 0
    array = np.zeros((m, n))
    while m>0:
        jj = j
        ii = i
        for j in range(j, n):
            array[i, j] = fill[k]
            k += 1
        for i in range(ii+1, m):
            array[i, j] = fill[k]
            k += 1
        for j in range(n-2,jj-1,-1):
            array[i, j] = fill[k]
            k += 1
        for i in range(n-2,ii,-1):
            array[i, j] = fill[k]
            k += 1        
        m -= 1
        n -= 1
        j += 1
        if k > len(fill):
            break
        # print (array)
    return(array)

def spiral_heatmap(contrasts, n, ylabels, xlabels, sort_by = None, vmax = 3, vmin = -3, reverse_neg = True, abs_rank = False, chop_tail = 0, ax = None, delta_type = ['delta']):
    spirals = pd.DataFrame(np.zeros((len(contrasts) * n, len(contrasts[0]) * n)))
    from mpl_toolkits.axes_grid1.inset_locator import inset_axes
    if sort_by is not None:
        xlabels = [xlabels[i] for i in sort_by]
    mean_delta = pd.DataFrame(np.zeros((len(contrasts) , len(contrasts[0]))), columns = xlabels, index = ylabels)
    for i in range(len(contrasts)):
        for j in range(len(contrasts[0])):
            if sort_by is not None:
                jj = sort_by[j]
            else:
                jj = j
            if delta_type[i] == 'delta':
                bootstrap = contrasts[i][jj].hedges_g.results.bootstraps[0]
            elif delta_type[i] == 'deltadelta':
                bootstrap = contrasts[i][jj].delta_g.delta_delta.bootstraps_delta_delta
            long_ranks = sample_bootstrap(bootstrap, n, n, reverse_neg = reverse_neg, abs_rank = abs_rank, chop_tail = chop_tail)
            spiral = spiralize(long_ranks, n, n)
            spirals.iloc[i*n:i*n+n, j*n:j*n+n] = spiral
            mean_delta.iloc[i, j] = np.mean(long_ranks)
    if ax == None:
        f, a = plt.subplots(1, 1)
    else:
        a = ax
    sns.heatmap(spirals, cmap = 'vlag', cbar_kws={"shrink": 0.5}, ax = a, vmax = vmax, vmin = vmin)
    a.set_xticks(np.linspace(n/2, len(contrasts[0])*n-n/2, len(contrasts[0])))
    a.set_xticklabels(xlabels, rotation = 45, ha= 'right')
    a.set_yticks(np.linspace(n/2, len(contrasts)*n-n/2, len(contrasts)));
    a.set_yticklabels(ylabels, ha= 'right', rotation = 0)
    if ax == None:
        f.gca().set_aspect('equal')
        f.set_size_inches(len(contrasts[0])/3,  len(ylabels))
        return f, a, mean_delta
    else:
        return a, mean_delta