forked from uwdata/errudite
/
predictor.py
153 lines (132 loc) · 4.66 KB
/
predictor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
from typing import List, Dict, Any
from ..utils import Registrable
class Predictor(Registrable):
"""A base class for predictors.
A predictor runs prediction on raw texts and also instances.
It also saves the performance score for the predictor.
This is a subclass of ``errudite.utils.registrable.Registrable`` and all the actual rewrite
rule classes are registered under ``Predictor`` by their names.
Parameters
----------
name : str
The name of the predictor.
description : str
A sentence describing the predictor.
model : any
The executable model.
perform_metrics : List[str]
The name of performance metrics.
Attributes
----------
perform : Dict[str, float]
.. code-block:: js
{ perform_name: the averaged performance score. }
"""
def __init__(self,
name: str,
description: str,
model: any,
perform_metrics: List[str]):
self.name: str = name
self.description: str = description
self.predictor: Any = model
self.perform: Dict[str, float] = {}
self.perform_metrics: List[str] = perform_metrics
for p in self.perform_metrics:
self.perform[p] = 0
def predict(self, **kwargs):
"""
run the prediction.
Raises
------
NotImplementedError
Should be implemented in subclasses.
"""
raise NotImplementedError
def evaluate_performance(self, instances: List['Instance']) -> None:
"""Save the performance of the predictor.
It iterates through metric names in ``self.perform_metrics``, and average the
corresponding metrics in ``instance.prediction.perform``. It saves the results
in ``self.perform``.
Parameters
----------
instances : List[Instance]
The list of instances, with predictions from this model already saved as
part of its entries.
Returns
-------
None
The result is saved in ``self.perform``.
"""
instances = list(filter(lambda i: i.vid==0, instances))
n_total = len(instances)
if n_total != 0:
for metric in self.perform_metrics:
self.perform[metric] = sum([
i.get_entry('prediction', self.name).perform[metric] for i in instances]) / n_total
else:
print(n_total)
print(self.name)
def serialize(self) -> Dict:
"""Seralize the instance into a json format, for sending over
to the frontend.
Returns
-------
Dict[str, Any]
The serialized version.
"""
return {
'perform': self.perform,
'name': self.name,
'description': self.description
}
def __repr__(self) -> str:
"""
Override the print func by displaying the class name and the predictor name.
"""
return f'{self.__class__.__name__} {self.name}'
@classmethod
def create_from_json(cls, raw: Dict[str, str]) -> 'Predictor':
"""
Recreate the predictor from its seralized raw json.
Parameters
----------
raw : Dict[str, str]
The json version definition of the predictor, with
name, description, model_path, and model_online_path.
Returns
-------
Predictor
The re-created predictor.
"""
try:
return Predictor.by_name(raw["model_class"])(
name=raw["name"] if "name" in raw else None,
description=raw["description"] if "description" in raw else None,
model_path=raw["model_path"] if "model_path" in raw else None,
model_online_path=raw["model_online_path"] if "model_online_path" in raw else None)
except:
raise
@classmethod
def model_predict(cls,
predictor: 'Predictor',
**targets) -> 'Label':
"""
Define a class method that takes Target inputs, run model predictions,
and wrap the output prediction into Labels.
Parameters
----------
predictor : Predictor
A predictor object, with the predict method implemented.
targets : Target
Targets in kwargs format
Returns
-------
Label
The predicted output, with performance saved.
Raises
-------
NotImplementedError
This needs to be implemented per task.
"""
raise NotImplementedError