Skip to content

Commit

Permalink
Refactor losses
Browse files Browse the repository at this point in the history
  • Loading branch information
tkipf committed May 12, 2019
1 parent 0a41143 commit 246fcac
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 23 deletions.
13 changes: 5 additions & 8 deletions modules.py
Expand Up @@ -15,15 +15,11 @@ class CompILE(nn.Module):
max_num_segments: Maximum number of segments to predict.
temp_b: Gumbel softmax temperature for boundary variables (b).
temp_z: Temperature for latents (z), only if latent_dist='concrete'.
beta_b: Scaling factor for KL term of boundary variables (b).
beta_z: Scaling factor for KL term of latents (z).
prior_rate: Rate (lambda) for Poisson prior.
latent_dist: Whether to use Gaussian latents ('gaussian') or concrete /
Gumbel softmax latents ('concrete').
"""
def __init__(self, input_dim, hidden_dim, latent_dim, max_num_segments,
temp_b=1., temp_z=1., beta_b=.1, beta_z=.1, prior_rate=3.,
latent_dist='gaussian'):
temp_b=1., temp_z=1., latent_dist='gaussian'):
super(CompILE, self).__init__()

self.input_dim = input_dim
Expand All @@ -32,9 +28,6 @@ def __init__(self, input_dim, hidden_dim, latent_dim, max_num_segments,
self.max_num_segments = max_num_segments
self.temp_b = temp_b
self.temp_z = temp_z
self.beta_b = beta_b
self.beta_z = beta_z
self.prior_rate = prior_rate
self.latent_dist = latent_dist

self.embed = nn.Embedding(input_dim, hidden_dim)
Expand All @@ -47,6 +40,8 @@ def __init__(self, input_dim, hidden_dim, latent_dim, max_num_segments,
self.head_z_2 = nn.Linear(hidden_dim, latent_dim * 2)
elif latent_dist == 'concrete':
self.head_z_2 = nn.Linear(hidden_dim, latent_dim)
else:
raise ValueError('Invalid argument for `latent_dist`.')

self.head_b_1 = nn.Linear(hidden_dim, hidden_dim) # Boundaries (b).
self.head_b_2 = nn.Linear(hidden_dim, 1)
Expand Down Expand Up @@ -114,6 +109,8 @@ def get_latents(self, encodings, probs_b, evaluate=False):
else:
sample_z_idx = torch.argmax(logits_z, dim=1)
sample_z = utils.to_one_hot(sample_z_idx, logits_z.size(1))
else:
raise ValueError('Invalid argument for `latent_dist`.')

return logits_z, sample_z

Expand Down
4 changes: 2 additions & 2 deletions train.py
Expand Up @@ -66,7 +66,7 @@
# Run forward pass.
model.train()
outputs = model.forward(inputs, lengths)
loss, nll, kl_z, kl_b = utils.get_losses(model, inputs, outputs)
loss, nll, kl_z, kl_b = utils.get_losses(inputs, outputs, args)

loss.backward()
optimizer.step()
Expand All @@ -75,7 +75,7 @@
# Run evaluation.
model.eval()
outputs = model.forward(inputs, lengths, evaluate=True)
acc, rec = utils.get_reconstruction_accuracy(model, inputs, outputs)
acc, rec = utils.get_reconstruction_accuracy(inputs, outputs, args)

# Accumulate metrics.
batch_acc += acc.item()
Expand Down
38 changes: 25 additions & 13 deletions utils.py
Expand Up @@ -99,46 +99,58 @@ def get_segment_probs(all_b_samples, all_masks, segment_id):
return neg_cumsum


def get_losses(model, inputs, outputs):
"""Get losses (NLL, KL divergences and neg. ELBO)."""
targets = inputs.view(-1)
def get_losses(inputs, outputs, args, beta_b=.1, beta_z=.1, prior_rate=3.,):
"""Get losses (NLL, KL divergences and neg. ELBO).
Args:
inputs: Padded input sequences.
outputs: CompILE model output tuple.
args: Argument dict from `ArgumentParser`.
beta_b: Scaling factor for KL term of boundary variables (b).
beta_z: Scaling factor for KL term of latents (z).
prior_rate: Rate (lambda) for Poisson prior.
"""

targets = inputs.view(-1)
all_encs, all_recs, all_masks, all_b, all_z = outputs
input_dim = args.num_symbols + 1

nll = 0.
kl_z = 0.
for seg_id in range(model.max_num_segments):
for seg_id in range(args.num_segments):
seg_prob = get_segment_probs(
all_b['samples'], all_masks, seg_id)
preds = all_recs[seg_id].view(-1, model.input_dim)
preds = all_recs[seg_id].view(-1, input_dim)
seg_loss = F.cross_entropy(
preds, targets, reduction='none').view(-1, inputs.size(1))

# Ignore EOS token (last sequence element) in loss.
nll += (seg_loss[:, :-1] * seg_prob[:, :-1]).sum(1).mean(0)

# KL divergence on z.
if model.latent_dist == 'gaussian':
if args.latent_dist == 'gaussian':
mu, log_var = torch.split(
all_z['logits'][seg_id], model.latent_dim, dim=1)
all_z['logits'][seg_id], args.latent_dim, dim=1)
kl_z += kl_gaussian(mu, log_var).mean(0)
elif model.latent_dist == 'concrete':
elif args.latent_dist == 'concrete':
kl_z += kl_categorical_uniform(
F.softmax(all_z['logits'][seg_id], dim=-1)).mean(0)
else:
raise ValueError('Invalid argument for `latent_dist`.')

# KL divergence on b (first segment only, ignore first time step).
# TODO(tkipf): Implement alternative prior on soft segment length.
probs_b = F.softmax(all_b['logits'][0], dim=-1)
log_prior_b = poisson_categorical_log_prior(
probs_b.size(1), model.prior_rate, device=inputs.device)
kl_b = model.max_num_segments * kl_categorical(
probs_b.size(1), prior_rate, device=inputs.device)
kl_b = args.num_segments * kl_categorical(
probs_b[:, 1:], log_prior_b[:, 1:]).mean(0)

loss = nll + model.beta_z * kl_z + model.beta_b * kl_b
loss = nll + beta_z * kl_z + beta_b * kl_b
return loss, nll, kl_z, kl_b


def get_reconstruction_accuracy(model, inputs, outputs):
def get_reconstruction_accuracy(inputs, outputs, args):
"""Calculate reconstruction accuracy (averaged over sequence length)."""

all_encs, all_recs, all_masks, all_b, all_z = outputs
Expand All @@ -150,7 +162,7 @@ def get_reconstruction_accuracy(model, inputs, outputs):
for sample_idx in range(batch_size):
prev_boundary_pos = 0
rec_seq_parts = []
for seg_id in range(model.max_num_segments):
for seg_id in range(args.num_segments):
boundary_pos = torch.argmax(
all_b['samples'][seg_id], dim=-1)[sample_idx]
if prev_boundary_pos > boundary_pos:
Expand Down

0 comments on commit 246fcac

Please sign in to comment.