Skip to content

Commit

Permalink
fix test
Browse files Browse the repository at this point in the history
  • Loading branch information
patel-zeel committed Nov 24, 2023
1 parent 2dd2066 commit 414b105
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions astra/torch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ def train_fn(
((iter_losses, epoch_losses), state_dict_history)
"""

device = get_model_device(model)

def get_batch():
if input is None and output is None:
yield None, None
Expand All @@ -81,23 +83,21 @@ def get_batch():

for i in iterable:
if input is not None:
batch_input = input[idx[i : i + inner_batch_size]]
batch_input = input[idx[i : i + inner_batch_size]].to(device)
else:
batch_input = None

if output is not None:
batch_output = output[idx[i : i + inner_batch_size]]
batch_output = output[idx[i : i + inner_batch_size]].to(device)
else:
batch_output = None

yield batch_input, batch_output

device = get_model_device(model)

def one_step(batch_input, batch_output):
optimizer.zero_grad()
model_output = model(batch_input.to(device), **model_kwargs)
loss = loss_fn(model_output, batch_output.to(device), **loss_fn_kwargs)
model_output = model(batch_input, **model_kwargs)
loss = loss_fn(model_output, batch_output, **loss_fn_kwargs)
loss.backward()
optimizer.step()
return loss.item()
Expand Down

0 comments on commit 414b105

Please sign in to comment.