Skip to content

Commit

Permalink
🎨 flake8 fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
lalmei committed Feb 17, 2021
1 parent f24ea53 commit 0e5b453
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 25 deletions.
34 changes: 18 additions & 16 deletions src/whylogs/core/metrics/confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@


class ConfusionMatrix:

"""
Confusion Matrix Class to hold labels and matrix data
Expand Down Expand Up @@ -52,7 +51,8 @@ def add(self, predictions: List[Union[str, int, bool]],
scores (List[float]):
Raises:
NotImplementedError: in case targets do not fall into binary or multiclass suport
NotImplementedError: in case targets do not fall into binary or
multiclass suport
ValueError: incase missing validation or predictions
"""

Expand All @@ -79,18 +79,17 @@ def add(self, predictions: List[Union[str, int, bool]],
self.confusion_matrix[prediction_indx[ind],
targets_indx[ind]].track(scores[ind])

def merge(self, other_cm: ConfusionMatrix)-> ConfusionMatrix:
"""
merge two seperate confusion matrix which may or may not overlap in labels
Args:
other_cm (ConfusionMatrix): confusion_matrix to merge with self
def merge(self, other_cm):
"""
merges two seperate confusion matrix which may or may not overlap in labels
Returns:
ConfusionMatrix: merged confusion_matrix
"""
Args:
other_cm (ConfusionMatrix): confusion_matrix to merge with self
Returns:
ConfusionMatrix: merged confusion_matrix
"""

if self.labels is None or self.labels == []:
if self.labels is None or self.labels == []:
return other_cm
if other_cm.labels is None or other_cm.labels == []:
return self
Expand All @@ -105,10 +104,13 @@ def merge(self, other_cm: ConfusionMatrix)-> ConfusionMatrix:
return conf_Matrix

def to_protobuf(self, ):
return ScoreMatrixMessage(labels=self.labels, prediction_field=self.prediction_field,
return ScoreMatrixMessage(labels=self.labels,
prediction_field=self.prediction_field,
target_field=self.target_field,
score_field=self.score_field,
scores=[nt.to_protobuf() if nt else NumberTracker.to_protobuf(NumberTracker()) for nt in np.ravel(self.confusion_matrix)])
scores=[nt.to_protobuf() if nt else NumberTracker.to_protobuf(
NumberTracker()) for nt in np.ravel(
self.confusion_matrix)])

@classmethod
def from_protobuf(self, message,):
Expand All @@ -126,10 +128,10 @@ def from_protobuf(self, message,):
return CM_instance


def _merge_CM(old_conf_Matrix:ConfusionMatrix, new_conf_Matrix:ConfusionMatrix):
def _merge_CM(old_conf_Matrix: ConfusionMatrix, new_conf_Matrix: ConfusionMatrix):
"""
Merges two confusion_matrix since distinc or overlaping labels
Args:
old_conf_Matrix (ConfusionMatrix)
new_conf_Matrix (ConfusionMatrix): Will be overridden
Expand Down
19 changes: 10 additions & 9 deletions src/whylogs/core/metrics/model_metrics.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
from typing import List, Union

from whylogs.core.metrics.confusion_matrix import ConfusionMatrix
from whylogs.proto import ModelMetricsMessage


class ModelMetrics:

"""
Container class for Modelmetrics
Container class for Modelmetrics
Attributes:
confusion_matrix (ConfusionMatrix): ConfusionMatrix which keeps it track of counts with numbertracker
confusion_matrix (ConfusionMatrix): ConfusionMatrix which keeps it track of counts with numbertracker
"""

def __init__(self, confusion_matrix: ConfusionMatrix = ConfusionMatrix()):
Expand All @@ -30,12 +31,12 @@ def compute_confusion_matrix(self, predictions: List[Union[str, int, bool]],
computes the confusion_matrix, if one is already present merges to old one.
Args:
predictions (List[Union[str, int, bool]]): Description
targets (List[Union[str, int, bool]]): Description
scores (List[float], optional): Description
target_field (str, optional): Description
prediction_field (str, optional): Description
score_field (str, optional): Description
predictions (List[Union[str, int, bool]]):
targets (List[Union[str, int, bool]]):
scores (List[float], optional):
target_field (str, optional):
prediction_field (str, optional):
score_field (str, optional):
"""
labels = sorted(list(set(targets+predictions)))
confusion_matrix = ConfusionMatrix(labels,
Expand Down

0 comments on commit 0e5b453

Please sign in to comment.