Skip to content

Commit

Permalink
Merge pull request #2 from frankwhzhang/master
Browse files Browse the repository at this point in the history
fix gru4rec demo
  • Loading branch information
frankwhzhang committed Oct 8, 2019
2 parents 649894b + 1af51cc commit c9d4647
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 12 deletions.
3 changes: 2 additions & 1 deletion paddle_fl/core/trainer/fl_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,13 +100,14 @@ def run(self, feed, fetch):
self._logger.debug("begin to run recv program")
self.exe.run(self._recv_program)
self._logger.debug("begin to run current step")
self.exe.run(self._main_program,
loss = self.exe.run(self._main_program,
feed=feed,
fetch_list=fetch)
if self.cur_step % self._step == 0:
self._logger.debug("begin to run send program")
self.exe.run(self._send_program)
self.cur_step += 1
return loss

def stop(self):
return False
Expand Down
11 changes: 5 additions & 6 deletions paddle_fl/examples/gru4rec_demo/fl_master.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ def __init__(self):

def gru4rec_network(self,
vocab_size=37483,
hid_size=10,
hid_size=100,
init_low_bound=-0.04,
init_high_bound=0.04):
""" network definition """
Expand All @@ -29,7 +29,6 @@ def gru4rec_network(self,
initializer=fluid.initializer.Uniform(
low=init_low_bound, high=init_high_bound),
learning_rate=emb_lr_x),
#is_distributed=True,
is_sparse=False)
fc0 = fluid.layers.fc(input=emb,
size=hid_size * 3,
Expand All @@ -54,7 +53,7 @@ def gru4rec_network(self,
learning_rate=fc_lr_x))
cost = fluid.layers.cross_entropy(
input=self.fc, label=self.dst_wordseq)
acc = fluid.layers.accuracy(
self.acc = fluid.layers.accuracy(
input=self.fc, label=self.dst_wordseq, k=20)
self.loss = fluid.layers.mean(x=cost)
self.startup_program = fluid.default_startup_program()
Expand All @@ -70,17 +69,17 @@ def gru4rec_network(self,
job_generator.set_losses([model.loss])
job_generator.set_startup_program(model.startup_program)
job_generator.set_infer_feed_and_target_names(
[model.src_wordseq.name, model.dst_wordseq.name], [model.fc.name])
[model.src_wordseq.name, model.dst_wordseq.name], [model.loss.name, model.acc.name])

build_strategy = FLStrategyFactory()
build_strategy.fed_avg = True
build_strategy.inner_step = 10
build_strategy.inner_step = 1
strategy = build_strategy.create_fl_strategy()

# endpoints will be collected through the cluster
# in this example, we suppose endpoints have been collected
endpoints = ["127.0.0.1:8181"]
output = "fl_job_config"
job_generator.generate_fl_job(
strategy, server_endpoints=endpoints, worker_num=2, output=output)
strategy, server_endpoints=endpoints, worker_num=4, output=output)
# fl_job_config will be dispatched to workers
20 changes: 15 additions & 5 deletions paddle_fl/examples/gru4rec_demo/fl_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,31 @@

trainer_id = int(sys.argv[1]) # trainer id for each guest
place = fluid.CPUPlace()
train_file_dir = "mid_data/node1/0/"
train_file_dir = "mid_data/node4/%d/" % trainer_id
job_path = "fl_job_config"
job = FLRunTimeJob()
job.load_trainer_job(job_path, trainer_id)
trainer = FLTrainerFactory().create_fl_trainer(job)
trainer.start()

r = Gru4rec_Reader()
train_reader = r.reader(train_file_dir, place)
train_reader = r.reader(train_file_dir, place, batch_size = 125)

output_folder = "model_node4"
step_i = 0
while not trainer.stop():
step_i += 1
print("batch %d start train" % (step_i))
for data in train_reader():
print(data)
trainer.run(feed=data,
fetch=[])
#print(np.array(data['src_wordseq']))
ret_avg_cost = trainer.run(feed=data,
fetch=["mean_0.tmp_0"])
avg_ppl = np.exp(ret_avg_cost[0])
newest_ppl = np.mean(avg_ppl)
print("ppl:%.3f" % (newest_ppl))
save_dir = (output_folder + "/epoch_%d") % step_i
if trainer_id == 0:
print("start save")
trainer.save_inference_program(save_dir)
if step_i >= 40:
break
4 changes: 4 additions & 0 deletions paddle_fl/examples/gru4rec_demo/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,7 @@ sleep 2
python -u fl_trainer.py 0 >trainer0.log &
sleep 2
python -u fl_trainer.py 1 >trainer1.log &
sleep 2
python -u fl_trainer.py 2 >trainer2.log &
sleep 2
python -u fl_trainer.py 3 >trainer3.log &

0 comments on commit c9d4647

Please sign in to comment.