In [1]:
import os
import gc
import pickle
import argparse
import json

import numpy as np
import pandas as pd

from tqdm import tqdm

import pysam

import torch
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, IterableDataset
from transformers import BertConfig, BertForMaskedLM

import helpers.misc as misc                #miscellaneous functions
import helpers.train_eval as train_eval    #train and evaluation
from DNABERT.src.transformers.tokenization_dna import DNATokenizer

%load_ext autoreload
%autoreload 2

In [2]:
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:512"

In [3]:
if torch.cuda.is_available():
    device = torch.device('cuda')
    print('\nCUDA device: GPU\n')
else:
    device = torch.device('cpu')
    print('\nCUDA device: CPU\n')
    #raise Exception('CUDA is not found')


CUDA device: GPU



In [4]:
datadir = '/lustre/groups/epigenereg01/workspace/projects/vale/mlm/'

In [5]:
def kmers_stride1(seq, k=6):
    # splits a sequence into overlapping k-mers
    return [seq[i:i + k] for i in range(0, len(seq)-k+1)]  

In [6]:
class SeqDataset(IterableDataset):
    
    def __init__(self, fasta_fa, seq_df):
        
        if fasta_fa:
            self.fasta = pysam.FastaFile(fasta_fa)
        else:
             self.fasta = None

        self.seq_df = seq_df
        self.start = 0
        self.end = len(self.seq_df)
        
    def __len__(self):
        return len(self.seq_df)
                
    def __iter__(self):
        
        #worker_total_num = torch.utils.data.get_worker_info().num_workers
        #worker_id = torch.utils.data.get_worker_info().id
        
        for seq_idx in range(self.start,self.end):
            
            if self.fasta:
                seq = self.fasta.fetch(self.seq_df.iloc[seq_idx].seq_name).upper()
            else:
                seq = self.seq_df.iloc[seq_idx].seq.upper()
    
            #species_label = self.seq_df.iloc[idx].species_label
            
            seq = seq.replace('-','')
            
            if len(seq)<6:
                    continue
                
            k_merized_seq = kmers_stride1(seq)

            tokenized_seq = tokenizer.encode_plus(k_merized_seq, add_special_tokens=False)['input_ids']

            #N_tokens_overlap=np.random.randint(low=0,high=input_params.max_overlap_tokens),
            
            tokenized_chunks, _ = misc.get_chunks(tokenized_seq, 
                                                   N_tokens_chunk=input_params.max_tokens, 
                                                   N_tokens_overlap=input_params.max_overlap_tokens,
                                                   tokenizer_cls_token_id=tokenizer.cls_token_id,
                                                   tokenizer_eos_token_id=tokenizer.sep_token_id,
                                                   tokenizer_pad_token_id=None,
                                                   padding=False)

            for tokenized_chunk in tokenized_chunks:
                tokenized_chunk = torch.LongTensor(tokenized_chunk)
                yield tokenized_chunk,seq_idx
                        
    def close(self):
        self.fasta.close()

In [7]:
input_params = misc.dotdict({})

input_params.max_tokens = 512
input_params.max_overlap_tokens = 128
input_params.mlm_probability = 0.15

input_params.train_chunks = 64
input_params.batch_size = 16
input_params.weight_decay = 0.01
input_params.max_lr = 4e-4
input_params.step_size_up = 10000
input_params.step_size_down = 200000

input_params.fasta = datadir + 'fasta/241_mammals.shuffled.fa'

input_params.output_dir = './test/'
input_params.tot_epochs = 30
input_params.val_fraction = 0.02
input_params.validate_every = 16
input_params.save_at = ['-1']

input_params.save_at = misc.list2range(input_params.save_at)

In [8]:
seq_df = pd.read_csv(input_params.fasta + '.fai', header=None, sep='\t', usecols=[0,1], names=['seq_name','seq_len'])

#seq_df = seq_df.iloc[:3000]

seq_df['species_name'] = seq_df.seq_name.apply(lambda x:x.split(':')).apply(lambda x:x[1] if len(x)==2 else 'Homo_sapiens')

all_species = sorted(seq_df.species_name.unique())

if not input_params.species_agnostic:
    species_encoding = {species:idx for idx,species in enumerate(all_species)}
else:
    species_encoding = {species:0 for species in all_species}
    
seq_df['species_label'] = seq_df.species_name.map(species_encoding)

#seq_df = seq_df.sample(frac = 1., random_state = 1) #DO NOT SHUFFLE, otherwise too slow

In [9]:
tokenizer = DNATokenizer(vocab_file='./DNABERT/src/transformers/dnabert-config/bert-config-6/vocab.txt',
                        max_len=input_params.max_tokens)

In [10]:
#k_merized_chunk = kmers_stride1('AAAAAA')
#tokenized_chunk = tokenizer.encode_plus(k_merized_chunk,add_special_tokens=1)
#tokenized_chunk

In [11]:
config = BertConfig.from_pretrained('./DNABERT/src/transformers/dnabert-config/bert-config-6/config.json')

model = BertForMaskedLM(config).to(device)

In [12]:
#collate_fn = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=True, mlm_probability=0.15)

In [13]:
def collate_fn(batch):
    examples, seq_idx = zip(*batch)
    seq_padded = pad_sequence(examples, batch_first=True, padding_value=tokenizer.pad_token_id)
    return misc.mask_tokens(seq_padded, tokenizer, input_params.mlm_probability), seq_idx

In [14]:
N_train = int(len(seq_df)*(1-input_params.val_fraction))       
train_df, test_df = seq_df.iloc[:N_train], seq_df.iloc[N_train:]

train_chunk = np.repeat(list(range(input_params.train_chunks)),repeats = N_train // input_params.train_chunks + 1 )
train_df['train_chunk'] = train_chunk[:N_train]

train_dataset = SeqDataset(input_params.fasta, train_df)
train_dataloader = DataLoader(dataset = train_dataset, batch_size = input_params.batch_size, 
                              num_workers = 1, worker_init_fn=misc.worker_init_fn, collate_fn = collate_fn, shuffle = False)

test_dataset = SeqDataset(input_params.fasta, test_df)
test_dataloader = DataLoader(dataset = test_dataset, batch_size = input_params.batch_size, 
                             num_workers = 1,  worker_init_fn=misc.worker_init_fn, collate_fn = collate_fn, shuffle = False)

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  train_df['train_chunk'] = train_chunk[:N_train]


In [15]:
gc.collect()
torch.cuda.empty_cache()

In [16]:
#pretrained_dir = '/lustre/groups/epigenereg01/workspace/projects/vale/MLM/dnabert/default/6-new-12w-0/'
#model = BertForMaskedLM.from_pretrained(pretrained_dir).to(device)
#
#val_metrics =  train_eval.model_eval(model, None, test_dataloader, device,
#                    silent = False)

In [17]:
#pretrained_dir = '/lustre/groups/epigenereg01/workspace/projects/vale/mlm/models/dnabert-3utr/scheduler-1/checkpoints/epoch_7/'
#model = BertForMaskedLM.from_pretrained(pretrained_dir).to(device)
#
#val_metrics =  train_eval.model_eval(model, None, test_dataloader, device,
#                    silent = False)

In [18]:
# Prepare optimizer and schedule (linear warmup and decay)
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
    {
            "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
            "weight_decay": input_params.weight_decay,
    },
        {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
    ]

optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=input_params.max_lr, eps=1e-6, betas=(0.9,0.98))

scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, base_lr = 0, max_lr = input_params.max_lr,
                                             step_size_up = input_params.step_size_up, step_size_down = input_params.step_size_down, cycle_momentum=False)

In [19]:
last_epoch = 0

if input_params.checkpoint_dir:

    model = BertForMaskedLM.from_pretrained(input_params.checkpoint_dir).to(device)
    
    if os.path.isfile(input_params.checkpoint_dir + '/opimizer.pt'):
            optimizer.load_state_dict(torch.load(input_params.checkpoint_dir + '/opimizer.pt'))
            scheduler.load_state_dict(torch.load(input_params.checkpoint_dir + '/scheduler.pt'))

weights_dir = os.path.join(input_params.output_dir, 'checkpoints') #dir to save model weights at save_at epochs

if input_params.save_at:
    os.makedirs(weights_dir, exist_ok = True)

In [20]:
#val_metrics =  train_eval.model_eval(model, optimizer, test_dataloader, device,
#                    silent = False)

In [21]:
def metrics_to_str(metrics):
    loss, total_acc, masked_acc = metrics
    return f'loss: {loss:.4}, total acc: {total_acc:.3f}, masked acc: {masked_acc:.3f}'

In [None]:
from IPython.display import clear_output

clear_output()

#from utils.misc import print    #print function that displays time

if not input_params.test:

    for epoch in range(last_epoch+1, input_params.tot_epochs+1):

        print(f'EPOCH {epoch}: Training...')

        train_dataset.seq_df = train_df[train_df.train_chunk == (epoch-1) % input_params.train_chunks]
        train_dataset.end = len(train_dataset.seq_df)
        
        print(f'using train samples: {list(train_dataset.seq_df.index[[0,-1]])}')
        
        train_metrics = train_eval.model_train(model, optimizer, train_dataloader, device, scheduler=scheduler,
                            silent = False)
        
        print(f'epoch {epoch} - train ({scheduler.last_epoch+1} iterations), {metrics_to_str(train_metrics)}')

        if epoch in input_params.save_at or -1 in input_params.save_at: #save model weights

            misc.save_model_weights(model, optimizer, scheduler, weights_dir, epoch, input_params.save_at)

        if input_params.val_fraction>0 and ( epoch==input_params.tot_epochs or
                            (input_params.validate_every and epoch%input_params.validate_every==0)):

            print(f'EPOCH {epoch}: Validating...')

            val_metrics =  train_eval.model_eval(model, optimizer, test_dataloader, device,
                    silent = False)

            print(f'epoch {epoch} - validation, {metrics_to_str(val_metrics)}')
            