### Modules

In [None]:
# basic
import os, sys, glob, umap
import numpy as np, pandas as pd
import matplotlib.pyplot as plt, seaborn as sns
from sklearn.metrics import confusion_matrix, roc_curve, precision_recall_curve, auc, silhouette_samples
from matplotlib.lines import Line2D

### Figure 2A and 2B

In [None]:
def model_performance(pred_pos, pred_neg):
    
    pred_table = pd.concat([pred_pos, pred_neg], axis = 0)

    y_pred = [1 if i>=0.5 else 0 for i in pred_table['probability']]
    y_true = [1]*pred_pos.shape[0]+[0]*pred_neg.shape[0]
    confusion = confusion_matrix(y_true, y_pred)

    TP = confusion[1,1]
    FP = confusion[0,1]
    TN = confusion[0,0]
    FN = confusion[1,0]

    precision = round(TP/(TP+FP), 2)
    recall = round(TP/(TP+FN), 2)
    accuracy = round((TP+TN)/confusion.sum(), 2)

    out = [accuracy, precision, recall]
    
    return out

def roc_val(pred_pos, pred_neg):
    
    pred_table = pd.concat([pred_pos, pred_neg], axis = 0)

    y_pred = pred_table['probability']
    y_true = [1]*pred_pos.shape[0]+[0]*pred_neg.shape[0]

    fpr, tpr, thersholds = roc_curve(y_true, y_pred, pos_label = 1)
    roc_auc = auc(fpr, tpr)

    out = [fpr, tpr, roc_auc]
    
    return out

In [None]:
### m6ATM prediction results in IVT data
file_list = ['../data/ivtmix20_valid.csv',
             '../data/ivtmix40_valid.csv',
             '../data/ivtmix60_valid.csv',
             '../data/ivtmix80_valid.csv',
             '../data/ivtmix100_valid.csv',]

pred_pos_list = [pd.read_csv(file, index_col = 0) for file in file_list]
pred_neg = pd.read_csv('../data/ivtmix0_valid.csv', index_col = 0)


### Figure 2A: bar plot for performance evaluation 
# data
results = [model_performance(pred, pred_neg) for pred in pred_pos_list]

# plot
sns.set_theme(style = 'white') # theme
tab_color = sns.color_palette() # color palette
fig, ax = plt.subplots(figsize = (22, 6)) # figure size

x = np.array([0.4, 1, 1.6])  # the label locations
width = 0.1  # the width of the bars

rects1 = ax.bar(x-0.2, results[0], width, label = 'IVT-20%', alpha = 0.8)
rects2 = ax.bar(x-0.1, results[1], width, label = 'IVT-40%', alpha = 0.8)
rects3 = ax.bar(x, results[2], width, label = 'IVT-60%', alpha = 0.8)
rects4 = ax.bar(x+0.1, results[3], width, label = 'IVT-80%', alpha = 0.8)
rects5 = ax.bar(x+0.2, results[4], width, label = 'IVT-100%', alpha = 0.8)

plt.xticks(x, ['Accuracy', 'Precision', 'Recall'])

ax.bar_label(rects1, fmt = '%.2f', padding = 2, fontsize = 18, weight = 'bold')
ax.bar_label(rects2, fmt = '%.2f', padding = 2, fontsize = 18, weight = 'bold')
ax.bar_label(rects3, fmt = '%.2f', padding = 2, fontsize = 18, weight = 'bold')
ax.bar_label(rects4, fmt = '%.2f', padding = 2, fontsize = 18, weight = 'bold')
ax.bar_label(rects5, fmt = '%.2f', padding = 2, fontsize = 18, weight = 'bold')


plt.legend(labels  = ['IVT-20%', 'IVT-40%', 'IVT-60%', 'IVT-80%', 'IVT-100%'],
           fontsize = 20, loc = 'upper center', ncol = 5, bbox_to_anchor = (0.5, 1.25))

ax.set(xlim = (0.0, 1.9))
ax.set(ylim = (0.0, 1.0))
ax.tick_params(labelsize = 30)
sns.despine()

In [None]:
### roc plot
# data
results = [roc_val(pred, pred_neg) for pred in pred_pos_list]

# plot
sns.set_theme(style = 'white') # theme
tab_color = sns.color_palette()

plt.plot(results[0][0], results[0][1], color = tab_color[0], lw = 2, label = 'IVT-20% (area = {0:.3f})'.format(results[0][2]))
plt.plot(results[1][0], results[1][1], color = tab_color[1], lw = 2, label = 'IVT-40% (area = {0:.3f})'.format(results[1][2]))
plt.plot(results[2][0], results[2][1], color = tab_color[2], lw = 2, label = 'IVT-60% (area = {0:.3f})'.format(results[2][2]))
plt.plot(results[3][0], results[3][1], color = tab_color[3], lw = 2, label = 'IVT-80% (area = {0:.3f})'.format(results[3][2]))
plt.plot(results[4][0], results[4][1], color = tab_color[4], lw = 2, label = 'IVT-100% (area = {0:.3f})'.format(results[4][2]))

plt.plot([0, 1], [0, 1], color = 'gray', lw = 2, linestyle = '--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])

plt.xlabel('False Positive Rate', fontsize = 25)
plt.ylabel('True Positive Rate', fontsize = 25)  
plt.legend(loc = 'lower right', fontsize = 12)

plt.tick_params(labelsize = 16)

### Figure 2C and 2D

In [None]:
# 'umap_**.csv' files contain the UMAP-transformed site-level features for each group

In [None]:
def plot_umap(df, color_set2 = False):

    ### plot
    sns.set_theme(style = 'white') # theme
    tab_color = sns.color_palette() # color palette
    color_1 = tab_color[0] if color_set2 else tab_color[9]
    color_2 = tab_color[1] if color_set2 else tab_color[6]
     
    fig, ax = plt.subplots(figsize = (10, 8)) # figure size

    sns.kdeplot(data = df, x = 'x', y = 'y', hue = 'label', linewidths = 3, palette = [color_1, color_2], levels = 5, thresh = .3)

    ### legned
    custom_obj = [Line2D([0], [0], color = color_1), Line2D([0], [0], color = color_2)]
    custom_legend = ax.legend(custom_obj, ['m6A', 'Unmodified'], fontsize = 20)

    for i in custom_legend.legendHandles:
        i.set_linewidth(6)

    ### axis
    ax.set_xlabel('UMAP1', fontsize = 30)
    ax.set_ylabel('UMAP2', fontsize = 30)
    ax.tick_params(labelsize = 25)

    
def silhouette_plot(ax, slh_vals, labels, color_list):

    y_ticks = []
    y_lower = 0
    y_upper = 0
    for i, (cluster, color) in enumerate(zip(np.unique(labels), color_list)):
        slh_vals_k = slh_vals[labels == cluster]
        slh_vals_k.sort()
        y_upper += len(slh_vals_k)
        ax.fill_betweenx(np.arange(y_lower, y_upper), 0, slh_vals_k, facecolor = color, edgecolor = color, alpha = 0.7)
        y_lower += len(slh_vals_k)

        # Get the average silhouette score 
        avg_score = np.mean(slh_vals)
        ax.axvline(avg_score, linestyle = '--', linewidth = 2, color = '#81B900')
        ax.text(avg_score+0.02, 1, str(round(avg_score, 2)), color = '#81B900', fontsize = 25)
        ax.set_yticks([])
        ax.set_xlim([-0.8, 0.8])
        ax.set_xlabel('Silhouette coefficient values', fontsize = 25)
        ax.set_ylabel('Cluster labels', fontsize = 30)
        
        ax.tick_params(labelsize = 20)
        
        
def plot_silhouette(df, color_set2 = False):

    ### plot
    sns.set_theme(style = 'white') # theme
    tab_color = sns.color_palette() # color palette
    color_1 = tab_color[0] if color_set2 else tab_color[9]
    color_2 = tab_color[1] if color_set2 else tab_color[6]
    
    fig, ax = plt.subplots(figsize = (12, 6)) # figure size

    slh_vals = silhouette_samples(df.loc[:,['x', 'y']], df.loc[:,'label'])
    label_list = df.loc[:,'label']
    color_list = [color_1, color_2]

    silhouette_plot(ax, slh_vals, label_list, color_list)

    ### legned
    custom_obj = [Line2D([0], [0], color = color_1), Line2D([0], [0], color = color_2)]
    custom_legend = ax.legend(custom_obj, ['m6A', 'Unmodified'], fontsize = 20, loc = 'upper left')

    for i in custom_legend.legendHandles:
        i.set_linewidth(6)

    ### axis
    ax.set_xlim([-1, 1])
    

In [None]:
### Figure 2C: 100% m6A-modified vs. 0% modified (signal+trace)
df = pd.read_csv('../data/umap_ivtmix100.csv')

# plots
plot_umap(df) # UMAP analysis
plot_silhouette(df) # Silhouette analysis

In [None]:
### Figure S4: 60% m6A-modified vs. 0% modified (signal+trace)
df = pd.read_csv('../data/umap_ivtmix60.csv')

# plots
plot_umap(df) # UMAP analysis
plot_silhouette(df) # Silhouette analysis

In [None]:
### Figure S4: 20% m6A-modified vs. 0% modified (signal+trace)
df = pd.read_csv('../data/umap_ivtmix20.csv')

# plots
plot_umap(df) # UMAP analysis
plot_silhouette(df) # Silhouette analysis

In [None]:
### Figure S4: 100% m6A-modified vs. 0% modified (signal only)
df = pd.read_csv('../data/umap_ivtmix100_s.csv')

# plots
plot_umap(df, color_set2 = True) # UMAP analysis
plot_silhouette(df, color_set2 = True) # Silhouette analysis