Skip to content

Commit

Permalink
added generic weighted average function
Browse files Browse the repository at this point in the history
  • Loading branch information
GeetDsa committed Apr 9, 2020
1 parent 2fa096d commit 48c74b8
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 17 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -471,7 +471,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
[Unreleased]: https://github.com/sepandhaghighi/pycm/compare/v2.6...dev

### Added
- Weighted average F1
- Weighted average for various class stats
[2.6]: https://github.com/sepandhaghighi/pycm/compare/v2.5...v2.6
[2.5]: https://github.com/sepandhaghighi/pycm/compare/v2.4...v2.5
[2.4]: https://github.com/sepandhaghighi/pycm/compare/v2.3...v2.4
Expand Down
3 changes: 3 additions & 0 deletions Test/function_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,4 +515,7 @@
0.4
>>> cm3.LambdaB
0.35714285714285715
>>> cm.weighted_average("F2")
0.5684946392350044
"""
23 changes: 23 additions & 0 deletions pycm/pycm_obj.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,3 +656,26 @@ def relabel(self, mapping):
self.FP = self.class_stat["FP"]
self.FN = self.class_stat["FN"]
__class_stat_init__(self)
def weighted_average(self, item, weights=None):
"""
Calculate Weighted average of input parameter.
:param item: a class item for which weighted average has to be calculated
:type item1:str
:param weights: Explicitly passes weights or Support or P
:type item2:dict
:return: weighted average of the input parameter
"""
if not weights:
""" Set default weights to the Support or P"""
weights = list(self.class_stat["P"].values())
try:
item_values = list(self.class_stat[item].values())
except KeyError:
return "Error: Invalid item {}".format(item)
try:
return numpy.average(item_values,weights=weights)
except Exception:
return "None"

16 changes: 0 additions & 16 deletions pycm/pycm_overall_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -659,22 +659,6 @@ def macro_calc(item):
except Exception:
return "None"

def weighted_calc(item1,item2):
"""
Calculate PPV_Macro and TPR_Macro.
:param item1: F1
:type item1:dict
:param item2: P or Support
:type item2:dict
:return: weighted average F1
"""
try:
return numpy.average(list(item1.values()),weights=list(item2.values()))
except Exception:
return "None"

def PC_AC1_calc(P, TOP, POP):
"""
Calculate percent chance agreement for Gwet's AC1.
Expand Down

0 comments on commit 48c74b8

Please sign in to comment.