Skip to content
This repository
Browse code

Adds a cached results proxy, better performance

  • Loading branch information...
commit dd3234e450bbf97d0e77379f482bfa888273b727 1 parent 64e701b
Darren Dale authored December 17, 2011
97  praxes/dispatch/mptaskmanager.py
@@ -7,7 +7,7 @@
7 7
 import hashlib
8 8
 #import logging
9 9
 import multiprocessing
10  
-import threading
  10
+from threading import RLock
11 11
 import time
12 12
 
13 13
 import numpy as np
@@ -19,7 +19,22 @@
19 19
 DEBUG = False
20 20
 
21 21
 
22  
-class TaskManager(threading.Thread):
  22
+class QRLock(QtCore.QMutex):
  23
+
  24
+    def __init__(self):
  25
+        QtCore.QMutex.__init__(self, QtCore.QMutex.Recursive)
  26
+
  27
+    def __enter__(self):
  28
+        self.lock()
  29
+        return self
  30
+
  31
+    def __exit__(self, type, value, traceback):
  32
+        self.unlock()
  33
+
  34
+
  35
+class TaskManager(QtCore.QThread):
  36
+
  37
+    progress_report = QtCore.pyqtSignal(dict)
23 38
 
24 39
     @property
25 40
     def job_server(self):
@@ -49,15 +64,6 @@ def n_processed(self, val):
49 64
             self._n_processed = copy.copy(val)
50 65
 
51 66
     @property
52  
-    def n_submitted(self):
53  
-        with self.lock:
54  
-            return copy.copy(self._n_submitted)
55  
-    @n_submitted.setter
56  
-    def n_submitted(self, val):
57  
-        with self.lock:
58  
-            self._n_submitted = copy.copy(val)
59  
-
60  
-    @property
61 67
     def scan(self):
62 68
         return self._scan
63 69
 
@@ -70,18 +76,19 @@ def stopped(self, val):
70 76
         with self.lock:
71 77
             self.__stopped = copy.copy(val)
72 78
 
73  
-    def __init__(self, scan, progress_queue, **kwargs):
  79
+    def __init__(self, scan, results, **kwargs):
74 80
         super(TaskManager, self).__init__()
75 81
 
76  
-        self.__lock = threading.RLock()
  82
+        self.__lock = QRLock()
77 83
 
78 84
         self._scan = scan
79 85
         self._n_points = scan.entry.npoints
80 86
         self._n_cpus = kwargs.get(
81 87
             'n_local_processes', multiprocessing.cpu_count()
82 88
             )
  89
+        self._available_workers = 3 * self._n_cpus
83 90
 
84  
-        self.progress_queue = progress_queue
  91
+        self._results = results
85 92
         self.job_queue = []
86 93
 
87 94
         self._job_server = self.create_pool()
@@ -90,7 +97,6 @@ def __init__(self, scan, progress_queue, **kwargs):
90 97
 
91 98
         self._next_index = 0
92 99
         self._n_processed = 0
93  
-        self._n_submitted = 0
94 100
 
95 101
     def __iter__(self):
96 102
         return self
@@ -108,53 +114,44 @@ def next(self):
108 114
     def create_pool(self):
109 115
         raise NotImplementedError()
110 116
 
111  
-    def flush(self):
112  
-        while True:
113  
-            try:
114  
-                job = self.job_queue.pop(0)
115  
-                job.wait()
116  
-                #if res is not None:
117  
-                #    self.update_records(res)
118  
-            except IndexError:
119  
-                break
120  
-        self.n_submitted = 0
121  
-        self.queue_results()
  117
+    def _data_processed(self, data):
  118
+        with self.lock:
  119
+            self._available_workers += 1
  120
+
  121
+        self.update_records(data)
  122
+
  123
+        stats = {'n_processed': self.n_processed}
  124
+        self.progress_report.emit(stats)
122 125
 
123 126
     def process_data(self):
124  
-        for item in self:
125  
-            if self.stopped:
  127
+        while not self.stopped:
  128
+            with self.lock:
  129
+                if self._available_workers:
  130
+                    self._available_workers -= 1
  131
+                else:
  132
+                    time.sleep(0.01)
  133
+                    continue
  134
+
  135
+            try:
  136
+                item = self.next()
  137
+            except StopIteration:
  138
+                self.job_server.close()
  139
+                self.job_server.join()
126 140
                 break
127 141
 
128 142
             if item is None:
129 143
                 # next data point is not yet available
130  
-                if self.n_submitted > 0:
131  
-                    self.flush()
132  
-                else:
133  
-                    time.sleep(0.1)
  144
+                time.sleep(0.1)
134 145
                 continue
135 146
 
136  
-            if item:
  147
+            if item: # could be zero if masked
137 148
                 f, args = item
138  
-                job = self.job_server.apply_async(
139  
-                    f, args, callback=self.update_records
  149
+                self.job_server.apply_async(
  150
+                    f, args, callback=self._data_processed
140 151
                     )
141  
-                self.job_queue.append(job)
  152
+                #self.job_queue.append(job)
142 153
 
143 154
             self.n_processed += 1
144  
-            self.n_submitted += 1
145  
-
146  
-            if self.n_submitted >= self.n_cpus*3:
147  
-                self.flush()
148  
-
149  
-        self.job_server.close()
150  
-        self.job_server.join()
151  
-        if self.n_submitted > 0:
152  
-            self.flush()
153  
-
154  
-    def queue_results(self):
155  
-        stats = {}
156  
-        stats['n_processed'] = self.n_processed
157  
-        self.progress_queue.put(stats)
158 155
 
159 156
     def run(self):
160 157
         self.process_data()
81  praxes/fluorescence/mcaanalysiswindow.py
@@ -17,9 +17,9 @@
17 17
 from praxes.frontend.analysiswindow import AnalysisWindow
18 18
 from .ui.ui_mcaanalysiswindow import Ui_McaAnalysisWindow
19 19
 from .elementsview import ElementsView
  20
+from .results import XRFMapResultProxy
20 21
 from praxes.io import phynx
21 22
 
22  
-
23 23
 #logger = logging.getLogger(__file__)
24 24
 
25 25
 
@@ -52,6 +52,8 @@ def __init__(self, scan_data, parent=None):
52 52
                         % scan_data.__class__.__name__
53 53
                     )
54 54
             self._n_points = scan_data.entry.npoints
  55
+            self._dirty = False
  56
+            self._results = XRFMapResultProxy(self.scan_data)
55 57
 
56 58
             pymcaConfig = self.scan_data.pymca_config
57 59
             self.setupUi(self)
@@ -108,7 +110,7 @@ def __init__(self, scan_data, parent=None):
108 110
 
109 111
             self.progress_queue = Queue.Queue()
110 112
             self.timer = QtCore.QTimer(self)
111  
-            self.timer.timeout.connect(self.update)
  113
+            self.timer.timeout.connect(self.elementMapUpdated)
112 114
 
113 115
             self.elementsView.updateFigure()
114 116
 
@@ -230,7 +232,7 @@ def closeEvent(self, event):
230 232
             if res == QtGui.QMessageBox.Yes:
231 233
                 if self.analysisThread:
232 234
                     self.analysisThread.stop()
233  
-                    self.analysisThread.join()
  235
+                    self.analysisThread.wait()
234 236
                     #QtGui.qApp.processEvents()
235 237
             else:
236 238
                 event.ignore()
@@ -251,7 +253,9 @@ def configurePymca(self):
251 253
             self.statusbar.clearMessage()
252 254
 
253 255
     def elementMapUpdated(self):
254  
-        self.elementsView.updateFigure(self.getElementMap())
  256
+        if self._dirty:
  257
+            self._dirty = False
  258
+            self.elementsView.updateFigure(self.getElementMap())
255 259
         #QtGui.qApp.processEvents()
256 260
 
257 261
     def getElementMap(self, mapType=None, element=None):
@@ -259,13 +263,10 @@ def getElementMap(self, mapType=None, element=None):
259 263
         if mapType is None: mapType = self.mapType
260 264
 
261 265
         if mapType and element:
262  
-            with self.scan_data:
263  
-                try:
264  
-                    entry = '%s_%s'%(element, mapType)
265  
-                    return self.scan_data['element_maps'][entry].map
266  
-                except KeyError:
267  
-                    return np.zeros(self.scan_data.entry.acquisition_shape)
268  
-
  266
+            try:
  267
+                return self._results.get(element, mapType)
  268
+            except KeyError:
  269
+                return np.zeros(self.scan_data.entry.acquisition_shape)
269 270
         else:
270 271
             with self.scan_data:
271 272
                 return np.zeros(
@@ -273,27 +274,12 @@ def getElementMap(self, mapType=None, element=None):
273 274
                     )
274 275
 
275 276
     def initializeElementMaps(self, elements):
276  
-        with self.scan_data:
277  
-            if 'element_maps' in self.scan_data.entry.measurement:
278  
-                del self.scan_data['element_maps']
279  
-
280  
-            elementMaps = self.scan_data.create_group(
281  
-                'element_maps', type='ElementMaps'
  277
+        self._results = XRFMapResultProxy(
  278
+            self.scan_data,
  279
+            elements,
  280
+            self.scan_data.entry.npoints
282 281
             )
283 282
 
284  
-            for mapType, cls in [
285  
-                ('fit', 'Fit'),
286  
-                ('fit_error', 'FitError'),
287  
-                ('mass_fraction', 'MassFraction')
288  
-            ]:
289  
-                for element in elements:
290  
-                    entry = '%s_%s'%(element, mapType)
291  
-                    elementMaps.create_dataset(
292  
-                        entry,
293  
-                        type=cls,
294  
-                        data=np.zeros(self.scan_data.entry.npoints, 'f')
295  
-                    )
296  
-
297 283
     def processAverageSpectrum(self, indices=None):
298 284
         with self.scan_data:
299 285
             if indices is None:
@@ -353,6 +339,7 @@ def processComplete(self):
353 339
 
354 340
         self.setMenuToolsActionsEnabled(True)
355 341
         self.elementMapUpdated()
  342
+        self._results.flush()
356 343
 
357 344
     def processData(self):
358 345
         if sys.platform.startswith('win'):
@@ -368,20 +355,18 @@ def processData(self):
368 355
         settings.beginGroup('JobServers')
369 356
         n_local_processes, ok = settings.value(
370 357
             'LocalProcesses', QtCore.QVariant(1)
371  
-        ).toInt()
  358
+            ).toInt()
372 359
 
373 360
         thread = XfsTaskManager(
374 361
             self.scan_data,
375 362
             copy.deepcopy(self.pymcaConfig),
376  
-            self.progress_queue,
  363
+            self._results,
377 364
             n_local_processes=n_local_processes
378  
-        )
  365
+            )
379 366
 
380  
-#        self.thread.dataProcessed.connect(self.elementMapUpdated)
381  
-#        self.thread.jobStats.connect(self.jobStats.updateTable)
382  
-#        self.thread.finished.connect(self.processComplete)
383  
-#        self.thread.percentComplete.connect(self.progressBar.setValue)
384  
-        self.actionAbort.triggered.connect(self.abort) #thread.stop
  367
+        thread.progress_report.connect(self.update)
  368
+        thread.finished.connect(self.processComplete)
  369
+        self.actionAbort.triggered.connect(thread.stop) #thread.stop
385 370
 
386 371
         self.statusbar.showMessage('Analyzing spectra ...')
387 372
         self.statusbar.addPermanentWidget(self.progressBar)
@@ -392,24 +377,10 @@ def processData(self):
392 377
         thread.start()
393 378
         self.timer.start(1000)
394 379
 
395  
-    def abort(self):
396  
-        if self.analysisThread is not None:
397  
-            self.analysisThread.stop()
398  
-        self.processComplete()
399  
-
400  
-    def update(self):
401  
-        item = None
402  
-        while True:
403  
-            try:
404  
-                item = self.progress_queue.get(False)
405  
-            except Queue.Empty:
406  
-                break
407  
-        if item is None:
408  
-            return
409  
-
410  
-        self.elementMapUpdated()
  380
+    def update(self, report):
  381
+        self._dirty = True
411 382
 
412  
-        n_processed = item.pop('n_processed')
  383
+        n_processed = report.pop('n_processed')
413 384
         with self.scan_data:
414 385
             n_points = self.scan_data.entry.npoints
415 386
         progress = int((100.0 * n_processed) / n_points)
22  praxes/fluorescence/mptaskmanager.py
@@ -91,10 +91,10 @@ def mass_fraction_tool(self, val):
91 91
         with self.lock:
92 92
             self._mass_fraction_tool = val
93 93
 
94  
-    def __init__(self, scan, config, progress_queue, **kwargs):
  94
+    def __init__(self, scan, config, results, **kwargs):
95 95
         # needs to be set before the super call:
96 96
         self._config = config
97  
-        super(XfsTaskManager, self).__init__(scan, progress_queue, **kwargs)
  97
+        super(XfsTaskManager, self).__init__(scan, results, **kwargs)
98 98
 
99 99
         with scan:
100 100
             self._measurement = scan.entry.measurement
@@ -149,18 +149,6 @@ def next(self):
149 149
 
150 150
         return analyze_spectrum, (i, spectrum, monitor)
151 151
 
152  
-    def update_element_map(self, element, map_type, index, val):
153  
-        with self.scan:
154  
-            try:
155  
-                entry = '%s_%s'%(element, map_type)
156  
-                self.scan['element_maps'][entry][index] = val
157  
-            except ValueError:
158  
-                print "index %d out of range for %s" % (index, entry)
159  
-            except KeyError:
160  
-                print "%s not found in element_maps" % entry
161  
-            except TypeError:
162  
-                print entry, index, val
163  
-
164 152
     def update_records(self, data):
165 153
         if data is None:
166 154
             return
@@ -177,13 +165,13 @@ def update_records(self, data):
177 165
             else:
178 166
                 sigma_area = np.nan
179 167
 
180  
-            self.update_element_map(g, 'fit', index, fit_area)
181  
-            self.update_element_map(g, 'fit_error', index, sigma_area)
  168
+            self._results.update_fit(g, index, fit_area)
  169
+            self._results.update_fit_error(g, index, sigma_area)
182 170
 
183 171
         try:
184 172
             mass_fractions = result['concentrations']['mass fraction']
185 173
             for key, val in mass_fractions.items():
186 174
                 k = key.replace(' ', '_')
187  
-                self.update_element_map(k, 'mass_fraction', index, val)
  175
+                self._results.update_mass_fraction(k, index, val)
188 176
         except KeyError:
189 177
             pass
60  praxes/fluorescence/results.py
... ...
@@ -0,0 +1,60 @@
  1
+from threading import RLock
  2
+
  3
+import numpy as np
  4
+
  5
+
  6
+class XRFMapResultProxy(object):
  7
+
  8
+    def __init__(self, storage, elements=None, shape=None):
  9
+        self._lock = RLock()
  10
+        self._storage = storage
  11
+        self._cache = {}
  12
+
  13
+        if elements and shape:
  14
+            # we are overwriting an existing result:
  15
+            with self._storage:
  16
+                if 'element_maps' in self._storage:
  17
+                    del self._storage['element_maps']
  18
+
  19
+                element_maps = self._storage.create_group(
  20
+                    'element_maps', type='ElementMaps'
  21
+                    )
  22
+
  23
+                for map_type, cls in [
  24
+                    ('fit', 'Fit'),
  25
+                    ('fit_error', 'FitError'),
  26
+                    ('mass_fraction', 'MassFraction')
  27
+                    ]:
  28
+                    for element in elements:
  29
+                        data = np.zeros(shape, 'f')
  30
+                        entry = '%s_%s'%(element, map_type)
  31
+                        element_maps.create_dataset(entry, type=cls, data=data)
  32
+
  33
+        with self._storage:
  34
+            shape = self._storage.entry.acquisition_shape
  35
+            for k, v in self._storage['element_maps'].items():
  36
+                self._cache[k] = v[()].reshape(shape)
  37
+
  38
+    def update_fit(self, element, index, value):
  39
+        with self._lock:
  40
+            self._cache['%s_fit' % element].flat[index] = value
  41
+
  42
+    def update_fit_error(self, element, index, value):
  43
+        with self._lock:
  44
+            self._cache['%s_fit_error' % element].flat[index] = value
  45
+
  46
+    def update_mass_fraction(self, element, index, value):
  47
+        with self._lock:
  48
+            self._cache['%s_mass_fraction' % element].flat[index] = value
  49
+
  50
+    def flush(self):
  51
+        with self._lock:
  52
+            with self._storage:
  53
+                maps = self._storage['element_maps']
  54
+                for k, v in self._cache.items():
  55
+                    maps[k][()] = v.flatten()
  56
+                self._storage.file.flush()
  57
+
  58
+    def get(self, element, map_type):
  59
+        with self._lock:
  60
+            return self._cache['%s_%s' % (element, map_type)].copy()

0 notes on commit dd3234e

Please sign in to comment.
Something went wrong with that request. Please try again.