Skip to content

Commit

Permalink
Merge pull request #7 from TissueC/master
Browse files Browse the repository at this point in the history
modifying models for cotk version_up
  • Loading branch information
aaa123git committed Jun 19, 2020
2 parents 88b8080 + 24b123b commit 100d52c
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 23 deletions.
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ install:
- pip install -r requirements.txt
- pip install pytest --upgrade
- pip install pytest-dependency pytest-mock requests-mock pytest>=3.6.0 "coverage<5.0" pytest-cov==2.4.0 python-coveralls
- cd .. && git clone https://github.com/thu-coai/cotk.git && pip install ./cotk --progress-bar off && cd CVAE-tensorflow
- cd .. && git clone -b version_up https://github.com/thu-coai/cotk.git && pip install ./cotk --progress-bar off && cd CVAE-tensorflow

script:
- pytest test_CVAE_tensorflow.py --cov=. --cov-report term-missing
Expand Down
9 changes: 4 additions & 5 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def main(args):
wordvec_class = Glove
if args.cache:
data = try_cache(data_class, (args.datapath,), args.cache_dir)
vocab = data.vocab_list
vocab = data.frequent_vocab_list
embed = try_cache(lambda wv, ez, vl: wordvec_class(wv).load_matrix(ez, vl),
(args.wvpath, args.word_embedding_size, vocab),
args.cache_dir, wordvec_class.__name__)
Expand All @@ -59,16 +59,15 @@ def main(args):
args.cache_dir, wordvec_class.__name__)
else:
data = data_class(args.datapath,
min_vocab_times=args.min_vocab_times,
min_frequent_vocab_times=args.min_frequent_vocab_times,
max_sent_length=args.max_sent_length,
max_turn_length=args.max_turn_length)
wv = wordvec_class(args.wvpath)
vocab = data.vocab_list
vocab = data.frequent_vocab_list #dim:9508
embed = wv.load_matrix(args.word_embedding_size, vocab)
word2vec = wv.load_dict(vocab)

embed = np.array(embed, dtype = np.float32)

with tf.Session(config=config) as sess:
model = create_model(sess, data, args, embed)
if args.mode == "train":
Expand Down
32 changes: 16 additions & 16 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from utils import SummaryHelper
from utils.basic_decoder import MyBasicDecoder

from cotk._utils import trim_before_target

class CVAEModel(object):
def __init__(self, data, args, embed):
Expand Down Expand Up @@ -39,7 +39,7 @@ def __init__(self, data, args, embed):
# build the embedding table and embedding input
if embed is None:
# initialize the embedding randomly
self.word_embed = tf.get_variable('word_embed', [data.vocab_size, args.word_embedding_size], tf.float32)
self.word_embed = tf.get_variable('word_embed', [data.frequent_vocab_size, args.word_embedding_size], tf.float32)
else:
# initialize the embedding by pre-trained word vectors
self.word_embed = tf.get_variable('word_embed', dtype=tf.float32, initializer=embed)
Expand Down Expand Up @@ -95,7 +95,7 @@ def __init__(self, data, args, embed):
with tf.name_scope("decode"):
# get output projection function
dec_init_fn = tf.layers.Dense(args.dh_size, use_bias=True)
output_fn = tf.layers.Dense(data.vocab_size, use_bias=True)
output_fn = tf.layers.Dense(data.frequent_vocab_size, use_bias=True)

with tf.name_scope("training"):
decoder_input = responses_dec_input
Expand All @@ -120,7 +120,7 @@ def __init__(self, data, args, embed):
self.anneal_KL_loss = self.KL_weight * self.KL_loss

bow_logits = tf.layers.dense(tf.layers.dense(dec_init_fn_input, 400, activation=tf.tanh),\
data.vocab_size)
data.frequent_vocab_size)
tile_bow_logits = tf.tile(tf.expand_dims(bow_logits, 1), [1, decoder_len, 1])
bow_loss = self.decoder_mask * tf.nn.sparse_softmax_cross_entropy_with_logits(\
logits=tile_bow_logits,\
Expand All @@ -145,7 +145,7 @@ def __init__(self, data, args, embed):
scope="decoder_rnn")
self.decoder_distribution = infer_outputs.rnn_output
self.generation_index = tf.argmax(tf.split(self.decoder_distribution,
[2, data.vocab_size - 2], 2)[1], 2) + 2 # for removing UNK
[2, data.frequent_vocab_size - 2], 2)[1], 2) + 2 # for removing UNK

# calculate the gradient of parameters and update
self.params = [k for k in tf.trainable_variables() if args.name in k.name]
Expand Down Expand Up @@ -223,7 +223,7 @@ def _pad_batch(self, raw_batch):
for i, speaker in enumerate(raw_batch['posts_length']):
batch['contexts_length'].append(len(raw_batch['posts_length'][i]))

if raw_batch['posts_length'][i].size > 0:
if len(raw_batch['posts_length'][i]) > 0:
max_post_len = max(max_post_len, max(raw_batch['posts_length'][i]))
batch['posts_length'].append(np.concatenate([raw_batch['posts_length'][i],
np.array([0] * (max_cxt_size - len(raw_batch['posts_length'][i])))], 0))
Expand All @@ -237,18 +237,18 @@ def _cut_batch_data(self, batch_data, start, end):
invoked by ^SwitchboardCorpus.split_session^
'''
raw_batch = {'posts_length': [], 'responses_length': []}
for i in range(len(batch_data['turn_length'])):
for i in range(len(batch_data['session_turn_length'])):
raw_batch['posts_length'].append( \
batch_data['sent_length'][i][start: end - 1])
turn_len = len(batch_data['sent_length'][i])
batch_data['session_sent_length'][i][start: end - 1])
turn_len = len(batch_data['session_sent_length'][i])
if end - 1 < turn_len:
raw_batch['responses_length'].append( \
batch_data['sent_length'][i][end - 1])
batch_data['session_sent_length'][i][end - 1])
else:
raw_batch['responses_length'].append(1)

raw_batch['contexts'] = batch_data['sent'][:, start: end - 1]
raw_batch['responses'] = batch_data['sent'][:, end - 1]
raw_batch['contexts'] = batch_data['session'][:, start: end - 1]
raw_batch['responses'] = batch_data['session'][:, end - 1]
return self._pad_batch(raw_batch)

def split_session(self, batch_data, session_window, inference=False):
Expand All @@ -275,7 +275,7 @@ def split_session(self, batch_data, session_window, inference=False):
Size: ^[batch_size]^
'''
max_turn = np.max(batch_data['turn_length'])
max_turn = np.max(batch_data['session_turn_length'])
ends = list(range(2, max_turn + 1))
if not inference:
np.random.shuffle(ends)
Expand All @@ -301,7 +301,7 @@ def multi_reference_batches(self, data, batch_size):
batch_data = data.get_next_batch('multi_ref')
while batch_data is not None:
batch = self._cut_batch_data(batch_data,\
0, np.max(batch_data['turn_length']))
0, np.max(batch_data['session_turn_length']))
batch['candidate_allvocabs'] = batch_data['candidate_allvocabs']
yield batch
batch_data = data.get_next_batch('multi_ref')
Expand Down Expand Up @@ -437,7 +437,7 @@ def padding(matrix, pad_go_id=False):
cnt = 0
start_time = time.time()
while batched_data != None:
conv_data = [{'contexts': [], 'responses': [], 'generations': []} for _ in range(len(batched_data['turn_length']))]
conv_data = [{'contexts': [], 'responses': [], 'generations': []} for _ in range(len(batched_data['session_turn_length']))]
for cut_batch_data in self.split_session(batched_data, args.session_window, inference=True):
eval_out = self.step_decoder(sess, cut_batch_data, forward_only=True)
decoder_loss, gen_prob = eval_out[:6], eval_out[-1]
Expand Down Expand Up @@ -513,7 +513,7 @@ def test_multi_ref(self, sess, data, word2vec, args):
responses.append([])
# if data.eos_id in resp:
# resp = resp[:resp.index(data.eos_id)]
resp = data.trim(resp)
resp = trim_before_target(resp,data.eos_id)
if len(resp) == 0:
resp = [data.unk_id]
responses[rid].append(resp + [data.eos_id])
Expand Down
2 changes: 1 addition & 1 deletion run.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def run(*argv):
args.batch_size = 3
args.grad_clip = 5.0
args.show_sample = [0]
args.min_vocab_times = 5
args.min_frequent_vocab_times = 5
args.max_sent_length = 50
args.max_turn_length = 1000
args.checkpoint_steps = 1
Expand Down
1 change: 1 addition & 0 deletions test_CVAE_tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def modify_args(args):
args.restore = None
args.wvclass = 'Glove'
args.wvpath = path + '/tests/wordvector/dummy_glove/300d'
args.word_embedding_size=300 #must be the same as the dim of wvpath
args.out_dir = cwd + '/output_test'
args.log_dir = cwd + '/tensorboard_test'
args.model_dir = cwd + '/model_test'
Expand Down

0 comments on commit 100d52c

Please sign in to comment.