Skip to content

Commit

Permalink
clean up the cuda-ification of pytorch things
Browse files Browse the repository at this point in the history
  • Loading branch information
nfoti committed Feb 6, 2018
1 parent 367e10c commit 8ca495d
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 59 deletions.
8 changes: 3 additions & 5 deletions lib/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,17 @@
# )

class Normal(object):
def __init__(self, mu, sigma, use_cuda=False):
def __init__(self, mu, sigma):
assert mu.size() == sigma.size()
self.mu = mu
self.sigma = sigma
self.dtype = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor

def size(self, *args, **kwargs):
return self.mu.size(*args, **kwargs)

def sample(self):
return self.mu + self.sigma * Variable(torch.randn(self.mu.size()).type_as(self.mu.data))
eps = torch.randn(self.mu.size()).type_as(self.mu.data)
return self.mu + self.sigma * Variable(eps)

def logprob(self, x):
return torch.sum(
Expand Down Expand Up @@ -93,5 +93,3 @@ def DistributionCat(distributions, dim=-1):
)
else:
return _DistributionCat(distributions, dim=-1)

# TODO DistributionExpand
21 changes: 11 additions & 10 deletions lib/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ def parameters(self):
)

def cuda(self):
self.mu_net.cuda()
self.sigma_net.cuda()
self.mu_net.cuda()
self.sigma_net.cuda()


class FirstLayerSparseDecoder(object):
Expand Down Expand Up @@ -75,21 +75,18 @@ class BayesianGroupLassoGenerator(object):
generator is assumed to output a distribution as opposed to a Tensor in the
`FirstLayerSparseDecoder` model."""

def __init__(self, group_generators, group_input_dim, dim_z, use_cuda=False):
def __init__(self, group_generators, group_input_dim, dim_z):
self.group_generators = group_generators
self.group_input_dim = group_input_dim
self.dim_z = dim_z
self.num_groups = len(group_generators)

# Starting this off with reasonably large values is helpful so that proximal
# gradient descent doesn't prematurely kill them.
Ws_tnsr = torch.randn(self.num_groups, self.dim_z, self.group_input_dim)
if use_cuda:
Ws_tnsr = Ws_tnsr.cuda()
for gen in self.group_generators:
gen.cuda()

self.Ws = Variable(Ws_tnsr, requires_grad=True)
self.Ws = Variable(
torch.randn(self.num_groups, self.dim_z, self.group_input_dim),
requires_grad=True
)

def __call__(self, z):
return DistributionCat(
Expand All @@ -116,3 +113,7 @@ def proximal_step(self, t):
def group_lasso_penalty(self):
return torch.sum(torch.sqrt(torch.sum(torch.pow(self.Ws, 2), dim=2)))

def cuda(self):
self.Ws = Variable(self.Ws.data.cuda(), requires_grad=True)
for gen in self.group_generators:
gen.cuda()
85 changes: 41 additions & 44 deletions meg_oi_vae.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

import pickle as pkl
import mne
import numpy as np
Expand Down Expand Up @@ -45,7 +44,7 @@
cond=cond)
# compute inverse solutions
print("applying inverse operators")
ep_inv = apply_inverse_epochs(epochs, inv, lambda2=1/9, method='MNE')
ep_inv = apply_inverse_epochs(epochs, inv, lambda2=1.0 / 9, method='MNE')
ep_inv_ndarray = np.array([np.ascontiguousarray(ep.data.T) for ep in ep_inv])

# get vertex indices for each label
Expand All @@ -59,26 +58,25 @@
offsets = {'lh': 0, 'rh': n_lhverts}

for li, lab in enumerate(labels):

if isinstance(lab, mne.Label):
comp_labs = [lab]
elif isinstance(lab, mne.BiHemiLabel):
comp_labs = [lab.lh, lab.rh]

for clab in comp_labs:
hemi = clab.hemi
hi = 0 if hemi == 'lh' else 1

lverts = clab.get_vertices_used(vertices=src[hi]['vertno'])

# gets the indices in the source space vertex array, not the huge
# array.
# use `src[hi]['vertno'][lverts]` to get surface vertex indices to
# plot.
lverts = np.searchsorted(src[hi]['vertno'], lverts)
lverts += offsets[hemi]
vertidx.extend(lverts)
roiidx.extend(li*np.ones(lverts.size, dtype=np.int))
if isinstance(lab, mne.Label):
comp_labs = [lab]
elif isinstance(lab, mne.BiHemiLabel):
comp_labs = [lab.lh, lab.rh]

for clab in comp_labs:
hemi = clab.hemi
hi = 0 if hemi == 'lh' else 1

lverts = clab.get_vertices_used(vertices=src[hi]['vertno'])

# gets the indices in the source space vertex array, not the huge
# array.
# use `src[hi]['vertno'][lverts]` to get surface vertex indices to
# plot.
lverts = np.searchsorted(src[hi]['vertno'], lverts)
lverts += offsets[hemi]
vertidx.extend(lverts)
roiidx.extend(li*np.ones(lverts.size, dtype=np.int))

num_labels = len(labels)
M = n_verts
Expand Down Expand Up @@ -121,20 +119,22 @@
group_input_dim = 1

prior_theta_scale = 1.
lam = 0.
lam = 1
lam_adjustment = 1.

num_epochs = 100
num_epochs = 200
mc_samples = 1
batch_size = 1024
batch_size = 4096

ep_inv_tnsr = torch.from_numpy(ep_inv_stacked)

dataloader = torch.utils.data.DataLoader(
torch.utils.data.TensorDataset(ep_inv_tnsr.cuda(),
torch.zeros(ep_inv_tnsr.size(0)).cuda()),
batch_size=batch_size,
shuffle=True
torch.utils.data.TensorDataset(
ep_inv_tnsr.cuda(),
torch.zeros(ep_inv_tnsr.size(0))
),
batch_size=batch_size,
shuffle=True
)

# This value adjusts the impact of our learned variances in the sigma_net of
Expand All @@ -144,7 +144,6 @@
# some number of iterations.
stddev_multiple = 0.1

print('making neural nets')
inference_net = NormalNet(
mu_net=torch.nn.Sequential(
# inference_net_base,
Expand All @@ -168,10 +167,6 @@
)
)

if use_cuda:
inference_net.cuda()


def make_group_generator(group_output_dim):
# Note that this Variable is NOT going to show up in `net.parameters()` and
# therefore it is implicitly free from the ridge penalty/p(theta) prior.
Expand All @@ -190,26 +185,29 @@ def make_group_generator(group_output_dim):
generative_net = BayesianGroupLassoGenerator(
group_generators=[make_group_generator(gs) for gs in group_output_dims],
group_input_dim=group_input_dim,
dim_z=dim_z, use_cuda=use_cuda
dim_z=dim_z
)

mu = torch.zeros(1, dim_z)
sigma = torch.ones(1, dim_z)
if use_cuda:
mu = mu.cuda()
sigma = sigma.cuda()
prior_z = Normal(
Variable(torch.zeros(1, dim_z)),
Variable(torch.ones(1, dim_z))
)

prior_z = Normal(Variable(mu), Variable(sigma), use_cuda=True)
if use_cuda:
inference_net.cuda()
generative_net.cuda()
prior_z.mu = prior_z.mu.cuda()
prior_z.sigma = prior_z.sigma.cuda()

lr = 1e-5
lr = 1e-3
optimizer = torch.optim.Adam([
{'params': inference_net.parameters(), 'lr': lr},
# {'params': [inference_net_log_stddev], 'lr': lr},
{'params': generative_net.group_generators_parameters(), 'lr': lr},
{'params': [gen.sigma_net.extra_args[0] for gen in generative_net.group_generators], 'lr': lr}
])

Ws_lr = 1e-5
Ws_lr = 1e-6
optimizer_Ws = torch.optim.SGD([
{'params': [generative_net.Ws], 'lr': Ws_lr, 'momentum': 0}
])
Expand All @@ -228,7 +226,6 @@ def make_group_generator(group_output_dim):
elbo_per_iter = []
iteration = 0
for epoch in range(num_epochs):

for Xbatch, _ in dataloader:
if iteration > 1000:
stddev_multiple = 1
Expand Down

0 comments on commit 8ca495d

Please sign in to comment.