In [18]:
""" 
Run nSimplices + MDS vs MDS on subset (STOOL and VAGINA only) of HMP dataset 
"""

' \nRun nSimplices + MDS vs MDS on subset (STOOL and VAGINA only) of HMP dataset \n'

In [19]:
import numpy as np
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import os
import pandas as pd
import random as alea
from scipy.spatial.distance import pdist, squareform
from sklearn.manifold import MDS
import plotly.express as px


data_dir = "../data/"
output_dir = "../outputs"

# target_sites = ['NOSE']
# target_id = "N"
# target_colors = ["blue"]

# target_sites = ['NOSE', 'VAGINA', 'THROAT']
# target_id = "NVT"
# target_colors = ["darkgreen", "orange", "deeppink"]

target_sites = ['NOSE', 'THROAT']
target_id = "NT"
target_colors = ["blue", "orange"]

# if color == "deeppink":
#         return "THROAT"
#     if color == "black":
#         return "EARS"
#     if color == "cornflowerblue":
#         return "STOOL"
#     if color == "darkgreen":
#         return "NOSE"
#     if color == "red":
#         return "ELBOWS"
#     if color == "gray":
#         return "MOUTH"
#     if color == "orange":
#         return "VAGINA"

# target_sites = ['THROAT']
# target_id = "T"
# target_colors = ["deeppink"]


In [20]:
"""
Prepare dataset
"""

# iterate over QE and NB
files = ["hmp_v13lqphylotypeQuantNB_rs_c.csv", "hmp_v13lqphylotypeQuantE_rs_c.csv"]
color_df = pd.read_csv(os.path.join(data_dir, "hmp_v13lqphylotypePheno_rs_c.csv"), header=0)

for file in files:
    data_path = os.path.join(data_dir, file)
    df_hmp_ori = np.loadtxt(data_path, delimiter=",")
    df_hmp = []


    for index, row in color_df.iterrows():
        site_exist = False
        for site in target_sites:
            if row[site]:
                site_exist = True
        if site_exist:
            df_hmp.append(df_hmp_ori[index])
    df_hmp = np.array(df_hmp)
    np.savetxt(os.path.join(data_dir, file[:-8]+"target_"+target_id+"_"+file[-8:]), df_hmp, fmt='%1.17f', delimiter=',')


In [21]:
""" 
Prepare colors
"""

colors = []
new_color_df = color_df.copy(deep = True)
drop_indices = []

for index, row in color_df.iterrows():
    site_exist = False
    for i in range(len(target_sites)):
        site = target_sites[i]
        if row[site]:
            colors.append(target_colors[i])
            site_exist = True
    if not site_exist:
        drop_indices.append(index)

colors = np.array(colors)
print(colors.shape)

np.savetxt(os.path.join(data_dir, "hmp_target_"+target_id+"_"+"colors.txt"), colors, fmt="%s")

new_color_df = new_color_df.drop(drop_indices)
new_color_df.to_csv(os.path.join(data_dir, "hmp_v13lqphylotypePheno_target_"+target_id+"_"+"rs_c.csv"), header=True, index=False)

(270,)


In [22]:
""" 
Run nSimplices on HMP dataset
"""
colors = np.loadtxt(os.path.join(data_dir, "hmp_target_"+target_id+"_"+"colors.txt"), dtype="str")
exec(open("../nsimplices.py").read())
alea.seed(42)


In [23]:
""" 
Run
(1) NB normalization + nSimplices + cMDS 
(2) QE normalization + nSimplices + cMDS 

To derive the axes data
""" 

# output_files = ["hmp_target_"+target_id+"_"+"NB_nSimplices_cMDS_axes.txt", "hmp_target_"+target_id+"_"+"QE_nSimplices_cMDS_axes.txt"]
# data_files = ["hmp_v13lqphylotypeQuantNB_target_"+target_id+"_"+"rs_c.csv", "hmp_v13lqphylotypeQuantE_target_"+target_id+"_"+"rs_c.csv"]
output_files = ["hmp_target_"+target_id+"_"+"QE_nSimplices_cMDS_axes.txt"]
data_files = ["hmp_v13lqphylotypeQuantE_target_"+target_id+"_"+"rs_c.csv"]
subspace_dims = []
outlier_indices_list = []
for i in range(len(output_files)):
    output_file = output_files[i]
    data_file = data_files[i]
    axes_output_path = os.path.join(output_dir, output_file)
    print("======== NB/QE normalization + nSimplices + cMDS ========")
    data_path = os.path.join(data_dir, data_file)
    df_hmp = np.loadtxt(data_path, delimiter=",")
    hmp_dis_sq=squareform(pdist(df_hmp))

    feature_num = df_hmp.shape[1]
    dim_start = 1
    dim_end = 50
    # dim_start = 2
    # dim_end = 2

    print("hmp_dis_sq shape is:", hmp_dis_sq.shape)
    outlier_indices, subspace_dim , corr_pairwise_dis, corr_coord = nsimplices(hmp_dis_sq, feature_num, dim_start, dim_end, std_multi=3)
    print("subspace dimension is:", subspace_dim)
    subspace_dims.append(subspace_dim)
    outlier_indices_list.append(outlier_indices)

    # run cMDS to get the corrected coordinates in importance decreasing order
    _, _, Xe = cMDS(corr_pairwise_dis)
    np.savetxt(axes_output_path, Xe, fmt='%f')


hmp_dis_sq shape is: (270, 270)
dim in find_subspace_dim is: 1
dim in find_subspace_dim is: 2
dim in find_subspace_dim is: 3
dim in find_subspace_dim is: 4
dim in find_subspace_dim is: 5
dim in find_subspace_dim is: 6
dim in find_subspace_dim is: 7
dim in find_subspace_dim is: 8
dim in find_subspace_dim is: 9
dim in find_subspace_dim is: 10
dim in find_subspace_dim is: 11
dim in find_subspace_dim is: 12
dim in find_subspace_dim is: 13
dim in find_subspace_dim is: 14
dim in find_subspace_dim is: 15
dim in find_subspace_dim is: 16
dim in find_subspace_dim is: 17
dim in find_subspace_dim is: 18
dim in find_subspace_dim is: 19
dim in find_subspace_dim is: 20
dim in find_subspace_dim is: 21
dim in find_subspace_dim is: 22
dim in find_subspace_dim is: 23
dim in find_subspace_dim is: 24
dim in find_subspace_dim is: 25
dim in find_subspace_dim is: 26
dim in find_subspace_dim is: 27
dim in find_subspace_dim is: 28
dim in find_subspace_dim is: 29
dim in find_subspace_dim is: 30
dim in find_subsp





outlier_indices is: [ 23  42  62  89 108 112 135 143 155 191 204 213 234]
original coord is: [ 7.53157648e-02  2.24461952e-01 -1.68719183e-01  1.42681029e-01
 -2.16866478e-01  2.26650677e-01  1.44288305e-01 -9.33428255e-02
  2.94255223e-02 -1.81767298e-01 -1.27878266e-01 -1.23723295e-02
  8.82061539e-02 -9.79065908e-02 -8.60644240e-02 -5.13398178e-02
 -1.41853218e-01 -1.98794415e-01 -1.19644524e-01  1.48520677e-01
  1.33183080e-01  6.47050203e-02  1.18204351e-02 -5.30286719e-02
  7.21042827e-02  1.28624352e-01  1.59409231e-01  5.56290471e-02
  3.63183709e-02 -1.11728125e-01 -7.55694521e-02 -5.02529248e-02
 -2.19753368e-01 -1.60360567e-01  1.55828949e-02 -2.25614839e-01
 -1.21337311e-01  3.36662805e-02  7.86720950e-02  1.24081945e-01
  8.44655013e-02  1.36760406e-01  8.52986013e-02 -1.34482055e-01
 -9.73650637e-02  1.46603820e-01  1.32298917e-01 -9.89768656e-02
  7.14395853e-02 -2.11675770e-02 -1.40321128e-01 -2.04670385e-01
 -1.50426999e-01 -1.86687629e-01  1.26268522e-01 -3.26910612e-

In [24]:
""" 
Analyze the number of outliers
"""
print("subspace_dim is:", subspace_dim)
print("outlier number is:", len(outlier_indices_list[0]))
print("outlier percent is:", len(outlier_indices_list[0])/len(colors))

subspace_dim is: 2
outlier number is: 13
outlier percent is: 0.04814814814814815


In [25]:
# """ 
# Plot distribution of dimensions over the range of candidate dimensions 
# only for NB normalization
# """
# data_file  = "hmp_v13lqphylotypeQuantNB_target_"+target_id+"_"+"rs_c.csv"
# data_path = os.path.join(data_dir, data_file)
# df_hmp = np.loadtxt(data_path, delimiter=",")
# hmp_dis_sq=squareform(pdist(df_hmp))

# num_point = hmp_dis_sq.shape[0]
# hcolls = []
# start_dim = 1
# end_dim = 50
# for dim in range(start_dim, end_dim+1):
#     print("current dimension is:", dim)
#     heights = nsimplices_all_heights(num_point, hmp_dis_sq, dim, seed=dim+1)
#     hcolls.append(heights)

# # calculate median heights for tested dimension from start_dim to end_dim
# h_meds = []
# for hcoll in hcolls:
#     h_meds.append(np.median(hcoll))

# # calculate the ratio, where h_med_ratios[i] corresponds to h_meds[i-1]/h_meds[i]
# # which is the (median height of dim (i-1+start_dim))/(median height of dim (i+start_dim))
# h_med_ratios = []
# for i in range(1, len(hcolls)):
#     h_med_ratios.append(h_meds[i-1]/h_meds[i])

# # plot the height scatterplot and the ratios
# plt.figure(0)
# fig, ax1 = plt.subplots()
# color = 'red'
# ax1.set_xlabel(r'dimension tested $n$', fontsize=15)
# ax1.set_ylabel(r'median of heights', color = color, fontsize=15)
# ax1.scatter(list(range(start_dim, end_dim+1)), h_meds, color = color, s=10)
# ax1.plot(list(range(start_dim, end_dim+1)), h_meds, color=color)
# ax1.tick_params(axis ='y', labelcolor = color)
 
# # Adding Twin Axes to plot using dataset_2
# ax2 = ax1.twinx()
 
# color = 'black'
# ax2.set_ylabel(r'heights median ratio: $h_{n-1}/h_n$', color = color, fontsize=15)
# ax2.plot(list(range(start_dim+1, end_dim+1)), h_med_ratios, color = color)
# ax2.tick_params(axis ='y', labelcolor = color)
# plt.tight_layout(pad=2)
 
# # Show plot
# plt.savefig(os.path.join(output_dir, "hmp_"+target_id+"_dim.png"))

In [26]:
# print(h_med_ratios)

In [27]:
""" 
Inferred dimension
QE: STOOL - 3, VAGINA - 3, THROAT - 3, EARS - 3, NOSE - 3, ELBOWS - 3, MOUTH - 3
NB: STOOL - 41, VAGINA - 5, THROAT - , EARS - 3, NOSE - 3, ELBOWS - 3, MOUTH - 3

"""

' \nInferred dimension\nQE: STOOL - 3, VAGINA - 3, THROAT - 3, EARS - 3, NOSE - 3, ELBOWS - 3, MOUTH - 3\nNB: STOOL - 41, VAGINA - 5, THROAT - , EARS - 3, NOSE - 3, ELBOWS - 3, MOUTH - 3\n\n'

In [28]:
print("number of samples is:", len(colors))
print("number of outliers is:", len(outlier_indices_list[0]))

number of samples is: 270
number of outliers is: 13


site - number of samples - number of outliers - total distance (nSimplices) - total distance (MDS)

NOSE: 136 - 18 - 0.9809061157393936 - 1.2269812549552055

NOSE, THROAT: 270 - 30 - 1.362760825075196 - 1.3811090861305209

THROAT: 134 - 14

VAGINA: 186 - 26

VAGINA, STOOL, THROAT: 454 - 25


In [29]:
""" 
Run
(1) NB normalization + cMDS 
(2) QE normalization + cMDS 

To derive the axes data
"""  

axes_files = ["hmp_target_"+target_id+"_"+"NB_MDS_cMDS_axes.txt", "hmp_target_"+target_id+"_"+"QE_MDS_cMDS_axes.txt"] # put NB before QE
data_files = ["hmp_v13lqphylotypeQuantNB_target_"+target_id+"_"+"rs_c.csv", "hmp_v13lqphylotypeQuantE_target_"+target_id+"_"+"rs_c.csv"]

for i in range(len(axes_files)):
    axes_file = axes_files[i]
    data_file = data_files[i]

    print("======== QE/NB normalization + MDS + cMDS ========")
    data_path = os.path.join(data_dir, data_file)
    axes_output_path = os.path.join(output_dir, axes_file)

    df_hmp = np.loadtxt(data_path, delimiter=",")
    hmp_dis_sq=squareform(pdist(df_hmp))

    # Plot cMDS embedding using the pairs of axis from the four most significant axes 
    # enforce_dim = subspace_dims[i] # enforcing the dimension to be consistent with nSimplices QE+nsimplices+cMDS
    feature_num = df_hmp.shape[1]
    print("feature_num is:", feature_num)
    embedding = MDS(n_components=feature_num, max_iter=100000000000, dissimilarity='precomputed') 
    corr_coord = embedding.fit_transform(hmp_dis_sq)
    corr_dis_sq=squareform(pdist(corr_coord))
    _, _, Xe = cMDS(corr_dis_sq)

    np.savetxt(axes_output_path, Xe, fmt='%f')

feature_num is: 425






feature_num is: 425






In [30]:

""" 
Compute the average distance between the outlieres and the barycenter of regular points
"""

def centeroid(coords_list):
    """ 
    Computes the barycenter of a list of coordinates in coords_lst
    """

    length = coords_list.shape[0]
    centroid = []
    for i in range(coords_list.shape[1]):
        centroid.append(np.sum(coords_list[:, i])/length)
    return np.array(centroid)
    
axes_files = ["hmp_target_"+target_id+"_"+"QE_nSimplices_cMDS_axes.txt", \
    "hmp_target_"+target_id+"_"+"QE_MDS_cMDS_axes.txt"]

for i in range(len(axes_files)):
    axes_file = axes_files[i]
    subspace_dim = subspace_dims[0] # TODO: this should be change to i%2
    outlier_indices =  outlier_indices_list[0] # TODO: this should be change to i%2
    Xe = np.loadtxt(os.path.join(output_dir, axes_file))
    normal_indices = list(set(range(Xe.shape[0])) - set(outlier_indices))
    normal_centroid = centeroid(Xe[normal_indices, :]) # should only focus on the first inferred dimension coordinates
    normal_centroid = normal_centroid[:subspace_dim]
    outlier_Xe = Xe[outlier_indices, :subspace_dim]
    
    # compute distance
    total_distance = 0
    for outlier_coord in outlier_Xe:
        total_distance += np.linalg.norm(outlier_coord-normal_centroid)
    print("average distance is:", total_distance/len(outlier_indices))

    



average distance is: 1.1023714441729688
average distance is: 1.1190781209051068


In [31]:
""" 
Plot pairwise result
"""

# figure_files = ["hmp_target_"+target_id+"_"+"NB_nSimplices_cMDS.png", \
#     "hmp_target_"+target_id+"_"+"QE_nSimplices_cMDS.png", \
#     "hmp_target_"+target_id+"_"+"NB_MDS_cMDS.png", \
#     "hmp_target_"+target_id+"_"+"QE_MDS_cMDS.png"]
# axes_files = ["hmp_target_"+target_id+"_"+"NB_nSimplices_cMDS_axes.txt", \
#     "hmp_target_"+target_id+"_"+"QE_nSimplices_cMDS_axes.txt", \
#     "hmp_target_"+target_id+"_"+"NB_MDS_cMDS_axes.txt", \
#     "hmp_target_"+target_id+"_"+"QE_MDS_cMDS_axes.txt"]
titles = ["QuantE+nSimplices", "QuantE+MDS"]
figure_files = [
    "hmp_target_"+target_id+"_"+"QE_nSimplices_cMDS.png", \
    "hmp_target_"+target_id+"_"+"QE_MDS_cMDS.png"]
axes_files = [
    "hmp_target_"+target_id+"_"+"QE_nSimplices_cMDS_axes.txt", \
    "hmp_target_"+target_id+"_"+"QE_MDS_cMDS_axes.txt"]
titles = ["QuantE+nSimplices",  "QuantE+MDS"]
num_axes = 3 # show pairwise 2D plot to decompose the 3D plot

for i in range(len(figure_files)):
    # if i == 0 or i == 2:
    #     continue
    figure_file = figure_files[i]
    axes_file = axes_files[i]
    title = titles[i]
    outlier_indices =  outlier_indices_list[0] # TODO: change to i%2

    print("======== plot pairwise 2D plot (subset) ========")
    Xe = np.loadtxt(os.path.join(output_dir, axes_file))
    print(Xe.shape)
    for first_dim in range(num_axes):
        for second_dim in range(first_dim+1, num_axes):
            plt.figure()
            
            target_indices = []
            for color in target_colors:
                cur_indices = [i for i, e in enumerate(colors) if e == color]
                target_indices.append(cur_indices)
            # stool_indices = [i for i, e in enumerate(colors) if e == 'cornflowerblue']
            # ears_indices = [i for i, e in enumerate(colors) if e == 'orange']

            for i in range(len(target_colors)):
                color = target_colors[i]
                site = target_sites[i]
                site_indices = target_indices[i]
                
                # if 'nSimplices' not in figure_file:
                #     plt.scatter(Xe[site_indices, second_dim], \
                #     Xe[site_indices, first_dim], s=5, c=color, label = site)
                # else:
                site_outlier_indices = list(set(site_indices) & set(outlier_indices))
                site_normal_indices = list(set(site_indices) - set(outlier_indices))
            
                plt.scatter(Xe[site_normal_indices, second_dim], \
                    Xe[site_normal_indices, first_dim], s=5, c=color, label = site)
                plt.scatter(Xe[site_outlier_indices, second_dim], \
                    Xe[site_outlier_indices, first_dim], s=15, c=color, label = site+" outlier", marker="x")
                    
            plt.legend()
            plt.title(title, size=10)   
            plt.savefig(os.path.join(output_dir, figure_file[:-4]+"_"+str(first_dim)+"_"+str(second_dim)+".png"))
            

(270, 264)
(270, 269)


In [32]:
def color_to_site(color):
    """
    Returns the site for the color
    """
    if color == "orange":
        return "THROAT"
    if color == "black":
        return "EARS"
    if color == "cornflowerblue":
        return "STOOL"
    if color == "blue":
        return "NOSE"
    if color == "red":
        return "ELBOWS"
    if color == "gray":
        return "MOUTH"
    if color == "yellow":
        return "VAGINA"

def site_to_color(site):
    """
    Returns the color for the site
    """
    color_discrete_map = {'STOOL': 'cornflowerblue', 'VAGINA': 'yellow', 'THROAT': 'orange',\
        "EARS": 'black', "NOSE": "blue", "ELBOWS": 'red', "MOUTH": 'grey' }

    return color_discrete_map[site]

In [33]:
""" 
generate 3D plot of the first three axes 
"""

# axes_files = ["hmp_target_"+target_id+"_"+"NB_nSimplices_cMDS_axes.txt", \
#     "hmp_target_"+target_id+"_"+"QE_nSimplices_cMDS_axes.txt", \
#     "hmp_target_"+target_id+"_"+"NB_MDS_cMDS_axes.txt", \
#     "hmp_target_"+target_id+"_"+"QE_MDS_cMDS_axes.txt"]
    

# figure_files = ["hmp_target_"+target_id+"_"+"NB_nSimplices_cMDS.html", \
#     "hmp_target_"+target_id+"_"+"QE_nSimplices_cMDS.html",\
#     "hmp_target_"+target_id+"_"+"NB_MDS_cMDS.html", \
#     "hmp_target_"+target_id+"_"+"QE_MDS_cMDS.html"]


axes_files = [
    "hmp_target_"+target_id+"_"+"QE_nSimplices_cMDS_axes.txt", \
    "hmp_target_"+target_id+"_"+"QE_MDS_cMDS_axes.txt"]
    

figure_files = [
    "hmp_target_"+target_id+"_"+"QE_nSimplices_cMDS.html",\
    "hmp_target_"+target_id+"_"+"QE_MDS_cMDS.html"]


for i in range(len(axes_files)):
    axes_file = axes_files[i]
    figure_file = figure_files[i]
    outlier_indices = outlier_indices_list[0] #TODO: change this to %2
    normal_indices = list(set(list(range(Xe.shape[0]))) - set(outlier_indices))
    Xe = np.loadtxt(os.path.join(output_dir, axes_file))
    color_discrete_map = {'STOOL': 'cornflowerblue', 'VAGINA': 'yellow', 'THROAT': 'orange',\
        "EARS": 'black', "NOSE": "blue", "ELBOWS": 'red', "MOUTH": 'grey' }


    Xe_df = pd.DataFrame(Xe[:,:3], columns = ["axis_0", "axis_1", "axis_2"])
    Xe_df['color'] = colors
    Xe_df['label'] = \
        Xe_df.apply(lambda row: color_to_site(row['color']), axis=1)
    outlier_indicator = np.array([0]*len(colors))
    outlier_indicator[outlier_indices] = 1
    Xe_df['outlier'] = outlier_indicator.tolist()

        # specify trace names and symbols in a dict
    symbols = {'1': 'cross',
            '0':'circle'}

    

    fig = px.scatter_3d(Xe_df, x='axis_0', y='axis_1', z='axis_2',
            color='label', color_discrete_map=color_discrete_map, symbol='outlier')

    for i, d in enumerate(fig.data):
        # fig.data[i].marker.symbol = symbols[fig.data[i].name] 
        print(symbols)
        fig.data[i].marker.symbol = symbols[fig.data[i].name.split(',')[1].strip()] 
        
    fig.for_each_xaxis(lambda axis: axis.title.update(font=dict(color = 'blue', size=30)))
    fig.update_layout(scene = dict(
                    xaxis_title='axis 0',
                    yaxis_title='axis 1',
                    zaxis_title='axis 2'))
    fig.update_scenes(xaxis_title_font=dict(size=50))
    fig.update_scenes(yaxis_title_font=dict(size=50))
    fig.update_scenes(zaxis_title_font=dict(size=50))
    fig.update_scenes(xaxis = dict(tickfont=dict(size=15)))
    fig.update_scenes(yaxis = dict(tickfont=dict(size=15)))
    fig.update_scenes(zaxis = dict(tickfont=dict(size=15)))
    fig.write_html(os.path.join(output_dir, figure_file))




{'1': 'cross', '0': 'circle'}
{'1': 'cross', '0': 'circle'}
{'1': 'cross', '0': 'circle'}
{'1': 'cross', '0': 'circle'}
{'1': 'cross', '0': 'circle'}
{'1': 'cross', '0': 'circle'}
{'1': 'cross', '0': 'circle'}
{'1': 'cross', '0': 'circle'}


In [34]:
# """
# Combine three 3D dynamic plots together
# https://stackoverflow.com/questions/59868987/plotly-saving-multiple-plots-into-a-single-html/59869358#59869358
# """
# import plotly.graph_objects as go
# import plotly.express as px
# import plotly.offline as offline
# import pandas as pd

# from plotly.subplots import make_subplots


# """ 
# generate 3D plot of the first three axes 
# """

# target_id = 'M'

# axes_files = ["hmp_target_"+target_id+"_"+"NB_nSimplices_cMDS_axes.txt", \
#     "hmp_target_"+target_id+"_"+"NB_MDS_cMDS_axes.txt"]
    

# figure_files = ["hmp_target_"+target_id+"_"+"NB_nSimplices_cMDS.html", \
#     "hmp_target_"+target_id+"_"+"NB_MDS_cMDS.html"]

# fig = make_subplots(\
#     rows=2, cols=1, shared_xaxes=True, \
#     vertical_spacing=0.02)



# for i in range(len(axes_files)):
#     axes_file = axes_files[i]
#     figure_file = figure_files[i]
#     outlier_indices = outlier_indices_list[i%2]
#     normal_indices = list(set(list(range(Xe.shape[0]))) - set(outlier_indices))
#     Xe = np.loadtxt(os.path.join(output_dir, axes_file))
#     color_discrete_map = {'STOOL': 'cornflowerblue', 'VAGINA': 'orange', 'THROAT': 'deeppink',\
#         "EARS": 'black', "NOSE": "darkgreen", "ELBOWS": 'red', "MOUTH": 'grey' }


#     Xe_df = pd.DataFrame(Xe[:,:3], columns = ["axis_0", "axis_1", "axis_2"])
#     Xe_df['color'] = colors
#     Xe_df['label'] = \
#         Xe_df.apply(lambda row: color_to_site(row['color']), axis=1)
#     outlier_indicator = np.array(['nor']*len(colors))
#     outlier_indicator[outlier_indices] = "out"
#     Xe_df['outlier'] = outlier_indicator.tolist()

#         # specify trace names and symbols in a dict
#     symbols = {'out': 'cross',
#             'nor':'circle'}

    

#     # fig = px.scatter_3d(Xe_df, x='axis_0', y='axis_1', z='axis_2',
#     #         color='label', color_discrete_map=color_discrete_map, symbol='outlier')


#     fig.add_trace(go.Scatter3d(x = Xe_df["axis_0"], y = Xe_df["axis_1"], z = Xe_df["axis_2"], \
#         marker=dict(
#             color = colors
#         )), \
#         row=1, col=1)

#     # for i, d in enumerate(fig.data):
#     #     # fig.data[i].marker.symbol = symbols[fig.data[i].name] 
#     #     print(symbols)
#     #     fig.data[i].marker.symbol = symbols[fig.data[i].name.split(',')[1].strip()] 

#     # fig.update_layout(scene = dict(
#     #                 xaxis_title='axis 0',
#     #                 yaxis_title='axis 1',
#     #                 zaxis_title='axis 2'))
#     fig.write_html(os.path.join(output_dir, figure_file))

# offline.plot(fig, filename='name.html')



