-
Notifications
You must be signed in to change notification settings - Fork 812
[RFC] Prototype pretrained models in torchtext #1136
base: main
Are you sure you want to change the base?
Changes from all commits
2ccfaca
db27544
837c4d0
bea9ae4
a78d286
b570052
bb8b842
46f7ea3
cebf542
f8305ea
ea9720a
2a37567
b0617ae
f98ff2b
722eded
4a610c7
dea30f6
ad6227d
990e394
b18672e
c5acb37
aadb315
29f57fa
8472dc0
5116849
9047f39
02ea426
506cba8
88b5d7a
0f7dfbc
6f1e6ab
f07f80c
94299fd
beafa3c
64964e2
e72ce94
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,3 +1,7 @@ | ||
| from .xlmr_model import xlmr_base, xlmr_regular, xlmr_base_sentence_classifier, \ | ||
| xlmr_base_cross_lingual_mlm | ||
| from .utils import count_model_param | ||
|
|
||
| __all__ = ["count_model_param"] | ||
| __all__ = ['xlmr_base', 'xlmr_regular', | ||
| 'xlmr_base_sentence_classifier', 'xlmr_base_cross_lingual_mlm', | ||
| 'count_model_param'] |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,174 @@ | ||
| import torch.nn as nn | ||
| from .xlmr_transform import load_xlmr_transform | ||
| from .utils import load_state_dict_from_url | ||
| from torch.nn import Linear, LayerNorm, TransformerEncoder | ||
| import torch.nn.functional as F | ||
| from torchtext.experimental.modules import BertEmbedding, TransformerEncoderLayer | ||
|
|
||
|
|
||
| class XLMRModel(nn.Module): | ||
| """XLM-R model: a transformer encoder + embedding layer.""" | ||
|
|
||
| def __init__(self, ntoken, embed_dim, nhead, feedforward_dim, nlayers, dropout=0.5): | ||
| super(XLMRModel, self).__init__() | ||
| self.xlmr_embed = BertEmbedding(ntoken, embed_dim, dropout) | ||
| encoder_layers = TransformerEncoderLayer(embed_dim, nhead, feedforward_dim, dropout) | ||
| self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers) | ||
| self.embed_dim = embed_dim | ||
|
|
||
| def forward(self, src): | ||
| src = self.xlmr_embed(src) | ||
| output = self.transformer_encoder(src) | ||
| return output | ||
|
|
||
|
|
||
| # [TODO] Add torch.hub support | ||
| # [TODO] Download file from manifold | ||
| # [TODO] check base model config | ||
| def xlmr_base(): | ||
| ''' | ||
| Examples: | ||
| >>> from torchtext.experimental.models import xlmr_base | ||
| >>> xlmr_base_model, xlmr_base_transform = xlmr_base() | ||
| >>> xlmr_base_transform('this is an example') | ||
| >>> tensor([ 903, 83, 142, 27781]) | ||
| ''' | ||
| encoder = XLMRModel(250002, embed_dim=768, nhead=12, feedforward_dim=3072, nlayers=12, dropout=0.2) | ||
| encoder.load_state_dict(load_state_dict_from_url(PRETRAINED['xlmr.base'], hash_value=SHA256['xlmr.base'])) | ||
| return encoder, load_xlmr_transform() | ||
|
|
||
|
|
||
| def xlmr_regular(): | ||
| ''' | ||
| Examples: | ||
| >>> from torchtext.experimental.models import xlmr_regular | ||
| >>> xlmr_regular_model, xlmr_regular_transform = xlmr_regular() | ||
| >>> xlmr_regular_transform('this is an example') | ||
| >>> tensor([ 903, 83, 142, 27781]) | ||
| ''' | ||
| encoder = XLMRModel(250002, embed_dim=1024, nhead=16, feedforward_dim=4096, nlayers=24, dropout=0.2) | ||
| encoder.load_state_dict(load_state_dict_from_url(PRETRAINED['xlmr.regular'], hash_value=SHA256['xlmr.regular'])) | ||
| return encoder, load_xlmr_transform() | ||
|
|
||
|
|
||
| PRETRAINED = {'xlmr.regular': 'https://pytorch.s3.amazonaws.com/models/text/pretrained_models/xlmr_regular-5626411f.pt', | ||
| 'xlmr.base': 'https://pytorch.s3.amazonaws.com/models/text/pretrained_models/xlmr_base-4e52e3b8.pt'} | ||
| SHA256 = {'xlmr.regular': '5626411f5062c17b392725fcccd2fbc7f6df4b7d802279e4b65985bb01ed4480', | ||
| 'xlmr.base': '4e52e3b861231d9cd3cb974ed5166294cd93e966169007662fcee68135dc0602'} | ||
|
|
||
| ################################################################################## | ||
| # This part will be moved to stl-text/models folder | ||
|
|
||
|
|
||
| ########################### | ||
| # Sentence Classification | ||
| ########################### | ||
| class SentenceClassificationHead(nn.Module): | ||
| """Head for sentence-level classification.""" | ||
|
|
||
| def __init__(self, num_labels, embed_dim=768, dropout=0.2): | ||
| super(SentenceClassificationHead, self).__init__() | ||
| self.dense = nn.Linear(embed_dim, embed_dim) | ||
| self.dropout = nn.Dropout(dropout) | ||
| self.out_proj = nn.Linear(embed_dim, num_labels) | ||
| self.activation = nn.Tanh() | ||
|
|
||
| def forward(self, input_features): | ||
| x = input_features[:, 0, :] # The first token is reserved for [CLS] | ||
| x = self.dense(x) | ||
| x = self.activation(x) | ||
| x = self.dropout(x) | ||
| x = self.out_proj(x) | ||
| return x | ||
|
|
||
|
|
||
| def sentence_classifier_head(): | ||
| classifier = SentenceClassificationHead(10, embed_dim=768, dropout=0.2) | ||
| classifier.load_state_dict(load_state_dict_from_url(TASK_PRETRAINED['xlmr_base_sentence_classifier'], | ||
| hash_value=TASK_SHA256['xlmr_base_sentence_classifier'])) | ||
| return classifier | ||
|
|
||
|
|
||
| class TransformerEncoderSentenceClassification(nn.Module): | ||
| def __init__(self, transformer_encoder, classifier_head): | ||
| super(TransformerEncoderSentenceClassification, self).__init__() | ||
| self.transformer_encoder = transformer_encoder | ||
| self.classifier_head = classifier_head | ||
|
|
||
| def forward(self, src): | ||
| raise NotImplementedError("forward func has not been implemented yet.") | ||
|
|
||
|
|
||
| def xlmr_base_sentence_classifier(): | ||
| ''' | ||
| Examples: | ||
| >>> from torchtext.experimental.models import xlmr_base_sentence_classifier | ||
| >>> xlmr_sentence_classifier_model, xlmr_base_transform = xlmr_base_sentence_classifier() | ||
| >>> xlmr_base_transform('this is an example') | ||
| >>> tensor([ 903, 83, 142, 27781]) | ||
| ''' | ||
| # Load pretrained XLM-R | ||
| xlmr_model, xlmr_transform = xlmr_base() | ||
|
|
||
| # Load classifier head | ||
| sentence_classifier = sentence_classifier_head() | ||
| return TransformerEncoderSentenceClassification(xlmr_model, sentence_classifier), xlmr_transform | ||
|
|
||
|
|
||
| ########################### | ||
| # Language Modeling | ||
| ########################### | ||
| class CrossLingualMLMHead(nn.Module): | ||
| """Contain a cross-lingual MLM head.""" | ||
|
|
||
| def __init__(self, ntoken, embed_dim): | ||
| super(CrossLingualMLMHead, self).__init__() | ||
| self.mlm_span = Linear(embed_dim, embed_dim) | ||
| self.activation = F.gelu | ||
| self.norm_layer = LayerNorm(embed_dim, eps=1e-12) | ||
| self.mlm_head = Linear(embed_dim, ntoken) | ||
|
|
||
| def forward(self, src): | ||
| output = self.mlm_span(src) | ||
| output = self.activation(output) | ||
| output = self.norm_layer(output) | ||
| output = self.mlm_head(output) | ||
| return output | ||
|
|
||
|
|
||
| def cross_lingual_mlm_head(): | ||
| classifier = CrossLingualMLMHead(250002, 768) | ||
| # [TODO] Load the weight of LM head | ||
| return classifier | ||
|
|
||
|
|
||
| class TransformerEncoderLanguageModeling(nn.Module): | ||
| """Contain a transformer encoder plus LM head.""" | ||
|
|
||
| def __init__(self, transformer_encoder, lm_head): | ||
| super(TransformerEncoderLanguageModeling, self).__init__() | ||
| self.transformer_encoder = transformer_encoder | ||
| self.lm_head = lm_head | ||
|
|
||
| def forward(self, src): | ||
| output = self.transformer_encoder(src) | ||
| output = self.lm_head(output) | ||
| return output | ||
|
|
||
|
|
||
| def xlmr_base_cross_lingual_mlm(): | ||
| ''' | ||
| Examples: | ||
| >>> from torchtext.experimental.models import xlmr_base_cross_lingual_mlm | ||
| >>> xlmr_lm_model, xlmr_base_transform = xlmr_base_cross_lingual_mlm() | ||
| >>> xlmr_base_transform('this is an example') | ||
| >>> tensor([ 903, 83, 142, 27781]) | ||
| ''' | ||
| xlmr_model, xlmr_transform = xlmr_base() | ||
|
|
||
| lm_head = cross_lingual_mlm_head() | ||
| return TransformerEncoderLanguageModeling(xlmr_model, lm_head), xlmr_transform | ||
|
|
||
|
|
||
| TASK_PRETRAINED = {'xlmr_base_sentence_classifier': 'https://pytorch.s3.amazonaws.com/models/text/pretrained_models/xlmr_base_sentence_classifier-7e3fbb3f.pt'} | ||
| TASK_SHA256 = {'xlmr_base_sentence_classifier': '7e3fbb3fac705df2be377d9e1cc198ce3a578172a17b1943e94fa2efe592f278'} | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,29 @@ | ||
| import torch.nn as nn | ||
| from typing import List | ||
| from .utils import load_model_from_url | ||
|
|
||
|
|
||
| class XLMRTransform(nn.Module): | ||
| """XLM-R encode transform.""" | ||
|
|
||
| def __init__(self, tokenizer, vocab): | ||
| super(XLMRTransform, self).__init__() | ||
| self.tokenizer = tokenizer | ||
| self.vocab = vocab | ||
|
|
||
| def forward(self, input_src: str) -> List[int]: | ||
| return self.vocab(self.tokenizer(input_src)) | ||
|
|
||
|
|
||
| def load_xlmr_transform(): | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @datumbox Here is a pretrained transform case for the text domain.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sounds good. This should be supported fine for presets. The only question that I have is over the use of
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes. For the model itself, we store weights. However, for those two transforms, they are not really |
||
| tokenizer = load_model_from_url(TRANSFORM_PRETRAINED['xlmr_sentencepiece'], | ||
| hash_value=TRANSFORM_SHA256['xlmr_sentencepiece']) | ||
| vocab = load_model_from_url(TRANSFORM_PRETRAINED['xlmr_vocab'], | ||
| hash_value=TRANSFORM_SHA256['xlmr_vocab']) | ||
| return XLMRTransform(tokenizer, vocab) | ||
|
|
||
|
|
||
| TRANSFORM_PRETRAINED = {'xlmr_vocab': 'https://pytorch.s3.amazonaws.com/models/text/pretrained_models/xlmr_vocab-50081a8a.pt', | ||
| 'xlmr_sentencepiece': 'https://pytorch.s3.amazonaws.com/models/text/pretrained_models/xlmr_sentencepiece-d4797664.pt'} | ||
| TRANSFORM_SHA256 = {'xlmr_vocab': '50081a8a69175ba2ed207eaf74f7055100aef3d8e737b3f0b26ee4a7c8fc781c', | ||
| 'xlmr_sentencepiece': 'd47976646f6be0ae29b0f5d5b8a7b1d6381e46694a54e8d29dd51671f1471b33'} | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,5 @@ | ||
| from .transformer import TransformerEncoderLayer | ||
| from .embedding import PositionalEmbedding, BertEmbedding | ||
|
|
||
| __all__ = ['TransformerEncoderLayer', | ||
| 'PositionalEmbedding', 'BertEmbedding'] |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,30 @@ | ||
| import torch | ||
| import torch.nn as nn | ||
| from torch.nn import Dropout, LayerNorm | ||
|
|
||
|
|
||
| class PositionalEmbedding(nn.Module): | ||
| def __init__(self, embed_dim, max_len=514): | ||
| super(PositionalEmbedding, self).__init__() | ||
| self.pos_embedding = nn.Embedding(max_len, embed_dim) | ||
|
|
||
| def forward(self, x): | ||
| N, S = x.size() | ||
| pos = torch.arange(S, dtype=torch.long, | ||
| device=x.device).unsqueeze(0).expand((N, S)) | ||
| return self.pos_embedding(pos) | ||
|
|
||
|
|
||
| class BertEmbedding(nn.Module): | ||
| def __init__(self, ntoken, embed_dim=768, dropout=0.5): | ||
| super(BertEmbedding, self).__init__() | ||
| self.embed_dim = embed_dim | ||
| self.ntoken = ntoken | ||
| self.pos_embed = PositionalEmbedding(embed_dim) | ||
| self.embed = nn.Embedding(ntoken, embed_dim) | ||
| self.norm = LayerNorm(embed_dim) | ||
| self.dropout = Dropout(dropout) | ||
|
|
||
| def forward(self, src): | ||
| src = self.embed(src) + self.pos_embed(src) | ||
| return self.dropout(self.norm(src)) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,34 @@ | ||
| import torch.nn as nn | ||
| import torch.nn.functional as F | ||
| from torch.nn import Linear, Dropout, LayerNorm | ||
| from torchtext.nn import MultiheadAttentionContainer, InProjContainer, ScaledDotProduct | ||
|
|
||
|
|
||
| class TransformerEncoderLayer(nn.Module): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there a reason we can't use torch.nn.TransformerEncoderLayer here?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. torch.nn.TransformerEncoderLayer uses the MHA in |
||
| def __init__(self, embed_dim=768, nhead=12, feedforward_dim=3072, | ||
| dropout=0.2, activation=F.gelu): | ||
| super(TransformerEncoderLayer, self).__init__() | ||
| in_proj_container = InProjContainer(Linear(embed_dim, embed_dim), | ||
| Linear(embed_dim, embed_dim), | ||
| Linear(embed_dim, embed_dim)) | ||
| self.mha = MultiheadAttentionContainer(nhead, in_proj_container, | ||
| ScaledDotProduct(), Linear(embed_dim, embed_dim), batch_first=True) | ||
| self.linear1 = Linear(embed_dim, feedforward_dim) | ||
| self.dropout = Dropout(dropout) | ||
| self.linear2 = Linear(feedforward_dim, embed_dim) | ||
|
|
||
| self.norm1 = LayerNorm(embed_dim) | ||
| self.norm2 = LayerNorm(embed_dim) | ||
| self.dropout1 = Dropout(dropout) | ||
| self.dropout2 = Dropout(dropout) | ||
| self.activation = activation | ||
| # [TODO] Add init_weights() | ||
|
|
||
| def forward(self, src, src_mask=None, src_key_padding_mask=None): | ||
| attn_output, attn_output_weights = self.mha(src, src, src, attn_mask=src_mask) | ||
| src = src + self.dropout1(attn_output) | ||
| src = self.norm1(src) | ||
| src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) | ||
| src = src + self.dropout2(src2) | ||
| src = self.norm2(src) | ||
| return src | ||
Uh oh!
There was an error while loading. Please reload this page.