In [1]:
import os
import gc
import pickle
import argparse
import json
import shutil

import numpy as np
import pandas as pd

from tqdm import tqdm

import pysam

import torch
from torch.utils.data import DataLoader, IterableDataset
from transformers import PreTrainedTokenizerFast, DataCollatorForLanguageModeling

from DNABERT2.configuration_bert import BertConfig
from DNABERT2.bert_layers import BertForMaskedLM

import helpers.misc as misc                #miscellaneous functions
import helpers.train_eval as train_eval    #train and evaluation

%load_ext autoreload
%autoreload 2

In [2]:
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:512"

In [3]:
if torch.cuda.is_available():
    device = torch.device('cuda')
    print(f'\nCUDA device: {torch.cuda.get_device_name(0)}\n')
else:
    device = torch.device('cpu')
    print('\nCUDA device: CPU\n')
    #raise Exception('CUDA is not found')


CUDA device: Tesla V100S-PCIE-32GB



In [4]:
datadir = '/lustre/groups/epigenereg01/workspace/projects/vale/mlm/'

In [5]:
class SeqDataset(IterableDataset):
    
    def __init__(self, fasta_fa, seq_df):
        
        if fasta_fa:
            self.fasta = pysam.FastaFile(fasta_fa)
        else:
             self.fasta = None

        self.seq_df = seq_df
        self.start = 0
        self.end = len(self.seq_df)
        
    def __len__(self):
        return len(self.seq_df)
                
    def __iter__(self):
        
        #worker_total_num = torch.utils.data.get_worker_info().num_workers
        #worker_id = torch.utils.data.get_worker_info().id
        
        for seq_idx in range(self.start,self.end):
            
            if self.fasta:
                seq = self.fasta.fetch(self.seq_df.iloc[seq_idx].seq_name).upper()
            else:
                seq = self.seq_df.iloc[seq_idx].seq.upper()
    
            #species_label = self.seq_df.iloc[idx].species_label
            
            seq = seq.replace('-','')

            tokenized_seq = tokenizer(seq, add_special_tokens=False)['input_ids']

            #N_tokens_overlap=np.random.randint(low=0,high=input_params.max_overlap_tokens)

            tokenized_chunks, _ = misc.get_chunks(tokenized_seq, 
                                                   N_tokens_chunk=input_params.max_tokens, 
                                                   N_tokens_overlap=input_params.max_overlap_tokens,
                                                   tokenizer_cls_token_id=tokenizer.cls_token_id,
                                                   tokenizer_eos_token_id=tokenizer.sep_token_id,
                                                   tokenizer_pad_token_id=tokenizer.pad_token_id,
                                                   padding=True)

            for tokenized_chunk in tokenized_chunks:

                attention_mask = [1 if token_id!=tokenizer.pad_token_id else 0 for token_id in tokenized_chunk]
                
                tokenized_chunk = {'input_ids':tokenized_chunk,
                                   'seq_idx':seq_idx,
                                   'token_type_ids':[0]*len(tokenized_chunk), 
                                   'attention_mask':attention_mask}
                
                yield tokenized_chunk
                        
    def close(self):
        self.fasta.close()

In [115]:
input_params = misc.dotdict({})

input_params.max_overlap_tokens = 25
input_params.max_tokens = 127
 
input_params.train_chunks = 8
input_params.batch_size = 16
input_params.weight_decay = 1e-5
input_params.max_lr = 1e-4
 
input_params.fasta = datadir + 'fasta/241_mammals.shuffled.fa'
#input_params.fasta = datadir + 'fasta/Homo_sapiens_rna.fa'
#input_params.fasta = datadir + 'fasta/Homo_sapiens_dna_fwd.fa'

input_params.output_dir = './test/'
input_params.tot_epochs = 10
input_params.val_fraction = 0.1
input_params.validate_every = 1
input_params.save_at = [2]
input_params.save_at = ['-1']

input_params.save_at = misc.list2range(input_params.save_at)

In [116]:
seq_df = pd.read_csv(input_params.fasta + '.fai', header=None, sep='\t', usecols=[0,1], names=['seq_name','seq_len'])

#seq_df = seq_df.iloc[:3000]

#seq_df['n_chunks'] = seq_df.seq_len.apply(lambda x:np.ceil(x / (input_params.max_tokens-2-input_params.overlap_bp))).astype(int)

seq_df['species_name'] = seq_df.seq_name.apply(lambda x:x.split(':')).apply(lambda x:x[1] if len(x)==2 else 'Homo_sapiens')
#seq_df['species_name']  = 'Homo_sapiens'

all_species = sorted(seq_df.species_name.unique())

if not input_params.species_agnostic:
    species_encoding = {species:idx for idx,species in enumerate(all_species)}
else:
    species_encoding = {species:0 for species in all_species}
    
seq_df['species_label'] = seq_df.species_name.map(species_encoding)

#seq_df = seq_df.sample(frac = 1., random_state = 1) #DO NOT SHUFFLE, otherwise too slow

0it [02:54, ?it/s]
acc: 0.000343, masked acc: 0.00068, loss: 8.439:   0%|█                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                             | 752/425668 [01:45<16:29:31,  7.16it/s]
acc: 0.000642, masked acc: 0.000783, loss: 8.441:   0%|▏                                                                                                                                                                                                                                

In [117]:
new_tokens = ["NNNNNN"]

tokenizer = PreTrainedTokenizerFast(tokenizer_file="./DNABERT2/tokenizer.json",
mask_token = '[MASK]', pad_token = '[PAD]', sep_token = '[SEP]', cls_token = '[CLS]', unk_token = '[UNK]',)
tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = True

tokenizer.add_tokens(new_tokens)

1

In [118]:
config = BertConfig.from_pretrained('./DNABERT2/config.json')

config.vocab_size = config.vocab_size + len(new_tokens)

config.max_position_embeddings = max(input_params.max_tokens,512)
config.alibi_starting_size = max(input_params.max_tokens,512)

model = BertForMaskedLM(config).to(device)

model.resize_token_embeddings(len(tokenizer))

Embedding(4097, 768, padding_idx=0)

In [119]:
#model_dir = './DNABERT2/'
#model_dir = '/lustre/groups/epigenereg01/workspace/projects/vale/mlm/models/dnabert2-3utr/checkpoints/epoch_18/'
#
#tokenizer = PreTrainedTokenizerFast(tokenizer_file=model_dir + "tokenizer.json",
#mask_token = '[MASK]', pad_token = '[PAD]', sep_token = '[SEP]', cls_token = '[CLS]', unk_token = '[UNK]',)
#config = BertConfig.from_pretrained(model_dir + 'config.json')
#model = BertForMaskedLM(config).to(device)
#
#collate_fn = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=True, mlm_probability=0.15)

#model_dir = '/lustre/groups/epigenereg01/workspace/projects/vale/mlm/models/dnabert2-3utr-old/dnabert2-3utr/checkpoints/epoch_20/'

#test_dataset = SeqDataset(input_params.fasta, seq_df)
#test_dataloader = DataLoader(dataset = test_dataset, batch_size = 64, 
#                             num_workers = 1,  worker_init_fn=misc.worker_init_fn, collate_fn = collate_fn, shuffle = False)
#
#val_metrics =  train_eval.model_eval(model, None, test_dataloader, device,
#                    silent = False)

In [120]:
collate_fn = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=True, mlm_probability=0.15)

In [121]:
N_train = int(len(seq_df)*(1-input_params.val_fraction))       
train_df, test_df = seq_df.iloc[:N_train], seq_df.iloc[N_train:]

train_chunk = np.repeat(list(range(input_params.train_chunks)),repeats = N_train // input_params.train_chunks + 1 )
train_df['train_chunk'] = train_chunk[:N_train]

train_dataset = SeqDataset(input_params.fasta, train_df)
train_dataloader = DataLoader(dataset = train_dataset, batch_size = input_params.batch_size, 
                              num_workers = 1, worker_init_fn=misc.worker_init_fn, collate_fn = collate_fn, shuffle = False)

test_dataset = SeqDataset(input_params.fasta, test_df)
test_dataloader = DataLoader(dataset = test_dataset, batch_size = input_params.batch_size, 
                             num_workers = 1,  worker_init_fn=misc.worker_init_fn, collate_fn = collate_fn, shuffle = False)

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  train_df['train_chunk'] = train_chunk[:N_train]


In [122]:
epoch=4
test_dataset.seq_df = train_df[train_df.train_chunk == (epoch-1) % input_params.train_chunks]
test_dataset.end = len(test_dataset.seq_df)

In [123]:
test_dataloader = DataLoader(dataset = test_dataset, batch_size = 256, 
                             num_workers = 1,  worker_init_fn=misc.worker_init_fn, collate_fn = collate_fn, shuffle = False)

In [None]:
val_metrics =  train_eval.model_eval(model, None, test_dataloader, device,
                    silent = False)

  0%|                                                                                                         …

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


In [None]:
#from transformers import AutoTokenizer
#
#pretrained_dir = '/lustre/groups/epigenereg01/workspace/projects/vale/MLM/dnabert2/DNABERT-2-117M/'
#tokenizer = AutoTokenizer.from_pretrained(pretrained_dir,trust_remote_code=True)
#model = BertForMaskedLM.from_pretrained(pretrained_dir).to(device)
#
#val_metrics =  train_eval.model_eval(model, None, test_dataloader, device,
#                    silent = False)

In [14]:
gc.collect()
torch.cuda.empty_cache()

In [15]:
model_params = [p for p in model.parameters() if p.requires_grad]

optimizer = torch.optim.AdamW(model_params, betas=(0.9,0.98), eps=1e-6,
                             lr = input_params.max_lr, 
                             weight_decay = input_params.weight_decay)

In [16]:
#input_params.checkpoint_dir = '/lustre/groups/epigenereg01/workspace/projects/vale/mlm/dnabert2-3utr/single_gpu/checkpoints/epoch_4/'

In [17]:
scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, base_lr = 0, max_lr = input_params.max_lr,
                                             step_size_up = 30000, step_size_down = 470000, cycle_momentum=False,
                                             last_epoch = -1)

In [18]:
last_epoch = 0

if input_params.checkpoint_dir:

    model = BertForMaskedLM.from_pretrained(input_params.checkpoint_dir).to(device)
    tokenizer = PreTrainedTokenizerFast.from_pretrained(input_params.checkpoint_dir)

    tokenizer = PreTrainedTokenizerFast.from_pretrained(input_params.checkpoint_dir)

    if os.path.isfile(input_params.checkpoint_dir + '/optimizer.pt'):
            optimizer.load_state_dict(torch.load(input_params.checkpoint_dir + '/optimizer.pt'))
            scheduler.load_state_dict(torch.load(input_params.checkpoint_dir + '/scheduler.pt'))

    last_epoch = int(input_params.checkpoint_dir.rstrip('/').split('_')[-1]) #infer previous epoch from input_params.checkpoint_dir

In [19]:
def metrics_to_str(metrics):
    loss, total_acc, masked_acc = metrics
    return f'loss: {loss:.4}, total acc: {total_acc:.3f}, masked acc: {masked_acc:.3f}'

In [20]:
from IPython.display import clear_output

clear_output()

#from utils.misc import print    #print function that displays time

if not input_params.test:

    for epoch in range(last_epoch+1, input_params.tot_epochs+1):

        print(f'EPOCH {epoch}: Training...')

        train_dataset.seq_df = train_df[train_df.train_chunk == (epoch-1) % input_params.train_chunks]
        train_dataset.end = len(train_dataset.seq_df)
        
        print(f'using train samples: {list(train_dataset.seq_df.index[[0,-1]])}')

        train_metrics = train_eval.model_train(model, optimizer, train_dataloader, device, scheduler=scheduler,
                            silent = False)
        
        print(f'epoch {epoch} - train ({scheduler.last_epoch+1} iterations), {metrics_to_str(train_metrics)}')

        if epoch in input_params.save_at or -1 in input_params.save_at: #save model weights

            output_dir = misc.save_model_weights(model, optimizer, scheduler, weights_dir, epoch, input_params.save_at)
            _ = os.system('cp ./DNABERT2/*.py ' + output_dir) 
            
        if input_params.val_fraction>0 and ( epoch==input_params.tot_epochs or
                            (input_params.validate_every and epoch%input_params.validate_every==0)):

            print(f'EPOCH {epoch}: Validating...')

            val_metrics =  train_eval.model_eval(model, optimizer, test_dataloader, device,
                    silent = False)

            print(f'epoch {epoch} - validation, {metrics_to_str(val_metrics)}')
            

EPOCH 1: Training...
using train samples: [0, 425667]


acc: 0.000153, masked acc: 0.00104, loss: 8.257:   1%|████▋                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                         | 3301/425668 [03:13<6:53:41, 17.02it/s]

OutOfMemoryError: CUDA out of memory. Tried to allocate 768.00 MiB (GPU 0; 19.50 GiB total capacity; 17.54 GiB already allocated; 281.88 MiB free; 18.58 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF