Skip to content

Commit

Permalink
Resolve changes in tensor/variable API with PyTorch master (#824)
Browse files Browse the repository at this point in the history
  • Loading branch information
neerajprad authored and fritzo committed Feb 27, 2018
1 parent 29654dc commit a72f9ac
Show file tree
Hide file tree
Showing 49 changed files with 149 additions and 161 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ pip install .

For recent features you can install Pyro from source.

First install a recent PyTorch, currently PyTorch commit `853dba8`.
First install a recent PyTorch, currently PyTorch commit `05269b5`.
```sh
git clone git@github.com:pytorch/pytorch
cd pytorch
Expand Down
2 changes: 1 addition & 1 deletion examples/air/viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def draw_one(imgarr, z_arr):
# Note that this clipping makes the visualisation somewhat
# misleading, as it incorrectly suggests objects occlude one
# another.
clipped = np.clip(imgarr.data.cpu().numpy(), 0, 1)
clipped = np.clip(imgarr.detach().cpu().numpy(), 0, 1)
img = arr2img(clipped).convert('RGB')
draw = ImageDraw.Draw(img)
for k, z in enumerate(z_arr):
Expand Down
2 changes: 1 addition & 1 deletion examples/contrib/named/mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def main(args):

print('Parameters:')
for name in sorted(pyro.get_param_store().get_all_param_names()):
print('{} = {}'.format(name, pyro.param(name).data.cpu().numpy()))
print('{} = {}'.format(name, pyro.param(name).detach().cpu().numpy()))


if __name__ == '__main__':
Expand Down
2 changes: 1 addition & 1 deletion examples/contrib/named/tree_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def main(args):

print('Parameters:')
for name in sorted(pyro.get_param_store().get_all_param_names()):
print('{} = {}'.format(name, pyro.param(name).data.cpu().numpy()))
print('{} = {}'.format(name, pyro.param(name).detach().cpu().numpy()))


if __name__ == '__main__':
Expand Down
2 changes: 1 addition & 1 deletion examples/dmm/dmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ def main(args):
val_seq_lengths = data['valid']['sequence_lengths']
val_data_sequences = data['valid']['sequences']
N_train_data = len(training_seq_lengths)
N_train_time_slices = np.sum(training_seq_lengths)
N_train_time_slices = float(np.sum(training_seq_lengths))
N_mini_batches = int(N_train_data / args.mini_batch_size +
int(N_train_data % args.mini_batch_size > 0))

Expand Down
14 changes: 7 additions & 7 deletions examples/inclined_plane.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,10 @@ def simulate(mu, length=2.0, phi=np.pi / 6.0, dt=0.005, noise_sigma=None):
acceleration = Variable(torch.Tensor([little_g * np.sin(phi)])) - \
Variable(torch.Tensor([little_g * np.cos(phi)])) * mu

if acceleration.data[0] <= 0.0: # the box doesn't slide if the friction is too large
if acceleration.item() <= 0.0: # the box doesn't slide if the friction is too large
return Variable(torch.Tensor([1.0e5])) # return a very large time instead of infinity

while displacement.data[0] < length: # otherwise slide to the end of the inclined plane
while displacement.item() < length: # otherwise slide to the end of the inclined plane
displacement += velocity * dt
velocity += acceleration * dt
T += dt
Expand All @@ -67,7 +67,7 @@ def analytic_T(mu, length=2.0, phi=np.pi / 6.0):
torch.manual_seed(2)
observed_data = torch.cat([simulate(Variable(torch.Tensor([mu0])), noise_sigma=time_measurement_sigma)
for _ in range(N_obs)])
observed_mean = np.mean([T.data[0] for T in observed_data])
observed_mean = np.mean([T.item() for T in observed_data])


# define model with uniform prior on mu and gaussian noise on the descent time
Expand Down Expand Up @@ -100,8 +100,8 @@ def main(args):
posterior_std_dev = torch.std(torch.cat(posterior_samples), 0)

# report results
inferred_mu = posterior_mean.data[0]
inferred_mu_uncertainty = posterior_std_dev.data[0]
inferred_mu = posterior_mean.item()
inferred_mu_uncertainty = posterior_std_dev.item()
print("the coefficient of friction inferred by pyro is %.3f +- %.3f" %
(inferred_mu, inferred_mu_uncertainty))

Expand All @@ -111,9 +111,9 @@ def main(args):
# but will be systematically off from the third number
print("the mean observed descent time in the dataset is: %.4f seconds" % observed_mean)
print("the (forward) simulated descent time for the inferred (mean) mu is: %.4f seconds" %
simulate(posterior_mean).data[0])
simulate(posterior_mean).item())
print(("disregarding measurement noise, elementary calculus gives the descent time\n" +
"for the inferred (mean) mu as: %.4f seconds") % analytic_T(posterior_mean.data[0]))
"for the inferred (mean) mu as: %.4f seconds") % analytic_T(posterior_mean.item()))

"""
################## EXERCISE ###################
Expand Down
2 changes: 1 addition & 1 deletion examples/ss_vae_M2.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ def get_accuracy(data_loader, classifier_fn, batch_size):
for pred, act in zip(predictions, actuals):
for i in range(pred.size(0)):
v = torch.sum(pred[i] == act[i])
accurate_preds += (v.data[0] == 10)
accurate_preds += (v.item() == 10)

# calculate the accuracy between 0 and 1
accuracy = (accurate_preds * 1.0) / (len(predictions) * batch_size)
Expand Down
4 changes: 2 additions & 2 deletions examples/utils/vae_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,9 @@ def plot_tsne(z_mu, classes, name):
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
model_tsne = TSNE(n_components=2, random_state=0)
z_states = z_mu.data.cpu().numpy()
z_states = z_mu.detach().cpu().numpy()
z_embed = model_tsne.fit_transform(z_states)
classes = classes.data.cpu().numpy()
classes = classes.detach().cpu().numpy()
fig666 = plt.figure()
for ic in range(10):
ind_vec = np.zeros_like(classes)
Expand Down
4 changes: 2 additions & 2 deletions examples/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,9 +191,9 @@ def main(args):
for index in reco_indices:
test_img = x[index, :]
reco_img = vae.reconstruct_img(test_img)
vis.image(test_img.contiguous().view(28, 28).data.cpu().numpy(),
vis.image(test_img.contiguous().view(28, 28).detach().cpu().numpy(),
opts={'caption': 'test image'})
vis.image(reco_img.contiguous().view(28, 28).data.cpu().numpy(),
vis.image(reco_img.contiguous().view(28, 28).detach().cpu().numpy(),
opts={'caption': 'reconstructed image'})

# report test diagnostics
Expand Down
4 changes: 2 additions & 2 deletions examples/vae_comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def test(self, epoch):
n = min(x.size(0), 8)
comparison = torch.cat([x[:n],
recon_x.view(self.args.batch_size, 1, 28, 28)[:n]])
save_image(comparison.data.cpu(),
save_image(comparison.detach().cpu(),
os.path.join(OUTPUT_DIR, 'reconstruction_' + str(epoch) + '.png'),
nrow=n)

Expand Down Expand Up @@ -166,7 +166,7 @@ def compute_loss_and_gradient(self, x):
if self.mode == TRAIN:
loss.backward()
self.optimizer.step()
return loss.data[0]
return loss.item()

def initialize_optimizer(self, lr=1e-3):
model_params = itertools.chain(self.vae_encoder.parameters(), self.vae_decoder.parameters())
Expand Down
8 changes: 4 additions & 4 deletions pyro/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,18 +268,18 @@ def irange(name, size, subsample_size=None, subsample=None, use_cuda=None):
See `SVI Part II <http://pyro.ai/examples/svi_part_ii.html>`_ for an extended discussion.
"""
subsample, scale = _subsample(name, size, subsample_size, subsample, use_cuda)
if isinstance(subsample, Variable):
subsample = subsample.data
if not am_i_wrapped():
for i in subsample:
yield i
yield i.item() if isinstance(i, Variable) else i
else:
indep_context = poutine.indep(name, vectorized=False)
with poutine.scale(None, scale):
for i in subsample:
indep_context.next_context()
with indep_context:
yield i
# convert to python numeric type as functions like torch.ones(*args)
# do not work with dim 0 torch.Tensor instances.
yield i.item() if isinstance(i, Variable) else i


def map_data(name, data, fn, batch_size=None, batch_dim=0, use_cuda=None):
Expand Down
4 changes: 2 additions & 2 deletions pyro/infer/abstract_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def sample(self, *args, **kwargs):
if sample_shape:
raise ValueError("Arbitrary `sample_shape` not supported by Histogram class.")
d, values = self._dist_and_values(*args, **kwargs)
ix = d.sample().data[0]
ix = d.sample()
return values[ix]

__call__ = sample
Expand Down Expand Up @@ -151,4 +151,4 @@ def __call__(self, *args, **kwargs):
if not isinstance(logits, torch.autograd.Variable):
logits = Variable(logits)
ix = dist.Categorical(logits=logits).sample()
return traces[ix.data[0]]
return traces[ix]
2 changes: 1 addition & 1 deletion pyro/infer/mcmc/hmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def sample(self, trace):
energy_current = self._energy(z, r)
delta_energy = energy_proposal - energy_current
rand = pyro.sample('rand_t='.format(self._t), dist.Uniform(ng_zeros(1), ng_ones(1)))
if rand.log().data[0] < -delta_energy.data[0]:
if rand.log() < -delta_energy:
self._accept_cnt += 1
z = z_new
self._t += 1
Expand Down
8 changes: 4 additions & 4 deletions pyro/infer/mcmc/nuts.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def _build_basetree(self, z, r, z_grads, log_slice, direction):
z_new, r_new, z_grads, potential_energy = single_step_velocity_verlet(
z, r, self._potential_energy, step_size, z_grads=z_grads)
energy = potential_energy + self._kinetic_energy(r_new)
dE = (log_slice + energy).data[0]
dE = log_slice + energy

# As a part of the slice sampling process (see below), along the trajectory
# we eliminate states which p(z, r) < u, or dE < 0.
Expand Down Expand Up @@ -131,7 +131,7 @@ def _build_tree(self, z, r, z_grads, log_slice, direction, tree_depth):
other_half_tree_prob = other_half_tree.size / tree_size
is_other_half_tree = pyro.sample("is_other_halftree",
dist.Bernoulli(ps=ng_ones(1) * other_half_tree_prob))
if int(is_other_half_tree.data[0]) == 1:
if int(is_other_half_tree.item()) == 1:
z_proposal = other_half_tree.z_proposal

# leaves of the full tree are determined by the direction
Expand Down Expand Up @@ -190,7 +190,7 @@ def sample(self, trace):
for tree_depth in range(self.max_tree_depth + 1):
direction = pyro.sample("direction_t={}_treedepth={}".format(self._t, tree_depth),
dist.Bernoulli(ps=ng_ones(1) * 0.5))
direction = int(direction.data[0])
direction = int(direction.item())
if direction == 1: # go to the right, start from the right leaf of current tree
new_tree = self._build_tree(z_right, r_right, z_right_grads,
log_slice, direction, tree_depth)
Expand All @@ -210,7 +210,7 @@ def sample(self, trace):

accepted_prob = pyro.sample("acceptedprob_t={}_treedepth={}".format(self._t, tree_depth),
dist.Uniform(ng_zeros(1), ng_ones(1)))
if accepted_prob.data[0] < new_tree.size / tree_size:
if accepted_prob < new_tree.size / tree_size:
is_accepted = True
z = new_tree.z_proposal

Expand Down
2 changes: 1 addition & 1 deletion pyro/infer/svi.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def __init__(self,
self._loss = copy.copy(loss)

def new_loss(model, guide, *args, **kwargs):
return self._loss(model, guide, *args, **kwargs).data[0]
return self._loss(model, guide, *args, **kwargs).item()

self.loss = new_loss

Expand Down
2 changes: 1 addition & 1 deletion pyro/nn/auto_reg_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def __init__(self, input_dim, hidden_dim, output_dim_multiplier=1,

for k in range(hidden_dim):
# fill in mask1
m_k = self.mask_encoding[k]
m_k = self.mask_encoding[k].item()
slice_k = torch.cat([torch.ones(m_k), torch.zeros(input_dim - m_k)])
for j in range(input_dim):
self.mask1[k, self.permutation[j]] = slice_k[j]
Expand Down
2 changes: 1 addition & 1 deletion pyro/poutine/trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

def _warn_if_nan(name, value):
if isinstance(value, Variable):
value = value.data[0]
value = value.item()
if is_nan(value):
warnings.warn("Encountered NAN log_pdf at site '{}'".format(name))
if is_inf(value) and value > 0:
Expand Down
14 changes: 1 addition & 13 deletions tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,16 +124,6 @@ def is_iterable(obj):
return False


def _unwrap_variables(x, y):
if isinstance(x, Variable) and isinstance(y, Variable):
return x.data, y.data
elif isinstance(x, Variable) or isinstance(y, Variable):
raise AssertionError(
"cannot compare {} and {}".format(
type(x), type(y)))
return x, y


def assert_tensors_equal(a, b, prec=1e-5, msg=''):
assert a.size() == b.size(), msg
if prec == 0:
Expand All @@ -148,7 +138,7 @@ def assert_tensors_equal(a, b, prec=1e-5, msg=''):
diff[nan_mask] = 0
if diff.is_signed():
diff = diff.abs()
max_err = diff.max()
max_err = diff.max().item()
assert max_err < prec, msg


Expand Down Expand Up @@ -180,8 +170,6 @@ def _safe_coalesce(t):
# TODO Split this into assert_equal() and assert_close() or assert_almost_equal().
# TODO Use atol and rtol instead of prec
def assert_equal(x, y, prec=1e-5, msg=''):
x, y = _unwrap_variables(x, y)

if torch.is_tensor(x) and torch.is_tensor(y):
assert_equal(x.is_sparse, y.is_sparse, prec, msg)
if x.is_sparse:
Expand Down
2 changes: 1 addition & 1 deletion tests/contrib/gp/test_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def test_forward_rbf():
assert K.dim() == 2
assert K.size(0) == 2
assert K.size(1) == 2
assert_equal(K.data.sum(), 0.30531)
assert_equal(K.sum().item(), 0.30531)


def test_Kdiag():
Expand Down
4 changes: 2 additions & 2 deletions tests/contrib/gp/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def test_forward_gpr():
assert cov.size(0) == 2
assert cov.size(1) == 2
assert_equal(loc, y)
assert_equal(cov.data.abs().sum(), 0)
assert_equal(cov.abs().sum().item(), 0)


def test_forward_sgpr():
Expand All @@ -48,7 +48,7 @@ def test_forward_sgpr():
assert cov.size(0) == 2
assert cov.size(1) == 2
assert_equal(loc, y)
assert_equal(cov.data.abs().sum(), 0)
assert_equal(cov.abs().sum().item(), 0)


def test_forward_sgpr_vs_gpr():
Expand Down
2 changes: 1 addition & 1 deletion tests/distributions/dist_fixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def _convert_logits_to_ps(self, dist_params):
logits = Variable(torch.Tensor(dist_params.pop('logits')))
is_multidimensional = self.get_test_distribution_name() != 'Bernoulli'
ps, _ = get_probs_and_logits(logits=logits, is_multidimensional=is_multidimensional)
dist_params['ps'] = list(ps.data.cpu().numpy())
dist_params['ps'] = list(ps.detach().cpu().numpy())
return dist_params

def get_scipy_logpdf(self, idx):
Expand Down
8 changes: 4 additions & 4 deletions tests/distributions/test_categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,16 +37,16 @@ def setUp(self):
self.support = torch.Tensor([[0, 0], [1, 1], [2, 2]])

def test_log_pdf(self):
log_px_torch = dist.Categorical(self.ps).log_prob(self.test_data).sum().data[0]
log_px_np = float(sp.multinomial.logpmf(np.array([0, 0, 1]), 1, self.ps.data.cpu().numpy()))
log_px_torch = dist.Categorical(self.ps).log_prob(self.test_data).sum().item()
log_px_np = float(sp.multinomial.logpmf(np.array([0, 0, 1]), 1, self.ps.detach().cpu().numpy()))
assert_equal(log_px_torch, log_px_np, prec=1e-4)

def test_mean_and_var(self):
torch_samples = [dist.Categorical(self.ps).sample().data.cpu().numpy()
torch_samples = [dist.Categorical(self.ps).sample().detach().cpu().numpy()
for _ in range(self.n_samples)]
_, counts = np.unique(torch_samples, return_counts=True)
computed_mean = float(counts[0]) / self.n_samples
assert_equal(computed_mean, self.analytic_mean.data.cpu().numpy()[0], prec=0.05)
assert_equal(computed_mean, self.analytic_mean.detach().cpu().numpy()[0], prec=0.05)

def test_support_non_vectorized(self):
s = dist.Categorical(self.d_ps[0].squeeze(0)).enumerate_support()
Expand Down
10 changes: 5 additions & 5 deletions tests/distributions/test_delta.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,21 +26,21 @@ def setUp(self):
self.n_samples = 10

def test_log_pdf(self):
log_px_torch = dist.Delta(self.v).log_prob(self.test_data).sum().data[0]
assert_equal(log_px_torch, 0)
log_px_torch = dist.Delta(self.v).log_prob(self.test_data).sum()
assert_equal(log_px_torch.item(), 0)

def test_batch_log_prob(self):
log_px_torch = dist.Delta(self.vs_expanded).log_prob(self.batch_test_data_1).data
assert_equal(torch.sum(log_px_torch), 0)
assert_equal(log_px_torch.sum().item(), 0)
log_px_torch = dist.Delta(self.vs_expanded).log_prob(self.batch_test_data_2).data
assert_equal(torch.sum(log_px_torch), float('-inf'))
assert_equal(log_px_torch.sum().item(), float('-inf'))

def test_batch_log_prob_shape(self):
assert dist.Delta(self.vs).log_prob(self.batch_test_data_3).size() == (4, 1)
assert dist.Delta(self.v).log_prob(self.batch_test_data_3).size() == (4, 1)

def test_mean_and_var(self):
torch_samples = [dist.Delta(self.v).sample().data.cpu().numpy()
torch_samples = [dist.Delta(self.v).sample().detach().cpu().numpy()
for _ in range(self.n_samples)]
torch_mean = np.mean(torch_samples)
torch_var = np.var(torch_samples)
Expand Down
2 changes: 1 addition & 1 deletion tests/distributions/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@


def _unwrap_variable(x):
return x.data.cpu().numpy()
return x.detach().cpu().numpy()


def _log_prob_shape(dist, x_size=torch.Size()):
Expand Down
2 changes: 1 addition & 1 deletion tests/distributions/test_iaf.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def nonzero(x):
epsilon_vector = torch.zeros(1, input_dim)
epsilon_vector[0, j] = self.epsilon
delta = (arn(x + Variable(epsilon_vector)) - arn(x)) / self.epsilon
jacobian[j, k] = float(delta[0, k + output_index * input_dim].data.cpu().numpy()[0])
jacobian[j, k] = float(delta[0, k + output_index * input_dim])

permutation = arn.get_permutation()
permuted_jacobian = jacobian.clone()
Expand Down
2 changes: 1 addition & 1 deletion tests/distributions/test_one_hot_categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,4 +102,4 @@ def test_batch_log_dims(dim, ps):
ps = modify_params_using_dims(ps, dim)
support = dist.OneHotCategorical(ps).enumerate_support()
log_prob = dist.OneHotCategorical(ps).log_prob(support)
assert_equal(log_prob.data.cpu().numpy(), expected_log_pdf)
assert_equal(log_prob.detach().cpu().numpy(), expected_log_pdf)

0 comments on commit a72f9ac

Please sign in to comment.