Skip to content

Commit

Permalink
Merge pull request #265 from theislab/memfix
Browse files Browse the repository at this point in the history
Fix memory leak and forest classification for 2 cell types case
  • Loading branch information
LouisK92 committed May 10, 2023
2 parents a6f76d2 + fdaf25d commit 4fbf068
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 3 deletions.
11 changes: 11 additions & 0 deletions spapros/evaluation/evaluation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import gc
import pickle
import warnings
from enum import Enum
Expand Down Expand Up @@ -1912,6 +1913,11 @@ def single_forest_classifications(
ct_trees[ct].append(ct_trees_i[ct])
if progress and verbose:
progress.advance(forest_task)
# garbage collection
del X_train
del y_train
del cts_train
gc.collect()
# Get feature importances
importances = {
ct: pd.DataFrame(index=a.var.index, columns=[str(i) for i in range(n_trees)], dtype="float64")
Expand Down Expand Up @@ -1939,6 +1945,11 @@ def single_forest_classifications(
cts_test=cts_test,
masks=masks_test,
)
#garbage collection
del X_test
del y_test
del cts_test
gc.collect()

# Sort results
if sort_by_tree_performance:
Expand Down
5 changes: 2 additions & 3 deletions spapros/evaluation/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -1000,18 +1000,17 @@ def xgboost_forest_classification(
n_classes = len(np.unique(train_y))
clf = XGBClassifier(
max_depth=max_depth,
num_class=n_classes,
num_class=n_classes if n_classes > 2 else None,
n_estimators=250,
objective="multi:softmax" if n_classes > 2 else "binary:logistic",
early_stopping_rounds=5,
eval_metric="mlogloss",
eval_metric="mlogloss" if n_classes > 2 else "logloss",
learning_rate=lr,
colsample_bytree=colsample_bytree,
min_child_weight=min_child_weight,
gamma=gamma,
booster="gbtree", # TODO: compare with 'dart',rate_drop= 0.1
random_state=seed,
use_label_encoder=False, # To get rid of deprecation warning we convert labels into ints
n_jobs=n_jobs,
)
clf.fit(
Expand Down

0 comments on commit 4fbf068

Please sign in to comment.