# Plot

> A collection of plot functions

In [None]:
#| default_exp plot

In [None]:
#| hide
from nbdev.showdoc import *
%matplotlib inline

In [None]:
#| export
from fastbook import *
import seaborn as sns

#for embeddings
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from umap.umap_ import UMAP

In [None]:
import sys
sys.path.append('/notebooks/tools')

In [None]:
#| export
from tools.feature import *

In [None]:
#| export
sns.set(rc={"figure.dpi":300, 'savefig.dpi':300})
sns.set_context('notebook')
sns.set_style("ticks")

In [None]:
#| export
def get_embedding(df, method='pca', n_components=2, seed=123):
    if method == 'pca':
        reducer = PCA(n_components=n_components, random_state=seed)
    elif method == 'tsne':
        reducer = TSNE(n_components=n_components, random_state=seed)
    elif method == 'umap':
        reducer = UMAP(n_components=n_components, random_state=seed)
    else:
        raise ValueError('Invalid method specified')
        
    proj = reducer.fit_transform(df.iloc[:, 1:])
    embedding_df = pd.DataFrame(proj, columns=[f"{method.upper()}{i}" for i in range(1, n_components + 1)])
    embedding_df = pd.concat([df[df.columns[0]], embedding_df], axis=1)
    return embedding_df

In [None]:
#| export
def plot_embedding(df, method='pca', hue=None, palette='tab10', legend=False):
    embedding_df = get_embedding(df, method=method)
    x_col, y_col = [col for col in embedding_df.columns if col.startswith(method.upper())]
    sns.relplot(data=embedding_df, x=x_col, y=y_col, hue=hue, palette=palette, s=50, alpha=0.8, legend=legend)
    plt.xticks([])
    plt.yticks([])

In [None]:
df = pd.read_csv('kras_smiles.csv')

In [None]:
df.shape

In [None]:
prop = smi2prop(df,normalize=False)
prop_std = smi2prop(df,normalize=True)

In [None]:
plot_embedding(prop_std,method = 'umap', hue = df.group,legend=True)

In [None]:
plot_embedding(prop_std,method = 'tsne', hue = df.group,legend=True)

In [None]:
plot_embedding(prop_std,method = 'pca', hue = df.group,legend=True)

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()