@@ -65,9 +65,12 @@ def run_gbq_flu_model(model_config, run_config):
6565 df_train , df_test , feat_names )
6666
6767 # save
68- model_dir = run_config .output_root / f'UMass-{ model_config .model_name } '
69- model_dir .mkdir (parents = True , exist_ok = True )
70- preds_df .to_csv (model_dir / f'{ str (run_config .ref_date )} -UMass-{ model_config .model_name } .csv' , index = False )
68+ save_path = _build_save_path (
69+ root = run_config .output_root ,
70+ run_config = run_config ,
71+ model_config = model_config
72+ )
73+ preds_df .to_csv (save_path , index = False )
7174
7275
7376def _train_gbq_and_predict (model_config , run_config ,
@@ -181,6 +184,8 @@ def _get_test_quantile_predictions(model_config, run_config,
181184
182185 train_seasons = df_train ['season' ].unique ()
183186
187+ feat_importance = list ()
188+
184189 for b in tqdm (range (model_config .num_bags ), 'Bag number' ):
185190 # get indices of observations that are in bag
186191 bag_seasons = rng .choice (
@@ -197,10 +202,29 @@ def _get_test_quantile_predictions(model_config, run_config,
197202 alpha = q_level ,
198203 random_state = lgb_seeds [b , q_ind ])
199204 model .fit (X = x_train .loc [bag_obs_inds , :], y = y_train .loc [bag_obs_inds ])
205+
206+ feat_importance .append (
207+ pd .DataFrame ({
208+ 'feat' : x_train .columns ,
209+ 'importance' : model .feature_importances_ ,
210+ 'b' : b ,
211+ 'q_level' : q_level
212+ })
213+ )
200214
201215 # test set predictions
202216 test_preds_by_bag [:, b , q_ind ] = model .predict (X = x_test )
203217
218+ # combine and save feature importance scores
219+ if run_config .save_feat_importance :
220+ feat_importance = pd .concat (feat_importance , axis = 0 )
221+ save_path = _build_save_path (
222+ root = run_config .artifact_store_root ,
223+ run_config = run_config ,
224+ model_config = model_config ,
225+ subdir = 'feat_importance' )
226+ feat_importance .to_csv (save_path , index = False )
227+
204228 # combined predictions across bags: median
205229 test_pred_qs = np .median (test_preds_by_bag , axis = 1 )
206230
@@ -247,3 +271,11 @@ def _quantile_noncrossing(preds_df, gcols):
247271 .reset_index ()
248272
249273 return preds_df
274+
275+
276+ def _build_save_path (root , run_config , model_config , subdir = None ):
277+ save_dir = root / f'UMass-{ model_config .model_name } '
278+ if subdir is not None :
279+ save_dir = save_dir / subdir
280+ save_dir .mkdir (parents = True , exist_ok = True )
281+ return save_dir / f'{ str (run_config .ref_date )} -UMass-{ model_config .model_name } .csv'
0 commit comments