Skip to content

Commit

Permalink
Refactor model evaluation
Browse files Browse the repository at this point in the history
  • Loading branch information
tkipf committed May 12, 2019
1 parent 246fcac commit b88b174
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 9 deletions.
16 changes: 8 additions & 8 deletions modules.py
Expand Up @@ -62,7 +62,7 @@ def masked_encode(self, inputs, mask):
outputs.append(hidden[0])
return torch.stack(outputs, dim=1)

def get_boundaries(self, encodings, segment_id, lengths, evaluate=False):
def get_boundaries(self, encodings, segment_id, lengths):
"""Get boundaries (b) for a single segment in batch."""
if segment_id == self.max_num_segments - 1:
# Last boundary is always placed on last sequence element.
Expand All @@ -77,7 +77,7 @@ def get_boundaries(self, encodings, segment_id, lengths, evaluate=False):
encodings.size(0), 1, device=encodings.device) * utils.NEG_INF
# TODO(tkipf): Mask out padded positions with large neg. value.
logits_b = torch.cat([neg_inf, logits_b[:, 1:]], dim=1)
if not evaluate:
if self.training:
sample_b = utils.gumbel_softmax_sample(
logits_b, temp=self.temp_b)
else:
Expand All @@ -86,7 +86,7 @@ def get_boundaries(self, encodings, segment_id, lengths, evaluate=False):

return logits_b, sample_b

def get_latents(self, encodings, probs_b, evaluate=False):
def get_latents(self, encodings, probs_b):
"""Read out latents (z) form input encodings for a single segment."""
readout_mask = probs_b[:, 1:, None] # Offset readout by 1 to left.
readout = (encodings[:, :-1] * readout_mask).sum(1)
Expand All @@ -95,15 +95,15 @@ def get_latents(self, encodings, probs_b, evaluate=False):

# Gaussian latents.
if self.latent_dist == 'gaussian':
if not evaluate:
if self.training:
mu, log_var = torch.split(logits_z, self.latent_dim, dim=1)
sample_z = utils.gaussian_sample(mu, log_var)
else:
sample_z = logits_z[:, :self.latent_dim]

# Concrete / Gumbel softmax latents.
elif self.latent_dist == 'concrete':
if not evaluate:
if self.training:
sample_z = utils.gumbel_softmax_sample(
logits_z, temp=self.temp_z)
else:
Expand Down Expand Up @@ -131,7 +131,7 @@ def get_next_masks(self, all_b_samples):
else:
return None

def forward(self, inputs, lengths, evaluate=False):
def forward(self, inputs, lengths):

# Embed inputs.
embeddings = self.embed(inputs)
Expand All @@ -153,13 +153,13 @@ def forward(self, inputs, lengths, evaluate=False):

# Get boundaries (b) for current segment.
logits_b, sample_b = self.get_boundaries(
encodings, seg_id, lengths, evaluate)
encodings, seg_id, lengths)
all_b['logits'].append(logits_b)
all_b['samples'].append(sample_b)

# Get latents (z) for current segment.
logits_z, sample_z = self.get_latents(
encodings, sample_b, evaluate)
encodings, sample_b)
all_z['logits'].append(logits_z)
all_z['samples'].append(sample_z)

Expand Down
2 changes: 1 addition & 1 deletion train.py
Expand Up @@ -74,7 +74,7 @@
if step % args.log_interval == 0:
# Run evaluation.
model.eval()
outputs = model.forward(inputs, lengths, evaluate=True)
outputs = model.forward(inputs, lengths)
acc, rec = utils.get_reconstruction_accuracy(inputs, outputs, args)

# Accumulate metrics.
Expand Down

0 comments on commit b88b174

Please sign in to comment.