Skip to content

Commit

Permalink
Iterate through all of dataloader each epoch. Compute epoch energy as…
Browse files Browse the repository at this point in the history
… the average of the batch energies.
  • Loading branch information
belsten authored and belsten committed Apr 16, 2024
1 parent b881971 commit 0934e5c
Showing 1 changed file with 12 additions and 23 deletions.
35 changes: 12 additions & 23 deletions sparsecoding/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,29 +104,18 @@ def learn_dictionary(self, dataset, n_epoch, batch_size):
losses = []

dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
iterloader = iter(dataloader)
for i in range(n_epoch):
try:
batch = next(iterloader)
except StopIteration:
dataloader = DataLoader(dataset, batch_size=batch_size,
shuffle=True)
iterloader = iter(dataloader)
batch = next(iterloader)

# infer coefficients
a = self.inference_method.infer(batch, self.dictionary)

# update dictionary
self.update_dictionary(batch, a)

# normalize dictionary
self.normalize_dictionary()

# compute current loss
loss = self.compute_loss(batch, a)

losses.append(loss)
for _ in range(n_epoch):
loss = 0.0
for batch in dataloader:
# infer coefficients
a = self.inference_method.infer(batch, self.dictionary)
# update dictionary
self.update_dictionary(batch, a)
# normalize dictionary
self.normalize_dictionary()
# compute current loss
loss += self.compute_loss(batch, a)
losses.append(loss/len(dataloader))
return np.asarray(losses)

def compute_loss(self, data, a):
Expand Down

0 comments on commit 0934e5c

Please sign in to comment.