# 1. Setup

In [1]:
from transformers import BertConfig, BertForSequenceClassification, DNATokenizer
from transformers import glue_convert_examples_to_features as convert_examples_to_features
import torch

# 2. Get DNABERT model

In [2]:
# FZZ: get model from MODEL_CLASSES
config_class, model_class, tokenizer_class = BertConfig, BertForSequenceClassification, DNATokenizer
label_list = ['0','1']

config = config_class.from_pretrained(
    '/mnt/ceph/users/zzhang/DNABERT/myExp/6-new-12w-0',
    num_labels=len(label_list),
    finetuning_task='dnaprom',
    cache_dir=None,
)

config.hidden_dropout_prob = 0.1
config.attention_probs_dropout_prob = 0.1
config.split = int(100/512)
config.rnn = 'lstm'
config.num_rnn_layer = 2
config.rnn_dropout = 0
config.rnn_hidden = 768
config.output_hidden_states = True # add here FZZ

tokenizer = tokenizer_class.from_pretrained(
    'dna6',
    do_lower_case=False,
    cache_dir=None,
)
model = model_class.from_pretrained(
    '/mnt/ceph/users/zzhang/DNABERT/myExp/6-new-12w-0',
    from_tf=False,
    config=config,
    cache_dir=None,
)

model.eval()

<class 'transformers.tokenization_dna.DNATokenizer'>


BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(4101, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementw

In [3]:
from transformers.data.processors.glue import DnaPromProcessor
# will look for "train.tsv" or "dev.tsv", and convert to `InputExample` instances;
# most needed are text_a, text_b
# we can write a new Processor class for DeepSEA
examples = DnaPromProcessor().get_dev_examples(data_dir='sample_data/ft/6/')


max_length = 100
pad_on_left = False
pad_token = tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0]
pad_token_segment_id = 0
output_mode = 'classification'

# from `InputExample` to `InputFeature`, of token_ids and attention masks
features = convert_examples_to_features(
            examples,
            tokenizer,
            label_list=label_list,
            max_length=max_length,
            output_mode=output_mode,
            pad_on_left=pad_on_left,  # pad on the left for xlnet
            pad_token=pad_token,
            pad_token_segment_id=pad_token_segment_id,)

In [4]:
# main workhorse for fine-tuning classification/regression tasks
# model.bert

In [7]:
examples[0]

{
  "guid": "dev-1",
  "label": "0",
  "text_a": "GTGGGG TGGGGA GGGGAG GGGAGG GGAGGG GAGGGA AGGGAG GGGAGG GGAGGC GAGGCC AGGCCG GGCCGC GCCGCC CCGCCC CGCCCC GCCCCA CCCCAC CCCACT CCACTG CACTGC ACTGCA CTGCAG TGCAGG GCAGGT CAGGTG AGGTGG GGTGGG GTGGGC TGGGCC GGGCCT GGCCTG GCCTGT CCTGTA CTGTAG TGTAGC GTAGCA TAGCAG AGCAGC GCAGCT CAGCTG AGCTGC GCTGCA CTGCAC TGCACC GCACCT CACCTG ACCTGA CCTGAG CTGAGG TGAGGC GAGGCA AGGCAG GGCAGG GCAGGG CAGGGC AGGGCT GGGCTG GGCTGG GCTGGC CTGGCA TGGCAG GGCAGC GCAGCC CAGCCC AGCCCC GCCCCT CCCCTG CCCTGT CCTGTG CTGTGG TGTGGG GTGGGG TGGGGA GGGGAG GGGAGG GGAGGG GAGGGA AGGGAG GGGAGG GGAGGC GAGGCC AGGCCG GGCCGC GCCGCC CCGCCC CGCCCC GCCCCA CCCCAC CCCACT CCACTG CACTGC ACTGCA CTGCAG TGCAGG GCAGGT CAGGTG",
  "text_b": null
}

In [5]:
inputs = features[0]
outputs=model.bert(input_ids=torch.tensor([inputs.input_ids]), attention_mask = torch.tensor([inputs.attention_mask]))

Annotation:
- The layer number (13 layers)
- The batch number (1 sentence)
- The word / token number (100 tokens in our sentence)
- The hidden unit / feature number (768 features)


In [6]:
print(len(outputs))
print(outputs[0].shape)
print(outputs[1].shape)
print(len(outputs[2]))
print(outputs[2][0].shape)

3
torch.Size([1, 100, 768])
torch.Size([1, 768])
13
torch.Size([1, 100, 768])
