In [1]:
from google.cloud import storage

from tensor2tensor import models
from tensor2tensor import problems
from tensor2tensor.layers import common_layers
from tensor2tensor.utils import trainer_lib
from tensor2tensor.utils import t2t_model
from tensor2tensor.utils import registry
from tensor2tensor.utils import metrics

import tensorflow as tf

import numpy as np

import os

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])





  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


The TensorFlow contrib module will not be included in TensorFlow 2.0.
For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
  * https://github.com/tensorflow/io (for I/O related ops)
If you depend on functionality not listed there, please file an issue.






In [2]:
# GCSからファイルをダウンロードするメソッド(https://cloud.google.com/storage/docs/downloading-objects?hl=ja)
def download_blob(bucket_name, source_blob_name, destination_file_name):
    """Downloads a blob from the bucket."""
    storage_client = storage.Client()
    bucket = storage_client.get_bucket(bucket_name)
    blob = bucket.blob(source_blob_name)

    blob.download_to_filename(destination_file_name)

    print('Blob {} downloaded to {}.'.format(
        source_blob_name,
        destination_file_name))

In [4]:
# GCSのファイル一覧取得メソッドを参考
# https://cloud.google.com/storage/docs/listing-objects?hl=ja#storage-list-objects-python
def list_match_file_with_prefix(bucket_name, prefix, search_path):
    """Lists all the blobs in the bucket that begin with the prefix."""
    
    storage_client = storage.Client()

    # Note: Client.list_blobs requires at least package version 1.17.0.
    blobs = storage_client.list_blobs(bucket_name, prefix=prefix, delimiter=None)

    file_list = [blob.name for blob in blobs if search_path in blob.name]
    
    return file_list

In [11]:
BUDGET_NAME = os.environ['BUDGET_NAME']
PROBLEM = os.environ['PROBLEM']
DATA_DIR = os.environ['DATA_DIR']
TRAIN_DIR = os.environ['TRAIN_DIR']
HPARAMS = os.environ['HPARAMS']
MODEL = os.environ['MODEL']

In [12]:
src_file_name = os.path.join(TRAIN_DIR, 'checkpoint')
dist_file_name = os.path.join(TRAIN_DIR, 'checkpoint')

In [13]:
# checkpointのダウンロード
download_blob(BUDGET_NAME, src_file_name, dist_file_name)

Blob training/transformer_ende/checkpoint downloaded to training/transformer_ende/checkpoint.


In [14]:
import re
with open(dist_file_name) as f:
    l = f.readlines(1)
    ckpt_name = re.findall('model_checkpoint_path: "(.*?)"', l[0])[0]
    ckpt_path = os.path.join(TRAIN_DIR, ckpt_name)

In [86]:
ckpt_file_list = list_match_file_with_prefix(BUDGET_NAME, TRAIN_DIR, ckpt_path)

In [87]:
# checkpoint.variableを一式ダウンロード
for ckpt_file in ckpt_file_list:
    download_blob(BUDGET_NAME, ckpt_file, ckpt_file)

Blob training/transformer_ende/model.ckpt-40000.data-00000-of-00001 downloaded to training/transformer_ende/model.ckpt-40000.data-00000-of-00001.
Blob training/transformer_ende/model.ckpt-40000.index downloaded to training/transformer_ende/model.ckpt-40000.index.
Blob training/transformer_ende/model.ckpt-40000.meta downloaded to training/transformer_ende/model.ckpt-40000.meta.


In [15]:
vocab_file = list_match_file_with_prefix(BUDGET_NAME, DATA_DIR, os.path.join(DATA_DIR, 'vocab'))[0]

In [92]:
download_blob(BUDGET_NAME, vocab_file, vocab_file)

Blob transformer/vocab.translate_jpen.8192.subwords downloaded to transformer/vocab.translate_jpen.8192.subwords.


In [16]:
tfe = tf.contrib.eager
tfe.enable_eager_execution()
Modes = tf.estimator.ModeKeys

In [18]:
import pickle

import numpy as np

from tensor2tensor.data_generators import problem
from tensor2tensor.data_generators import text_problems
from tensor2tensor.utils import registry

# 前処理&学習で定義したPROBLEと同一のClass名にすること
@registry.register_problem
class Translate_JPEN(text_problems.Text2TextProblem):
    @property
    def approx_vocab_size(self):
        return 2**13

In [19]:
enfr_problem = problems.problem(PROBLEM)

In [20]:
# Get the encoders from the problem
encoders = enfr_problem.feature_encoders(DATA_DIR)





In [21]:
from functools import wraps
import time

def stop_watch(func) :
    @wraps(func)
    def wrapper(*args, **kargs) :
        start = time.time()
        print(f'{func.__name__} started ...')
        result = func(*args,**kargs)
        elapsed_time =  time.time() - start
        print(f'elapsed_time:{elapsed_time}')
        print(f'{func.__name__} completed')
        return result
    return wrapper

In [22]:
@stop_watch
def translate(inputs):
    encoded_inputs = encode(inputs)
    with tfe.restore_variables_on_create(ckpt_path):
        model_output = translate_model.infer(features=encoded_inputs)["outputs"]
    return decode(model_output)

def encode(input_str, output_str=None):
    """Input str to features dict, ready for inference"""
    inputs = encoders["inputs"].encode(input_str) + [1]  # add EOS id
    batch_inputs = tf.reshape(inputs, [1, -1, 1])  # Make it 3D.
    return {"inputs": batch_inputs}

def decode(integers):
    """List of ints to str"""
    integers = list(np.squeeze(integers))
    if 1 in integers:
        integers = integers[:integers.index(1)]
    return encoders["inputs"].decode(np.squeeze(integers))

In [24]:
hparams = trainer_lib.create_hparams(HPARAMS, data_dir=DATA_DIR, problem_name=PROBLEM)
translate_model = registry.model(MODEL)(hparams, Modes.PREDICT)


INFO:tensorflow:Setting T2TModel mode to 'infer'
INFO:tensorflow:Setting hparams.dropout to 0.0
INFO:tensorflow:Setting hparams.label_smoothing to 0.0
INFO:tensorflow:Setting hparams.layer_prepostprocess_dropout to 0.0
INFO:tensorflow:Setting hparams.symbol_dropout to 0.0
INFO:tensorflow:Setting hparams.attention_dropout to 0.0
INFO:tensorflow:Setting hparams.relu_dropout to 0.0




In [26]:
inputs = "My cat is so cute."
outputs = translate(inputs)
print(outputs)

translate started ...
:::MLPv0.5.0 transformer 1580438305.780114651 (/usr/local/lib/python3.7/site-packages/tensor2tensor/utils/expert_utils.py:231) model_hp_layer_postprocess_dropout: 0.0
:::MLPv0.5.0 transformer 1580438305.916484118 (/usr/local/lib/python3.7/site-packages/tensor2tensor/models/transformer.py:101) model_hp_hidden_layers: 6
:::MLPv0.5.0 transformer 1580438306.014828682 (/usr/local/lib/python3.7/site-packages/tensor2tensor/models/transformer.py:101) model_hp_attention_num_heads: 8
:::MLPv0.5.0 transformer 1580438306.103340864 (/usr/local/lib/python3.7/site-packages/tensor2tensor/models/transformer.py:101) model_hp_attention_dropout: 0.0
:::MLPv0.5.0 transformer 1580438306.295593739 (/usr/local/lib/python3.7/site-packages/tensor2tensor/layers/transformer_layers.py:182) model_hp_ffn_filter: {"filter_size": 2048, "use_bias": "True", "activation": "relu"}
:::MLPv0.5.0 transformer 1580438306.394455194 (/usr/local/lib/python3.7/site-packages/tensor2tensor/layers/transformer_la