Skip to content
Permalink
Browse files

final touches

  • Loading branch information...
sethtroisi committed Apr 8, 2019
1 parent 1270281 commit a141c859516fac82ea14efde7661ff4d325255fd
Showing with 45 additions and 28 deletions.
  1. +45 −28 ratings/cbt_ratings.py
@@ -205,7 +205,6 @@ def setup_models(models_table):
model_ids, model_runs = read_models(db)

assert len(model_ids) == cbt_models
print()
return model_ids, model_runs


@@ -309,24 +308,32 @@ def sync(eval_games_table, model_ids, model_runs):
print("{:<10}".format(count),s )


def compute_ratings(data=None):
""" Returns the tuples of (model_id, rating, sigma)
N.B. that `model_id` here is NOT the model number in the run
def compute_ratings(model_ids, data=None):
""" Calculate ratings from win records
'data' is tuples of (winner, loser) model_ids (not model numbers)
Args:
model_ids: dictionary of {(run, model_name): model_id}
data: list of tuples of (winner_id, loser_id)
Returns:
dictionary {(run, model_name): (rating, variance)}
"""
if data is None:
with sqlite3.connect("cbt_ratings.db") as db:
data = db.execute("select model_winner, model_loser from wins").fetchall()
model_ids = set([d[0] for d in data]).union(set([d[1] for d in data]))
query = "select model_winner, model_loser from wins"
data = db.execute(query).fetchall()

data_ids = sorted(set(np.array(data).flatten()))

# Map model_ids to a contiguous range.
ordered = sorted(model_ids)
# Map data_ids to a contiguous range.
new_id = {}
for i, m in enumerate(ordered):
for i, m in enumerate(data_ids):
new_id[m] = i

# A function to rewrite the model_ids in our pairs
# Create inverse model_ids lookup
model_names = {v:k for k,v in model_ids.items()}

# A function to rewrite the data_ids in our pairs
def ilsr_data(d):
p1, p2 = d
p1 = new_id[p1]
@@ -335,7 +342,7 @@ def ilsr_data(d):

pairs = list(map(ilsr_data, data))
ilsr_param = choix.ilsr_pairwise(
len(ordered),
len(data_ids),
pairs,
alpha=0.0001,
max_iter=800)
@@ -346,20 +353,25 @@ def ilsr_data(d):
# Elo conversion
elo_mult = 400 / math.log(10)

# Used to make all ratings positive.
min_rating = min(ilsr_param)
ratings = {}

for model_id, param, err in zip(ordered, ilsr_param, std_err):
ratings[model_id] = (elo_mult * (param - min_rating), elo_mult * err)
ratings = {}
for m_id, param, err in zip(data_ids, ilsr_param, std_err):
model_rating = (elo_mult * (param - min_rating), elo_mult * err)
ratings[model_names[m_id]] = model_rating

return ratings


def top_n(n=10):
with sqlite3.connect('cbt_ratings.db') as db:
model_ids, model_runs = read_models(db)

data = wins_subset()
r = compute_ratings(data)
return [(model_num_for(k), v) for v, k in
sorted([(v, k) for k, v in r.items()])[-n:][::-1]]
r = compute_ratings(model_ids, data)
top_models = sorted(ratings.items(), key=lambda k: k[::-1])
return top_models[-n:][::-1]


def wins_subset(run=None):
@@ -377,8 +389,8 @@ def wins_subset(run=None):
# No run is for cross eval, don't allow games from same run
data = db.execute("""
select model_winner, model_loser from wins
# """)
("""
""")
("""
join models m1 join models m2 where
m1.id = model_winner AND
m2.id = model_loser AND
@@ -408,25 +420,30 @@ def main():
model_ids, model_runs = read_models(db)

data = wins_subset()

print("DB has", len(data), "games")
if not data:
return

ratings = compute_ratings(data)
for v, k in sorted([(v, k) for k, v in ratings.items()], reverse=True)[:20]:
print("Top model({}) {}: {}".format(k, k, v))
ratings = compute_ratings(model_ids, data)
top_models = sorted(ratings.items(), key=lambda k: k[::-1])
print()
print("Best models")
for k, v in top_models[-20:][::-1]
print("{:>30}: {}".format("/".join(k), v))

# Stats on recent models
run = 'v' + str(max(int(r[1:]) for r, m in model_ids))
print()
print("Recent ratings for", run)
for m in sorted(m for r, m in model_ids if r == run)[-20:]:
m_id = model_ids[(run, m)]
if m_id in ratings:
rating, sigma = ratings[m_id]
print("{:>30}: {:.2f} ({:.3f})".format(m, rating, sigma))
name = run + "/" + m
rating = ratings.get((run, m))
if rating:
rating, sigma = rating
print("{:>30}: {:.2f} ({:.3f})".format(name, rating, sigma))
else:
print("{:>30}: ({}) not found".format(m, m_id))
print("{:>30}: not found".format(name))


if __name__ == '__main__':

0 comments on commit a141c85

Please sign in to comment.
You can’t perform that action at this time.