### Modules

In [None]:
# basic
import os, sys, glob, umap
import numpy as np, pandas as pd
import matplotlib.pyplot as plt, seaborn as sns
from sklearn.metrics import confusion_matrix, roc_curve, precision_recall_curve, auc, silhouette_samples
from matplotlib.lines import Line2D

In [None]:
sys.path.insert(1, '/home/bo-yi/package/m6atm/m6atm/preprocess')
from ReadClass import *

In [None]:
sys.path.insert(1, '/home/bo-yi/package/m6atm/m6atm/train')
from ModelData import *

In [None]:
# from m6atm_analysis.preprocess.ReadClass import *
# from m6atm_analysis.train.ModelData import *

### Fig. 2A and 2B

In [None]:
def model_performance(pred_pos, pred_neg):
    
    pred_table = pd.concat([pred_pos, pred_neg], axis = 0)

    y_pred = [1 if i>=0.5 else 0 for i in pred_table['probability']]
    y_true = [1]*pred_pos.shape[0]+[0]*pred_neg.shape[0]
    confusion = confusion_matrix(y_true, y_pred)

    TP = confusion[1,1]
    FP = confusion[0,1]
    TN = confusion[0,0]
    FN = confusion[1,0]

    precision = round(TP/(TP+FP), 2)
    recall = round(TP/(TP+FN), 2)
    accuracy = round((TP+TN)/confusion.sum(), 2)

    out = [accuracy, precision, recall]
    
    return out

def roc_val(pred_pos, pred_neg):
    
    pred_table = pd.concat([pred_pos, pred_neg], axis = 0)

    y_pred = pred_table['probability']
    y_true = [1]*pred_pos.shape[0]+[0]*pred_neg.shape[0]

    fpr, tpr, thersholds = roc_curve(y_true, y_pred, pos_label = 1)
    roc_auc = auc(fpr, tpr)

    out = [fpr, tpr, roc_auc]
    
    return out

In [None]:
### m6ATM prediction results in IVT data
file_list = ['/home/bo-yi/data/DRS/20230326_IVT1mix20/preprocessed/results.csv',
             '/home/bo-yi/data/DRS/20230326_IVT1mix40/preprocessed/results.csv',
             '/home/bo-yi/data/DRS/20230326_IVT1mix60/preprocessed/results.csv',
             '/home/bo-yi/data/DRS/20230326_IVT1mix80/preprocessed/results.csv',
             '/home/bo-yi/data/DRS/20230326_IVT1mix100/preprocessed/results.csv',]

pred_pos_list = [pd.read_csv(file, index_col = 0) for file in file_list]
pred_neg = pd.read_csv('/home/bo-yi/data/DRS/20230326_IVT1mix0/preprocessed/results.csv', index_col = 0)


### Figure 2A: bar plot for performance evaluation 
# data
results = [model_performance(pred, pred_neg) for pred in pred_pos_list]

# plot
sns.set_theme(style = 'white') # theme
tab_color = sns.color_palette() # color palette
fig, ax = plt.subplots(figsize = (22, 6)) # figure size

x = np.array([0.4, 1, 1.6])  # the label locations
width = 0.1  # the width of the bars

rects1 = ax.bar(x-0.2, results[0], width, label = 'IVT-20%', alpha = 0.8)
rects2 = ax.bar(x-0.1, results[1], width, label = 'IVT-40%', alpha = 0.8)
rects3 = ax.bar(x, results[2], width, label = 'IVT-60%', alpha = 0.8)
rects4 = ax.bar(x+0.1, results[3], width, label = 'IVT-80%', alpha = 0.8)
rects5 = ax.bar(x+0.2, results[4], width, label = 'IVT-100%', alpha = 0.8)

plt.xticks(x, ['Accuracy', 'Precision', 'Recall'])

ax.bar_label(rects1, fmt = '%.2f', padding = 2, fontsize = 18, weight = 'bold')
ax.bar_label(rects2, fmt = '%.2f', padding = 2, fontsize = 18, weight = 'bold')
ax.bar_label(rects3, fmt = '%.2f', padding = 2, fontsize = 18, weight = 'bold')
ax.bar_label(rects4, fmt = '%.2f', padding = 2, fontsize = 18, weight = 'bold')
ax.bar_label(rects5, fmt = '%.2f', padding = 2, fontsize = 18, weight = 'bold')


plt.legend(labels  = ['IVT-20%', 'IVT-40%', 'IVT-60%', 'IVT-80%', 'IVT-100%'],
           fontsize = 20, loc = 'upper center', ncol = 5, bbox_to_anchor = (0.5, 1.25))

ax.set(xlim = (0.0, 1.9))
ax.set(ylim = (0.0, 1.0))
ax.tick_params(labelsize = 30)
sns.despine()

In [None]:
### roc plot
# data
results = [roc_val(pred, pred_neg) for pred in pred_pos_list]

# plot
sns.set_theme(style = 'white') # theme
tab_color = sns.color_palette()

plt.plot(results[0][0], results[0][1], color = tab_color[0], lw = 2, label = 'IVT-20% (area = {0:.3f})'.format(results[0][2]))
plt.plot(results[1][0], results[1][1], color = tab_color[1], lw = 2, label = 'IVT-40% (area = {0:.3f})'.format(results[1][2]))
plt.plot(results[2][0], results[2][1], color = tab_color[2], lw = 2, label = 'IVT-60% (area = {0:.3f})'.format(results[2][2]))
plt.plot(results[3][0], results[3][1], color = tab_color[3], lw = 2, label = 'IVT-80% (area = {0:.3f})'.format(results[3][2]))
plt.plot(results[4][0], results[4][1], color = tab_color[4], lw = 2, label = 'IVT-100% (area = {0:.3f})'.format(results[4][2]))

plt.plot([0, 1], [0, 1], color = 'gray', lw = 2, linestyle = '--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])

plt.xlabel('False Positive Rate', fontsize = 25)
plt.ylabel('True Positive Rate', fontsize = 25)  
plt.legend(loc = 'lower right', fontsize = 12)

plt.tick_params(labelsize = 16)

### Fig. 2C and 2D

In [None]:
# Please run the scrtip "get_preprocessed.sh" first to get the intermediate files for each dataset
# Intermediate files in each dataset should include "***_data.npy" and "***_label.npy"
# Since input data is very large, we also provide preprocessed feature files in datafrmae format: "feature.csv" alternatively

In [None]:
def get_umap_feature(data_dir, model_file, signal_only = False):
    
    '''
    Arguments:
        data_dir: path should include "***_data.npy" and "***_label.npy"
        model_file: pre-trained feature encoder
    '''
    
    ### data
    temp_dir = os.path.join(data_dir, 'temp')
    os.makedirs(temp_dir, exist_ok = True)
    
    if len(sorted(glob.glob(os.path.join(temp_dir, 'bag_*.npy')))) == 0:
        
        bag_class = ATMbag(data_dir, processes = 24)
        bag_class.to_bag(temp_dir)
    
    bag_data = np.load(sorted(glob.glob(os.path.join(temp_dir, 'bag_*.npy')))[0], allow_pickle = True)
    bag_meta = pd.read_csv(sorted(glob.glob(os.path.join(temp_dir, 'site_*.csv')))[0], index_col = 0)
    
    ### dataloader
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    dataset = WNBagloader(data = list(bag_data),
                           transform = transforms.Compose([ToTensor(device = device)]),
                               site = bag_meta['site'],
                           coverage = bag_meta['coverage'],
                           signal_only = signal_only)

    dataloader = torch.utils.data.DataLoader(dataset, batch_size = 1, shuffle = False)
    
    ### model
    b_classifier = BClassifier(input_size = 1024, output_class = 1) 

    if signal_only:
        # model - signal only
        i_classifier = WaveNetModel(layers = 3, blocks = 2, input_channels = 1, kernel_size = 2, dropout = 0.2, num_classes = 1, last_channels = 32*61)
    else: 
        # model - signal + trace
        i_classifier = WaveNetModel(layers = 3, blocks = 2, input_channels = 1, kernel_size = 2, dropout = 0.2, num_classes = 1)
    
    model = DSMIL(i_classifier, b_classifier).to(device)
    model.load_state_dict(torch.load(model_file))
    
    ### feature
    model.eval()
    feat_list = []
    for n, batch in tqdm(enumerate(dataloader), total = len(dataloader)):

        # forward
        data, _, _ = batch[0].squeeze(0).float().to(device), batch[1].float().to(device), batch[2].float().to(device)
        classes, pred_bag, A, B = model.forward(data)
        feat_list.append(B.squeeze().cpu().detach().numpy().copy())

    return feat_list

def get_umap_df(data_dir1, data_dir2, model_file, signal_only = False):
    
    feat1 = get_umap_feature(data_dir1, model_file, signal_only = signal_only)
    feat2 = get_umap_feature(data_dir2, model_file, signal_only = signal_only)
    
    ubags = np.vstack(feat1+feat2)
    embedding = umap.UMAP(n_neighbors = 5, n_epochs = 1000).fit_transform(ubags)
    df = pd.DataFrame({'x': embedding[:,0], 'y': embedding[:,1], 'label': ['m6A']*len(feat1)+['Unmodified']*len(feat2)})
    
    return df

def silhouette_plot(ax, slh_vals, labels, color_list):

    y_ticks = []
    y_lower = 0
    y_upper = 0
    for i, (cluster, color) in enumerate(zip(np.unique(labels), color_list)):
        slh_vals_k = slh_vals[labels == cluster]
        slh_vals_k.sort()
        y_upper += len(slh_vals_k)
        ax.fill_betweenx(np.arange(y_lower, y_upper), 0, slh_vals_k, facecolor = color, edgecolor = color, alpha = 0.7)
        y_lower += len(slh_vals_k)

        # Get the average silhouette score 
        avg_score = np.mean(slh_vals)
        ax.axvline(avg_score, linestyle = '--', linewidth = 2, color = '#81B900')
        ax.text(avg_score+0.02, 1, str(round(avg_score, 2)), color = '#81B900', fontsize = 25)
        ax.set_yticks([])
        ax.set_xlim([-0.8, 0.8])
        ax.set_xlabel('Silhouette coefficient values', fontsize = 25)
        ax.set_ylabel('Cluster labels', fontsize = 30)
        
        ax.tick_params(labelsize = 20)

In [None]:
def plot_umap(df):

    ### plot
    sns.set_theme(style = 'white') # theme
    tab_color = sns.color_palette() # color palette
    fig, ax = plt.subplots(figsize = (10, 8)) # figure size

    sns.kdeplot(data = df, x = 'x', y = 'y', hue = 'label', linewidths = 3, palette = [tab_color[9], tab_color[6]], levels = 5, thresh = .3)

    ### legned
    custom_obj = [Line2D([0], [0], color = tab_color[9]), Line2D([0], [0], color = tab_color[6])]
    custom_legend = ax.legend(custom_obj, ['m6A', 'Unmodified'], fontsize = 20)

    for i in custom_legend.legendHandles:
        i.set_linewidth(6)

    ### axis
    ax.set_xlabel('UMAP1', fontsize = 30)
    ax.set_ylabel('UMAP2', fontsize = 30)
    ax.tick_params(labelsize = 25)

    
def plot_silhouette(df):

    ### plot
    sns.set_theme(style = 'white') # theme
    tab_color = sns.color_palette() # color palette
    fig, ax = plt.subplots(figsize = (12, 6)) # figure size

    slh_vals = silhouette_samples(df.loc[:,['x', 'y']], df.loc[:,'label'])
    label_list = df.loc[:,'label']
    color_list = [tab_color[9], tab_color[6]]

    silhouette_plot(ax, slh_vals, label_list, color_list)

    ### legned
    custom_obj = [Line2D([0], [0], color = tab_color[9]), Line2D([0], [0], color = tab_color[6])]
    custom_legend = ax.legend(custom_obj, ['m6A', 'Unmodified'], fontsize = 20, loc = 'upper left')

    for i in custom_legend.legendHandles:
        i.set_linewidth(6)

    ### axis
    ax.set_xlim([-1, 1])
    

In [None]:
### Pre-trained model
model_file1 = '/home/bo-yi/paper/m6atm_analysis/dsmil_s.pth' # pre-trained signal only model 
model_file2 = '/home/bo-yi/paper/m6atm_analysis/dsmil_f.pth' # pre-trained signal + trace model

# model_file1 = <PATH_TO/dsmil_s.pth> # pre-trained signal only model 
# model_file2 = <PATH_TO/dsmil_f.pth> # pre-trained signal + trace model

In [None]:
### Figure 2C: 100% m6A-modified vs. 0% modified (signal+trace)
data_dir_ivt100 = '/home/bo-yi/data/DRS/20230326_IVT1mix100/preprocessed' # data1
data_dir_ivt0 = '/home/bo-yi/data/DRS/20230326_IVT1mix0/preprocessed' # data2

df = get_umap_df(data_dir_ivt100, data_dir_ivt0, model_file2, signal_only = False)

# plots
plot_umap(df) # UMAP analysis
plot_silhouette(df) # Silhouette analysis

In [None]:
### Figure S4: 60% m6A-modified vs. 0% modified (signal+trace)
data_dir_ivt60 = '/home/bo-yi/data/DRS/20230326_IVT1mix60/preprocessed' # data1
data_dir_ivt0 = '/home/bo-yi/data/DRS/20230326_IVT1mix0/preprocessed' # data2

df = get_umap_df(data_dir_ivt60, data_dir_ivt0, model_file2, signal_only = False)

# plots
plot_umap(df) # UMAP analysis
plot_silhouette(df) # Silhouette analysis

In [None]:
### Figure S4: 20% m6A-modified vs. 0% modified (signal+trace)
data_dir_ivt20 = '/home/bo-yi/data/DRS/20230326_IVT1mix20/preprocessed' # data1
data_dir_ivt0 = '/home/bo-yi/data/DRS/20230326_IVT1mix0/preprocessed' # data2

df = get_umap_df(data_dir_ivt20, data_dir_ivt0, model_file2, signal_only = False)

# plots
plot_umap(df) # UMAP analysis
plot_silhouette(df) # Silhouette analysis

In [None]:
### Figure S4: 100% m6A-modified vs. 0% modified (signal only)
data_dir_ivt100 = '/home/bo-yi/data/DRS/20230326_IVT1mix100/preprocessed' # data1
data_dir_ivt0 = '/home/bo-yi/data/DRS/20230326_IVT1mix0/preprocessed' # data2

df = get_umap_df(data_dir_ivt100, data_dir_ivt0, model_file1, signal_only = True)

# plots
plot_umap(df) # UMAP analysis
plot_silhouette(df) # Silhouette analysis

### Start from CSV files