In [None]:
import os,re
import numpy as np
import pandas as pd 
import seaborn as sns
import matplotlib.pyplot as plt
%matplotlib inline

import plotly_express as px
import plotly.graph_objs as go
from plotly.offline import iplot

from sklearn.decomposition import PCA
from sklearn.manifold import TSNE

palette = {'noncas':'#7E7E7F','cas12':'#BAD3E9','cas9':'#F3C8AB','cas13':'#C1DBB1'}
markers = {'noncas':'o','cas12':'X','cas9':'P','cas13':'s'}

## PCA & t-SNE

In [None]:
def pca_n_components(df):
    pca = PCA().fit(df)
    pcaratio = pca.explained_variance_ratio_
    trace = go.Scatter(x=np.arange(len(pcaratio)),y=np.cumsum(pcaratio))
    data = [trace]
    fig = dict(data=data)
    iplot(fig)

def tsne_data_process(df, n_pca, labels):
    pca = PCA(n_components=n_pca)
    dataPCA = pca.fit_transform(df)
    X = dataPCA
    Xtsne = TSNE(n_components=2).fit_transform(X)
    dftsne = pd.DataFrame(Xtsne)
    dftsne['cluster'] = labels
    dftsne.columns = ['x1','x2','cluster']
    return dftsne
    
def tsne_clst_fig(bef_df, bef_n_pca, bef_labs, aft_df, aft_n_pca, aft_labs, out_tag):
    bef_tsne_df = tsne_data_process(bef_df, bef_n_pca, bef_labs)
    aft_tsne_df = tsne_data_process(aft_df, aft_n_pca, aft_labs)
    fig, ax = plt.subplots(2, 1, figsize=(6,8))
    bwidth = 0.5
    color = 'black'
    sns.set(rc={'axes.facecolor':(0,0,0,0), 'figure.facecolor':(0,0,0,0)})
    sns.scatterplot(data=bef_tsne_df,x='x1',y='x2',hue='cluster',style='cluster',legend=False,markers=markers,alpha=0.5,ax=ax[0],size=0.1,palette=palette)
    ax[0].set_xlim(-130,130)
    ax[0].set_ylim(-130,130)
    ax[0].spines['top'].set_color(color)
    ax[0].spines['bottom'].set_color(color)
    ax[0].spines['left'].set_color(color)
    ax[0].spines['right'].set_color(color)
    ax[0].spines['top'].set_linewidth(bwidth)
    ax[0].spines['bottom'].set_linewidth(bwidth)
    ax[0].spines['left'].set_linewidth(bwidth)
    ax[0].spines['right'].set_linewidth(bwidth)
    ax[0].tick_params(axis='both', colors=(0,0,0,0), grid_alpha=0)
    #ax[0].set_title('Before fine-tuning')
    ax[0].set_xlabel(None)
    ax[0].set_ylabel(None)
    sns.scatterplot(data=aft_tsne_df,x='x1',y='x2',hue='cluster',style='cluster',legend=False,markers=markers,alpha=0.5,ax=ax[1],size=0.1,palette=palette)
    ax[1].set_xlim(-130,130)
    ax[1].set_ylim(-130,130)
    ax[1].spines['top'].set_color(color)
    ax[1].spines['bottom'].set_color(color)
    ax[1].spines['left'].set_color(color)
    ax[1].spines['right'].set_color(color)
    ax[1].spines['top'].set_linewidth(bwidth)
    ax[1].spines['bottom'].set_linewidth(bwidth)
    ax[1].spines['left'].set_linewidth(bwidth)
    ax[1].spines['right'].set_linewidth(bwidth)
    ax[1].tick_params(axis='both', colors=(0,0,0,0), grid_alpha=0)
    #ax[1].set_title('After fine-tuning')
    ax[1].set_xlabel(None)
    ax[1].set_ylabel(None)
    fig.savefig(out_tag+'.tsne.pdf',format='pdf')
    fig.savefig(out_tag+'.tsne.png',dpi=720,format='png')
    return fig

def tsne_clst_fig_single(df, n_pca, labs, lim, out_tag):
    tsne_df = tsne_data_process(df, n_pca, labs)
    fig, ax = plt.subplots(1, 1, figsize=(6,4))
    plt.margins(tight=True)
    ax.margins(x=-0.45, y=-0.45)
    bwidth = 0.5
    color = 'black'
    sns.set(rc={'axes.facecolor':(0,0,0,0), 'figure.facecolor':(0,0,0,0)})
    sns.scatterplot(data=tsne_df,x='x1',y='x2',hue='cluster',style='cluster',legend=False,markers=markers,alpha=0.5,ax=ax,size=0.1,palette=palette)
    ax.set_xlim(-lim,lim)
    ax.set_ylim(-lim,lim)
    ax.spines['top'].set_color(color)
    ax.spines['bottom'].set_color(color)
    ax.spines['left'].set_color(color)
    ax.spines['right'].set_color(color)
    ax.spines['top'].set_linewidth(bwidth)
    ax.spines['bottom'].set_linewidth(bwidth)
    ax.spines['left'].set_linewidth(bwidth)
    ax.spines['right'].set_linewidth(bwidth)
    ax.tick_params(axis='both', colors=(0,0,0,0), grid_alpha=0)
    ax.set_xlabel(None)
    ax.set_ylabel(None)
    fig.savefig(out_tag+'.tsne.pdf',format='pdf')
    fig.savefig(out_tag+'.tsne.png',dpi=720,format='png')
    return fig

In [None]:
before = pd.read_csv('esm2_650M/val.emb.tab',sep='\t')
bef_labs = before.Label
bef_df = before.iloc[:,3:]

In [None]:
after = pd.read_csv('esm2_650M/val.ft_emb.tab',sep='\t')
aft_labs = after.Label
aft_df = after.iloc[:,3:]

In [None]:
pca_n_components(bef_df)
pca_n_components(aft_df)

In [None]:
## proka
bef_n_pca = 106
aft_n_pca = 45
lim = 130

## virus
#bef_n_pca = 39
#aft_n_pca = 3
#lim = 80

In [None]:
fig_org = tsne_clst_fig_single(bef_df, bef_n_pca, bef_labs, lim, out_tag='val_before_ft')
fig_ft = tsne_clst_fig_single(aft_df, aft_n_pca, aft_labs, lim, out_tag='val_after_ft')