In [None]:
#default_exp plotting.manifolds

In [None]:
# export 
import matplotlib.pyplot as plt
import matplotlib as mpl
mpl.rcParams['text.usetex'] = True
mpl.rcParams['text.latex.preamble'] = [r'\usepackage{amsmath}'] #for \text command
import numpy as np

In [None]:
# export
import numpy as np

def plot_manifold_2d(data, s, alpha, c, title, title_color = 'black'):
    fig, axs = plt.subplots(1,1, figsize = (15,10))

    axs.set_axis_off()
    s = 1
    alpha = .5
    selected_points = list(range(data.shape[0]))
    x = data[:, 0][selected_points]
    y = data[:, 1][selected_points]
    ax = fig.add_subplot(1,1,1)

    ax.scatter(x, y,  s=s, alpha=alpha, marker='.',vmin=0,vmax=np.pi, c=c )
    ax.set_xlabel(r'$\phi_1$', fontsize = 60)
    ax.set_ylabel(r'$\phi_2$', fontsize = 60)
    ax.set_title(title, fontsize = 80, color= title_color)

    xmin = np.float(np.format_float_positional(data[:, 0][selected_points].min(), precision=2, fractional=False))
    xmax = np.float(np.format_float_positional(data[:, 0][selected_points].max(), precision=2, fractional=False))
    ymin = np.float(np.format_float_positional(data[:, 1][selected_points].min(), precision=2, fractional=False))
    ymax = np.float(np.format_float_positional(data[:, 1][selected_points].max(), precision=2, fractional=False))
    ax.set_xticks([xmin, xmax])
    ax.set_yticks([ymin, ymax])
    ax.tick_params(labelsize=30)

In [None]:
# export
def plot_manifold_3d(data, s, alpha, c, title, title_color = 'black'):
    
    fig, axs = plt.subplots(1,1, figsize = (15,10))
    axs.set_axis_off()

    selected_points = list(range(data.shape[0]))
    x = data[:, 0][selected_points]
    y = data[:, 1][selected_points]
    z = data[:, 2][selected_points]

    ax = fig.add_subplot(1,1,1, projection='3d')
    ax.scatter(x, y, z, s=s,c = c, alpha=alpha, marker='.',vmin=0,vmax=np.pi)
    ax.set_xlabel(r'$\phi_1$', fontsize = 60)
    ax.set_ylabel(r'$\phi_2$', fontsize = 60)
    ax.set_zlabel(r'$\phi_3$', fontsize = 60)
    ax.set_title(title, fontsize = 80, color = title_color)

    xmin = np.float(np.format_float_positional(data[:, 0][selected_points].min(), precision=2, fractional=False))
    xmax = np.float(np.format_float_positional(data[:, 0][selected_points].max(), precision=2, fractional=False))
    ymin = np.float(np.format_float_positional(data[:, 1][selected_points].min(), precision=2, fractional=False))
    ymax = np.float(np.format_float_positional(data[:, 1][selected_points].max(), precision=2, fractional=False))
    zmin = np.float(np.format_float_positional(data[:, 2][selected_points].min(), precision=2, fractional=False))
    zmax = np.float(np.format_float_positional(data[:, 2][selected_points].max(), precision=2, fractional=False))
    
    ax.set_xticks([xmin, xmax])
    ax.set_yticks([ymin, ymax])
    ax.set_zticks([zmin, zmax])
    ax.tick_params(labelsize=30)
    ax.xaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
    ax.yaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
    ax.zaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
    # make the grid lines transparent
    ax.xaxis._axinfo["grid"]['color'] =  (1,1,1,0)
    ax.yaxis._axinfo["grid"]['color'] =  (1,1,1,0)
    ax.zaxis._axinfo["grid"]['color'] =  (1,1,1,0)
    

In [None]:
# export
import math
import matplotlib.pyplot as plt
import numpy as np

def plot_manifold_featurespace(data,title,ncord = 6):

    s = 100
    alpha = .5
    fig, axes = plt.subplots(ncord,ncord, figsize = (25,25))

    xmins = np.zeros(ncord)
    xmaxs = np.zeros(ncord)
    xmeans = np.zeros(ncord)
    for d in range(ncord):
        xmins[d] = data[:,d].min()
        xmaxs[d] = data[:,d].max()    
        xmeans[d] = (xmins[d] +  xmaxs[d]) / 2

    gap = (xmaxs - xmins).max()

    for d in range(ncord):
        xmins[d] = xmeans[d] - gap / 2
        xmaxs[d] = xmeans[d] + gap / 2  

    for d in range(ncord):
        print(d)
        for e in range(ncord):
            xticks = list(range(math.ceil(xmins[d]), math.ceil(xmaxs[d])))
            yticks = list(range(math.ceil(xmins[e]), math.ceil(xmaxs[e])))
            if d != e:
                axes[e,d].scatter(data[:,d],data[:,e], s = .1, alpha = .1)
                axes[e,d].set_xlim(xmins[d], xmaxs[d])
                axes[e,d].set_ylim(xmins[e], xmaxs[e])
                axes[e,d].set_xticks(xticks)
                axes[e,d].set_yticks(yticks)

            if d == e:

                axes[d,e].hist(data[:,d])
                axes[e,d].set_xticks(xticks)

    for d in range(ncord):
        axes[ncord- 1,d].set_xlabel(r'$\xi_{{{}}}$'.format(d), fontsize= 30)
        axes[d,0].set_ylabel(r'$\xi_{{{}}}$'.format(d), fontsize= 30)

    fig = plt.gcf()
    fig.suptitle(title, fontsize=70)


