Skip to content

Commit

Permalink
make it easier to switch dtypes
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelstanton committed Oct 21, 2022
1 parent a8dbc8a commit 67b59a6
Show file tree
Hide file tree
Showing 6 changed files with 24 additions and 21 deletions.
3 changes: 2 additions & 1 deletion lambo/acquisitions/monte_carlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,10 @@ def __init__(
):
model.eval()
mocked_features = model.get_features(X_baseline, model.bs)
ref_point = ref_point.to(mocked_features)
# for string kernels
if mocked_features.ndim > 2:
mocked_features = mocked_features[..., 0].to(ref_point) # doint let this fail
mocked_features = mocked_features[..., 0] # don't let this fail

super().__init__(
model=model,
Expand Down
5 changes: 3 additions & 2 deletions lambo/models/base_surrogate.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

class BaseSurrogate(torch.nn.Module):
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
dtype = torch.float

def _set_transforms(self, tokenizer, max_shift, mask_size, train_prepend=None):
# convert from string to LongTensor of token indexes
Expand All @@ -36,9 +37,9 @@ def _set_transforms(self, tokenizer, max_shift, mask_size, train_prepend=None):

def _get_datasets(self, X_train, X_test, Y_train, Y_test):
if isinstance(Y_train, np.ndarray):
Y_train = torch.from_numpy(Y_train).float()
Y_train = torch.from_numpy(Y_train).to(self.dtype)
if isinstance(Y_test, np.ndarray):
Y_test = torch.from_numpy(Y_test).float()
Y_test = torch.from_numpy(Y_test).to(self.dtype)

train_dataset = dataset_util.TransformTensorDataset(
[X_train, Y_train], self.train_transform
Expand Down
6 changes: 3 additions & 3 deletions lambo/models/deep_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,9 @@ def __init__(self, tokenizer, model, model_kwargs, lr, bs, weight_decay,
def fit(self, X_train, Y_train, X_val, Y_val, X_test, Y_test, reset=False, log_prefix="deep_ens", **kwargs):
super().fit(X_train, Y_train)
if isinstance(Y_train, np.ndarray):
Y_train = torch.from_numpy(Y_train).float()
Y_train = torch.from_numpy(Y_train).to(self.dtype)
if isinstance(Y_test, np.ndarray):
Y_test = torch.from_numpy(Y_test).float()
Y_test = torch.from_numpy(Y_test).to(self.dtype)

print(f'{X_train.shape[0]} train, {X_val.shape[0]} val, {X_test.shape[0]} test')

Expand Down Expand Up @@ -135,7 +135,7 @@ def forward(self, X, Y=None, bs=128, num_samples=None, **kwargs):

if Y is not None:
if isinstance(Y, np.ndarray):
Y = torch.from_numpy(Y).float()
Y = torch.from_numpy(Y).to(self.dtype)

dataset = dataset_util.TransformTensorDataset([X, Y], self.test_transform)
loader = torch.utils.data.DataLoader(
Expand Down
24 changes: 12 additions & 12 deletions lambo/models/gp_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,14 +218,14 @@ def __init__(self, feature_dim, out_dim, encoder, likelihood=None, covar_module=
BaseGPSurrogate.__init__(self, encoder=encoder, *args, **kwargs)

# initialize GP
dummy_X = torch.randn(2, feature_dim).to(self.device)
dummy_Y = torch.randn(2, out_dim).to(self.device)
covar_module = covar_module if covar_module is None else covar_module.to(self.device)
dummy_X = torch.randn(2, feature_dim).to(self.device, self.dtype)
dummy_Y = torch.randn(2, out_dim).to(self.device, self.dtype)
covar_module = covar_module if covar_module is None else covar_module.to(self.device, self.dtype)
SingleTaskGP.__init__(
self, dummy_X, dummy_Y, likelihood, covar_module, outcome_transform, input_transform
)
self.likelihood.initialize(noise=self.task_noise_init)
self.encoder = encoder.to(self.device)
self.encoder = encoder.to(self.device, self.dtype)

def clear_cache(self):
self.train()
Expand Down Expand Up @@ -273,15 +273,15 @@ def __init__(self, feature_dim, out_dim, encoder, likelihood=None, covar_module=
BaseGPSurrogate.__init__(self, encoder=encoder, *args, **kwargs)

# initialize GP
dummy_X = torch.randn(2, feature_dim).to(self.device)
dummy_Y = torch.randn(2, out_dim).to(self.device)
covar_module = covar_module if covar_module is None else covar_module.to(self.device)
dummy_X = torch.randn(2, feature_dim).to(self.device, self.dtype)
dummy_Y = torch.randn(2, out_dim).to(self.device, self.dtype)
covar_module = covar_module if covar_module is None else covar_module.to(self.device, self.dtype)
KroneckerMultiTaskGP.__init__(
self, dummy_X, dummy_Y, likelihood, covar_module=covar_module, outcome_transform=outcome_transform,
input_transform=input_transform, *args, **kwargs
)
self.likelihood.initialize(task_noises=self.task_noise_init)
self.encoder = encoder.to(self.device)
self.encoder = encoder.to(self.device, self.dtype)

def forward(self, X):
features = self.get_features(X, self.bs) if isinstance(X, np.ndarray) else X
Expand Down Expand Up @@ -351,15 +351,15 @@ def __init__(self, feature_dim, out_dim, num_inducing_points, encoder, noise_con
likelihood.initialize(task_noises=self.task_noise_init)

# initialize GP
dummy_X = 2 * (torch.rand(num_inducing_points, feature_dim).to(self.device) - 0.5)
dummy_Y = torch.randn(num_inducing_points, out_dim).to(self.device)
covar_module = covar_module if covar_module is None else covar_module.to(self.device)
dummy_X = 2 * (torch.rand(num_inducing_points, feature_dim).to(self.device, self.dtype) - 0.5)
dummy_Y = torch.randn(num_inducing_points, out_dim).to(self.device, self.dtype)
covar_module = covar_module if covar_module is None else covar_module.to(self.device, self.dtype)

self.base_cls = SingleTaskVariationalGP
self.base_cls.__init__(self, dummy_X, dummy_Y, likelihood, out_dim, learn_inducing_points,
covar_module=covar_module, inducing_points=dummy_X,
outcome_transform=outcome_transform, input_transform=input_transform)
self.encoder = encoder.to(self.device)
self.encoder = encoder.to(self.device, self.dtype)
self.mll_beta = mll_beta

def clear_cache(self):
Expand Down
4 changes: 2 additions & 2 deletions lambo/models/shared_elements.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,10 @@ def forward(self, x):


def pool_features(tokens, token_features, ignore_idxs):
mask = torch.ones_like(tokens).float()
mask = torch.ones_like(tokens, dtype=torch.float)
for idx in ignore_idxs:
mask *= tokens.ne(idx)
mask = mask.unsqueeze(-1).float()
mask = mask.unsqueeze(-1).to(token_features)
pooled_features = (mask * token_features).sum(-2) / (mask.sum(-2) + 1e-6)

return pooled_features
Expand Down
3 changes: 2 additions & 1 deletion lambo/optimizers/lambo.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,8 @@ def optimize(self, candidate_pool, pool_targets, all_seqs, all_targets, log_pref

# initialize latent token-choice decision variables
opt_params = torch.empty(
*opt_features.shape, requires_grad=self.optimize_latent, device=self.surrogate_model.device
*opt_features.shape, requires_grad=self.optimize_latent, device=self.surrogate_model.device,
dtype=self.surrogate_model.dtype
)
opt_params.copy_(opt_features)

Expand Down

0 comments on commit 67b59a6

Please sign in to comment.