Skip to content

Commit

Permalink
Add XLMR Base and Large pre-trained models and corresponding transfor…
Browse files Browse the repository at this point in the history
…mations (#1406)
  • Loading branch information
parmeet committed Oct 19, 2021
1 parent 0930843 commit 1fb2aed
Show file tree
Hide file tree
Showing 15 changed files with 883 additions and 2 deletions.
Binary file added test/asset/xlmr.base.output.pt
Binary file not shown.
Binary file added test/asset/xlmr.large.output.pt
Binary file not shown.
70 changes: 70 additions & 0 deletions test/models/test_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import torchtext
import torch

from ..common.torchtext_test_case import TorchtextTestCase
from ..common.assets import get_asset_path


class TestModels(TorchtextTestCase):
def test_xlmr_base_output(self):
asset_name = "xlmr.base.output.pt"
asset_path = get_asset_path(asset_name)
xlmr_base = torchtext.models.XLMR_BASE_ENCODER
model = xlmr_base.get_model()
model = model.eval()
model_input = torch.tensor([[0, 43523, 52005, 3647, 13293, 113307, 40514, 2]])
actual = model(model_input)
expected = torch.load(asset_path)
torch.testing.assert_close(actual, expected)

def test_xlmr_base_jit_output(self):
asset_name = "xlmr.base.output.pt"
asset_path = get_asset_path(asset_name)
xlmr_base = torchtext.models.XLMR_BASE_ENCODER
model = xlmr_base.get_model()
model = model.eval()
model_jit = torch.jit.script(model)
model_input = torch.tensor([[0, 43523, 52005, 3647, 13293, 113307, 40514, 2]])
actual = model_jit(model_input)
expected = torch.load(asset_path)
torch.testing.assert_close(actual, expected)

def test_xlmr_large_output(self):
asset_name = "xlmr.large.output.pt"
asset_path = get_asset_path(asset_name)
xlmr_base = torchtext.models.XLMR_LARGE_ENCODER
model = xlmr_base.get_model()
model = model.eval()
model_input = torch.tensor([[0, 43523, 52005, 3647, 13293, 113307, 40514, 2]])
actual = model(model_input)
expected = torch.load(asset_path)
torch.testing.assert_close(actual, expected)

def test_xlmr_large_jit_output(self):
asset_name = "xlmr.large.output.pt"
asset_path = get_asset_path(asset_name)
xlmr_base = torchtext.models.XLMR_LARGE_ENCODER
model = xlmr_base.get_model()
model = model.eval()
model_jit = torch.jit.script(model)
model_input = torch.tensor([[0, 43523, 52005, 3647, 13293, 113307, 40514, 2]])
actual = model_jit(model_input)
expected = torch.load(asset_path)
torch.testing.assert_close(actual, expected)

def test_xlmr_transform(self):
xlmr_base = torchtext.models.XLMR_BASE_ENCODER
transform = xlmr_base.transform()
test_text = "XLMR base Model Comparison"
actual = transform([test_text])
expected = [[0, 43523, 52005, 3647, 13293, 113307, 40514, 2]]
torch.testing.assert_close(actual, expected)

def test_xlmr_transform_jit(self):
xlmr_base = torchtext.models.XLMR_BASE_ENCODER
transform = xlmr_base.transform()
transform_jit = torch.jit.script(transform)
test_text = "XLMR base Model Comparison"
actual = transform_jit([test_text])
expected = [[0, 43523, 52005, 3647, 13293, 113307, 40514, 2]]
torch.testing.assert_close(actual, expected)
58 changes: 58 additions & 0 deletions test/test_functional.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import torch
from torchtext import functional
from .common.torchtext_test_case import TorchtextTestCase


class TestFunctional(TorchtextTestCase):
def test_to_tensor(self):
input = [[1, 2], [1, 2, 3]]
padding_value = 0
actual = functional.to_tensor(input, padding_value=padding_value)
expected = torch.tensor([[1, 2, 0], [1, 2, 3]], dtype=torch.long)
torch.testing.assert_close(actual, expected)

def test_to_tensor_jit(self):
input = [[1, 2], [1, 2, 3]]
padding_value = 0
to_tensor_jit = torch.jit.script(functional.to_tensor)
actual = to_tensor_jit(input, padding_value=padding_value)
expected = torch.tensor([[1, 2, 0], [1, 2, 3]], dtype=torch.long)
torch.testing.assert_close(actual, expected)

def test_truncate(self):
input = [[1, 2], [1, 2, 3]]
max_seq_len = 2
actual = functional.truncate(input, max_seq_len=max_seq_len)
expected = [[1, 2], [1, 2]]
self.assertEqual(actual, expected)

def test_truncate_jit(self):
input = [[1, 2], [1, 2, 3]]
max_seq_len = 2
truncate_jit = torch.jit.script(functional.truncate)
actual = truncate_jit(input, max_seq_len=max_seq_len)
expected = [[1, 2], [1, 2]]
self.assertEqual(actual, expected)

def test_add_token(self):
input = [[1, 2], [1, 2, 3]]
token_id = 0
actual = functional.add_token(input, token_id=token_id)
expected = [[0, 1, 2], [0, 1, 2, 3]]
self.assertEqual(actual, expected)

actual = functional.add_token(input, token_id=token_id, begin=False)
expected = [[1, 2, 0], [1, 2, 3, 0]]
self.assertEqual(actual, expected)

def test_add_token_jit(self):
input = [[1, 2], [1, 2, 3]]
token_id = 0
add_token_jit = torch.jit.script(functional.add_token)
actual = add_token_jit(input, token_id=token_id)
expected = [[0, 1, 2], [0, 1, 2, 3]]
self.assertEqual(actual, expected)

actual = add_token_jit(input, token_id=token_id, begin=False)
expected = [[1, 2, 0], [1, 2, 3, 0]]
self.assertEqual(actual, expected)
33 changes: 33 additions & 0 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import torch
from torchtext import transforms
from torchtext.vocab import vocab
from collections import OrderedDict

from .common.torchtext_test_case import TorchtextTestCase
from .common.assets import get_asset_path


class TestTransforms(TorchtextTestCase):
def test_spmtokenizer_transform(self):
asset_name = "spm_example.model"
asset_path = get_asset_path(asset_name)
transform = transforms.SpmTokenizerTransform(asset_path)
actual = transform(["Hello World!, how are you?"])
expected = [['▁Hello', '▁World', '!', ',', '▁how', '▁are', '▁you', '?']]
self.assertEqual(actual, expected)

def test_spmtokenizer_transform_jit(self):
asset_name = "spm_example.model"
asset_path = get_asset_path(asset_name)
transform = transforms.SpmTokenizerTransform(asset_path)
transform_jit = torch.jit.script(transform)
actual = transform_jit(["Hello World!, how are you?"])
expected = [['▁Hello', '▁World', '!', ',', '▁how', '▁are', '▁you', '?']]
self.assertEqual(actual, expected)

def test_vocab_transform(self):
vocab_obj = vocab(OrderedDict([('a', 1), ('b', 1), ('c', 1)]))
transform = transforms.VocabTransform(vocab_obj)
actual = transform([['a', 'b', 'c']])
expected = [[0, 1, 2]]
self.assertEqual(actual, expected)
10 changes: 10 additions & 0 deletions torchtext/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
import os
_TEXT_BUCKET = 'https://download.pytorch.org/models/text'
_CACHE_DIR = os.path.expanduser('~/.torchtext/cache')

from . import data
from . import nn
from . import datasets
from . import utils
from . import vocab
from . import transforms
from . import functional
from . import models
from . import experimental
from . import legacy
from ._extension import _init_extension
Expand All @@ -18,6 +25,9 @@
'datasets',
'utils',
'vocab',
'transforms',
'functional',
'models',
'experimental',
'legacy']

Expand Down
6 changes: 4 additions & 2 deletions torchtext/data/datasets_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
import defusedxml.ElementTree as ET
except ImportError:
import xml.etree.ElementTree as ET

from torchtext import _CACHE_DIR
"""
These functions and classes are meant solely for use in torchtext.datasets and not
for public consumption yet.
Expand Down Expand Up @@ -213,7 +215,7 @@ def _wrap_split_argument_with_fn(fn, splits):
raise ValueError("Internal Error: Given function {} did not adhere to standard signature.".format(fn))

@functools.wraps(fn)
def new_fn(root=os.path.expanduser('~/.torchtext/cache'), split=splits, **kwargs):
def new_fn(root=_CACHE_DIR, split=splits, **kwargs):
result = []
for item in _check_default_set(split, splits, fn.__name__):
result.append(fn(root, item, **kwargs))
Expand Down Expand Up @@ -250,7 +252,7 @@ def decorator(func):
raise ValueError("Internal Error: Given function {} did not adhere to standard signature.".format(fn))

@functools.wraps(func)
def wrapper(root=os.path.expanduser('~/.torchtext/cache'), *args, **kwargs):
def wrapper(root=_CACHE_DIR, *args, **kwargs):
new_root = os.path.join(root, dataset_name)
if not os.path.exists(new_root):
os.makedirs(new_root)
Expand Down
45 changes: 45 additions & 0 deletions torchtext/functional.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import torch
from torch import Tensor
from torch.nn.utils.rnn import pad_sequence
from typing import List, Optional

__all__ = [
'to_tensor',
'truncate',
'add_token',
]


def to_tensor(input: List[List[int]], padding_value: Optional[int] = None) -> Tensor:
if padding_value is None:
output = torch.tensor(input, dtype=torch.long)
return output
else:
output = pad_sequence(
[torch.tensor(ids, dtype=torch.long) for ids in input],
batch_first=True,
padding_value=float(padding_value)
)
return output


def truncate(input: List[List[int]], max_seq_len: int) -> List[List[int]]:
output: List[List[int]] = []

for ids in input:
output.append(ids[:max_seq_len])

return output


def add_token(input: List[List[int]], token_id: int, begin: bool = True) -> List[List[int]]:
output: List[List[int]] = []

if begin:
for ids in input:
output.append([token_id] + ids)
else:
for ids in input:
output.append(ids + [token_id])

return output
1 change: 1 addition & 0 deletions torchtext/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .roberta import * # noqa: F401, F403
18 changes: 18 additions & 0 deletions torchtext/models/roberta/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from .model import (
RobertaEncoderParams,
RobertaClassificationHead,
)

from .bundler import (
RobertaModelBundle,
XLMR_BASE_ENCODER,
XLMR_LARGE_ENCODER,
)

__all__ = [
"RobertaEncoderParams",
"RobertaClassificationHead",
"RobertaModelBundle",
"XLMR_BASE_ENCODER",
"XLMR_LARGE_ENCODER",
]
99 changes: 99 additions & 0 deletions torchtext/models/roberta/bundler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@

import os
from dataclasses import dataclass
from functools import partial

from typing import Optional, Callable
from torch.hub import load_state_dict_from_url
from torch.nn import Module
import logging

logger = logging.getLogger(__name__)

from .model import (
RobertaEncoderParams,
RobertaModel,
_get_model,
)

from .transforms import get_xlmr_transform

from torchtext import _TEXT_BUCKET


@dataclass
class RobertaModelBundle:
"""
Example - Pretrained encoder
>>> import torch, torchtext
>>> xlmr_base = torchtext.models.XLMR_BASE_ENCODER
>>> model = xlmr_base.get_model()
>>> transform = xlmr_base.transform()
>>> model_input = torch.tensor(transform(["Hello World"]))
>>> output = model(model_input)
>>> output.shape
torch.Size([1, 4, 768])
>>> input_batch = ["Hello world", "How are you!"]
>>> from torchtext.functional import to_tensor
>>> model_input = to_tensor(transform(input_batch), padding_value=transform.pad_idx)
>>> output = model(model_input)
>>> output.shape
torch.Size([2, 6, 768])
Example - Pretrained encoder attached to un-initialized classification head
>>> import torch, torchtext
>>> xlmr_large = torchtext.models.XLMR_LARGE_ENCODER
>>> classifier_head = torchtext.models.RobertaClassificationHead(num_classes=2, input_dim = xlmr_large.params.embedding_dim)
>>> classification_model = xlmr_large.get_model(head=classifier_head)
>>> transform = xlmr_large.transform()
>>> model_input = torch.tensor(transform(["Hello World"]))
>>> output = classification_model(model_input)
>>> output.shape
torch.Size([1, 2])
"""
_params: RobertaEncoderParams
_path: Optional[str] = None
_head: Optional[Module] = None
transform: Optional[Callable] = None

def get_model(self, head: Optional[Module] = None, *, dl_kwargs=None) -> RobertaModel:

if head is not None:
input_head = head
if self._head is not None:
logger.log("A custom head module was provided, discarding the default head module.")
else:
input_head = self._head

model = _get_model(self._params, input_head)

dl_kwargs = {} if dl_kwargs is None else dl_kwargs
state_dict = load_state_dict_from_url(self._path, **dl_kwargs)
if input_head is not None:
model.load_state_dict(state_dict, strict=False)
else:
model.load_state_dict(state_dict, strict=True)
return model

@property
def params(self) -> RobertaEncoderParams:
return self._params


XLMR_BASE_ENCODER = RobertaModelBundle(
_path=os.path.join(_TEXT_BUCKET, "xlmr.base.encoder.pt"),
_params=RobertaEncoderParams(vocab_size=250002),
transform=partial(get_xlmr_transform,
vocab_path=os.path.join(_TEXT_BUCKET, "xlmr.vocab.pt"),
spm_model_path=os.path.join(_TEXT_BUCKET, "xlmr.sentencepiece.bpe.model"),
)
)

XLMR_LARGE_ENCODER = RobertaModelBundle(
_path=os.path.join(_TEXT_BUCKET, "xlmr.large.encoder.pt"),
_params=RobertaEncoderParams(vocab_size=250002, embedding_dim=1024, ffn_dimension=4096, num_attention_heads=16, num_encoder_layers=24),
transform=partial(get_xlmr_transform,
vocab_path=os.path.join(_TEXT_BUCKET, "xlmr.vocab.pt"),
spm_model_path=os.path.join(_TEXT_BUCKET, "xlmr.sentencepiece.bpe.model"),
)
)

0 comments on commit 1fb2aed

Please sign in to comment.