Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
ebc09f0
Add __init__.py to visualization folder.
a-googler Aug 26, 2017
65636df
modify the hacked-up batching scheme to prevent excessively-long shuf…
nshazeer Aug 28, 2017
814a472
Make an outline for docs.
Aug 28, 2017
6d00c8b
For the experts: Remove padding, add summaries and better params.
a-googler Aug 28, 2017
ab90f71
Fix setup.py for visualization
Aug 29, 2017
684f0d0
Migrate En-De BPE translation to Problem, add UNK option in TokenText…
Aug 29, 2017
da6643d
Use more robust method for showing the visualizations.
a-googler Aug 29, 2017
03a3861
Small simplification is vis notebook.
a-googler Aug 29, 2017
4cc039a
Add a cyclic linear learning rate scheme, play with VAE.
Aug 29, 2017
e742509
Bug fix, no access to targets during decoding. Move to correct place
Aug 29, 2017
ee3296f
Fix decode_from_dataset so that it decodes from multiple batches again
Aug 29, 2017
a2cf057
Add edit distance as metric as additional evaluation criteria.
a-googler Aug 29, 2017
a8ee62a
Add IMDB sentiment classification dataset
Aug 29, 2017
1fc6766
Create a Problem class for the lm1b dataset.
a-googler Aug 29, 2017
357c9d4
Adding example problem to T2T documentation
katelee168 Aug 29, 2017
ffe2386
Added optional memory-efficient versions of conv-hidden-relu and self…
nshazeer Aug 29, 2017
5bf1e82
v1.2.1
Aug 29, 2017
8353ef2
Finish LM1B transfer to Problem, add CNN+DailyMail dataset, style cor…
Aug 29, 2017
b54b711
fix docs
Aug 29, 2017
a3be70a
Correct cyclic lr scheme, docs, play with AE.
Aug 29, 2017
f715f85
Separate CLI t2t_decoder
Aug 29, 2017
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 14 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,26 @@ You can chat with us and other users on
with T2T announcements.

Here is a one-command version that installs tensor2tensor, downloads the data,
trains an English-German translation model, and lets you use it interactively:
trains an English-German translation model, and evaluates it:
```
pip install tensor2tensor && t2t-trainer \
--generate_data \
--data_dir=~/t2t_data \
--problems=translate_ende_wmt32k \
--model=transformer \
--hparams_set=transformer_base_single_gpu \
--output_dir=~/t2t_train/base \
--output_dir=~/t2t_train/base
```

You can decode from the model interactively:

```
t2t-decoder \
--data_dir=~/t2t_data \
--problems=translate_ende_wmt32k \
--model=transformer \
--hparams_set=transformer_base_single_gpu \
--output_dir=~/t2t_train/base
--decode_interactive
```

Expand Down Expand Up @@ -106,14 +117,12 @@ echo "Goodbye world" >> $DECODE_FILE
BEAM_SIZE=4
ALPHA=0.6

t2t-trainer \
t2t-decoder \
--data_dir=$DATA_DIR \
--problems=$PROBLEM \
--model=$MODEL \
--hparams_set=$HPARAMS \
--output_dir=$TRAIN_DIR \
--train_steps=0 \
--eval_steps=0 \
--decode_beam_size=$BEAM_SIZE \
--decode_alpha=$ALPHA \
--decode_from_file=$DECODE_FILE
Expand Down
34 changes: 34 additions & 0 deletions docs/example_life.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# T2T: Life of an Example

[![PyPI
version](https://badge.fury.io/py/tensor2tensor.svg)](https://badge.fury.io/py/tensor2tensor)
[![GitHub
Issues](https://img.shields.io/github/issues/tensorflow/tensor2tensor.svg)](https://github.com/tensorflow/tensor2tensor/issues)
[![Contributions
welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg)](CONTRIBUTING.md)
[![Gitter](https://img.shields.io/gitter/room/nwjs/nw.js.svg)](https://gitter.im/tensor2tensor/Lobby)
[![License](https://img.shields.io/badge/License-Apache%202.0-brightgreen.svg)](https://opensource.org/licenses/Apache-2.0)

This document show how a training example passes through the T2T pipeline,
and how all its parts are connected to work together.

## The Life of an Example

A training example passes the following stages in T2T:
* raw input (text from command line or file)
* encoded input after [Problem.feature_encoder](https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/data_generators/problem.py#L173) function `encode` is usually a sparse tensor, e.g., a vector of `tf.int32`s
* batched input after [data input pipeline](https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/data_reader.py#L242) where the inputs, after [Problem.preprocess_examples](https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/data_generators/problem.py#L188) are grouped by their length and made into batches.
* dense input after being processed by a [Modality](https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/modality.py#L30) function `bottom`.
* dense output after [T2T.model_fn_body](https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/t2t_model.py#L542)
* back to sparse output through [Modality](https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/modality.py#L30) function `top`.
* if decoding, back through [Problem.feature_encoder](https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/data_generators/problem.py#L173) function `decode` to display on the screen.

We go into these phases step by step below.

## Feature Encoders

TODO: describe [Problem.feature_encoder](https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/data_generators/problem.py#L173) which is a dict of encoders that have `encode` and `decode` functions.

## Modalities

TODO: describe [Modality](https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/modality.py#L30) which has `bottom` and `top` but also sharded versions and one for targets.
35 changes: 23 additions & 12 deletions docs/index.md
Original file line number Diff line number Diff line change
@@ -1,11 +1,4 @@
# T2T: Tensor2Tensor Transformers

Check us out on
<a href=https://github.com/tensorflow/tensor2tensor>
GitHub
<img src="https://github.com/favicon.ico" width="16">
</a>
.
# Tensor2Tensor Docs Index

[![PyPI
version](https://badge.fury.io/py/tensor2tensor.svg)](https://badge.fury.io/py/tensor2tensor)
Expand All @@ -16,8 +9,26 @@ welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg)](CO
[![Gitter](https://img.shields.io/gitter/room/nwjs/nw.js.svg)](https://gitter.im/tensor2tensor/Lobby)
[![License](https://img.shields.io/badge/License-Apache%202.0-brightgreen.svg)](https://opensource.org/licenses/Apache-2.0)

See our
[README](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/README.md)
for documentation.

More documentation and tutorials coming soon...
Welcome to Tensor2Tensor!

Tensor2Tensor, or T2T for short, is a library we use to create,
investigate and deploy deep learning models. This page hosts our
documentation, from basic tutorials to full code documentation.

## Basics

* [Walkthrough: Install and Run](walkthrough.md)
* [Tutorial: Train on Your Data](new_problem.md)
* [Tutorial: Create Your Own Model](new_model.md)

## Deep Dive

* [Life of an Example](example_life.md): how all parts of T2T are connected and work together
* [Distributed Training](distributed_training.md)

## Code documentation

See our
[README](https://github.com/tensorflow/tensor2tensor/blob/master/README.md)
for now, code docs coming.
16 changes: 16 additions & 0 deletions docs/new_model.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# T2T: Create Your Own Model

[![PyPI
version](https://badge.fury.io/py/tensor2tensor.svg)](https://badge.fury.io/py/tensor2tensor)
[![GitHub
Issues](https://img.shields.io/github/issues/tensorflow/tensor2tensor.svg)](https://github.com/tensorflow/tensor2tensor/issues)
[![Contributions
welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg)](CONTRIBUTING.md)
[![Gitter](https://img.shields.io/gitter/room/nwjs/nw.js.svg)](https://gitter.im/tensor2tensor/Lobby)
[![License](https://img.shields.io/badge/License-Apache%202.0-brightgreen.svg)](https://opensource.org/licenses/Apache-2.0)

Here we show how to create your own model in T2T.

## The T2TModel class

TODO: complete.
240 changes: 240 additions & 0 deletions docs/new_problem.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,240 @@
# T2T: Train on Your Own Data

[![PyPI
version](https://badge.fury.io/py/tensor2tensor.svg)](https://badge.fury.io/py/tensor2tensor)
[![GitHub
Issues](https://img.shields.io/github/issues/tensorflow/tensor2tensor.svg)](https://github.com/tensorflow/tensor2tensor/issues)
[![Contributions
welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg)](CONTRIBUTING.md)
[![Gitter](https://img.shields.io/gitter/room/nwjs/nw.js.svg)](https://gitter.im/tensor2tensor/Lobby)
[![License](https://img.shields.io/badge/License-Apache%202.0-brightgreen.svg)](https://opensource.org/licenses/Apache-2.0)

Let's add a new dataset together and train the transformer model. We'll be learning to define English words by training the transformer to "translate" between English words and their definitions on a character level.

# About the Problem

For each problem we want to tackle we create a new problem class and register it. Let's call our problem `Word2def`.

Since many text2text problems share similar methods, there's already a class called `Text2TextProblem` that extends the base problem class, `Problem` (both found in `problem.py`).

For our problem, we can go ahead and create the file `word2def.py` in the `data_generators` folder and add our new problem, `Word2def`, which extends `TranslateProblem`. Let's also register it while we're at it so we can specify the problem through flags.

```python
@registry.register_problem()
class Word2def(problem.Text2TextProblem):
"""Problem spec for English word to dictionary definition."""
return NotImplementedError()
```

We need to implement the following methods from `Text2TextProblem` in our new class:
* is_character_level
* targeted_vocab_size
* generator
* input_space_id
* target_space_id
* num_shards
* vocab_name
* use_subword_tokenizer

Let's tackle them one by one:

**input_space_id, target_space_id, is_character_level, targeted_vocab_size, use_subword_tokenizer**:

SpaceIDs tell Tensor2Tensor what sort of space the input and target tensors are in. These are things like, EN_CHR (English character), EN_TOK (English token), AUDIO_WAV (audio waveform), IMAGE, DNA (genetic bases). The complete list can be found at `data_generators/problem.py` in the class `SpaceID`.

Since we're generating definitions and feeding in words at the character level, we set `is_character_level` to true, and use the same SpaceID, EN_CHR, for both input and target. Additionally, since we aren't using tokens, we don't need to give a `targeted_vocab_size` or define `use_subword_tokenizer`.

**vocab_name**:

`vocab_name` will be used to name your vocabulary files. We can call ours `'vocab.word2def.en'`

**num_shards**:

The number of shards to break data files into.

```python
@registry.register_problem()
class Word2def(problem.Text2TextProblem):
"""Problem spec for English word to dictionary definition."""
def is_character_level(self):
return True

@property
def vocab_name(self):
return "vocab.word2def.en"

@property
def input_space_id(self):
return problem.SpaceID.EN_CHR

@property
def target_space_id(self):
return problem.SpaceID.EN_CHR

@property
def num_shards(self):
return 100

@property
def use_subword_tokenizer(self):
return False
```

**generator**:

We're almost done. `generator` generates the training and evaluation data and stores them in files like "word2def_train.lang1" in your DATA_DIR. Thankfully several commonly used methods like `character_generator`, and `token_generator` are already written in the file `wmt.py`. We will import `character_generator` and write:
```python
def generator(self, data_dir, tmp_dir, train):
character_vocab = text_encoder.ByteTextEncoder()
datasets = _WORD2DEF_TRAIN_DATASETS if train else _WORD2DEF_TEST_DATASETS
tag = "train" if train else "dev"
return character_generator(datasets[0], datasets[1], character_vocab, EOS)
```

Now our `word2def.py` file looks like the below:

```python
@registry.register_problem()
class Word2def(problem.Text2TextProblem):
"""Problem spec for English word to dictionary definition."""
@property
def is_character_level(self):
return True

@property
def vocab_name(self):
return "vocab.word2def.en"

def generator(self, data_dir, tmp_dir, train):
character_vocab = text_encoder.ByteTextEncoder()
datasets = _WORD2DEF_TRAIN_DATASETS if train else _WORD2DEF_TEST_DATASETS
tag = "train" if train else "dev"
return character_generator(datasets[0], datasets[1], character_vocab, EOS)

@property
def input_space_id(self):
return problem.SpaceID.EN_CHR

@property
def target_space_id(self):
return problem.SpaceID.EN_CHR

@property
def num_shards(self):
return 100

@property
def use_subword_tokenizer(self):
return False
```

## Data:
Now we need to tell Tensor2Tensor where our data is located.

I've gone ahead and split all words into a train and test set and saved them in files called `words.train.txt`, `words.test.txt`,
`definitions.train.txt`, and `definitions.test.txt` in a directory called `LOCATION_OF_DATA/`. Let's tell T2T where these files are:

```python
# English Word2def datasets
_WORD2DEF_TRAIN_DATASETS = [
[
"LOCATION_OF_DATA/", ("words_train.txt", "definitions_train.txt")
]
]
_WORD2DEF_TEST_DATASETS = [
[
"LOCATION_OF_DATA", ("words_test.txt", "definitions_test.txt")
]
]
```

## Putting it all together

Now our `word2def.py` file looks like: (with the correct imports)
```python
""" Problem definition for word to dictionary definition.
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import tarfile # do we need this import

from tensor2tensor.data_generators import generator_utils
from tensor2tensor.data_generators import problem
from tensor2tensor.data_generators import text_encoder
from tensor2tensor.data_generators.wmt import character_generator

from tensor2tensor.utils import registry

import tensorflow as tf

FLAGS = tf.flags.FLAGS

# English Word2def datasets
_WORD2DEF_TRAIN_DATASETS = [
LOCATION_OF_DATA+'words_train.txt',
LOCATION_OF_DATA+'definitions_train.txt'
]

_WORD2DEF_TEST_DATASETS = [
LOCATION_OF_DATA+'words_test.txt',
LOCATION_OF_DATA+'definitions_test.txt'
]

@registry.register_problem()
class Word2def(problem.Text2TextProblem):
"""Problem spec for English word to dictionary definition."""
@property
def is_character_level(self):
return True

@property
def vocab_name(self):
return "vocab.word2def.en"

def generator(self, data_dir, tmp_dir, train):
character_vocab = text_encoder.ByteTextEncoder()
datasets = _WORD2DEF_TRAIN_DATASETS if train else _WORD2DEF_TEST_DATASETS
tag = "train" if train else "dev"
return character_generator(datasets[0], datasets[1], character_vocab, EOS)

@property
def input_space_id(self):
return problem.SpaceID.EN_CHR

@property
def target_space_id(self):
return problem.SpaceID.EN_CHR

@property
def num_shards(self):
return 100

@property
def use_subword_tokenizer(self):
return False

```

# Hyperparameters
All hyperparamters inherit from `_default_hparams()` in `problem.py.` If you would like to customize your hyperparameters, add another method to the file `problem_hparams.py`.

# Run the problem
Now that we've gotten our problem set up, let's train a model and generate definitions.

We specify our problem name, the model, and hparams.
```bash
PROBLEM=word2def
MODEL=transformer
HPARAMS=transofmer_base_single_gpu
```

The rest of the steps are as given in the [walkthrough](walkthrough.md).


What if we wanted to train a model to generate words given definitions? In T2T, we can change the problem name to be `PROBLEM=word2def_rev`.

All done. Let us know what definitions your model generated.
Loading