Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
2ccfaca
set up modules and models directory
Jan 30, 2021
db27544
set up transformer encoder
Jan 31, 2021
837c4d0
set up pretrain XLM-R model
Jan 31, 2021
bea9ae4
add transform
Jan 31, 2021
a78d286
download and archive files
Jan 31, 2021
b570052
commit model
Jan 31, 2021
bb8b842
move args to json file
Jan 31, 2021
46f7ea3
udpate md5
Jan 31, 2021
cebf542
checkpoint
Jan 31, 2021
f8305ea
checkpoint
Jan 31, 2021
ea9720a
lint errors
Feb 1, 2021
2a37567
checkpoint
Feb 1, 2021
b0617ae
Add TODO things
Feb 2, 2021
f98ff2b
wipe the forward func
Feb 2, 2021
722eded
create factory func to build pretrained model
Feb 2, 2021
4a610c7
update docs
Feb 2, 2021
dea30f6
checkpoint
Feb 3, 2021
ad6227d
checkpoint
Feb 3, 2021
990e394
add sentence classification model
Feb 3, 2021
b18672e
switch to models folder
Feb 5, 2021
c5acb37
Merge branch 'master' into pretrained_prototype
Feb 5, 2021
aadb315
change the name of the task class
Feb 8, 2021
29f57fa
sync with master branch
Feb 10, 2021
8472dc0
split model
Feb 10, 2021
5116849
remove root directory
Feb 10, 2021
9047f39
checkpoint
Feb 10, 2021
02ea426
switch to torch.hub
Feb 10, 2021
506cba8
add SHA256 check
Feb 10, 2021
88b5d7a
remove load_args_from_json func
Feb 11, 2021
0f7dfbc
switch to load_sentence_classifier_head
Feb 12, 2021
6f1e6ab
add language modeling head
Feb 16, 2021
f07f80c
activation in TransformerEncoderLayer
Feb 16, 2021
94299fd
Merge branch 'master' into pretrained_prototype
Feb 17, 2021
beafa3c
switch MHA to batch first
Feb 17, 2021
64964e2
max length of PositionalEmbedding + Batch first
Feb 17, 2021
e72ce94
update the pretrained model
Feb 17, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion torchtext/experimental/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
from .xlmr_model import xlmr_base, xlmr_regular, xlmr_base_sentence_classifier, \
xlmr_base_cross_lingual_mlm
from .utils import count_model_param

__all__ = ["count_model_param"]
__all__ = ['xlmr_base', 'xlmr_regular',
'xlmr_base_sentence_classifier', 'xlmr_base_cross_lingual_mlm',
'count_model_param']
17 changes: 17 additions & 0 deletions torchtext/experimental/models/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch
from torchtext.utils import download_from_url


def count_model_param(nn_model, unit=10**6):
Expand All @@ -20,3 +21,19 @@ def count_model_param(nn_model, unit=10**6):
model_parameters = filter(lambda p: p.requires_grad, nn_model.parameters())
params = sum([torch.prod(torch.tensor(p.size())) for p in model_parameters])
return params.item() / unit


def load_state_dict_from_url(url, overwrite=False, hash_value=None, hash_type="sha256"):
try:
if hash_value:
return torch.hub.load_state_dict_from_url(url, check_hash=True)
else:
return torch.hub.load_state_dict_from_url(url, check_hash=False)
except ImportError:
file_path = download_from_url(url, hash_value=hash_value, hash_type=hash_type)
return torch.load(file_path)


def load_model_from_url(url, overwrite=False, hash_value=None, hash_type="sha256"):
file_path = download_from_url(url, overwrite=overwrite, hash_value=hash_value, hash_type=hash_type)
return torch.load(file_path)
174 changes: 174 additions & 0 deletions torchtext/experimental/models/xlmr_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
import torch.nn as nn
from .xlmr_transform import load_xlmr_transform
from .utils import load_state_dict_from_url
from torch.nn import Linear, LayerNorm, TransformerEncoder
import torch.nn.functional as F
from torchtext.experimental.modules import BertEmbedding, TransformerEncoderLayer


class XLMRModel(nn.Module):
"""XLM-R model: a transformer encoder + embedding layer."""

def __init__(self, ntoken, embed_dim, nhead, feedforward_dim, nlayers, dropout=0.5):
super(XLMRModel, self).__init__()
self.xlmr_embed = BertEmbedding(ntoken, embed_dim, dropout)
encoder_layers = TransformerEncoderLayer(embed_dim, nhead, feedforward_dim, dropout)
self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
self.embed_dim = embed_dim

def forward(self, src):
src = self.xlmr_embed(src)
output = self.transformer_encoder(src)
return output


# [TODO] Add torch.hub support
# [TODO] Download file from manifold
# [TODO] check base model config
def xlmr_base():
'''
Examples:
>>> from torchtext.experimental.models import xlmr_base
>>> xlmr_base_model, xlmr_base_transform = xlmr_base()
>>> xlmr_base_transform('this is an example')
>>> tensor([ 903, 83, 142, 27781])
'''
encoder = XLMRModel(250002, embed_dim=768, nhead=12, feedforward_dim=3072, nlayers=12, dropout=0.2)
encoder.load_state_dict(load_state_dict_from_url(PRETRAINED['xlmr.base'], hash_value=SHA256['xlmr.base']))
return encoder, load_xlmr_transform()


def xlmr_regular():
'''
Examples:
>>> from torchtext.experimental.models import xlmr_regular
>>> xlmr_regular_model, xlmr_regular_transform = xlmr_regular()
>>> xlmr_regular_transform('this is an example')
>>> tensor([ 903, 83, 142, 27781])
'''
encoder = XLMRModel(250002, embed_dim=1024, nhead=16, feedforward_dim=4096, nlayers=24, dropout=0.2)
encoder.load_state_dict(load_state_dict_from_url(PRETRAINED['xlmr.regular'], hash_value=SHA256['xlmr.regular']))
return encoder, load_xlmr_transform()


PRETRAINED = {'xlmr.regular': 'https://pytorch.s3.amazonaws.com/models/text/pretrained_models/xlmr_regular-5626411f.pt',
'xlmr.base': 'https://pytorch.s3.amazonaws.com/models/text/pretrained_models/xlmr_base-4e52e3b8.pt'}
SHA256 = {'xlmr.regular': '5626411f5062c17b392725fcccd2fbc7f6df4b7d802279e4b65985bb01ed4480',
'xlmr.base': '4e52e3b861231d9cd3cb974ed5166294cd93e966169007662fcee68135dc0602'}

##################################################################################
# This part will be moved to stl-text/models folder


###########################
# Sentence Classification
###########################
class SentenceClassificationHead(nn.Module):
"""Head for sentence-level classification."""

def __init__(self, num_labels, embed_dim=768, dropout=0.2):
super(SentenceClassificationHead, self).__init__()
self.dense = nn.Linear(embed_dim, embed_dim)
self.dropout = nn.Dropout(dropout)
self.out_proj = nn.Linear(embed_dim, num_labels)
self.activation = nn.Tanh()

def forward(self, input_features):
x = input_features[:, 0, :] # The first token is reserved for [CLS]
x = self.dense(x)
x = self.activation(x)
x = self.dropout(x)
x = self.out_proj(x)
return x


def sentence_classifier_head():
classifier = SentenceClassificationHead(10, embed_dim=768, dropout=0.2)
classifier.load_state_dict(load_state_dict_from_url(TASK_PRETRAINED['xlmr_base_sentence_classifier'],
hash_value=TASK_SHA256['xlmr_base_sentence_classifier']))
return classifier


class TransformerEncoderSentenceClassification(nn.Module):
def __init__(self, transformer_encoder, classifier_head):
super(TransformerEncoderSentenceClassification, self).__init__()
self.transformer_encoder = transformer_encoder
self.classifier_head = classifier_head

def forward(self, src):
raise NotImplementedError("forward func has not been implemented yet.")


def xlmr_base_sentence_classifier():
'''
Examples:
>>> from torchtext.experimental.models import xlmr_base_sentence_classifier
>>> xlmr_sentence_classifier_model, xlmr_base_transform = xlmr_base_sentence_classifier()
>>> xlmr_base_transform('this is an example')
>>> tensor([ 903, 83, 142, 27781])
'''
# Load pretrained XLM-R
xlmr_model, xlmr_transform = xlmr_base()

# Load classifier head
sentence_classifier = sentence_classifier_head()
return TransformerEncoderSentenceClassification(xlmr_model, sentence_classifier), xlmr_transform


###########################
# Language Modeling
###########################
class CrossLingualMLMHead(nn.Module):
"""Contain a cross-lingual MLM head."""

def __init__(self, ntoken, embed_dim):
super(CrossLingualMLMHead, self).__init__()
self.mlm_span = Linear(embed_dim, embed_dim)
self.activation = F.gelu
self.norm_layer = LayerNorm(embed_dim, eps=1e-12)
self.mlm_head = Linear(embed_dim, ntoken)

def forward(self, src):
output = self.mlm_span(src)
output = self.activation(output)
output = self.norm_layer(output)
output = self.mlm_head(output)
return output


def cross_lingual_mlm_head():
classifier = CrossLingualMLMHead(250002, 768)
# [TODO] Load the weight of LM head
return classifier


class TransformerEncoderLanguageModeling(nn.Module):
"""Contain a transformer encoder plus LM head."""

def __init__(self, transformer_encoder, lm_head):
super(TransformerEncoderLanguageModeling, self).__init__()
self.transformer_encoder = transformer_encoder
self.lm_head = lm_head

def forward(self, src):
output = self.transformer_encoder(src)
output = self.lm_head(output)
return output


def xlmr_base_cross_lingual_mlm():
'''
Examples:
>>> from torchtext.experimental.models import xlmr_base_cross_lingual_mlm
>>> xlmr_lm_model, xlmr_base_transform = xlmr_base_cross_lingual_mlm()
>>> xlmr_base_transform('this is an example')
>>> tensor([ 903, 83, 142, 27781])
'''
xlmr_model, xlmr_transform = xlmr_base()

lm_head = cross_lingual_mlm_head()
return TransformerEncoderLanguageModeling(xlmr_model, lm_head), xlmr_transform


TASK_PRETRAINED = {'xlmr_base_sentence_classifier': 'https://pytorch.s3.amazonaws.com/models/text/pretrained_models/xlmr_base_sentence_classifier-7e3fbb3f.pt'}
TASK_SHA256 = {'xlmr_base_sentence_classifier': '7e3fbb3fac705df2be377d9e1cc198ce3a578172a17b1943e94fa2efe592f278'}
29 changes: 29 additions & 0 deletions torchtext/experimental/models/xlmr_transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import torch.nn as nn
from typing import List
from .utils import load_model_from_url


class XLMRTransform(nn.Module):
"""XLM-R encode transform."""

def __init__(self, tokenizer, vocab):
super(XLMRTransform, self).__init__()
self.tokenizer = tokenizer
self.vocab = vocab

def forward(self, input_src: str) -> List[int]:
return self.vocab(self.tokenizer(input_src))


def load_xlmr_transform():
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@datumbox Here is a pretrained transform case for the text domain.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good. This should be supported fine for presets.

The only question that I have is over the use of load_model_from_url. Do you typically store the whole model or only its weights? As far as I've seen the latter it typically preferred but maybe that does not work for you if the definition of the model is on an external lib.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. For the model itself, we store weights. However, for those two transforms, they are not really nn.Module so we have to store the whole models. Happy to hear a better solution.

tokenizer = load_model_from_url(TRANSFORM_PRETRAINED['xlmr_sentencepiece'],
hash_value=TRANSFORM_SHA256['xlmr_sentencepiece'])
vocab = load_model_from_url(TRANSFORM_PRETRAINED['xlmr_vocab'],
hash_value=TRANSFORM_SHA256['xlmr_vocab'])
return XLMRTransform(tokenizer, vocab)


TRANSFORM_PRETRAINED = {'xlmr_vocab': 'https://pytorch.s3.amazonaws.com/models/text/pretrained_models/xlmr_vocab-50081a8a.pt',
'xlmr_sentencepiece': 'https://pytorch.s3.amazonaws.com/models/text/pretrained_models/xlmr_sentencepiece-d4797664.pt'}
TRANSFORM_SHA256 = {'xlmr_vocab': '50081a8a69175ba2ed207eaf74f7055100aef3d8e737b3f0b26ee4a7c8fc781c',
'xlmr_sentencepiece': 'd47976646f6be0ae29b0f5d5b8a7b1d6381e46694a54e8d29dd51671f1471b33'}
5 changes: 5 additions & 0 deletions torchtext/experimental/modules/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from .transformer import TransformerEncoderLayer
from .embedding import PositionalEmbedding, BertEmbedding

__all__ = ['TransformerEncoderLayer',
'PositionalEmbedding', 'BertEmbedding']
30 changes: 30 additions & 0 deletions torchtext/experimental/modules/embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import torch
import torch.nn as nn
from torch.nn import Dropout, LayerNorm


class PositionalEmbedding(nn.Module):
def __init__(self, embed_dim, max_len=514):
super(PositionalEmbedding, self).__init__()
self.pos_embedding = nn.Embedding(max_len, embed_dim)

def forward(self, x):
N, S = x.size()
pos = torch.arange(S, dtype=torch.long,
device=x.device).unsqueeze(0).expand((N, S))
return self.pos_embedding(pos)


class BertEmbedding(nn.Module):
def __init__(self, ntoken, embed_dim=768, dropout=0.5):
super(BertEmbedding, self).__init__()
self.embed_dim = embed_dim
self.ntoken = ntoken
self.pos_embed = PositionalEmbedding(embed_dim)
self.embed = nn.Embedding(ntoken, embed_dim)
self.norm = LayerNorm(embed_dim)
self.dropout = Dropout(dropout)

def forward(self, src):
src = self.embed(src) + self.pos_embed(src)
return self.dropout(self.norm(src))
34 changes: 34 additions & 0 deletions torchtext/experimental/modules/transformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Linear, Dropout, LayerNorm
from torchtext.nn import MultiheadAttentionContainer, InProjContainer, ScaledDotProduct


class TransformerEncoderLayer(nn.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason we can't use torch.nn.TransformerEncoderLayer here?

Copy link
Contributor Author

@zhangguanheng66 zhangguanheng66 Feb 4, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch.nn.TransformerEncoderLayer uses the MHA in torch.nn. Here we use the MHA container in torchtext.

def __init__(self, embed_dim=768, nhead=12, feedforward_dim=3072,
dropout=0.2, activation=F.gelu):
super(TransformerEncoderLayer, self).__init__()
in_proj_container = InProjContainer(Linear(embed_dim, embed_dim),
Linear(embed_dim, embed_dim),
Linear(embed_dim, embed_dim))
self.mha = MultiheadAttentionContainer(nhead, in_proj_container,
ScaledDotProduct(), Linear(embed_dim, embed_dim), batch_first=True)
self.linear1 = Linear(embed_dim, feedforward_dim)
self.dropout = Dropout(dropout)
self.linear2 = Linear(feedforward_dim, embed_dim)

self.norm1 = LayerNorm(embed_dim)
self.norm2 = LayerNorm(embed_dim)
self.dropout1 = Dropout(dropout)
self.dropout2 = Dropout(dropout)
self.activation = activation
# [TODO] Add init_weights()

def forward(self, src, src_mask=None, src_key_padding_mask=None):
attn_output, attn_output_weights = self.mha(src, src, src, attn_mask=src_mask)
src = src + self.dropout1(attn_output)
src = self.norm1(src)
src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
src = src + self.dropout2(src2)
src = self.norm2(src)
return src