In [None]:
#Import necessary packages
import os
import pandas as pd
import numpy as np

In [None]:
#Changes to Data directory
cwd = os.getcwd()
if 'Data' not in cwd:
    os.chdir('../Data')
    cwd = os.getcwd()
print(cwd)

In [None]:
import seaborn as sns
from matplotlib import pyplot as plt

def plot_cluster_and_labels(df, lab, title, xlab, ylab, save_name=None, fig_size=(8,6)):
    """
    Creates a scatter plot of the data passed in, and colors each
    data point based on its label. The figure can be saved if a 
    file name is provided.
    Arguments:
        df - dataframe with a columns headers corresponding to lab, 
        xlab and ylab
        lab - string, col header for datapoint labels
        title - string, plot title
        xlab, ylab - string, col header for datapoint coordinates
        save_name - string, filename to save the image under
        fig_size - tuple, plot size
    
    Outputs:
        Scatter plot of data and labels and saved file
    """
    #Gets unique labels and maps it to a color
    labels = list(pd.unique(df[lab]))
    lut = dict(zip(labels, sns.hls_palette(len(labels), l=0.5, s=0.8)))
    
    #Plots the scatter plot
    fig, ax = plt.subplots(1, figsize=fig_size)
    ax.set_xlabel(xlab)
    ax.set_ylabel(ylab)
    ax.set_title(title)
    for i in range(len(labels)):
        idx = df[lab] == labels[i]
        ax.scatter(df.loc[idx, xlab], df.loc[idx, ylab],
                   cmap = lut[labels[i]], alpha=1, s=3)
    ax.legend(labels)
    
    #Save plot
    if save_name is not None:
        plt.savefig(save_name)
    
    return

In [None]:
df = pd.read_csv('aa10681_TableS9.csv')
healthy = pd.read_csv('aaq0681_TableS5.csv')

In [None]:
healthy_dict = {1:'fCT1',
               5:'fCT2',
               0:'fCT3',
               2:'fCT4',
               6:'fCT5',
               4:'Periskeletal Cells',
               3:'Tenocytes',
               7:'Cycling cells'}

In [None]:
plot_cluster_and_labels(healthy, 'ident', 'title', 'tSNE_1', 'tSNE_2', save_name=None, fig_size=(8,6))

In [None]:
healthy['ident'] = healthy['ident'].apply(lambda x: healthy_dict[x])
# healthy.drop(labels='lit_cell_types', inplace=True, axis=1)
healthy.head()
heathy.to_csv('aaq0681_TableS5_labeled.csv', index=False)

In [None]:
plot_cluster_and_labels(healthy, 'lit_cell_types', 'title', 'tSNE_1', 'tSNE_2', save_name=None, fig_size=(8,6))

In [None]:
plot_cluster_and_labels(df, 'ident', 'title', 'tSNE_1', 'tSNE_2', save_name=None, fig_size=(8,6))