Skip to content

Commit 07ee41c

Browse files
authored
Adding a way to clear GPU memory (#722)
1 parent 22d16cf commit 07ee41c

File tree

4 files changed

+65
-0
lines changed

4 files changed

+65
-0
lines changed
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
-- src/api.rs:599
2+
-- pgml::api::clear_gpu_cache
3+
CREATE FUNCTION pgml."clear_gpu_cache"(
4+
"memory_usage" REAL DEFAULT NULL /* Option<f32> */
5+
) RETURNS bool /* bool */
6+
IMMUTABLE STRICT PARALLEL SAFE
7+
LANGUAGE c /* Rust */
8+
AS 'MODULE_PATHNAME', 'clear_gpu_cache_wrapper';

pgml-extension/src/api.rs

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -580,6 +580,29 @@ pub fn embed_batch(
580580
crate::bindings::transformers::embed(transformer, inputs, &kwargs.0)
581581
}
582582

583+
584+
/// Clears the GPU cache.
585+
///
586+
/// # Arguments
587+
///
588+
/// * `memory_usage` - Optional parameter indicating the memory usage percentage (0.0 -> 1.0)
589+
///
590+
/// # Returns
591+
///
592+
/// Returns `true` if the GPU cache was successfully cleared, `false` otherwise.
593+
/// # Example
594+
///
595+
/// ```sql
596+
/// SELECT pgml.clear_gpu_cache(memory_usage => 0.5);
597+
/// ```
598+
#[pg_extern(immutable, parallel_safe, name = "clear_gpu_cache")]
599+
pub fn clear_gpu_cache(
600+
memory_usage: default!(Option<f32>, "NULL")
601+
) -> bool {
602+
let memory_usage: Option<f32> = memory_usage.map(|memory_usage| memory_usage.try_into().unwrap());
603+
crate::bindings::transformers::clear_gpu_cache(memory_usage)
604+
}
605+
583606
#[pg_extern(immutable, parallel_safe)]
584607
pub fn chunk(
585608
splitter: &str,

pgml-extension/src/bindings/transformers.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,17 @@ def embed(transformer, inputs, kwargs):
131131

132132
return model.encode(inputs, **kwargs)
133133

134+
def clear_gpu_cache(memory_usage: None):
135+
if not torch.cuda.is_available():
136+
raise PgMLException(f"No GPU availables")
137+
138+
139+
mem_used = torch.cuda.memory_usage()
140+
if not memory_usage or mem_used >= int(memory_usage * 100.0):
141+
torch.cuda.empty_cache()
142+
return True
143+
return False
144+
134145

135146
def load_dataset(name, subset, limit: None, kwargs: "{}"):
136147
kwargs = orjson.loads(kwargs)

pgml-extension/src/bindings/transformers.rs

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -311,3 +311,26 @@ pub fn load_dataset(
311311

312312
num_rows
313313
}
314+
315+
pub fn clear_gpu_cache(
316+
memory_usage: Option<f32>
317+
) -> bool {
318+
319+
Python::with_gil(|py| -> bool {
320+
let clear_gpu_cache: Py<PyAny> = PY_MODULE.getattr(py, "clear_gpu_cache").unwrap().into();
321+
clear_gpu_cache
322+
.call1(
323+
py,
324+
PyTuple::new(
325+
py,
326+
&[
327+
memory_usage.into_py(py),
328+
],
329+
),
330+
)
331+
.unwrap()
332+
.extract(py)
333+
.unwrap()
334+
})
335+
}
336+

0 commit comments

Comments
 (0)