Skip to content

v0.12.0

Compare
Choose a tag to compare
@parmeet parmeet released this 10 Mar 18:31
· 305 commits to main since this release
d7a34d6

Highlights

In this release, we have revamped the library to provide a more comprehensive experience for users to do NLP modeling using TorchText and PyTorch.

  • Migrated datasets to build on top of TorchData DataPipes
  • Added support RoBERTa and XLM-RoBERTa pre-trained models
  • Added support for Scriptable tokenizers
  • Added support for composable transforms and functionals

Datasets

TorchText has modernized its datasets by migrating from older-style Iterable Datasets to TorchData’s DataPipes. TorchData is a library that provides modular/composable primitives, allowing users to load and transform data in performant data pipelines. These DataPipes work out-of-the-box with PyTorch DataLoader and would enable new functionalities like auto-sharding. Users can now easily do data manipulation and pre-processing using user-defined functions and transformations in a functional style programming. Datasets backed by DataPipes also enable standard flow-control like batching, collation, shuffling and bucketizing. Collectively, DataPipes provides a comprehensive experience for data preprocessing and tensorization needs in a pythonic and flexible way for model training.

from functools import partial
import torchtext.functional as F
import torchtext.transforms as T
from torch.hub import[ load_state_dict_from_url](https://pytorch.org/docs/stable/hub.html#torch.hub.load_state_dict_from_url)
from torch.utils.data import DataLoader
from torchtext.datasets import SST2

# Tokenizer to split input text into tokens
encoder_json_path = "https://download.pytorch.org/models/text/gpt2_bpe_encoder.json"
vocab_bpe_path = "https://download.pytorch.org/models/text/gpt2_bpe_vocab.bpe"
tokenizer = T.GPT2BPETokenizer(encoder_json_path, vocab_bpe_path)
# vocabulary converting tokens to IDs
vocab_path = "https://download.pytorch.org/models/text/roberta.vocab.pt"
vocab = T.VocabTransform([load_state_dict_from_url](https://pytorch.org/docs/stable/hub.html#torch.hub.load_state_dict_from_url)(vocab_path))
# Add BOS token to the beginning of sentence
add_bos = T.AddToken(token=0, begin=True)
# Add EOS token to the end of sentence
add_eos = T.AddToken(token=2, begin=False)

# Create SST2 dataset datapipe and apply pre-processing
batch_size = 32
train_dp = SST2(split="train")
train_dp = train_dp.batch(batch_size).rows2columnar(["text", "label"])
train_dp = train_dp.map(tokenizer, input_col="text", output_col="tokens")
train_dp = train_dp.map(partial(F.truncate, max_seq_len=254), input_col="tokens")
train_dp = train_dp.map(vocab, input_col="tokens")
train_dp = train_dp.map(add_bos, input_col="tokens")
train_dp = train_dp.map(add_eos, input_col="tokens")
train_dp = train_dp.map(partial(F.to_tensor, padding_value=1), input_col="tokens")
train_dp = train_dp.map(F.to_tensor, input_col="label")
# create DataLoader
dl = DataLoader(train_dp, batch_size=None)
batch = next(iter(dl))
model_input = batch["tokens"]
target = batch["label"]

TorchData is required in order to use these datasets. Please install following instructions at https://github.com/pytorch/data

Models

We have added support for pre-trained RoBERTa and XLM-R models. The models are torchscriptable and hence can be employed for production use-cases. The modeling APIs let users attach custom task-specific heads with pre-trained encoders. The API also comes equipped with data pre-processing transforms to match the pre-trained weights and model configuration.

import torch, torchtext
from torchtext.functional import to_tensor
xlmr_base = torchtext.models.XLMR_BASE_ENCODER
model = xlmr_base.get_model()
transform = xlmr_base.transform()
input_batch = ["Hello world", "How are you!"]
model_input = to_tensor(transform(input_batch), padding_value=1)
output = model(model_input)
output.shape
torch.Size([2, 6, 768])

# add classification head
import torch.nn as nn
class ClassificationHead(nn.Module):
    def __init__(self, input_dim, num_classes):
        super().__init__()
        self.output_layer = nn.Linear(input_dim, num_classes)

    def forward(self, features):
        #get features from cls token
        x = features[:, 0, :]
        return self.output_layer(x)

binary_classifier = xlmr_base.get_model(head=ClassificationHead(input_dim=768, num_classes=2)) 
output = binary_classifier(model_input)
output.shape
torch.Size([2, 2])

Transforms and tokenizers

We have revamped our transforms to provide composable building blocks to do text pre-processing. They support both batched and non-batched inputs. Furthermore, we have added support for a number of commonly used tokenizers including SentencePiece, GPT-2 BPE and CLIP.

import torchtext.transforms as T
from torch.hub import load_state_dict_from_url

padding_idx = 1
bos_idx = 0
eos_idx = 2
max_seq_len = 256
xlmr_vocab_path = r"https://download.pytorch.org/models/text/xlmr.vocab.pt"
xlmr_spm_model_path = r"https://download.pytorch.org/models/text/xlmr.sentencepiece.bpe.model"

text_transform = T.Sequential(
    T.SentencePieceTokenizer(xlmr_spm_model_path),
    T.VocabTransform(load_state_dict_from_url(xlmr_vocab_path)),
    T.Truncate(max_seq_len - 2),
    T.AddToken(token=bos_idx, begin=True),
    T.AddToken(token=eos_idx, begin=False),
)

text_transform([“Hello World”, “How are you”])

Tutorial

We have added an end-2-end tutorial to perform SST-2 binary text classification with pre-trained XLM-R base architecture and demonstrates the usage of new datasets, transforms and models.

Backward Incompatible changes

We have removed the legacy folder in this release which provided access to legacy datasets and abstractions. For additional information, please refer to the corresponding github issue (#1422) and PR (#1437)

New Features

Models

  • Add XLMR Base and Large pre-trained models and corresponding transformations (#1407)
  • Added option to specify whether to load pre-trained weights (#1424)
  • Added Option for freezing encoder weights (#1428)
  • Enable optional return of all states in transformer encoder (#1430)
  • Added support for RobertaModel to accept model configuration (#1431)
  • Allow inferred scaling in MultiheadSelfAttention for head_dim != 64 (#1432)
  • Added attention mask to transformer encoder modules (#1435)
  • Added builder method in Model Bundler to facilitate model creation with user-defined configuration and checkpoint (#1442)
  • Cleaned up Model API (#1452)
  • Fixed bool attention mask in transformer encoder (#1454)
  • Removed xlmr transform class and instead used sequential for model transforms composition (#1482)
  • Added support for pre-trained Roberta encoder for base and large architecture #1491

Transforms, Tokenizers, Ops

  • Added ToTensor and LabelToIndex transformations (#1415)
  • Added Truncate Transform (#1458)
  • Updated input annotation type to Any to support torch-scriptability during transform composability (#1453)
  • Added AddToken transform (#1463)
  • Added GPT-2 BPE pre-tokenizer operator leveraging re2 regex library (#1459)
  • Added Torchscriptable GPT-2 BPE Tokenizer for RoBERTa models (#1462)
  • Migrated GPT-2 BPE tokenizer logic to C++ (#1469)
  • fix optionality of default arg in to_tensor (#1475)
  • added scriptable sequential transform (#1481)
  • Removed optionality of dtype in ToTensor (#1492)
  • Fixed max sequence length for xlmr transform (#1495)
  • add max_tokens kwarg to vocab factory (#1525)
  • Refactor vocab factory method to accept special tokens as a keyword argument (#1436)
  • Implemented ClipTokenizer that builds on top of GPT2BPETokenizer (#1541)

Datasets

Migration of datasets on top of datapipes

Newly added datasets

Misc

  • Fix split filter logic in AmazonReviewPolarity (#1505)
  • use os.path.join for consistency. #1506
  • Fixing dataset test failures due to incorrect caching mode in AG_NEWS (#1517)
  • Added caching for extraction datapipe for AmazonReviewPolarity (#1527)
  • Added caching for extraction datapipe for Yahoo (#1528)
  • Added caching for extraction datapipe for yelp full (#1529)
  • Added caching for extraction datapipe for yelp polarity (#1530)
  • Added caching for extraction datapipe for DBPedia (#1571)
  • Added caching for extraction datapipe for SogouNews and AmazonReviewFull (#1594)
  • Fixed issues with extraction caching (#1550, #1551, #1552)
  • Updating Conll2000Chunking dataset to be consistent with other datasets (#1590)
  • [BC-breaking] removed unnecessary split argument from datasets (#1591)

Improvements

Testing

Revamp TorchText dataset testing to use mocked data

Others

  • Fixed attention mask testing (#1439)
  • Fixed CircleCI download failures on windows for XLM-R unit tests (#1441)
  • Asses unit tests for testing model training (#1449)
  • Parameterized XLMR and Roberta model integration tests (#1496)
  • Removed redundant get asset functions from parameterized_utils file (#1501)
  • Parameterize jit and non-jit model integration tests (#1502)
  • fixed cache logic to work with datapipes (#1522)
  • Convert get_mock_dataset fn in AmazonReviewPolarity to be private (#1543)
  • Removing unused TEST_MODELS_PARAMETERIZED_ARGS constant from model test (#1544)
  • Removed real dataset caching and testing in favor of mocked dataset testing (#1587)
  • fixed platform-dependent expectation for Multi30k mocked test. (#1593)
  • Fixed Conll2000Chunking Test (#1595)
  • Updated IWSLT testing to start from compressed file (#1596)
  • Used unicode strings to test utf-8 handling for all non-IWSLT dataset tests. (#1599)
  • Parameterize tests for similar datasets (#1600)

Examples

  • non-distributed training example for SST-2 binary text classification data using XLM-Roberta model (#1468)

Documentation

Dataset Documentation

  • Updated docs for text classification and language modeling datasets (#1603)
  • Updated docs for Machine Translation, Sequence Tagging, Question Answer, Unsupervised Learning datasets (#1597)
  • Updated docs for CC100 and SST2 (#1604)
  • Update sphinx version, added rst files for models, transforms and functionals (#1434)
  • Removed experimental documentation (#1457)
  • Fix links in README (#1461)
  • Added sphinx based tutorial for SST-2 binary classification task using XLM-R model (#1468)
  • pointed to pytorch.org docs instead of outdated rtd link (#1480)
  • Added documentation describing XLM-R, the datasets it was trained on, and relevant license information (#1497)
  • Fixed CI doc build (#1504)
  • Remove example using next(...) from README (#1516)

Misc

  • Hide symbols when building third party code (#1467)
  • Add .DS_Store files to gitignore (#1470)
  • Remove Python 3.6 support as it has reached EOL (#1484)
  • Added .gitattributes file to hide generated circleci files in PRs (#1485)
  • Switched to use FileOpener from FileLoader (#1488)
  • Update python_requires in setup.py to reflect support for non-EOL python versions (#1521)
  • Added auto-formatters (#1545)
  • fix typo in torchtext/vocab/vocab_factory.py (#1565)
  • Formatted datasets and tests (#1601, #1602)