-
Notifications
You must be signed in to change notification settings - Fork 130
/
framework.rst
65 lines (45 loc) · 1.92 KB
/
framework.rst
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
.. _framework:
====================
RETURNN as Framework
====================
Install RETURNN via ``pip`` (`PyPI entry <https://pypi.org/project/returnn/>`__).
Then :code:`import returnn` should work.
See `demo-returnn-as-framework.py <https://github.com/rwth-i6/returnn/blob/master/demos/demo-returnn-as-framework.py>`__ as a full example.
Basically you can write very high level code like this:
.. code-block:: python
from returnn.TFEngine import Engine
from returnn.Dataset import init_dataset
from returnn.Config import get_global_config
config = get_global_config(auto_create=True)
config.update(dict(
# ...
))
engine = Engine(config)
train_data = init_dataset({"class": "Task12AXDataset", "num_seqs": 1000, "name": "train"})
dev_data = init_dataset({"class": "Task12AXDataset", "num_seqs": 100, "name": "dev", "fixed_random_seed": 1})
engine.init_train_from_config(train_data=train_data, dev_data=dev_data)
Or you go lower level and construct the computation graph yourself:
.. code-block:: python
from returnn.TFNetwork import TFNetwork
config = get_global_config(auto_create=True)
net = TFNetwork(train_flag=True)
net.construct_from_dict({
# ...
})
fetches = net.get_fetches_dict()
with tf.compat.v1.Session() as session:
results = session.run(fetches, feed_dict={
# ...
# you could use FeedDictDataProvider
})
Or even lower level and just use parts from ``TFUtil``, ``TFNativeOp``, etc.:
.. code-block:: python
from returnn.TFNativeOp import ctc_loss
from returnn.TFNativeOp import edit_distance
from returnn.TFNativeOp import NativeLstm2
from returnn.TFUtil import ctc_greedy_decode
from returnn.TFUtil import get_available_gpu_min_compute_capability
from returnn.TFUtil import safe_log
from returnn.TFUtil import reuse_name_scope
from returnn.TFUtil import dimshuffle
# ...