### Modules

In [1]:
# basic
import os, sys, glob
import numpy as np, pandas as pd
import matplotlib.pyplot as plt, seaborn as sns
from scipy import stats
from scipy.stats import pearsonr

In [2]:
sys.path.insert(1, '/home/bo-yi/package/m6atm/m6atm/train')
from MixBag import *
from ModelData import *

In [3]:
out_dir = '/home/bo-yi/paper'
data_dir = '/home/bo-yi/package/m6atm/m6atm/data'
model_dir = '/home/bo-yi/package/m6atm/m6atm/model'

### Fig. 4A-1

In [None]:
### data
dir_list = ['/home/bo-yi/data/DRS/20230326_IVT1mix20',
            '/home/bo-yi/data/DRS/20230326_IVT1mix40',
            '/home/bo-yi/data/DRS/20230326_IVT1mix60',
            '/home/bo-yi/data/DRS/20230326_IVT1mix80',
            '/home/bo-yi/data/DRS/20230326_IVT1mix100',]

pred_list = [pd.read_csv(os.path.join(i, 'preprocessed/results.csv'), index_col = 0) for i in dir_list]
group = [['IVT-%s%%'%(ratio)]*pred.shape[0] for pred, ratio in zip(pred_list, range(20, 101, 20))]
group = sum(group, [])

pred_table = pd.concat(pred_list, axis = 0)
pred_table['group'] = group

In [None]:
### data
ratio = pred_table.ratio.tolist()
gth = [[ratio/100]*pred.shape[0] for pred, ratio in zip(pred_list, range(0, 101, 20))]
gth = sum(gth, [])

corr = round(pearsonr(ratio, gth)[0], 2)

### plot
sns.set_theme(style = 'white') # theme
tab_color = sns.color_palette() # color palette
fig, ax = plt.subplots(figsize = (12, 6)) # figure size

sns.violinplot(x = 'group', y ='ratio', data = pred_table, size = 3)
ax.text(0.2, 1.02, 'Pearson\'s = %s'%(corr), horizontalalignment = 'center', fontsize = 20)
ax.tick_params(labelsize = 16)
ax.set(xlabel = None)

plt.ylabel('Predicted m6A ratio', fontsize = 25)

plt.savefig(os.path.join(out_dir, 'fig4a-1.png'), bbox_inches = 'tight', dpi = 300)

### Fig. 4A-2

In [None]:
n_bags = 10
bag_data, bag_label, bag_size, bag_site = load_mixed_bags(data_dir, prefix = 'ivt_ratio')

bag_data_used = []
bag_label_used = []
bag_size_used = []
bag_site_used = []
for i in range(0, len(bag_data), n_bags*2):
    
    bag_data_used.extend(bag_data[i:i+10])
    bag_label_used.extend(bag_label[i:i+10])
    bag_size_used.extend(bag_size[i:i+10])
    bag_site_used.extend(bag_site[i:i+10])
    
dataloader = get_mixed_loader(bag_data_used, bag_label_used, bag_size_used, bag_site_used, split = False)

In [None]:
# n_bags = 10
# data_dir1 = '/home/bo-yi/data/DRS/20210806_IVTm6a/preprocessed'
# data_dir2 = '/home/bo-yi/data/DRS/20210806_IVTum/preprocessed'
# bag_data, bag_label, bag_size, bag_site = get_mixed_bags(data_dir1, data_dir2, data_dir, prefix = 'ivt_ratio',
#                                                          pct_range = [0, 1], bag_size = [20, 1000], n_bags = n_bags, processes = 24)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
dsmil_pth = os.path.join(model_dir, 'dsmil_ivt.pth')
classifier_pth = os.path.join(model_dir, 'classifer_ivt.pth')
result = dsmil_pred(dsmil_pth, classifier_pth, dataloader, out_dir = out_dir, thres = 0.5, device = device)

In [None]:
### data
ratio = result.ratio.tolist()
gth = bag_label_used

corr = round(pearsonr(ratio, gth)[0], 2)

### plot
sns.set_theme(style = 'white') # theme
tab_color = sns.color_palette() # color palette
fig, ax = plt.subplots(figsize = (8, 6)) # figure size

sns.regplot(x = ratio, y = gth,
            scatter_kws = {'alpha': 0.5, 's': 10},
            line_kws = {'color': tab_color[3], 'linewidth': 2, 'alpha': 0.7})

ax.text(0.18, 1, 'Pearson\'s = %s'%(corr), horizontalalignment = 'center', fontsize = 18)
ax.set_xlabel('Ground-truth m6A ratio (IVT)', fontsize = 20)
ax.set_ylabel('Predicted m6A ratio', fontsize = 20)

plt.savefig(os.path.join(out_dir, 'fig4a-2.png'), bbox_inches = 'tight', dpi = 300)

### Fig 4C

In [None]:
os.environ['CUDA_VISIBLE_DEVICES'] = '1'

In [None]:
data_dir1 = '/home/bo-yi/data/DRS/20210806_IVTm6a/preprocessed'
data_dir2 = '/home/bo-yi/data/DRS/20210806_IVTum/preprocessed'
data_dir = '/home/bo-yi/package/m6atm/m6atm/data/ivt_coverage'
dsmil_file = '/home/bo-yi/package/m6atm/m6atm/model/dsmil_ivt_20.pth'
classifier_file = '/home/bo-yi/package/m6atm/m6atm/model/classifer_ivt.pth'

In [None]:
# coverage_list = [20, 30, 50, 100, 200, 500, 1000]
# for coverage in coverage_list:
#     _ = get_mixed_bags(data_dir1, data_dir2, data_dir, bag_size = [coverage, coverage], pct_range = [0, 1],
#                        prefix = 'ivt_c%s'%(coverage), n_bags = 5, processes = 24)

In [None]:
n_bags = 5
coverage_list = [20, 30, 50, 100, 200, 500, 1000]
corr_list = []
for coverage in coverage_list:
    
    bag_data, bag_label, size_data, site_data = load_mixed_bags(data_dir, prefix = 'ivt_c%s'%(coverage))

    bag_data_used = []
    bag_label_used = []
    bag_size_used = []
    bag_site_used = []
    for i in range(0, len(bag_data), n_bags*2):

        bag_data_used.extend(bag_data[i:i+n_bags])
        bag_label_used.extend(bag_label[i:i+n_bags])
        bag_size_used.extend(size_data[i:i+n_bags])
        bag_site_used.extend(site_data[i:i+n_bags])

    dataloader = get_mixed_loader(bag_data_used, bag_label_used, bag_size_used, bag_site_used, split = False)
    result = dsmil_pred(dsmil_file, classifier_file, dataloader, out_dir = data_dir)

    ### data
    ratio = result.ratio.tolist()
    gth = bag_label_used

    corr = round(pearsonr(ratio, gth)[0], 2)
    corr_list.append(corr)

In [None]:
##### main
sns.set_theme(style = 'white') # theme
tab_color = sns.color_palette() # color palette
fig, ax = plt.subplots(figsize = (10, 6)) # figure size

# plot the scatter plot with regression line and Pearson's coefficient
x_val = coverage_list
y_val = corr_list
ax.plot(x_val, y_val, marker = 'o', color = tab_color[2], lw = 3)

ax.set_xlabel('Read coverage', fontsize = 25)
ax.set_ylabel('Pearson\'s coefficient', fontsize = 25)
ax.set_ylim([0.5, 1])

ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)

ax.tick_params(labelsize = 20)

for i, text in enumerate(x_val):
    if text!=30:
        ax.annotate(str(text), (x_val[i]+10, y_val[i]-0.025), fontsize = 16)
    
plt.savefig(os.path.join(out_dir, 'fig4c.png'), bbox_inches = 'tight', dpi = 300)

### Fig 4D

In [None]:
pred_ivtr50 = pd.read_csv(os.path.join('/home/bo-yi/data/DRS/20230410_IVTR1m6a50', 'preprocessed/results.csv'), index_col = 0)
pred_ivtr20 = pd.read_csv(os.path.join('/home/bo-yi/data/DRS/20230410_IVTR1m6a20', 'preprocessed/results.csv'), index_col = 0)

In [None]:
pred_ivtr50 = pred_ivtr50[(pred_ivtr50.probability>=0.9) | (pred_ivtr50.probability<0.5)]
pred_ivtr20 = pred_ivtr20[(pred_ivtr20.probability>=0.9) | (pred_ivtr20.probability<0.5)]

In [None]:
ratio_50 = pred_ivtr50.ratio.tolist()
ratio_20 = pred_ivtr20.ratio.tolist()

ratio_list = ratio_20+ratio_50
group_list = ['IVTR-20%']*len(ratio_20)+['IVTR-50%']*len(ratio_50)
ratio_table = pd.DataFrame({'ratio': ratio_list, 'group': group_list})

In [None]:
### plot
sns.set_theme(style = 'white') # theme
tab_color = sns.color_palette() # color palette
fig, ax = plt.subplots(figsize = (8, 8)) # figure size

sns.boxplot(x = 'group', y ='ratio', data = ratio_table, showfliers = False)

# asterisks
x1, x2 = 0, 1
y, h, col = ratio_table['ratio'].max()*0.9, 0.05, 'k'

pval = stats.median_test(ratio_20, ratio_50)[1]
if pval < 0.001:
    symbol = '***'
elif pval < 0.01:
    symbol = '**'
elif pval < 0.05:
    symbol = '*'
else:
    symbol = 'ns'

plt.plot([x1, x1, x2, x2], [y, y+h, y+h, y], lw = 1.5, c = col)
plt.text((x1+x2)*.5, y+h, symbol, ha = 'center', va = 'bottom', color = col)

ax.tick_params(labelsize = 24)
ax.set_xlabel(None)
ax.set_ylabel('Predicted m6A ratio', size = 32)

plt.savefig(os.path.join(out_dir, 'fig4d.png'), bbox_inches = 'tight', dpi = 300)

In [None]:
np.median(ratio_20)

### Fig 4E

In [4]:
pred_hek293twt1 = pd.read_csv(os.path.join('/home/bo-yi/data/DRS/20211012_HEK293T-WT1', 'preprocessed/results.csv'), index_col = 0)
pred_hek293tko1 = pd.read_csv(os.path.join('/home/bo-yi/data/DRS/20211012_HEK293T-KO1', 'preprocessed/results.csv'), index_col = 0)

In [5]:
pred_hek293twt1 = pred_hek293twt1[pred_hek293twt1.coverage>=100]
pred_hek293tko1 = pred_hek293tko1[pred_hek293tko1.coverage>=100]

In [6]:
pred_hek293twt1 = pred_hek293twt1[pred_hek293twt1.ratio<0.5]
pred_hek293tko1 = pred_hek293tko1[pred_hek293tko1.ratio<0.5]

In [7]:
p_table = pred_hek293twt1[pred_hek293twt1.probability>=0.9]
n_table = pred_hek293tko1[pred_hek293tko1.probability<0.1]
table = p_table.merge(n_table, how = 'inner', on = ['transcript', 'position']).loc[:,['transcript', 'position']]

In [None]:
ratio_wt = pred_hek293twt1.merge(table, how = 'inner', on = ['transcript', 'position']).ratio.tolist()
ratio_ko = pred_hek293tko1.merge(table, how = 'inner', on = ['transcript', 'position']).ratio.tolist()

ratio_list = ratio_wt+ratio_ko
group_list = ['HEK293T-WT']*len(ratio_wt)+['HEK293T-METTL3_KO']*len(ratio_ko)
ratio_table = pd.DataFrame({'ratio': ratio_list, 'group': group_list})

In [None]:
### plot
sns.set_theme(style = 'white') # theme
tab_color = sns.color_palette() # color palette
fig, ax = plt.subplots(figsize = (8, 8)) # figure size

sns.boxplot(x = 'group', y ='ratio', data = ratio_table, showfliers = False)

# asterisks
x1, x2 = 0, 1
y, h, col = ratio_table['ratio'].max()*1.05, 0.01, 'k'

pval = stats.median_test(ratio_wt, ratio_ko)[1]
if pval < 0.001:
    symbol = '***'
elif pval < 0.01:
    symbol = '**'
elif pval < 0.05:
    symbol = '*'
else:
    symbol = 'ns'

plt.plot([x1, x1, x2, x2], [y, y+h, y+h, y], lw = 1.5, c = col)
plt.text((x1+x2)*.5, y+h, symbol, ha = 'center', va = 'bottom', color = col)    
    
ax.tick_params(labelsize = 24)
ax.set_xlabel(None)
ax.set_ylabel('Predicted m6A ratio', size = 32)

plt.savefig(os.path.join(out_dir, 'fig4e.png'), bbox_inches = 'tight', dpi = 300)