diff --git a/torchtext/experimental/models/__init__.py b/torchtext/experimental/models/__init__.py index 2065636d77..6bd7bdc37a 100644 --- a/torchtext/experimental/models/__init__.py +++ b/torchtext/experimental/models/__init__.py @@ -1,3 +1,7 @@ +from .xlmr_model import xlmr_base, xlmr_regular, xlmr_base_sentence_classifier, \ + xlmr_base_cross_lingual_mlm from .utils import count_model_param -__all__ = ["count_model_param"] +__all__ = ['xlmr_base', 'xlmr_regular', + 'xlmr_base_sentence_classifier', 'xlmr_base_cross_lingual_mlm', + 'count_model_param'] diff --git a/torchtext/experimental/models/utils.py b/torchtext/experimental/models/utils.py index b24ddd0404..d5b508f9c8 100644 --- a/torchtext/experimental/models/utils.py +++ b/torchtext/experimental/models/utils.py @@ -1,4 +1,5 @@ import torch +from torchtext.utils import download_from_url def count_model_param(nn_model, unit=10**6): @@ -20,3 +21,19 @@ def count_model_param(nn_model, unit=10**6): model_parameters = filter(lambda p: p.requires_grad, nn_model.parameters()) params = sum([torch.prod(torch.tensor(p.size())) for p in model_parameters]) return params.item() / unit + + +def load_state_dict_from_url(url, overwrite=False, hash_value=None, hash_type="sha256"): + try: + if hash_value: + return torch.hub.load_state_dict_from_url(url, check_hash=True) + else: + return torch.hub.load_state_dict_from_url(url, check_hash=False) + except ImportError: + file_path = download_from_url(url, hash_value=hash_value, hash_type=hash_type) + return torch.load(file_path) + + +def load_model_from_url(url, overwrite=False, hash_value=None, hash_type="sha256"): + file_path = download_from_url(url, overwrite=overwrite, hash_value=hash_value, hash_type=hash_type) + return torch.load(file_path) diff --git a/torchtext/experimental/models/xlmr_model.py b/torchtext/experimental/models/xlmr_model.py new file mode 100644 index 0000000000..054b832f37 --- /dev/null +++ b/torchtext/experimental/models/xlmr_model.py @@ -0,0 +1,174 @@ +import torch.nn as nn +from .xlmr_transform import load_xlmr_transform +from .utils import load_state_dict_from_url +from torch.nn import Linear, LayerNorm, TransformerEncoder +import torch.nn.functional as F +from torchtext.experimental.modules import BertEmbedding, TransformerEncoderLayer + + +class XLMRModel(nn.Module): + """XLM-R model: a transformer encoder + embedding layer.""" + + def __init__(self, ntoken, embed_dim, nhead, feedforward_dim, nlayers, dropout=0.5): + super(XLMRModel, self).__init__() + self.xlmr_embed = BertEmbedding(ntoken, embed_dim, dropout) + encoder_layers = TransformerEncoderLayer(embed_dim, nhead, feedforward_dim, dropout) + self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers) + self.embed_dim = embed_dim + + def forward(self, src): + src = self.xlmr_embed(src) + output = self.transformer_encoder(src) + return output + + +# [TODO] Add torch.hub support +# [TODO] Download file from manifold +# [TODO] check base model config +def xlmr_base(): + ''' + Examples: + >>> from torchtext.experimental.models import xlmr_base + >>> xlmr_base_model, xlmr_base_transform = xlmr_base() + >>> xlmr_base_transform('this is an example') + >>> tensor([ 903, 83, 142, 27781]) + ''' + encoder = XLMRModel(250002, embed_dim=768, nhead=12, feedforward_dim=3072, nlayers=12, dropout=0.2) + encoder.load_state_dict(load_state_dict_from_url(PRETRAINED['xlmr.base'], hash_value=SHA256['xlmr.base'])) + return encoder, load_xlmr_transform() + + +def xlmr_regular(): + ''' + Examples: + >>> from torchtext.experimental.models import xlmr_regular + >>> xlmr_regular_model, xlmr_regular_transform = xlmr_regular() + >>> xlmr_regular_transform('this is an example') + >>> tensor([ 903, 83, 142, 27781]) + ''' + encoder = XLMRModel(250002, embed_dim=1024, nhead=16, feedforward_dim=4096, nlayers=24, dropout=0.2) + encoder.load_state_dict(load_state_dict_from_url(PRETRAINED['xlmr.regular'], hash_value=SHA256['xlmr.regular'])) + return encoder, load_xlmr_transform() + + +PRETRAINED = {'xlmr.regular': 'https://pytorch.s3.amazonaws.com/models/text/pretrained_models/xlmr_regular-5626411f.pt', + 'xlmr.base': 'https://pytorch.s3.amazonaws.com/models/text/pretrained_models/xlmr_base-4e52e3b8.pt'} +SHA256 = {'xlmr.regular': '5626411f5062c17b392725fcccd2fbc7f6df4b7d802279e4b65985bb01ed4480', + 'xlmr.base': '4e52e3b861231d9cd3cb974ed5166294cd93e966169007662fcee68135dc0602'} + +################################################################################## +# This part will be moved to stl-text/models folder + + +########################### +# Sentence Classification +########################### +class SentenceClassificationHead(nn.Module): + """Head for sentence-level classification.""" + + def __init__(self, num_labels, embed_dim=768, dropout=0.2): + super(SentenceClassificationHead, self).__init__() + self.dense = nn.Linear(embed_dim, embed_dim) + self.dropout = nn.Dropout(dropout) + self.out_proj = nn.Linear(embed_dim, num_labels) + self.activation = nn.Tanh() + + def forward(self, input_features): + x = input_features[:, 0, :] # The first token is reserved for [CLS] + x = self.dense(x) + x = self.activation(x) + x = self.dropout(x) + x = self.out_proj(x) + return x + + +def sentence_classifier_head(): + classifier = SentenceClassificationHead(10, embed_dim=768, dropout=0.2) + classifier.load_state_dict(load_state_dict_from_url(TASK_PRETRAINED['xlmr_base_sentence_classifier'], + hash_value=TASK_SHA256['xlmr_base_sentence_classifier'])) + return classifier + + +class TransformerEncoderSentenceClassification(nn.Module): + def __init__(self, transformer_encoder, classifier_head): + super(TransformerEncoderSentenceClassification, self).__init__() + self.transformer_encoder = transformer_encoder + self.classifier_head = classifier_head + + def forward(self, src): + raise NotImplementedError("forward func has not been implemented yet.") + + +def xlmr_base_sentence_classifier(): + ''' + Examples: + >>> from torchtext.experimental.models import xlmr_base_sentence_classifier + >>> xlmr_sentence_classifier_model, xlmr_base_transform = xlmr_base_sentence_classifier() + >>> xlmr_base_transform('this is an example') + >>> tensor([ 903, 83, 142, 27781]) + ''' + # Load pretrained XLM-R + xlmr_model, xlmr_transform = xlmr_base() + + # Load classifier head + sentence_classifier = sentence_classifier_head() + return TransformerEncoderSentenceClassification(xlmr_model, sentence_classifier), xlmr_transform + + +########################### +# Language Modeling +########################### +class CrossLingualMLMHead(nn.Module): + """Contain a cross-lingual MLM head.""" + + def __init__(self, ntoken, embed_dim): + super(CrossLingualMLMHead, self).__init__() + self.mlm_span = Linear(embed_dim, embed_dim) + self.activation = F.gelu + self.norm_layer = LayerNorm(embed_dim, eps=1e-12) + self.mlm_head = Linear(embed_dim, ntoken) + + def forward(self, src): + output = self.mlm_span(src) + output = self.activation(output) + output = self.norm_layer(output) + output = self.mlm_head(output) + return output + + +def cross_lingual_mlm_head(): + classifier = CrossLingualMLMHead(250002, 768) + # [TODO] Load the weight of LM head + return classifier + + +class TransformerEncoderLanguageModeling(nn.Module): + """Contain a transformer encoder plus LM head.""" + + def __init__(self, transformer_encoder, lm_head): + super(TransformerEncoderLanguageModeling, self).__init__() + self.transformer_encoder = transformer_encoder + self.lm_head = lm_head + + def forward(self, src): + output = self.transformer_encoder(src) + output = self.lm_head(output) + return output + + +def xlmr_base_cross_lingual_mlm(): + ''' + Examples: + >>> from torchtext.experimental.models import xlmr_base_cross_lingual_mlm + >>> xlmr_lm_model, xlmr_base_transform = xlmr_base_cross_lingual_mlm() + >>> xlmr_base_transform('this is an example') + >>> tensor([ 903, 83, 142, 27781]) + ''' + xlmr_model, xlmr_transform = xlmr_base() + + lm_head = cross_lingual_mlm_head() + return TransformerEncoderLanguageModeling(xlmr_model, lm_head), xlmr_transform + + +TASK_PRETRAINED = {'xlmr_base_sentence_classifier': 'https://pytorch.s3.amazonaws.com/models/text/pretrained_models/xlmr_base_sentence_classifier-7e3fbb3f.pt'} +TASK_SHA256 = {'xlmr_base_sentence_classifier': '7e3fbb3fac705df2be377d9e1cc198ce3a578172a17b1943e94fa2efe592f278'} diff --git a/torchtext/experimental/models/xlmr_transform.py b/torchtext/experimental/models/xlmr_transform.py new file mode 100644 index 0000000000..34844b2be5 --- /dev/null +++ b/torchtext/experimental/models/xlmr_transform.py @@ -0,0 +1,29 @@ +import torch.nn as nn +from typing import List +from .utils import load_model_from_url + + +class XLMRTransform(nn.Module): + """XLM-R encode transform.""" + + def __init__(self, tokenizer, vocab): + super(XLMRTransform, self).__init__() + self.tokenizer = tokenizer + self.vocab = vocab + + def forward(self, input_src: str) -> List[int]: + return self.vocab(self.tokenizer(input_src)) + + +def load_xlmr_transform(): + tokenizer = load_model_from_url(TRANSFORM_PRETRAINED['xlmr_sentencepiece'], + hash_value=TRANSFORM_SHA256['xlmr_sentencepiece']) + vocab = load_model_from_url(TRANSFORM_PRETRAINED['xlmr_vocab'], + hash_value=TRANSFORM_SHA256['xlmr_vocab']) + return XLMRTransform(tokenizer, vocab) + + +TRANSFORM_PRETRAINED = {'xlmr_vocab': 'https://pytorch.s3.amazonaws.com/models/text/pretrained_models/xlmr_vocab-50081a8a.pt', + 'xlmr_sentencepiece': 'https://pytorch.s3.amazonaws.com/models/text/pretrained_models/xlmr_sentencepiece-d4797664.pt'} +TRANSFORM_SHA256 = {'xlmr_vocab': '50081a8a69175ba2ed207eaf74f7055100aef3d8e737b3f0b26ee4a7c8fc781c', + 'xlmr_sentencepiece': 'd47976646f6be0ae29b0f5d5b8a7b1d6381e46694a54e8d29dd51671f1471b33'} diff --git a/torchtext/experimental/modules/__init__.py b/torchtext/experimental/modules/__init__.py new file mode 100644 index 0000000000..a622eab793 --- /dev/null +++ b/torchtext/experimental/modules/__init__.py @@ -0,0 +1,5 @@ +from .transformer import TransformerEncoderLayer +from .embedding import PositionalEmbedding, BertEmbedding + +__all__ = ['TransformerEncoderLayer', + 'PositionalEmbedding', 'BertEmbedding'] diff --git a/torchtext/experimental/modules/embedding.py b/torchtext/experimental/modules/embedding.py new file mode 100644 index 0000000000..a9c5196306 --- /dev/null +++ b/torchtext/experimental/modules/embedding.py @@ -0,0 +1,30 @@ +import torch +import torch.nn as nn +from torch.nn import Dropout, LayerNorm + + +class PositionalEmbedding(nn.Module): + def __init__(self, embed_dim, max_len=514): + super(PositionalEmbedding, self).__init__() + self.pos_embedding = nn.Embedding(max_len, embed_dim) + + def forward(self, x): + N, S = x.size() + pos = torch.arange(S, dtype=torch.long, + device=x.device).unsqueeze(0).expand((N, S)) + return self.pos_embedding(pos) + + +class BertEmbedding(nn.Module): + def __init__(self, ntoken, embed_dim=768, dropout=0.5): + super(BertEmbedding, self).__init__() + self.embed_dim = embed_dim + self.ntoken = ntoken + self.pos_embed = PositionalEmbedding(embed_dim) + self.embed = nn.Embedding(ntoken, embed_dim) + self.norm = LayerNorm(embed_dim) + self.dropout = Dropout(dropout) + + def forward(self, src): + src = self.embed(src) + self.pos_embed(src) + return self.dropout(self.norm(src)) diff --git a/torchtext/experimental/modules/transformer.py b/torchtext/experimental/modules/transformer.py new file mode 100644 index 0000000000..71c5b2f176 --- /dev/null +++ b/torchtext/experimental/modules/transformer.py @@ -0,0 +1,34 @@ +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import Linear, Dropout, LayerNorm +from torchtext.nn import MultiheadAttentionContainer, InProjContainer, ScaledDotProduct + + +class TransformerEncoderLayer(nn.Module): + def __init__(self, embed_dim=768, nhead=12, feedforward_dim=3072, + dropout=0.2, activation=F.gelu): + super(TransformerEncoderLayer, self).__init__() + in_proj_container = InProjContainer(Linear(embed_dim, embed_dim), + Linear(embed_dim, embed_dim), + Linear(embed_dim, embed_dim)) + self.mha = MultiheadAttentionContainer(nhead, in_proj_container, + ScaledDotProduct(), Linear(embed_dim, embed_dim), batch_first=True) + self.linear1 = Linear(embed_dim, feedforward_dim) + self.dropout = Dropout(dropout) + self.linear2 = Linear(feedforward_dim, embed_dim) + + self.norm1 = LayerNorm(embed_dim) + self.norm2 = LayerNorm(embed_dim) + self.dropout1 = Dropout(dropout) + self.dropout2 = Dropout(dropout) + self.activation = activation + # [TODO] Add init_weights() + + def forward(self, src, src_mask=None, src_key_padding_mask=None): + attn_output, attn_output_weights = self.mha(src, src, src, attn_mask=src_mask) + src = src + self.dropout1(attn_output) + src = self.norm1(src) + src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) + src = src + self.dropout2(src2) + src = self.norm2(src) + return src