# Cytosine methylation prediction with neural networks

## Import packages and modules

In [2]:
#dataset and model architectures
from WGBSDataset import WGBSDataset
from NSDataset6 import NSDataset
from TCN_model import TCN_model
from ConvNeXt_model112 import ConvNeXt_model
from Transformer_model_window16 import Transformer_model

#other modules and packages
import os
import numpy as np
import h5py
import math
import torch
import pandas as pd
from torch import split
from time import time
from torch.utils.data import Subset
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from sklearn.metrics import precision_recall_curve
from sklearn.metrics import roc_curve
from sklearn import metrics
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import pytorch_lightning as pl
import torchmetrics
from torch.nn.utils.rnn import pad_sequence
from sklearn.metrics import average_precision_score
import random

## Initialization: only has to be done once

### Prepare the WGBS data

In [7]:
#directory to the nanopore sequencing data
NS_dir = '/home/yarivl/temp/'

#directories to WGBS coverage files
coverage_dir_1 = '/home/yarivl/WGBS_data_replicate1/SRR8936101.1_1_bismark_bt2_pe.deduplicated.bismark.cov'
coverage_dir_2 = '/home/yarivl/WGBS_data_replicate2/SRR8936102.1_1_bismark_bt2_pe.deduplicated.bismark.cov'

#directories to store processed files
WGBS_dir_1 = '/home/yarivl/thesis/WGBS_data/labels_per_chromosome_1/'
WGBS_dir_2 = '/home/yarivl/thesis/WGBS_data/labels_per_chromosome_2/'
WGBS_dir_merged = '/home/yarivl/thesis/WGBS_data/labels_per_chromosome_merged/'
label_dict_dir = '/home/yarivl/thesis/WGBS_data/npzs/'

In [None]:
#create directory for storing all necessary files
if not os.path.exists(WGBS_dir_1):
    os.makedirs(WGBS_dir_1)
if not os.path.exists(WGBS_dir_2):
    os.makedirs(WGBS_dir_2)
if not os.path.exists(WGBS_dir_merged):
    os.makedirs(WGBS_dir_merged)
if not os.path.exists(label_dict_dir):
    os.makedirs(label_dict_dir)
    
#choose parameters for dataset creation
min_coverage = 3
interval_size = 1000000

#create a dataset object
WGBS_dataset = WGBSDataset(fast5_dir = NS_dir)
    
#filter WGBS data
WGBS_dataset.filter_WGBS(coverage_dir_1, WGBS_dir_1, min_coverage, upper_cutoff = 90.0, lower_cutoff = 0.0)
WGBS_dataset.filter_WGBS(coverage_dir_2, WGBS_dir_2, min_coverage, upper_cutoff = 90.0, lower_cutoff = 0.0)

#merge WGBS data
WGBS_dataset.merge_WGBS(WGBS_dir_1, WGBS_dir_2, WGBS_dir_merged, min_coverage)

#create label dictionaries
WGBS_dataset.prepare_WGBS(WGBS_dir_merged, label_dict_dir, min_coverage, interval_size)

### Create dataset object

In [8]:
dataset = NSDataset(fast5_dir = NS_dir, label_dict_dir = None, CpG = True) #17:10 appr 13 min

In [None]:
torch.save(EC_dataset, 'E_coli_train_val_test_split/dataset.pt')

In [9]:
train_inds = []
val_inds = []
test_inds = []
#lens = np.sort(dataset.lengths)

for i in range(len(dataset)):
    #if lens[i] <= 5000:
    rand = random.random()
    if rand < 0.8:
        train_inds.append(i)
    elif rand < 0.9:
        val_inds.append(i)
    else:
        test_inds.append(i)

In [None]:
np.save('ConvNeXt_E_coli/train_inds.npy', train_inds)
np.save('ConvNeXt_E_coli/test_inds.npy', test_inds)
np.save('ConvNeXt_E_coli/val_inds.npy', val_inds)

# Necessary functions

In [13]:
#collate function for dataloader
def custom_collate(data):
    #unpack data
    X = [seq for seq, target in data]
    Y = [target for seq, target in data]
    
    #pad data
    x = seqs_padded_batched = pad_sequence(X)
    x = torch.transpose(x, dim0 = 1, dim1 = 2)
    x = torch.transpose(x, dim0 = 0, dim1 = 2)
    
    #pad labels
    y = targets_padded_batched = pad_sequence(Y, padding_value = 2.0)
    y = torch.transpose(y, dim0 = 0, dim1 = 1)
    return x, y

## Load data

In [4]:
dataset = torch.load('/home/yarivl/thesis/E_coli_train_val_test_split/dataset_E_coli.pt')

In [5]:
train_inds = np.load('E_coli_train_val_test_split/train_inds.npy')
val_inds = np.load('E_coli_train_val_test_split/val_inds.npy')
test_inds = np.load('E_coli_train_val_test_split/test_inds.npy')

In [11]:
len(test_inds) + len(val_inds) + len(train_inds)

15184

In [10]:
train = Subset(dataset, train_inds)
val = Subset(dataset, val_inds)
test = Subset(dataset, test_inds)

## Train model

In [None]:
#create dataloaders
train_loader = DataLoader(train, batch_size = 4, num_workers = 16, collate_fn = custom_collate)
val_loader = DataLoader(val, batch_size = 4, num_workers = 16, collate_fn = custom_collate)

#model hyperparameters
dims = [7, 16, 32, 64, 64]
dropout = 0.2
heads = 4
window = 7

# model
model = Transformer_model(dims, dropout, heads, window)

# training
trainer = pl.Trainer(accelerator = "gpu", devices = "1", max_epochs = 100)
trainer.fit(model, train_loader, val_loader)

  rank_zero_deprecation(
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name               | Type                   | Params
--------------------------------------------------------------
0 | tp                 | TransformerPreparation | 0     
1 | connection_layers  | ModuleList             | 6.9 K 
2 | transformer_layers | ModuleList             | 115 K 
3 | decoder            | Sequential             | 65    
--------------------------------------------------------------
122 K     Trainable params
0         Non-trainable params
122 K     Total params
0.490     Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

Training: -1it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

## Store model

In [9]:
path = '/home/yarivl/thesis/final_models/Transformer_E_coli_2'
torch.save(model, path)

In [10]:
trainer.save_checkpoint('/home/yarivl/thesis/final_models/Transformer_E_coli_2.ckpt')

## Resume training if needed

In [None]:
#create dataloaders
train_loader = DataLoader(train, batch_size = 4, num_workers = 16, collate_fn = custom_collate)
val_loader = DataLoader(val, batch_size = 4, num_workers = 16, collate_fn = custom_collate)

dims = [7, 96, 192, 192, 384, 384]
dropout = 0.2
heads = 4
window = 7

# model
model = Transformer_model(dims, dropout, heads, window)
trainer = pl.Trainer(accelerator = "gpu", devices = "1", max_epochs = 100, resume_from_checkpoint = '/home/yarivl/thesis/final_models/Transformer_E_coli_1.ckpt')

# automatically restores model, epoch, step, LR schedulers, apex, etc...
trainer.fit(model, train_loader, val_loader)

  rank_zero_deprecation(
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
Restoring states from the checkpoint file at /home/yarivl/thesis/final_models/Transformer_E_coli_1.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
Restored all states from the checkpoint file at /home/yarivl/thesis/final_models/Transformer_E_coli_1.ckpt

  | Name               | Type                   | Params
--------------------------------------------------------------
0 | tp                 | TransformerPreparation | 0     
1 | connection_layers  | ModuleList             | 278 K 
2 | transformer_layers | ModuleList             | 4.5 M 
3 | decoder            | Sequential             | 385   
--------------------------------------------------------------
4.8 M     Trainable params
0         Non-trainable params
4.8 M     Total params
19.307    Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

Training: -1it [00:00, ?it/s]



## Evaluate model

### Calculate precision, recall, false positive rate, true positive rate

In [12]:
del model
torch.cuda.empty_cache()

In [15]:
model.eval() #17:33 3 min
y_true = []
y_hat = []

for i in range(len(test)):
    #get predictions
    x = test[i][0]
    x = torch.transpose(x, dim0 = 0, dim1 = 1)
    x = torch.unsqueeze(x, 0)
    y_pred = model(x)
    y_pred = y_pred.detach().numpy()[0]
    
    #get labels
    y_lab = np.array(test[i][1])

    #get mask
    mask = y_lab != 2
    
    y_true.append(y_lab[mask])
    y_hat.append(y_pred[mask])
    
y_true = np.concatenate(y_true)
y_hat = np.concatenate(y_hat)

#calculate precision, recall, false positive rate, true positive rate
fpr, tpr, thresholds_ROC = roc_curve(y_true, y_hat)
precision, recall, thresholds_PR = precision_recall_curve(y_true, y_hat)

#calculate area under the curve
AUC_ROC = metrics.auc(fpr, tpr)
AUC_PR = average_precision_score(y_true, y_hat)
print("AUC_ROC: " + str(AUC_ROC))
print("AUC_PR: " + str(AUC_PR))

y_hat_test = np.copy(y_hat)
y_true_test = np.copy(y_true)

AUC_ROC: 0.7527623462383244
AUC_PR: 0.7225570965933833


### Plot ROC curve

In [None]:
fig, ax = plt.subplots()
ax.set_xlabel('False positive rate')
ax.set_ylabel('True positive rate')
ax.set_title('Receiver Operating Characteristics curve')
ax.axvline(x = 0, color='black', linestyle='--', linewidth = 1)
ax.axhline(y = 1, color='black', linestyle='--', linewidth = 1)
ax.plot(fpr, tpr)

plt.show()

### Plot PR curve

In [None]:
import matplotlib.pyplot as plt
fig, ax = plt.subplots()
ax.set_xlabel('Recall')
ax.set_ylabel('Precision')
ax.set_title('Precision-Recall curve')
ax.axvline(x = 1, color='black', linestyle='--', linewidth = 1)
ax.axhline(y = 1, color='black', linestyle='--', linewidth = 1)
ax.plot(recall, precision)
#ax.plot(recall[:1935769], thresholds_PR, color = 'orange')

plt.show()

In [17]:
from sklearn.metrics import accuracy_score
from sklearn.metrics import precision_score
from sklearn.metrics import recall_score

In [18]:
acc = accuracy_score(y_true, np.round(y_hat))
prec = precision_score(y_true, np.round(y_hat))
rec = recall_score(y_true, np.round(y_hat))

In [19]:
print("accuracy: " + str(acc))
print("precision: " + str(prec))
print("recall: " + str(rec))

accuracy: 0.6893889572846702
precision: 0.6823890763256731
recall: 0.6829499047534101


In [20]:
test_model.n_params()

[('layers.0.weight', 672),
 ('layers.0.bias', 96),
 ('layers.1.conv.weight', 1248),
 ('layers.1.conv.bias', 96),
 ('layers.1.norm.weight', 96),
 ('layers.1.norm.bias', 96),
 ('layers.1.pointwise_net.0.weight', 36864),
 ('layers.1.pointwise_net.0.bias', 384),
 ('layers.1.pointwise_net.3.weight', 36864),
 ('layers.1.pointwise_net.3.bias', 96),
 ('layers.2.weight', 10752),
 ('layers.2.bias', 112),
 ('layers.3.conv.weight', 1456),
 ('layers.3.conv.bias', 112),
 ('layers.3.norm.weight', 112),
 ('layers.3.norm.bias', 112),
 ('layers.3.pointwise_net.0.weight', 50176),
 ('layers.3.pointwise_net.0.bias', 448),
 ('layers.3.pointwise_net.3.weight', 50176),
 ('layers.3.pointwise_net.3.bias', 112),
 ('decoder.0.weight', 112),
 ('decoder.0.bias', 1),
 ('total', 190193)]

## Get predictions on validation set

In [None]:
model.eval() #17:33 3 min
y_true = []
y_hat = []

for i in range(len(val)):
    #get predictions
    x = val[i][0]
    x = torch.transpose(x, dim0 = 0, dim1 = 1)
    x = torch.unsqueeze(x, 0)
    y_pred = model(x)
    y_pred = y_pred.detach().numpy()[0]
    
    #get labels
    y_lab = np.array(val[i][1])

    #get mask
    mask = y_lab != 2
    
    y_true.append(y_lab[mask])
    y_hat.append(y_pred[mask])
    
y_true = np.concatenate(y_true)
y_hat = np.concatenate(y_hat)

#calculate precision, recall, false positive rate, true positive rate
fpr, tpr, thresholds_ROC = roc_curve(y_true, y_hat)
precision, recall, thresholds_PR = precision_recall_curve(y_true, y_hat)

#calculate area under the curve
AUC_ROC = metrics.auc(fpr, tpr)
AUC_PR = average_precision_score(y_true, y_hat)
print("AUC_ROC: " + str(AUC_ROC))
print("AUC_PR: " + str(AUC_PR))

y_hat_val = np.copy(y_hat)
y_true_val = np.copy(y_true)

## Store test and validation set predictions

In [None]:
np.save('/home/yarivl/thesis/results/Transformer/E_coli/y_hat_test.npy', y_hat_test)
np.save('/home/yarivl/thesis/results/Transformer/E_coli/y_true_test.npy', y_true_test)
np.save('/home/yarivl/thesis/results/Transformer/E_coli/y_hat_val.npy', y_hat_val)
np.save('/home/yarivl/thesis/results/Transformer/E_coli/y_true_val.npy', y_true_val)