## Interactive PCA with existing PCA

This is an interactive view of the PCA plots generated by CODEX. You should have two files that need to be referenced: any of the **PCA calculations** from CODEX and the **dataset file** where the time series data is stored. 

Make sure all the packages are installed properly and then you should just be able to execute the cell below where the PCA plot should pop up. Clicking on any point in the PCA will open a new browser tab with the point indicated in the plot, the time series plot itself, and the class probabilities.

Each point that is clicked will be **automatically be saved as 'time_series %DATE AND TIME% .html'**.

The colors are easily changable with the dictionary form below or by setting a matplotlib cmap. The dot indicating which the time series was clicked on will be red, so avoid that color or change the color of the indicator. 


**!!!!!: This only seems to work in the proper jupyter notebook browser, not vscode.** (or maybe it is just my vscode that doesnt work idk)

To open jupyter notebook, in your terminal go the the directory with the interactive PCA:</br>
cd ./interactivePCA 

Then start jupyer notebook by simply entering: </br>
jupyter notebook

then open the notebook and run it

In [None]:
# CHANGE THIS PART

import pandas as pd
import plotly.graph_objs as go
from plotly.subplots import make_subplots
from ipywidgets import Output, VBox
import http.server
import socketserver
import threading
import webbrowser
from collections import OrderedDict
import matplotlib.pyplot as plt
import numpy as np
import datetime
import matplotlib


### If you have already existing pca results and the dataset set their paths here manually
pca_path = './pca_testvalues.csv'
dataset_path = './dataset.csv'

### You might want to change how many of the original PCA points are displayed in the interactive plot, as too many will crash it
pca_sample = 5000

# load the data
pca = pd.read_csv(pca_path, sep=',', index_col=False)
time_series = pd.read_csv(dataset_path, sep=',', index_col=False)
#set how many datapoints should be displayed
pca = pca.sample(pca_sample).reset_index(drop=True) 

             

### Setting the colors scheme (any colorpalette from matplotlib is fine)
cmap_name = 'viridis'

Nclasses = len(np.unique(pca['Class']))
cmap = plt.cm.get_cmap(cmap_name, Nclasses)
color_dict = {label: matplotlib.colors.rgb2hex(cmap(idx)) for idx, label in enumerate(np.unique(pca['Class']))}


#### OR MANUALLY SELECT COLORS FOR EACH CLASS
#color_dict = {'WT': 'green', 'PIK3CA_H1047R': 'blue', 'ErbB2': 'lightgreen', 'Akt1_E17K': 'yellow', 'PTEN_del': 'purple', 'PIK3CA_E545K': 'lightblue'}


## color of the indicator in the plot (choose something that is in good contrast with the other colors chosen above)
indicator_color = 'red'


In [None]:
# FROM HERE ON LEAVE AS IS
components = pca[0:2]

pca['colors'] = pca['Class'].apply(lambda x: color_dict.get(x))


# Make the PCA
fig = go.FigureWidget()
fig.add_scatter(x=pca["0"], 
    y=pca["1"], 
    marker_color=list(pca.colors),
    mode='markers',
    marker=dict(size=10, opacity=0.2),
    text=list(pca.ID + ', ' + pca['Class'])
)

fig.update_layout(
    title="PCA",
    xaxis_title="PC1",
    yaxis_title="PC2",
    width=800, height=800
    )


fig1 = fig.data[0]

# Get the probability values
prob = pd.concat([pca['ID'], pca['Class'], pca.filter(regex='Prob_', axis=1)], axis=1)
prob = round(prob, 2)


# Identify the measurements
colnames = list(time_series.columns.values)
colnames.remove('ID')
colnames.remove('class')
groups = list(OrderedDict.fromkeys([i.split('_')[0] for i in colnames]))


measurements = []

for i in groups:
    measurements.append(time_series.filter(regex=i))

measurements.append(time_series['ID'])
print(measurements[-1])
    

# Create the interactive plot   
out = Output()
@out.capture(clear_output=True)
def display_timeseries(trace, points, state):
    
    ts = make_subplots(rows=2, cols=2,  
        subplot_titles=['PCA', 'ID: ' + prob.iloc[points.point_inds[0]]['ID'] + ' -> Class: '+ str(prob.iloc[points.point_inds[0]]['Class'])],
        specs=[[{"type": "scatter",'rowspan': 2},
            {"type": "scatter"}],
            [None,
            {"type": "table"}]],
        row_width=[0.1, 0.5]
        )
    for i in range(len(measurements)-1):
        # get id
        point_id = measurements[i].loc[measurements[-1] == prob.iloc[points.point_inds[0]]['ID']]
        ts.add_trace(go.Scatter(x=list(range(0,(measurements[i].shape[1])*5,5)),
            y=point_id.iloc[0,:],
            name=groups[i]
            ),row=1,col=2)

    ts.add_trace(go.Scatter(x=pca["0"], 
        y=pca["1"], 
        marker_color=list(pca.colors),
        mode='markers',
        marker=dict(size=10, opacity=0.5),
        text=list(pca.ID+', '+pca['Class'])
    ), row=1, col=1)

    table=go.Figure(data=[go.Table(header=dict(values=list(prob.columns)),
                cells=dict(values=[prob[col][points.point_inds[0]] for col in prob.columns]))])
    ts.add_trace(table.data[0], row=2, col=2)
    
    
    ts.add_trace(go.Scatter(x=points.xs, 
        y=points.ys, 
        marker_color=indicator_color,
        mode='markers',
        marker=dict(size=10, opacity=1)
    ), row=1, col=1)
    
    
    ts['layout']['xaxis'].update(title_text='PC1')
    ts['layout']['yaxis'].update(title_text='PC2')
    
    ts['layout']['xaxis2'].update(title_text='Time in minutes')
    ts['layout']['yaxis2'].update(title_text='Ratio')
    ts['layout']['yaxis2'].update(range=[0, 2.5])
    
    ts.update_layout(title='Table and Scatter Plot')

    tsf = go.FigureWidget(ts)
    
    time = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
    tsf.write_html('./time_series_' + time + '.html')
    plot_file ='time_series_' + time + '.html'
    
    # To create the threads that open another tab
    def serve_plot():
        handler = http.server.SimpleHTTPRequestHandler
        with socketserver.TCPServer(("", 0), handler) as httpd:
            port = httpd.server_address[1]
            webbrowser.open(f"http://localhost:{port}/{plot_file}")
            httpd.serve_forever()

    server_thread = threading.Thread(target=serve_plot)
    server_thread.daemon = True
    server_thread.start()

    import time
    time.sleep(1)

    return f"http://localhost:{port}/{plot_file}"


fig1.on_click(display_timeseries)

VBox([fig, out])
