-
Notifications
You must be signed in to change notification settings - Fork 105
/
__init__.py
354 lines (273 loc) · 12 KB
/
__init__.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
from __future__ import absolute_import
import copy
import tempfile
import pyarrow as pa
import os
import torch
import zipfile
import numpy as np
import pandas as pd
from sklearn.preprocessing import StandardScaler
from pysurvival import utils
from pysurvival.utils._functions import _get_time_buckets
class BaseModel(object):
""" Base class for all estimators in pysurvival. It should not be used on
its own.
"""
def __init__(self, auto_scaler=True):
# Creating a scikit-learner scaler
self.auto_scaler = auto_scaler
if self.auto_scaler:
self.scaler = StandardScaler()
else:
self.scaler = None
# Creating a place holder for the time axis
self.times = [0.]
# Creating the model's name
self.__repr__()
def __repr__(self):
""" Creates the representation of the Object """
self.name = self.__class__.__name__
return self.name
def save(self, path_file):
""" Save the model components:
* the paremeters of the model (parameters)
* the PyTorch model itself (model) if it exists
And Compress them into a zip file
Parameters
----------
* path_file, str
address of the file where the model will be saved
"""
# Ensuring the file has the proper name
folder_name = os.path.dirname(path_file) + '/'
file_name = os.path.basename(path_file)
if not file_name.endswith('.zip'):
file_name += '.zip'
# Checking if the folder is accessible
if not os.access(folder_name, os.W_OK):
error_msg = '{} is not an accessible directory.'.format(folder_name)
raise OSError(error_msg)
# Saving all the elements to save
elements_to_save = []
# Changing the format of scaler parameters if exist
temp_scaler = copy.deepcopy(self.__dict__.get('scaler'))
if temp_scaler is not None:
self.__dict__['scaler'] = temp_scaler.__dict__
# Saving the model parameters
parameters_to_save = {}
for k in self.__dict__ :
if k != 'model' :
parameters_to_save[k] = self.__dict__[k]
# Serializing the parameters
elements_to_save.append('parameters')
with open('parameters' , 'wb') as f:
serialized_to_save = pa.serialize(parameters_to_save)
f.write(serialized_to_save.to_buffer())
# Saving the torch model if exists
if 'model' in self.__dict__.keys():
elements_to_save.append('model')
torch.save(self.model, 'model')
# Compressing the elements to save in zip
full_path = folder_name + file_name
print('Saving the model to disk as {}'.format(full_path))
with zipfile.ZipFile(full_path, 'w') as myzip:
for f in elements_to_save:
myzip.write(f)
# Erasing temp files
for temp_file in elements_to_save:
os.remove(temp_file)
# Restore the scaler
if temp_scaler is not None:
self.scaler = StandardScaler()
self.__dict__['scaler'] = copy.deepcopy(temp_scaler)
def load(self, path_file):
""" Load the model components from a .zip file:
* the parameters of the model (.params)
* the PyTorch model itself (.model) is exists
Parameters
----------
* path_file, str
address of the file where the model will be loaded from
"""
# Ensuring the file has the proper name
folder_name = os.path.dirname(path_file) + '/'
file_name = os.path.basename(path_file)
if not file_name.endswith('.zip'):
file_name += '.zip'
# Opening the '.zip' file
full_path = folder_name + file_name
print('Loading the model from {}'.format(full_path))
# Creating temp folder
temp_folder = tempfile.mkdtemp() + '/'
# Unzip files in temp folder
with zipfile.ZipFile(path_file, 'r') as zip_ref:
zip_ref.extractall(temp_folder)
input_zip=zipfile.ZipFile(path_file)
# Loading the files
elements_to_load = []
for file_name in input_zip.namelist():
# Loading the parameters
if 'parameters' in file_name.lower():
content = input_zip.read( 'parameters' )
self.__dict__ = copy.deepcopy(pa.deserialize(content))
elements_to_load.append(temp_folder +'parameters')
# If a scaler was available then load it too
temp_scaler = copy.deepcopy(self.__dict__.get('scaler'))
if temp_scaler is not None:
self.scaler = StandardScaler()
self.scaler.__dict__ = temp_scaler
# Loading the PyTorch model
if 'model' in file_name.lower():
model = torch.load( temp_folder + 'model' )
self.model = model
elements_to_load.append(temp_folder +'model')
# Erasing temp files
for temp_file in elements_to_load:
os.remove(temp_file)
def get_time_buckets(self, extra_timepoint=False):
""" Creating the time buckets based on the times axis such that
for the k-th time bin is [ t(k-1), t(k) ] in the time axis.
"""
# Checking if the time axis has already been created
if self.times is None or len(self.times) <= 1:
error = 'The time axis needs to be created before'
error += ' using the method get_time_buckets.'
raise AttributeError(error)
# Creating the base time buckets
time_buckets = _get_time_buckets(self.times)
# Adding an additional element if specified
if extra_timepoint:
time_buckets += [ (time_buckets[-1][1], time_buckets[-1][1]*1.01) ]
self.time_buckets = time_buckets
def predict_hazard(self, x, t = None, **kwargs):
""" Predicts the hazard function h(t, x)
Parameters
----------
* `x` : **array-like** *shape=(n_samples, n_features)* --
array-like representing the datapoints.
x should not be standardized before, the model
will take care of it
* `t`: **double** *(default=None)* --
time at which the prediction should be performed.
If None, then return the function for all available t.
Returns
-------
* `hazard`: **numpy.ndarray** --
array-like representing the prediction of the hazard function
"""
# Checking if the data has the right format
x = utils.check_data(x)
# Calculating hazard, density, survival
hazard, density, survival = self.predict( x, t, **kwargs)
return hazard
def predict_density(self, x, t = None, **kwargs):
""" Predicts the density function d(t, x)
Parameters
----------
* `x` : **array-like** *shape=(n_samples, n_features)* --
array-like representing the datapoints.
x should not be standardized before, the model
will take care of it
* `t`: **double** *(default=None)* --
time at which the prediction should be performed.
If None, then return the function for all available t.
Returns
-------
* `density`: **numpy.ndarray** --
array-like representing the prediction of density function
"""
# Checking if the data has the right format
x = utils.check_data(x)
# Calculating hazard, density, survival
hazard, density, survival = self.predict( x, t, **kwargs )
return density
def predict_survival(self, x, t = None, **kwargs):
""" Predicts the survival function S(t, x)
Parameters
----------
* `x` : **array-like** *shape=(n_samples, n_features)* --
array-like representing the datapoints.
x should not be standardized before, the model
will take care of it
* `t`: **double** *(default=None)* --
time at which the prediction should be performed.
If None, then return the function for all available t.
Returns
-------
* `survival`: **numpy.ndarray** --
array-like representing the prediction of the survival function
"""
# Checking if the data has the right format
x = utils.check_data(x)
# Calculating hazard, density, survival
hazard, density, survival = self.predict( x, t, **kwargs)
return survival
def predict_cdf(self, x, t = None, **kwargs):
""" Predicts the cumulative density function F(t, x)
Parameters
----------
* `x` : **array-like** *shape=(n_samples, n_features)* --
array-like representing the datapoints.
x should not be standardized before, the model
will take care of it
* `t`: **double** *(default=None)* --
time at which the prediction should be performed.
If None, then return the function for all available t.
Returns
-------
* `cdf`: **numpy.ndarray** --
array-like representing the prediction of the cumulative
density function
"""
# Checking if the data has the right format
x = utils.check_data(x)
# Calculating survival and cdf
survival = self.predict_survival(x, t, **kwargs)
cdf = 1. - survival
return cdf
def predict_cumulative_hazard(self, x, t = None, **kwargs):
""" Predicts the cumulative hazard function H(t, x)
Parameters
----------
* `x` : **array-like** *shape=(n_samples, n_features)* --
array-like representing the datapoints.
x should not be standardized before, the model
will take care of it
* `t`: **double** *(default=None)* --
time at which the prediction should be performed.
If None, then return the function for all available t.
Returns
-------
* `cumulative_hazard`: **numpy.ndarray** --
array-like representing the prediction of the cumulative_hazard
function
"""
# Checking if the data has the right format
x = utils.check_data(x)
# Calculating hazard/cumulative_hazard
hazard = self.predict_hazard(x, t, **kwargs)
cumulative_hazard = np.cumsum(hazard, 1)
return cumulative_hazard
def predict_risk(self, x, **kwargs):
""" Predicts the Risk Score/Mortality function for all t,
R(x) = sum( cumsum(hazard(t, x)) )
According to Random survival forests from Ishwaran H et al
https://arxiv.org/pdf/0811.1645.pdf
Parameters
----------
* `x` : **array-like** *shape=(n_samples, n_features)* --
array-like representing the datapoints.
x should not be standardized before, the model
will take care of it
Returns
-------
* `risk_score`: **numpy.ndarray** --
array-like representing the prediction of Risk Score function
"""
# Checking if the data has the right format
x = utils.check_data(x)
# Calculating cumulative_hazard/risk
cumulative_hazard = self.predict_cumulative_hazard(x, None, **kwargs)
risk_score = np.sum(cumulative_hazard, 1)
return risk_score