Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prevent corruption of memory cache #1136

Merged
merged 3 commits into from Oct 28, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/enforce-labels.yml
Expand Up @@ -7,7 +7,7 @@ jobs:
enforce-label:
runs-on: ubuntu-latest
steps:
- uses: yogevbd/enforce-label-action@2.1.0
- uses: yogevbd/enforce-label-action@2.2.1
with:
REQUIRED_LABELS_ANY: "pr:change,pr:deprecation,pr:fix,pr:new-feature,pr:removal"
REQUIRED_LABELS_ANY_DESCRIPTION: "Select at least one label with a 'pr:' prefix for this pull request"
Expand Down
12 changes: 3 additions & 9 deletions src/pymor/core/cache.py
Expand Up @@ -63,6 +63,7 @@

import atexit
from collections import OrderedDict
from copy import deepcopy
import functools
import getpass
import hashlib
Expand Down Expand Up @@ -150,22 +151,15 @@ def get(self, key):
if value is self.NO_VALUE:
return False, None
else:
from pymor.vectorarrays.interface import VectorArray
if isinstance(value, VectorArray):
value = value.copy()
return True, value
return True, deepcopy(value)

def set(self, key, value):
if key in self._cache:
getLogger('pymor.core.cache.MemoryRegion').warning('Key already present in cache region, ignoring.')
return
if len(self._cache) == self.max_keys:
self._cache.popitem(last=False)

import numpy as np
if isinstance(value, np.ndarray):
value.setflags(write=False)
self._cache[key] = value
self._cache[key] = deepcopy(value)

def clear(self):
self._cache = OrderedDict()
Expand Down
20 changes: 20 additions & 0 deletions src/pymortests/cache.py
Expand Up @@ -7,7 +7,11 @@
from uuid import uuid4
from datetime import datetime, timedelta

import numpy as np

from pymor.core import cache
from pymor.models.basic import StationaryModel
from pymor.operators.numpy import NumpyMatrixOperator
from pymortests.base import runmodule


Expand Down Expand Up @@ -86,5 +90,21 @@ def test_region_api():
assert backend.get('mykey') == (True, 1)


def test_memory_region_safety():

op = NumpyMatrixOperator(np.eye(1))
rhs = op.range.make_array(np.array([1]))
m = StationaryModel(op, rhs)
m.enable_caching('memory')

U = m.solve()
del U[:]
U = m.solve()
assert len(U) == 1
del U[:]
U = m.solve()
assert len(U) == 1


if __name__ == "__main__":
runmodule(filename=__file__)