Skip to content

Commit

Permalink
FIX assert in EM algorithm to handle case where loss converges
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelchughes committed Apr 1, 2020
1 parent 222e9ee commit 6817272
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 5 deletions.
2 changes: 1 addition & 1 deletion cp3/src/GMM_PenalizedMLEstimator_EM.py
Expand Up @@ -300,7 +300,7 @@ def fit(self, x_ND, x_valid_ND=None, verbose=True):
# Verify the loss goes down after the M step
loss_m = self.calc_EM_loss(r_NK, x_ND)
self.history['train_loss_em'].append(loss_m)
## TODO this should pass: assert loss_m + 1e-9 <= loss_e
## TODO this should pass: assert loss_m <= loss_e + 1e-9 (UPDATED 2020-04-01)

if verbose:
print("iter %4d / %4d after %9.1f sec | train loss % 9.6f %s" % (
Expand Down
7 changes: 3 additions & 4 deletions cp3/src/GMM_PenalizedMLEstimator_LBFGS.py
Expand Up @@ -526,8 +526,6 @@ def callback_update_history(cur_param_vec):
optimal_param_vec = result.x
self.log_pi_K, self.mu_KD, self.stddev_KD = self.to_common_parameters_from_flat_array(optimal_param_vec)



def write_history_to_csv(self, csv_path):
''' Write history of training to comma separated value (CSV) file
Expand All @@ -545,6 +543,7 @@ def write_history_to_csv(self, csv_path):
None.
'''
df = pd.DataFrame()
for key in self.history:
df[key] = self.history[key]
cur_list = self.history[key]
if df.shape[0] == 0 or df.shape[0] == len(cur_list):
df[key] = cur_list
df.to_csv(csv_path, index=False)

0 comments on commit 6817272

Please sign in to comment.