Skip to content

Commit ab82656

Browse files
authored
Merge pull request #36 from reichlab/feat_importance
feature importance for one example model run
2 parents 4883b52 + 0126982 commit ab82656

File tree

4 files changed

+262261
-4
lines changed

4 files changed

+262261
-4
lines changed
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# This script executes one retrospective run of the main gbq_qr model
2+
# and saves information about feature importances.
3+
4+
# This script should be run with code/gbq as the working directory:
5+
# python retrospective-experiments/gbq_qr_feat_importance.py
6+
7+
import os
8+
9+
ref_date = '2024-01-06'
10+
output_root = '../../retrospective-hub/model-output'
11+
artifact_store_root = '../../retrospective-hub/model-artifacts'
12+
13+
command = f'python gbq.py --ref_date {ref_date} --output_root {output_root} --artifact_store_root {artifact_store_root} --save_feat_importance'
14+
15+
os.system(command)

code/gbq/run.py

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,12 @@ def run_gbq_flu_model(model_config, run_config):
6565
df_train, df_test, feat_names)
6666

6767
# save
68-
model_dir = run_config.output_root / f'UMass-{model_config.model_name}'
69-
model_dir.mkdir(parents=True, exist_ok=True)
70-
preds_df.to_csv(model_dir / f'{str(run_config.ref_date)}-UMass-{model_config.model_name}.csv', index=False)
68+
save_path = _build_save_path(
69+
root=run_config.output_root,
70+
run_config=run_config,
71+
model_config=model_config
72+
)
73+
preds_df.to_csv(save_path, index=False)
7174

7275

7376
def _train_gbq_and_predict(model_config, run_config,
@@ -181,6 +184,8 @@ def _get_test_quantile_predictions(model_config, run_config,
181184

182185
train_seasons = df_train['season'].unique()
183186

187+
feat_importance = list()
188+
184189
for b in tqdm(range(model_config.num_bags), 'Bag number'):
185190
# get indices of observations that are in bag
186191
bag_seasons = rng.choice(
@@ -197,10 +202,29 @@ def _get_test_quantile_predictions(model_config, run_config,
197202
alpha=q_level,
198203
random_state=lgb_seeds[b, q_ind])
199204
model.fit(X=x_train.loc[bag_obs_inds, :], y=y_train.loc[bag_obs_inds])
205+
206+
feat_importance.append(
207+
pd.DataFrame({
208+
'feat': x_train.columns,
209+
'importance': model.feature_importances_,
210+
'b': b,
211+
'q_level': q_level
212+
})
213+
)
200214

201215
# test set predictions
202216
test_preds_by_bag[:, b, q_ind] = model.predict(X=x_test)
203217

218+
# combine and save feature importance scores
219+
if run_config.save_feat_importance:
220+
feat_importance = pd.concat(feat_importance, axis=0)
221+
save_path = _build_save_path(
222+
root=run_config.artifact_store_root,
223+
run_config=run_config,
224+
model_config=model_config,
225+
subdir='feat_importance')
226+
feat_importance.to_csv(save_path, index=False)
227+
204228
# combined predictions across bags: median
205229
test_pred_qs = np.median(test_preds_by_bag, axis=1)
206230

@@ -247,3 +271,11 @@ def _quantile_noncrossing(preds_df, gcols):
247271
.reset_index()
248272

249273
return preds_df
274+
275+
276+
def _build_save_path(root, run_config, model_config, subdir=None):
277+
save_dir = root / f'UMass-{model_config.model_name}'
278+
if subdir is not None:
279+
save_dir = save_dir / subdir
280+
save_dir.mkdir(parents=True, exist_ok=True)
281+
return save_dir / f'{str(run_config.ref_date)}-UMass-{model_config.model_name}.csv'

code/gbq/utils.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,9 @@ def parse_args():
3232

3333
run_config = SimpleNamespace(
3434
ref_date=ref_date,
35-
output_root=args.output_root
35+
output_root=args.output_root,
36+
artifact_store_root=args.artifact_store_root,
37+
save_feat_importance=args.save_feat_importance
3638
)
3739

3840
if args.short_run:
@@ -79,6 +81,13 @@ def _make_parser():
7981
help='Path to a directory in which model outputs are saved',
8082
type=lambda s: Path(s),
8183
default=Path('../../submissions-hub/model-output'))
84+
parser.add_argument('--artifact_store_root',
85+
help='Path to a directory in which artifacts related to model runs are saved',
86+
type=lambda s: Path(s),
87+
default=Path('../../submissions-hub/model-artifacts'))
88+
parser.add_argument('--save_feat_importance',
89+
help='Flag to save feature importances',
90+
action='store_true')
8291

8392
return parser
8493

0 commit comments

Comments
 (0)