In [51]:
import os
import torch
from collections import OrderedDict
from transformers import BartTokenizer, PretrainedConfig, BartForConditionalGeneration
from memformers.models.membart import MemBartForConditionalGeneration, MemBartModel

from memformers.models.membart.utils import get_model_config


In [54]:
model_config = PretrainedConfig.from_dict(get_model_config("membart-base.yaml"))


In [55]:
def process_weights(state_dict):
    new_state_dict = OrderedDict()
    for key, value in state_dict.items():
        key = key.replace("recurrent_training_cell.cell.", "")
        new_state_dict[key] = value
    return new_state_dict


In [56]:
state_dict = torch.load("./data/base/iter_156335_model_state.pth", map_location="cpu")


In [57]:
model = MemBartModel.from_pretrained("qywu/membart-base")


Some weights of the model checkpoint at qywu/membart-base were not used when initializing MemBartModel: ['final_logits_bias', 'lm_head.weight']
- This IS expected if you are initializing MemBartModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing MemBartModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [58]:
model = MemBartForConditionalGeneration(model_config)


In [59]:
state_dict = process_weights(state_dict)
print(model.load_state_dict(state_dict, strict=False))
model.tie_weights()


_IncompatibleKeys(missing_keys=['final_logits_bias', 'lm_head.weight'], unexpected_keys=[])


In [60]:
tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")


In [61]:
model = model.eval()


In [65]:
text1 = """
Barack Obama served as the 44th President of the United States.
"""

text2 = """
<mask> served as the 44th President of the United States.
"""


In [66]:
memory_states = model.construct_memory(1)

# t = 0
# memory generation step
input_ids = torch.LongTensor([tokenizer.encode(text1, add_special_tokens=True)])
encoder_outputs = model.model.encoder(input_ids=input_ids, memory_states=memory_states, attention_mask=None)

memory_states = encoder_outputs.memory_states

# input_ids = torch.LongTensor([tokenizer.encode(text2,
#                                                add_special_tokens=True)])

# encoder_outputs = model.model.encoder(input_ids=input_ids,
#                     memory_states=memory_states,
#                     attention_mask=None)

# memory_states = encoder_outputs.memory_states

# # input_ids = torch.LongTensor([tokenizer.encode("g g h h i i",
# #                                                add_special_tokens=True)])

# # encoder_outputs = model.model.encoder(input_ids=input_ids,
# #                     memory_states=memory_states,
# #                     attention_mask=None)

# # memory_states = encoder_outputs.memory_states


In [68]:
# t = 1
# without memory states
input_ids2 = torch.LongTensor([tokenizer.encode(text2, add_special_tokens=True)])

encoder_outputs2 = model.model.encoder(input_ids=input_ids2, memory_states=memory_states, attention_mask=None)

outputs = model.generate(
    encoder_outputs=encoder_outputs2,
    decoder_start_token_id=tokenizer.bos_token_id,
    max_length=64,
    num_beams=4,
    do_sample=False,
    return_dict_in_generate=True,
)

tokenizer.decode(outputs.sequences[0])


'<s><s> Barack Obama served as the 44th President of the United States.\n</s>'

In [36]:
from transformers import T5Tokenizer, T5ForConditionalGeneration


In [37]:
bart_model = BartForConditionalGeneration.from_pretrained("facebook/bart-base")


In [38]:
text = """
Cindy has a book called "A"

Mary has a book called "B"

<extra_id_0> has a book called "B"
"""


In [39]:
input_ids = torch.LongTensor([tokenizer.encode(text, add_special_tokens=True)])

outputs = bart_model.generate(
    input_ids,
    decoder_start_token_id=tokenizer.bos_token_id,
    max_length=32,
    num_beams=4,
    do_sample=False,
    return_dict_in_generate=True,
)

tokenizer.decode(outputs.sequences[0])


'<s><s>Cindy has a book called "A" and "B" in her book.Mary has a phone number.<extra_id_</s>'

In [40]:
tokenizer = T5Tokenizer.from_pretrained("t5-base")


For now, this behavior is kept to avoid breaking backwards compatibility when padding/encoding with `truncation is True`.
- Be aware that you SHOULD NOT rely on t5-base automatically truncating your input to 512 when padding/encoding.
- If you want to encode/pad to sequences longer than 512 you can either instantiate this tokenizer with `model_max_length` or pass `max_length` when encoding/padding.


In [96]:
t5_model = T5ForConditionalGeneration.from_pretrained("t5-base")


In [97]:
input_ids = torch.LongTensor([tokenizer.encode(text, add_special_tokens=True)])

outputs = t5_model.generate(
    input_ids,
    decoder_start_token_id=tokenizer.bos_token_id,
    max_length=32,
    num_beams=4,
    do_sample=False,
    return_dict_in_generate=True,
)

tokenizer.decode(outputs.sequences[0])


'<pad><extra_id_0> Mary<extra_id_1> Cindy has a book called "A" Mary has a book called "A" Mary<extra_id_2> Mary<extra_id_3> Cindy<extra_id_4>Cindy has a'

In [45]:
# upload a model to huggingface.co/models

from huggingface_hub import notebook_login

notebook_login()


Token is valid.
Your token has been saved in your configured git credential helpers (cache).
Your token has been saved to /home/qywu/.cache/huggingface/token
Login successful


In [46]:
model.push_to_hub("qywu/membart-base")


pytorch_model.bin: 100%|██████████| 731M/731M [01:02<00:00, 11.8MB/s] 
Upload 1 LFS files: 100%|██████████| 1/1 [01:02<00:00, 62.17s/it]


CommitInfo(commit_url='https://huggingface.co/qywu/membart-base/commit/844d9fec91e004f9fad0708ff3a9b70074f91535', commit_message='Upload MemBartForConditionalGeneration', commit_description='', oid='844d9fec91e004f9fad0708ff3a9b70074f91535', pr_url=None, pr_revision=None, pr_num=None)

In [49]:
tokenizer.push_to_hub("qywu/membart-base")


CommitInfo(commit_url='https://huggingface.co/qywu/membart-base/commit/ce6654163ae033f54791fad168fc60605cafb10e', commit_message='Upload tokenizer', commit_description='', oid='ce6654163ae033f54791fad168fc60605cafb10e', pr_url=None, pr_revision=None, pr_num=None)

In [73]:
tokenizer = BartTokenizer.from_pretrained("facebook/bart-large")

Downloading (…)olve/main/vocab.json: 100%|██████████| 899k/899k [00:00<00:00, 10.6MB/s]
Downloading (…)olve/main/merges.txt: 100%|██████████| 456k/456k [00:00<00:00, 8.43MB/s]
Downloading (…)okenizer_config.json: 100%|██████████| 26.0/26.0 [00:00<00:00, 12.8kB/s]


In [78]:
tokenizer.special_tokens_map

{'bos_token': '<s>',
 'eos_token': '</s>',
 'unk_token': '<unk>',
 'sep_token': '</s>',
 'pad_token': '<pad>',
 'cls_token': '<s>',
 'mask_token': '<mask>'}