# BioGPT for causal inference

In [1]:
from transformers import pipeline, set_seed
from transformers import BioGptTokenizer, BioGptForCausalLM
import torch

model = BioGptForCausalLM.from_pretrained("microsoft/biogpt")
tokenizer = BioGptTokenizer.from_pretrained("microsoft/biogpt")

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
generator = pipeline('text-generation', model=model, tokenizer=tokenizer)
set_seed(42)

In [3]:
generator("COVID-19 is", max_length=20, num_return_sequences=5, do_sample=True)

[{'generated_text': 'COVID-19 is a disease that spreads worldwide and is currently found in a growing proportion of the population'},
 {'generated_text': 'COVID-19 is one of the largest viral epidemics in the world.'},
 {'generated_text': 'COVID-19 is a common condition affecting an estimated 1.1 million people in the United States alone.'},
 {'generated_text': 'COVID-19 is a pandemic, the incidence has been increased in a manner similar to that in other'},
 {'generated_text': 'COVID-19 is transmitted via droplets, air-borne, or airborne transmission.'}]

# play around

In [3]:
with open("data/KD-DTI/raw/relis_test.x") as f:
    test_data = f.readlines()

test_line = test_data[0].lower()

In [26]:
# test_line_token = tokenizer.encode(test_line, return_tensors='pt')
a = 'COVID-19 is'
test_line_token = tokenizer.encode(a, add_special_tokens=False, return_tensors="pt")
# test_line_token = torch.cat((test_line_token, torch.tensor([[2]])), dim=1)

In [27]:
test_line_token.shape

torch.Size([1, 4])

In [22]:
test_line_token[0][-1]

tensor(4)

In [28]:
out = model(test_line_token)

In [29]:
out.logits.shape

torch.Size([1, 4, 42384])

In [40]:
test_line_token.shape

torch.Size([1, 4])

In [53]:
# concatenate the last token of the torch.argmax(out.logits, dim=2)[0] with the test_line_token[0]
input2 = torch.cat((test_line_token[0], torch.tensor([torch.argmax(out.logits, dim=2)[0][-1]])), dim=0)

In [55]:
input2

tensor([4805,    9,  656,   21,   14])

In [17]:
out2 = model(input2)

In [59]:
tokenizer.convert_ids_to_tokens(int(torch.argmax(out2.logits, dim=2)[0][-1]))

'</s>'

In [53]:
# out.past_key_values: (num_layers, 2, batch_size, num_heads, sequence_length, embed_size_per_head)
len(out.past_key_values[0][0][0][0][0])

64

In [4]:
text = "can not you see the star."
encoded_input = tokenizer(text, return_tensors='pt')
output = model(**encoded_input)

In [5]:
output.logits.shape

torch.Size([1, 8, 42384])

In [6]:
encoded_input['input_ids'].shape

torch.Size([1, 8])

In [7]:
tokenizer.convert_ids_to_tokens(encoded_input['input_ids'][-1].tolist())

['</s>',
 'can</w>',
 'not</w>',
 'you</w>',
 'see</w>',
 'the</w>',
 'star</w>',
 '.</w>']

In [8]:
output_token_ids = torch.argmax(output.logits[0, :, :], 1)

tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(output_token_ids.tolist()))

'The be be get the difference. </s>'

In [9]:
sentence = "COVID-19 is"
inputs = tokenizer(sentence, return_tensors="pt")

set_seed(42)

with torch.no_grad():
    beam_output = model.generate(**inputs,
                                 min_length=100,
                                 max_length=1024,
                                 num_beams=5,
                                 early_stopping=True
                                )
tokenizer.decode(beam_output[0], skip_special_tokens=True)

'COVID-19 is a global pandemic caused by severe acute respiratory syndrome coronavirus 2 (SARS-CoV-2), the causative agent of coronavirus disease 2019 (COVID-19), which has spread to more than 200 countries and territories, including the United States (US), Canada, Australia, New Zealand, the United Kingdom (UK), and the United States of America (USA), as of March 11, 2020, with more than 800,000 confirmed cases and more than 800,000 deaths.'

# The structure of the model

In [10]:
model

BioGptForCausalLM(
  (biogpt): BioGptModel(
    (embed_tokens): Embedding(42384, 1024, padding_idx=1)
    (embed_positions): BioGptLearnedPositionalEmbedding(1026, 1024)
    (layers): ModuleList(
      (0): BioGptDecoderLayer(
        (self_attn): BioGptAttention(
          (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
          (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
          (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
          (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
        )
        (activation_fn): GELUActivation()
        (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        (fc1): Linear(in_features=1024, out_features=4096, bias=True)
        (fc2): Linear(in_features=4096, out_features=1024, bias=True)
        (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
      )
      (1): BioGptDecoderLayer(
        (self_attn): BioGptAtten

In [11]:
import torch
from fairseq.models.transformer_lm import TransformerLanguageModel
m = TransformerLanguageModel.from_pretrained(
        "checkpoints/Pre-trained-BioGPT", 
        "checkpoint.pt", 
        "data",
        tokenizer='moses', 
        bpe='fastbpe', 
        bpe_codes="data/bpecodes",
        min_len=100,
        max_len_b=1024)
m.cuda()
src_tokens = m.encode("COVID-19 is")
generate = m.generate([src_tokens], beam=5)[0]
output = m.decode(generate[0]["tokens"])
print(output)

2023-03-23 03:23:40 | INFO | fairseq.file_utils | loading archive file checkpoints/Pre-trained-BioGPT
2023-03-23 03:23:40 | INFO | fairseq.file_utils | loading archive file data
2023-03-23 03:23:42 | INFO | fairseq.tasks.language_modeling | dictionary: 42384 types
2023-03-23 03:23:45 | INFO | fairseq.models.fairseq_model | {'_name': None, 'common': {'_name': None, 'no_progress_bar': False, 'log_interval': 100, 'log_format': None, 'tensorboard_logdir': None, 'wandb_project': None, 'azureml_logging': False, 'seed': 1, 'cpu': False, 'tpu': False, 'bf16': False, 'memory_efficient_bf16': False, 'fp16': True, 'memory_efficient_fp16': False, 'fp16_no_flatten_grads': False, 'fp16_init_scale': 128, 'fp16_scale_window': None, 'fp16_scale_tolerance': 0.0, 'min_loss_scale': 0.0001, 'threshold_loss_scale': None, 'user_dir': None, 'empty_cache_freq': 0, 'all_gather_list_size': 16384, 'model_parallel_size': 1, 'quantization_config_path': None, 'profile': False, 'reset_logging': False, 'suppress_crash

COVID-19 is a global pandemic caused by severe acute respiratory syndrome coronavirus 2 (SARS-CoV-2), the causative agent of coronavirus disease 2019 (COVID-19), which has spread to more than 200 countries and territories, including the United States (US), Canada, Australia, New Zealand, the United Kingdom (UK), and the United States of America (USA), as of March 11, 2020, with more than 800,000 confirmed cases and more than 800,000 deaths.


# KD-DTI

In [4]:
import torch
from src.transformer_lm_prompt import TransformerLanguageModelPrompt
m = TransformerLanguageModelPrompt.from_pretrained(
        "checkpoints/RE-DTI-BioGPT", 
        "checkpoint_avg.pt", 
        "data/KD-DTI/relis-bin",
        tokenizer='moses', 
        bpe='fastbpe', 
        bpe_codes="data/bpecodes",
        max_len_b=1024,
        beam=1)
m.cuda()
src_text="" # input text, e.g., a PubMed abstract
src_tokens = m.encode(src_text)
generate = m.generate([src_tokens], beam=5)[0]
output = m.decode(generate[0]["tokens"])
print(output)

2023-04-01 02:57:54 | INFO | fairseq.file_utils | loading archive file checkpoints/RE-DTI-BioGPT
2023-04-01 02:57:54 | INFO | fairseq.file_utils | loading archive file data/KD-DTI/relis-bin
2023-04-01 02:58:20 | INFO | src.language_modeling_prompt | dictionary: 42384 types
2023-04-01 02:58:23 | INFO | fairseq.models.fairseq_model | {'_name': None, 'common': {'_name': None, 'no_progress_bar': False, 'log_interval': 100, 'log_format': None, 'tensorboard_logdir': None, 'wandb_project': None, 'azureml_logging': False, 'seed': 1, 'cpu': False, 'tpu': False, 'bf16': False, 'memory_efficient_bf16': False, 'fp16': False, 'memory_efficient_fp16': False, 'fp16_no_flatten_grads': False, 'fp16_init_scale': 128, 'fp16_scale_window': None, 'fp16_scale_tolerance': 0.0, 'min_loss_scale': 0.0001, 'threshold_loss_scale': None, 'user_dir': '../../src', 'empty_cache_freq': 0, 'all_gather_list_size': 16384, 'model_parallel_size': 1, 'quantization_config_path': None, 'profile': False, 'reset_logging': False

learned1 learned2 learned3 learned4 learned5 learned6 learned7 learned8 learned9 the interaction between cetirizine and histamine h1 receptor is antagonist; the interaction between diphenhydramine and histamine h1 receptor is antagonist.


The forward() func.

In [15]:
import inspect

print(inspect.getsource(model.forward))

    @add_start_docstrings_to_model_forward(BIOGPT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
    @add_code_sample_docstrings(
        checkpoint=_CHECKPOINT_FOR_DOC,
        output_type=CausalLMOutputWithCrossAttentions,
        config_class=_CONFIG_FOR_DOC,
    )
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
        r"""
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *opti

# The number of the parameters

In [14]:
total_params = sum(p.numel() for p in model.parameters())
total_params

346763264

# The size of the model

In [12]:
def get_model_size(model):
    with torch.no_grad():
        total_params = sum(p.numel() for p in model.parameters())
        total_size = total_params * 4 # assuming float32
        for name, buf in model.named_buffers():
            total_size += buf.numel() * 4
        for name, buf in model.named_buffers():
            total_size += buf.numel() * 4
        for name, param in model.named_parameters():
            total_size += param.numel() * 4
    return total_size / (1024 ** 2)

In [13]:
get_model_size(model)

2645.59375

# Pre-process DTI dataset

In [1]:
# open the train, valid, and test data in data/KD-DTI/raw
import json
train_data = None
valid_data = None
test_data = None
for data in ["train", "valid", "test"]:
    with open("data/KD-DTI/raw/{}.json".format(data), "r") as f:
        if data == 'train':
            train_data = json.load(f)
            with open(f"data/KD-DTI/raw/original_{data}.json", "w") as new:
                json.dump(train_data, new)
        elif data == 'valid':
            valid_data = json.load(f)
            with open(f"data/KD-DTI/raw/original_{data}.json", "w") as new:
                json.dump(valid_data, new)
        else:
            test_data = json.load(f)
            with open(f"data/KD-DTI/raw/original_{data}.json", "w") as new:
                json.dump(test_data, new)
        


In [2]:
# open the csv file /home/tian/Projects/MyReaserch/DBid_to_names/drugbank vocabulary.csv, then read them into a dictionary. Ignore the first row, and take the colunm 0 as keys, 
# colum 2 as values.
import csv
with open('/home/tian/Projects/MyReaserch/DBid_to_names/drugbank vocabulary.csv', newline='') as csvfile:
    reader = csv.reader(csvfile, delimiter=',')
    next(reader)
    drugbank_id_to_name = {row[0]: row[2] for row in reader}

drugbank_id_to_name

{'DB00001': 'Lepirudin',
 'DB00002': 'Cetuximab',
 'DB00003': 'Dornase alfa',
 'DB00004': 'Denileukin diftitox',
 'DB00005': 'Etanercept',
 'DB00006': 'Bivalirudin',
 'DB00007': 'Leuprolide',
 'DB00008': 'Peginterferon alfa-2a',
 'DB00009': 'Alteplase',
 'DB00010': 'Sermorelin',
 'DB00011': 'Interferon alfa-n1',
 'DB00012': 'Darbepoetin alfa',
 'DB00013': 'Urokinase',
 'DB00014': 'Goserelin',
 'DB00015': 'Reteplase',
 'DB00016': 'Erythropoietin',
 'DB00017': 'Salmon calcitonin',
 'DB00018': 'Interferon alfa-n3',
 'DB00019': 'Pegfilgrastim',
 'DB00020': 'Sargramostim',
 'DB00022': 'Peginterferon alfa-2b',
 'DB00023': 'Asparaginase Escherichia coli',
 'DB00024': 'Thyrotropin alfa',
 'DB00025': 'Antihemophilic factor, human recombinant',
 'DB00026': 'Anakinra',
 'DB00027': 'Gramicidin D',
 'DB00028': 'Human immunoglobulin G',
 'DB00029': 'Anistreplase',
 'DB00030': 'Insulin human',
 'DB00031': 'Tenecteplase',
 'DB00032': 'Menotropins',
 'DB00033': 'Interferon gamma-1b',
 'DB00034': 'Inter

In [3]:
# open the json file /home/tian/Projects/MyReaserch/DBid_to_names/target vocabulary.json, make a dictionary with the key as the target id, and the value as the target name.
import json
with open('/home/tian/Projects/MyReaserch/DBid_to_names/target vocabulary.json') as f:
    target_id_to_name = json.load(f)

print(len(target_id_to_name))
# If the value of target_id_to_name is "This page doesn't exist. What a pain.", then ignore that line.

target_id_to_name = {k: v for k, v in target_id_to_name.items() if v != "This page doesn't exist. What a pain."}
print(f'After filtering: {len(target_id_to_name)}')



10215
After filtering: 5838


In [4]:
# for train_data, valid_data, and test_data, replace the drugbank id with the drug name, and the target id with the target name.
# After replacing, save the data into a new json file, respectively.
for data in [train_data, valid_data, test_data]:
    for item in data:
        for triple in data[item]['triples']:
            for drugs_and_targets in triple:
                if drugs_and_targets not in ['drug', 'target']:
                    continue
                else:
                    # print(triple[drugs_and_targets])
                    if not type(triple[drugs_and_targets]) is str:
                        raise TypeError(data[item])
                    if '####' in triple[drugs_and_targets]:
                        id = triple[drugs_and_targets].split('####')[1]
                        print(id)
                        if drugs_and_targets == 'drug':
                            triple[drugs_and_targets] = drugbank_id_to_name[id]
                        else:
                            triple[drugs_and_targets] = target_id_to_name[id]
    # if id is not None:
    #         break    
    
with open("data/KD-DTI/raw/train.json", "w") as f:
    json.dump(train_data, f)
with open("data/KD-DTI/raw/valid.json", "w") as f:
    json.dump(valid_data, f)
with open("data/KD-DTI/raw/test.json", "w") as f:
    json.dump(test_data, f)


DB01090
BE0003584
DB01090
BE0000411
DB01090
BE0003585
DB00081
BE0000066
DB00640
BE0000354
DB00640
BE0000241
DB00640
BE0000924
DB04908
BE0000451
DB04908
BE0000389
DB04908
BE0000291
DB04908
BE0000389
DB00115
BE0000777
DB00930
BE0004809
DB14089
BE0004813
DB00939
BE0000017
DB00991
BE0000017
DB00991
BE0000262
DB15495
BE0009786
DB15495
BE0009007
DB15496
BE0009786
DB15496
BE0009007
DB00115
BE0000320
DB00212
BE0000270
DB00857
BE0004670
DB00315
BE0000797
DB00311
BE0000267
DB00712
BE0000017
DB00712
BE0000262
DB06791
BE0003528
DB06791
BE0002147
DB00611
BE0000632
DB01207
BE0000759
DB01016
BE0000119
DB00997
BE0000742
DB00773
BE0002425
DB00315
BE0000797
DB00315
BE0000659
DB00315
BE0000460
DB00315
BE0000476
DB00315
BE0000291
DB00752
BE0002196
DB00752
BE0002198
DB00635
BE0000794
DB00688
BE0004520
DB01069
BE0000442
DB00427
BE0000442
DB00568
BE0000442
DB06372
BE0001051
DB06372
BE0004734
DB00115
BE0002176
DB00594
BE0000895
DB01262
BE0000892
DB00677
BE0002180
DB11726
BE0002341
DB00363
BE0000146
DB01427
BE

# Try DTI inference manually

In [1]:
import torch

model = torch.load("checkpoints/RE-DTI-BioGPT/checkpoint_avg.pt")

  from .autonotebook import tqdm as notebook_tqdm


In [15]:
model.keys()

dict_keys(['args', 'cfg', 'model', 'criterion', 'optimizer_history', 'task_state', 'extra_state', 'last_optimizer_state'])

# manually KD-DTI

In [1]:
with open("data/KD-DTI/raw/relis_test.x") as f:
    test_data = f.readlines()

# test_data = [line.strip().split(" ") for line in test_data]


In [2]:
with open("data/KD-DTI/raw/relis_test.y") as f:
    test_data_y = f.readlines()

In [3]:
def find_the_max_length(data):
    max_length = 0
    whichone = 0
    for i, line in enumerate(data):
        if len(line) > max_length:
            max_length = len(line)
            whichone = i
    return max_length, whichone

find_the_max_length(test_data)

(9839, 1158)

In [4]:
import torch
from src.transformer_lm_prompt import TransformerLanguageModelPrompt
m = TransformerLanguageModelPrompt.from_pretrained(
        "checkpoints/RE-DTI-BioGPT", 
        "checkpoint_avg.pt", 
        "data/KD-DTI/relis-bin",
        tokenizer='moses', 
        bpe='fastbpe', 
        bpe_codes="data/bpecodes",
        max_len_b=1024,
        beam=5)
m.cuda()

  from .autonotebook import tqdm as notebook_tqdm
2023-04-07 21:20:11 | INFO | fairseq.file_utils | loading archive file checkpoints/RE-DTI-BioGPT
2023-04-07 21:20:11 | INFO | fairseq.file_utils | loading archive file data/KD-DTI/relis-bin
2023-04-07 21:20:13 | INFO | src.language_modeling_prompt | dictionary: 42384 types
2023-04-07 21:20:16 | INFO | fairseq.models.fairseq_model | {'_name': None, 'common': {'_name': None, 'no_progress_bar': False, 'log_interval': 100, 'log_format': None, 'tensorboard_logdir': None, 'wandb_project': None, 'azureml_logging': False, 'seed': 1, 'cpu': False, 'tpu': False, 'bf16': False, 'memory_efficient_bf16': False, 'fp16': False, 'memory_efficient_fp16': False, 'fp16_no_flatten_grads': False, 'fp16_init_scale': 128, 'fp16_scale_window': None, 'fp16_scale_tolerance': 0.0, 'min_loss_scale': 0.0001, 'threshold_loss_scale': None, 'user_dir': '../../src', 'empty_cache_freq': 0, 'all_gather_list_size': 16384, 'model_parallel_size': 1, 'quantization_config_pat

GeneratorHubInterface(
  (models): ModuleList(
    (0): TransformerLanguageModelPrompt(
      (decoder): TransformerDecoder(
        (dropout_module): FairseqDropout()
        (embed_tokens): Embedding(42393, 1024, padding_idx=1)
        (embed_positions): LearnedPositionalEmbedding(1026, 1024, padding_idx=1)
        (layers): ModuleList(
          (0): TransformerDecoderLayerBase(
            (dropout_module): FairseqDropout()
            (self_attn): MultiheadAttention(
              (dropout_module): FairseqDropout()
              (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
              (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
              (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
              (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
            )
            (activation_dropout_module): FairseqDropout()
            (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine

In [145]:
item = 20
test_x = test_data[item]
test_y = test_data_y[item]
src_text = test_x
src_tokens = m.encode(src_text)
"""
    def encode(self, sentence: str) -> torch.LongTensor:
        sentence = self.tokenize(sentence)
        sentence = self.apply_bpe(sentence)
        return self.binarize(sentence)"""
generate = m.generate([src_tokens], beam=5)[0]
output = m.decode(generate[0]["tokens"])
print(output)

inhibitory effects of a series of 7-substituted-indazoles toward nitric oxide synthases: particular potency of 1h-indazole-7-carbonitrile. a series of new 7-monosubstituted and 3,7-disubstituted indazoles have been prepared and evaluated as inhibitors of nitric oxide synthases (nos). 1h-indazole-7-carbonitrile (6) was found equipotent to 7-nitro-1h-indazole (1) and demonstrated preference for constitutive nos over inducible nos. by contrast, 1h-indazole-7-carboxamide (8) was slightly less potent but demonstrated a surprising selectivity for the neuronal nos. further substitution of 6 by a br-atom at carbon-3 of the heterocycle enhanced 10-fold the inhibitory effects. inhibition of no formation by 6 appeared to be competitive versus both substrate and the cofactor (6r) -5,6,7,8-tetrahydro-l-biopterin (h (4) b). in close analogies with 1, compound 6 strongly inhibited the nadph oxidase activity of nnos and induced a spin state transition of the heme-fe (iii). our results are explained wi

In [10]:
print(output.split("learned9 ", 1)[1])
print(test_y)

the interaction between indazole-7-carbonitrile and nitric-oxide synthase brain (nos1) is inhibitor; the interaction between indazole-7-carbonitrile and nitric-oxide synthase endothelial (nos3) is inhibitor; the interaction between indazole-7-carbonitrile and nitric-oxide synthase inducible (nos2) is inhibitor.
the interaction between 3-bromo-1h-indazole-7-carbonitrile and nitric-oxide synthase brain (nos1) is inhibitor.



In [8]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters())

count_parameters(m)

346772480

In [171]:
test_data_original = "Inhibition of rat brain monoamine oxidase activities by psoralen and isopsoralen: implications for the treatment of affective disorders. Psoralen and isopsoralen, furocoumarins isolated from the plant Psoralea corylifolia L., were demonstrated to exhibit in vitro inhibitory actions on monoamine oxidase (MAO) activities in rat brain mitochondria, preferentially inhibiting MAO-A activity over MAO-B activity. This inhibition of enzyme activities was found to be dose-dependent and reversible. For MAO-A, the IC50 values are 15.2 +/- 1.3 microM psoralen and 9.0 +/- 0.6 microM isopsoralen. For MAO-B, the IC50 values are 61.8 +/- 4.3 microM psoralen and 12.8 +/- 0.5 microM isopsoralen. Lineweaver-Burk transformation of the inhibition data indicates that inhibition by both psoralen and isopsoralen is non-competitive for MAO-A. The Ki values were calculated to be 14.0 microM for psoralen and 6.5 microM for isopsoralen. On the other hand, inhibition by both psoralen and isopsoralen is competitive for MAO-B. The Ki values were calculated to be 58.1 microM for psoralen and 10.8 microM for isopsoralen. These inhibitory actions of psoralen and isopsoralen on rat brain mitochondrial MAO activities are discussed in relation to their toxicities and their potential applications to treat affective disorders."
test_data_original = test_data_original.lower()
test_data_tokens = m.encode(test_data_original)
test_data_tokens

tensor([  468,     5,   366,   251, 12156,  3183,   619,    23, 34292, 25183,
            8, 15443, 10573, 25183,    20,  1468,    16,     6,    53,     5,
         6042,  1418,  1170,  5207, 17277, 34292, 25183,     8, 15443, 10573,
        25183,     7, 35045, 26361,  1578,   451,    29,     6,  1041, 34292,
          846,  5263,  4032,  1511, 23484,   688,     4,     7,    19,   301,
           13,  2171,    10,   307,  1166,  2247,    25, 12156,  3183,    12,
         3733,   399,    11,   619,    10,   366,   251,  2754,     7,  5899,
         3313,  3733,   399,     9,    14,    79,   222,  3733,   399,     9,
          787, 14960,  1432,  1066,     4,    35,   468,     5,   439,   619,
           17,    95,    13,    33,   228,     9,   267,     8, 14446,  1618,
        12024,     4,    16,  3733,   399,     9,    14,     7,     6,  1356,
          297,   358,    31, 17239,    51,    26,    81,  2865,  6831, 34292,
        25183,     8,  7160,    51,    26,    81,  3010,  6831, 

get the specific layer outputs

In [12]:
m.state_dict().keys()
m.models[0].decoder

odict_keys(['_float_tensor', 'models.0.decoder.version', 'models.0.decoder.embed_tokens.weight', 'models.0.decoder.embed_positions.weight', 'models.0.decoder.layers.0.self_attn.k_proj.weight', 'models.0.decoder.layers.0.self_attn.k_proj.bias', 'models.0.decoder.layers.0.self_attn.v_proj.weight', 'models.0.decoder.layers.0.self_attn.v_proj.bias', 'models.0.decoder.layers.0.self_attn.q_proj.weight', 'models.0.decoder.layers.0.self_attn.q_proj.bias', 'models.0.decoder.layers.0.self_attn.out_proj.weight', 'models.0.decoder.layers.0.self_attn.out_proj.bias', 'models.0.decoder.layers.0.self_attn_layer_norm.weight', 'models.0.decoder.layers.0.self_attn_layer_norm.bias', 'models.0.decoder.layers.0.fc1.weight', 'models.0.decoder.layers.0.fc1.bias', 'models.0.decoder.layers.0.fc2.weight', 'models.0.decoder.layers.0.fc2.bias', 'models.0.decoder.layers.0.final_layer_norm.weight', 'models.0.decoder.layers.0.final_layer_norm.bias', 'models.0.decoder.layers.1.self_attn.k_proj.weight', 'models.0.decod

In [172]:
prefix = torch.arange(42384, 42393)
test_data_tokens_prefix = torch.cat((test_data_tokens, prefix), dim=-1)
with torch.no_grad():
    out = m.models[0].decoder(test_data_tokens_prefix.unsqueeze(0))
# out: length = 2, 0 is the output sentence, 1 is the outputs of every layer
# out[1]:{'attn': None, 'inner_states': (length = 25)}
# out[1]['inner_states']: 25x 304x 1x 1024

In [179]:
torch.argmax(out[0][0][-1])

tensor(6)

In [204]:
# aa is the output of the last hidden layer
aa = m.models[0].decoder.output_projection(m.models[0].decoder.layer_norm(out[1]['inner_states'][-1]))
aa = aa.reshape(1, 304, -1)
aa == out[0]

tensor([[[True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         ...,
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True]]])

In [181]:
import inspect

print(inspect.getsource(m.models[0].decoder.forward))

    def forward(
        self,
        prev_output_tokens,
        encoder_out: Optional[Dict[str, List[Tensor]]] = None,
        incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
        features_only: bool = False,
        full_context_alignment: bool = False,
        alignment_layer: Optional[int] = None,
        alignment_heads: Optional[int] = None,
        src_lengths: Optional[Any] = None,
        return_all_hiddens: bool = False,
    ):
        """
        Args:
            prev_output_tokens (LongTensor): previous decoder outputs of shape
                `(batch, tgt_len)`, for teacher forcing
            encoder_out (optional): output from the encoder, used for
                encoder-side attention, should be of size T x B x C
            incremental_state (dict): dictionary used for storing state during
                :ref:`Incremental decoding`
            features_only (bool, optional): only return features without
                applying outpu

See the comparation of the outputs in:

examples/RE-DTI/generate_checkpoint_avg.pt.detok.extracted.eval_res.json