In [None]:
!mkdir sparse_vector

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive
Mounted at /content/drive


In [None]:
import torch
from torch.utils import data
import random
from sklearn.preprocessing import LabelBinarizer
from sklearn.model_selection import train_test_split, StratifiedKFold
from collections import Counter
import pandas as pd
import numpy as np
import scipy
from tqdm import trange
from tqdm import tqdm
from datetime import datetime
import sys
import os
import seaborn as sns
from matplotlib import pyplot as plt
from joblib import Parallel, delayed, dump, load
from matplotlib import pyplot as plt
from sparse_vector.sparse_vector import SparseVector
from scipy.signal import convolve2d, convolve
import time
from torch import nn
import torch.nn.functional as F
from sklearn.metrics import roc_auc_score, f1_score
from IPython.display import clear_output

import torch
from transformers import BertModel, BertConfig, PreTrainedTokenizer, AutoTokenizer, BertForTokenClassification, AutoModelForSequenceClassification
import collections

from transformers import utils

from torch.utils.data import DataLoader
import sklearn
from sklearn.metrics import accuracy_score
from torch.nn import CrossEntropyLoss

import gc


import warnings
warnings.filterwarnings("ignore")

In [None]:
def seq2kmer(seq, k):
    """
    Convert original sequence to kmers

    Arguments:
    seq -- str, original sequence.
    k -- int, kmer of length k specified.

    Returns:
    kmers -- str, kmers separated by space
    """
    kmer = [seq[x:x+k] for x in range(len(seq)+1-k)]
    kmers = " ".join(kmer)
    return kmers

In [None]:
class PredDataset(data.Dataset):
    def __init__(self, chroms, dna_source, intervals, tokenizer):

        self.chroms = chroms
        self.intervals = intervals
        self.tokenizer = tokenizer
        self.dna_source = dna_source


    def __len__(self):
        return len(self.intervals)

    def __getitem__(self, index):
        interval = self.intervals[index]
        chrom = interval[0]
        begin = interval[1]
        end = interval[2]

        k_mers = seq2kmer(self.dna_source[chrom][begin:end+5].upper(),6)
        encoded_k_mers = self.tokenizer.encode_plus(k_mers, add_special_tokens=False, max_length=512)["input_ids"]

        return torch.LongTensor(encoded_k_mers), (chrom, begin, end)

In [None]:
ASSEMBLY = "G4_cut_small"
chroms = [f'chr{i}' for i in list(range(1, 23)) + ['X']]
G4DNA = load(f'{ASSEMBLY}.pkl')

In [None]:
def chrom_reader(chrom):
    files = sorted([i for i in os.listdir(f'/content/drive/My Drive/DeepZ_data_creation/data/hg19_dna/') if f"{chrom}_" in i])
    return ''.join([load(f"/content/drive/My Drive/DeepZ_data_creation/data/hg19_dna/{file}") for file in files])


DNA = {chrom:chrom_reader(chrom) for chrom in tqdm(chroms)}

100%|██████████| 23/23 [00:07<00:00,  3.24it/s]


In [None]:
%%capture
!pip install einops transformers==4.27 peft omegaconf torch evaluate accelerate numpy scikit-learn Pillow textaugment scipy
!pip uninstall triton

In [None]:
from transformers import AutoTokenizer, AutoModel, BertForSequenceClassification, AutoModelForMaskedLM, AutoModelForSequenceClassification, BertForTokenClassification
import importlib

In [None]:
tokenizer = AutoTokenizer.from_pretrained('zhihan1996/DNABERT-2-117M')
model = AutoModel.from_pretrained('zhihan1996/DNABERT-2-117M', trust_remote_code=True)
gena_module_name = model.__class__.__module__
cls = getattr(importlib.import_module(gena_module_name), 'BertForSequenceClassification')
model = cls.from_pretrained('zhihan1996/DNABERT-2-117M', num_labels=2, output_attentions=True)

Explicitly passing a `revision` is encouraged when loading a configuration with custom code to ensure no malicious code has been contributed in a newer revision.
Explicitly passing a `revision` is encouraged when loading a model with custom code to ensure no malicious code has been contributed in a newer revision.
Some weights of the model checkpoint at zhihan1996/DNABERT-2-117M were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to b

In [None]:
G4_kouzine = {}
for chrom in DNA:
    G4_kouzine[chrom] = np.zeros(len(DNA[chrom]), dtype = bool)


with open("/content/EndoQuad_hg19.bed")as f:
    for idx, line in enumerate(f):
        if idx>0:
            chrom, start, end, _ , _ , _ = line.strip().split()
            if chrom in G4_kouzine:
                G4_kouzine[chrom][int(start):int(end)] = 1

dump(G4_kouzine, 'endo.pkl')

['endo.pkl']

In [None]:
G4 = load('endo.pkl')

In [None]:
width = 512

np.random.seed(42)

ints_in = []
ints_out = []


for chrm in chroms:
    for st in trange(0, G4[chrm].shape[0] - width, width):
        interval = [st, min(st + width, G4[chrm].shape[0])]
        if G4[chrm][interval[0]: interval[1]].any():
            ints_in.append([chrm, int(interval[0]), int(interval[1]), 1])
        else:
            ints_out.append([chrm, int(interval[0]), int(interval[1]), 0])

print(len(ints_in))
print(len(ints_out))

ints_in_full = ints_in
ints_out_full = ints_out

100%|██████████| 486817/486817 [00:01<00:00, 276975.29it/s]
100%|██████████| 474998/474998 [00:02<00:00, 214787.85it/s]
100%|██████████| 386762/386762 [00:01<00:00, 257370.65it/s]
100%|██████████| 373348/373348 [00:01<00:00, 204120.06it/s]
100%|██████████| 353350/353350 [00:01<00:00, 279614.31it/s]
100%|██████████| 334209/334209 [00:01<00:00, 190704.34it/s]
100%|██████████| 310817/310817 [00:01<00:00, 287314.54it/s]
100%|██████████| 285867/285867 [00:00<00:00, 288950.08it/s]
100%|██████████| 275807/275807 [00:00<00:00, 286630.30it/s]
100%|██████████| 264716/264716 [00:00<00:00, 280676.67it/s]
100%|██████████| 263684/263684 [00:01<00:00, 161043.16it/s]
100%|██████████| 261429/261429 [00:00<00:00, 280302.68it/s]
100%|██████████| 224941/224941 [00:00<00:00, 279753.67it/s]
100%|██████████| 209667/209667 [00:00<00:00, 276833.80it/s]
100%|██████████| 200256/200256 [00:00<00:00, 285379.27it/s]
100%|██████████| 176474/176474 [00:00<00:00, 287586.97it/s]
100%|██████████| 158584/158584 [00:00<00

184812
5745459





In [None]:
ints_in = ints_in_full
ints_out = [ints_out_full[i] for i in np.random.choice(range(len(ints_out_full)),
                                                    size=len(ints_in) * 2, replace=False)]
# ints_out = ints_out_full

print(len(ints_in))
print(len(ints_out))

184812
369624


In [None]:
equalized = ints_in + ints_out

In [None]:
divisions = list(StratifiedKFold(5, shuffle=True,
                                 random_state=42).split(equalized, [f"{elem[3]}_{elem[0]}"
                                         for i, elem
                                         in enumerate(equalized)]))

In [None]:
dump([equalized, divisions], 'hg_divisions_g4.pkl', 3)

['hg_divisions_kouzine_g4.pkl']

In [None]:
width = 128
pad = 192
k_mer_pad = 5

def final_prediction(chrom):

    intervals = []
    ends = []


    prediction = np.zeros(len(DNA[chrom]), dtype=np.float32)


    for st in range(0, len(DNA[chrom]) - 512, width):
        interval = [st, min(st + 512, len(DNA[chrom]))]
        intervals.append([chrom, interval[0], interval[1]])

    pred_dataset = PredDataset(chroms, DNA, intervals,
                               tokenizer)

    params = {'batch_size':32, 'num_workers':5, 'shuffle':False}
    load_predict = data.DataLoader(pred_dataset, **params)



    model.to(device)
    with torch.no_grad():
        for input_ids, intervals in tqdm(load_predict):
            input_ids = input_ids.to(device)
            outputs = torch.softmax(model(input_ids = input_ids)['logits'],axis = -1).cpu().numpy()[:,:,1]
            for ind, interval in enumerate(zip(intervals[0], intervals[1], intervals[2])):
                if interval[1] == 0:
                    prediction[interval[1]:interval[2]] = outputs[ind]
                else:
                    prediction[interval[1]+pad:interval[2]] = outputs[ind, pad:]

    dump(prediction, f'pred_DNABERT2_{model}_{chrom}', 3)

In [None]:
device = "cuda"
torch.cuda.empty_cache()
model.to(device)
model.eval()

BertForTokenClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(4096, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, ele

In [None]:
for chrom in chroms[:]:
    print(f"BEGIN CHROM {chrom}")
    final_prediction(chrom)

In [None]:
equalized, divisions = load('hg_divisions_endo.pkl')

In [None]:
com_len = sum([len(DNA[chrom]) for chrom in chroms])
sums = []

for chrom in tqdm(chroms):
    loc_sum = []
    for model_num in range(5):
        vec = load(f"new_mod_hg_{model_num}_{chrom}_DNABERT2")
        loc_sum.append(vec.sum())
    sums.append(loc_sum)

multipliers = np.array(sums).sum(axis=0) / com_len

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 21/21 [06:08<00:00, 17.56s/it]


In [None]:
for chrom in tqdm(chroms):
    vecs = np.array([load(f"new_mod_hg_{model_num}_{chrom}_DNABERT2")
                     for model_num in range(5)])
    res_vec = (vecs / multipliers[:, None]) * multipliers.mean()
    mean_vec = res_vec.mean(axis=0)

    test_ints = []
    for MODEL_NUMBER in range(5):
        train_inds, test_inds = divisions[MODEL_NUMBER]
        train_intervals, test_intervals = [equalized[i] for i in train_inds], [equalized[i] for i in test_inds]
        test_ints.extend([(MODEL_NUMBER, inter) for inter in test_intervals if inter[0] == chrom])

    for model_num, inters in test_ints:
        mean_vec[inters[1]: inters[2]] = res_vec[model_num, inters[1]: inters[2]]
    dump(mean_vec, f"new_mod_hg_{chrom}_DNABERT2", 3)

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 21/21 [21:42<00:00, 62.04s/it]


In [None]:
blacklist = {c:[] for c in chroms}
with open("endo.bed")as f:
    for line in f:
        chrom, start, end =  line.strip().split()
        blacklist[chrom].append((int(start), int(end)))

In [None]:
for chrom in tqdm(chroms):
    vec = load(f"new_mod_hg_{chrom}_DNABERT2")
    for s_idx, e_idx in blacklist[chrom]:
        vec[s_idx:e_idx] = 0
    dump(vec, f"new_mod_hg_{chrom}_DNABERT2", 3)

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 21/21 [15:19<00:00, 43.77s/it]


In [None]:
all_pred = []
all_true = []
for chrom in tqdm(chroms):
    all_pred.append(load(f"new_mod_hg_{chrom}_DNABERT2"))
    all_true.append(G4DNA[chrom][:].astype(int))

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 21/21 [02:38<00:00,  7.56s/it]


In [None]:
roc_auc_score(np.concatenate(all_true), np.concatenate(all_pred))

0.8362092912637998

In [None]:
print(sklearn.metrics.classification_report(np.concatenate(all_true), np.concatenate(all_pred)>0.5, digits=4))

              precision    recall  f1-score   support

           0     0.9998    0.9970    0.9984 2654081885
           1     0.0365    0.3707    0.0664    813333

    accuracy                         0.9968 2654895218
   macro avg     0.5181    0.6839    0.5324 2654895218
weighted avg     0.9995    0.9968    0.9981 2654895218

