Skip to content

Commit

Permalink
run transfomer dist to converge
Browse files Browse the repository at this point in the history
  • Loading branch information
panyx0718 committed Jun 20, 2018
1 parent 280c9e2 commit 936c80d
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 1 deletion.
23 changes: 23 additions & 0 deletions run.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
export PADDLE_PSERVERS=127.0.0.1
export POD_IP=127.0.0.1
export PADDLE_TRAINERS_NUM=2
export PADDLE_IS_LOCAL=0

export PADDLE_TRAINER_ID=0
export TRAINING_ROLE=PSERVER
export PADDLE_PORT=6177
sh run_transform.sh &>log/pserver_4.log &

sleep 60

export CUDA_VISIBLE_DEVICES=0,1
export TRAINING_ROLE=TRAINER
export PADDLE_TRAINER_ID=0
export PADDLE_PORT=6177
sh run_transform.sh &> log/trainer1.log &

export CUDA_VISIBLE_DEVICES=4,5
export TRAINING_ROLE=TRAINER
export PADDLE_TRAINER_ID=1
export PADDLE_PORT=6177
sh run_transform.sh
3 changes: 3 additions & 0 deletions run_transform.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
#!/bin/bash

python transformer_cloud/train.py --src_vocab_fpath nist06n/cn_30001.dict --trg_vocab_fpath nist06n/en_30001.dict --train_file_pattern 'nist06n/data/part-00' --batch_size 256 --use_token_batch True --special_token '_GO' '_EOS' '_UNK'
9 changes: 8 additions & 1 deletion transformer_cloud/train.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import sys
import time
import argparse
import ast
Expand Down Expand Up @@ -301,7 +302,7 @@ def split_data(data, num_part=dev_count):
else:
print "init fluid.framework.default_startup_program"
exe.run(fluid.framework.default_startup_program())

sys.stderr.write('1111\n')
train_data = reader.DataReader(
src_vocab_fpath=args.src_vocab_fpath,
trg_vocab_fpath=args.trg_vocab_fpath,
Expand All @@ -316,6 +317,7 @@ def split_data(data, num_part=dev_count):
end_mark=args.special_token[1],
unk_mark=args.special_token[2],
clip_last_batch=False)
sys.stderr.write('1112222\n')

train_data = read_multiple(reader=train_data.batch_generator)
build_strategy = fluid.BuildStrategy()
Expand Down Expand Up @@ -384,16 +386,20 @@ def test(exe=test_exe):

return test

sys.stderr.write('2222\n')
if args.val_file_pattern is not None:
test = test_context()

data_input_names = encoder_data_input_fields + decoder_data_input_fields[:
-1] + label_data_input_fields
util_input_names = encoder_util_input_fields + decoder_util_input_fields
init = False
sys.stderr.write('start training!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!\n')
for pass_id in xrange(TrainTaskConfig.pass_num):
pass_start_time = time.time()
sys.stderr.write('pass %d\n' % pass_id)
for batch_id, data in enumerate(train_data()):
sys.stderr.write('batch %d\n' % batch_id)
feed_list = []
total_num_token = 0
#lr_rate = lr_scheduler.update_learning_rate()
Expand All @@ -418,6 +424,7 @@ def test(exe=test_exe):
"@GRAD"] = 1. / total_num_token if TrainTaskConfig.use_avg_cost else np.asarray(
[1.], dtype="float32")
outs = train_exe.run(fetch_list=[sum_cost.name, token_num.name], feed=feed_list)
train_exe.bcast_params()
#outs = exe.run(train_progm,fetch_list=[sum_cost.name, token_num.name],feed=feed_list[0])
sum_cost_val, token_num_val = np.array(outs[0]), np.array(outs[1])
total_sum_cost = sum_cost_val.sum(
Expand Down

0 comments on commit 936c80d

Please sign in to comment.