# 1. Setup

In [1]:
from transformers import BertConfig, BertForSequenceClassification, DNATokenizer
from transformers import glue_convert_examples_to_features as convert_examples_to_features
import torch
import numpy as np
from tqdm import tqdm, trange

# 2. Get DNABERT model

In [2]:
# FZZ: get model from MODEL_CLASSES
config_class, model_class, tokenizer_class = BertConfig, BertForSequenceClassification, DNATokenizer
label_list = ['0','1']

config = config_class.from_pretrained(
    '/mnt/ceph/users/zzhang/DNABERT/myExp/6-new-12w-0',
    num_labels=len(label_list),
    finetuning_task='dnaprom',
    cache_dir=None,
)

config.hidden_dropout_prob = 0.1
config.attention_probs_dropout_prob = 0.1
config.split = int(100/512)
config.rnn = 'lstm'
config.num_rnn_layer = 2
config.rnn_dropout = 0
config.rnn_hidden = 768
config.output_hidden_states = True # add here FZZ

tokenizer = tokenizer_class.from_pretrained(
    'dna6',
    do_lower_case=False,
    cache_dir=None,
)
model = model_class.from_pretrained(
    '/mnt/ceph/users/zzhang/DNABERT/myExp/6-new-12w-0',
    from_tf=False,
    config=config,
    cache_dir=None,
)

# _ = model.eval()

<class 'transformers.tokenization_dna.DNATokenizer'>


Downloading:   0%|          | 0.00/8.10k [00:00<?, ?B/s]

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
print(device)

cuda


In [4]:
import h5py
import os

data_dir = "/mnt/ceph/users/zzhang/workspace_src/AMBER/examples/data/zero_shot_deepsea"
store = h5py.File(os.path.join(data_dir, "train.h5"), 'r')

In [5]:
d = store['x'][101]
store.close()
# double-check DeepSEA letter index; should have a 100% match in hg19
''.join(['AGCT'[x] for x in d.argmax(axis=1)])

'TGCTGCTTTTTCCCCTTAGCCCTGGGCGAGGTCATCATAGAGGGGGAGTGGCAATGGCTCACAAGGTACTAGTGGAACCCCAGTAAGTTATCTCAGAGCCCGCTTAGAACACAAGTGCTACGTCCCCCAAAAGCTTTGCAATGAGTATCTGATGGGAACAAACTCAGTCAAGGACAGGCCCAGGTTGGGGCTTGCAGGCTGCAGATTCAGAATTGTTTATGAGATGGGAGCCATACTTTCTAACAACAAGACCTGAATTTCTCAATTTAATCCAAGTCGTGACTTAAGTTAGCGCCCTTCCGTTCCTCTATTACATTTCTGTTCGGCATGGATCAAATTGCCTACAAGGTGGAACAGATTTCAACTGCAATCTCTGAACCAGAAAATTCACTTATTCTCATGAAAGTTTGTAATCTTTGGAGAGTTGCTTAAACACTTAAAACCATCTTTCCTCTTTCTATACTCCAAACTTACCTGCTGCAATTTCTTGCTAAGAAGCAAAGTGCTATTTGCCTATTCCTATCTCTCTTTACCATCAGACACTCCTTAAGTTAAGAGCTAGATAATTCGCTCAGCCTCAGGCCAGGCCGAGCCTCACTCTAGAAGTCACATTCCTGAGGTGTAGGGGGTCAAAATGCCTCTCATTGTTCAGAAGCAGGTGAGGGGCCAGCCAGGGCACATCCTGCTCTCCAGGCTTGGTTCAGATAACTGTCAGCCCAGTTTTCAAGAGCACACACCAAAAATGCACCAAAGCTTACATCCATACAAACACCCGCACATGGATGTTTATGGAAGCTTATTTGTTTTTATTCATAATCACCCAAACTCAGAATCAACCAAGATGTCCTTCAGTAGATGAATGGATAAACTGTGGTGTGTCCAGGCAGTGGAATATTATTCAACGCAAAAAGAAATGAGCTATCAAGGCATGAAAAAATATGGGGGAACTTTAAATGCATAAATGAGTGAAAGAAGCCAGTCTGAAAAGGCTACACCCCG

In [6]:
from transformers.data.processors.utils import DataProcessor, InputExample, InputFeatures


class DeepSEA919Processor(DataProcessor):
    """Processor for the 2015 DeepSEA 919 multi-task data"""

    def get_labels(self):
        return ["0", "1"] 

    def get_train_examples(self, data_dir):
        print("LOOKING AT {}".format(os.path.join(data_dir, "train.h5")))
        return self._create_examples(self._read_h5(os.path.join(data_dir, "train.h5")), "train")

    def get_dev_examples(self, data_dir):
        return self._create_examples(self._read_h5(os.path.join(data_dir, "val.h5")), "dev")
    
    def _read_h5(self, fp):
        with h5py.File(fp, 'r') as store:
            return zip(*[[self._matrix_to_seq(x, letteridx='AGCT') for x in store['x'][()]], store['y'][()]])
    
    @staticmethod
    def _matrix_to_seq(d, letteridx='ACGT'):
        MAX_LEN = 512
        s = ''.join([letteridx[x] for x in d.argmax(axis=1)])
        slen = len(s)
        if slen > MAX_LEN:
            s = s[(slen//2-MAX_LEN//2) : (slen//2+MAX_LEN//2)]
        return s
            

    def _create_examples(self, lines, set_type):
        """Creates examples for the training and dev sets."""
        examples = []
        TOKEN_SIZE = 6
        for (i, line) in tqdm(enumerate(lines)):
            guid = "%s-%s" % (set_type, i)
            text_a = ' '.join([line[0][i:i+TOKEN_SIZE] for i in range(len(line[0])-TOKEN_SIZE)])
            label = str(line[1][0])
            examples.append(InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
        return examples


In [7]:
processor = DeepSEA919Processor()
examples = processor.get_train_examples(data_dir)
print(len(examples))

LOOKING AT /mnt/ceph/users/zzhang/workspace_src/AMBER/examples/data/zero_shot_deepsea/train.h5


4400000it [05:35, 13108.36it/s]

4400000





In [8]:
max_length = 512
pad_on_left = False
pad_token = tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0]
pad_token_segment_id = 0
output_mode = 'classification'

# from `InputExample` to `InputFeature`, add token_ids and attention masks
features = convert_examples_to_features(
            examples,
            tokenizer,
            label_list=label_list,
            max_length=max_length,
            output_mode=output_mode,
            pad_on_left=pad_on_left,  # pad on the left for xlnet
            pad_token=pad_token,
            pad_token_segment_id=pad_token_segment_id,)

In [9]:
all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
all_attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long)
all_token_type_ids = torch.tensor([f.token_type_ids for f in features], dtype=torch.long)

In [10]:
dataset = torch.utils.data.TensorDataset(all_input_ids, all_attention_mask, all_token_type_ids)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=16)

Output = (loss, logits, hidden_states)

hidden_states = tuple of 13, each is ordered as below:

Annotation:
- The layer number (13 layers)
- The batch number (1 sentence)
- The word / token number (100 tokens in our sentence)
- The hidden unit / feature number (768 features)

See also: https://mccormickml.com/2019/05/14/BERT-word-embeddings-tutorial/


In [11]:
def write_out(f, data):
    for d in data:
        f.write(','.join(["%.5f"%x for x in d]) + "\n")

In [12]:
data_iterator = tqdm(dataloader)

with open('deepsea_embedding.csv', 'w') as f:
    for batch in data_iterator:
        batch = tuple(t.to(device) for t in batch)
        token_emb = torch.stack(
            model.bert(input_ids=batch[0], attention_mask=batch[1])[2], 
            dim=1) # batchsize, 13, 512, 768
        token_emb = torch.mean(token_emb[:, -4:, :, :], dim=1) # last 4 embedding layers --> batch, 512, 768
        seq_emb = torch.mean(token_emb, dim=1)  # batch, 768
        seq_emb = seq_emb.cpu().detach().numpy()
        write_out(f, seq_emb)

100%|██████████| 275000/275000 [13:09:47<00:00,  5.80it/s]  


In [13]:
#outputs = []
#for i in tqdm(range(len(features))):
#    inputs = features[i]
#    outputs.append(model.bert(input_ids=torch.tensor([inputs.input_ids]), attention_mask = torch.tensor([inputs.attention_mask])))

Reload and check if the order still aligns: