Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP
Browse files

added sofia predictor

  • Loading branch information...
commit 63ac6a251f36890a25cb0dc2b7a0251369c8c07e 1 parent 9c5ed07
@rrenaud rrenaud authored
Showing with 49 additions and 0 deletions.
  1. +1 −0  data/logreg-peg.model
  2. +48 −0 sofia_predict.py
View
1  data/logreg-peg.model
1 addition, 0 deletions not shown
View
48 sofia_predict.py
@@ -0,0 +1,48 @@
+import subprocess
+
+import pymongo
+
+import game
+import game_state_features
+import math
+
+def log_odds_to_prob(logodds):
+ odds = math.exp(logodds)
+ return odds / (1 + odds)
+
+def encode_features_for_all_turns(game_val):
+ return [game_state_features.state_to_features(game_val, game_state)
+ for game_state in game_val.game_state_iterator()]
+
+class SofiaWinPredictor:
+ def __init__(self, model_name, prediction_type='logistic',
+ hash_mask_bits=None):
+ args = ['sofia-ml',
+ '--model_in', model_name,
+ '--test_stream',
+ '--prediction_type', prediction_type]
+ if hash_mask_bits is not None:
+ args.extend(['--hash_mask_bits', str(hash_mask_bits)])
+ p = subprocess.PIPE
+ self.sofia_proc = subprocess.Popen(args, stdin=p, stdout=p)
+
+ def predict_all_turns(self, game_val):
+ all_turns_features = encode_features_for_all_turns(game_val)
+ for features in all_turns_features:
+ game_state_features.output_libsvm_state(features,
+ self.sofia_proc.stdin)
+ ret = []
+ for _ in all_turns_features:
+ line = self.sofia_proc.stdout.readline()
+ # print line,
+ ret.append(log_odds_to_prob(float(line)))
+ return ret
+
+def main():
+ c = pymongo.Connection()
+ g = game.Game(c.test.games.find_one())
+ predictor = SofiaWinPredictor('data/logreg-peg.model')
+ print predictor.predict_all_turns(g)
+
+if __name__ == '__main__':
+ main()
Please sign in to comment.
Something went wrong with that request. Please try again.