In [2]:
import more_itertools as mit
import numpy as np
import torch
import math
import timeit

In [3]:
state = (-1, 1, -1, 1, 0, 0, 0, 0, 1)

In [4]:
def line_py_gt(state):
    for i in (0, 1, 2):
        yield state[i], state[i + 3], state[i + 6]
    for i in (0, 3, 6):
        yield state[i], state[i + 1], state[i + 2]
    for i in (0,):
        yield state[i], state[i + 4], state[i + 8]
    for i in (2,):
        yield state[i], state[i + 2], state[i + 4]

In [5]:
%%timeit -n 100000
list(line_py_gt(state))

957 ns ± 141 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [6]:
s = 3
sm1 = 2
s2 = 9
s2m1 = 8
sp1 = 4
def line_py(state):
    for i in range(0, s2, s):
        yield state[i:i+s]
    for i in range(s):
        yield state[i::s]
    yield state[::sp1]
    yield state[sm1:s2m1:sm1]

In [7]:
%%timeit -n 100000
list(line_py(state))

898 ns ± 65.9 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [8]:
index = list(line_py(list(range(9))))
def line_py_indexed(state):
    for i, j, k in index:
        yield state[i], state[j], state[k]

In [9]:
%%timeit -n 100000
list(line_py_indexed(state))

776 ns ± 94.6 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [10]:
def line_np(state):
    raw_state = np.array(state).reshape(-1, 3)
    rel_state = 2 * (raw_state == 1) - (raw_state != 0)
    return np.concatenate([
        rel_state, rel_state.T,
        np.stack([np.diag(rel_state), np.diag(np.fliplr(rel_state))])
    ], axis=0)

In [11]:
%%timeit -n 10000
list(line_np(state))

10.3 µs ± 1.53 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [14]:
def line_torch(state):
    raw_state = torch.tensor(state).view(-1, 3)
    rel_state = 2 * (raw_state == 1) - 1 * (raw_state != 0)
    return torch.cat([
        rel_state, rel_state.T,
        torch.stack([rel_state.diag(), rel_state.fliplr().diag()])
    ], dim=0)

In [15]:
%%timeit -n 10000
list(line_torch(state))

35.8 µs ± 1.41 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [16]:
def line_test(state):
    curr_state = np.array(state).reshape(3, 3)
    temp_state = curr_state.copy()

    curr_state[temp_state == 1] = 1
    curr_state[(temp_state != 1) & (temp_state != 0)] = -1

    return list(curr_state[:]) + list(curr_state.T[:]) + [np.diag(curr_state)] + [np.diag(np.fliplr(curr_state))]

In [17]:
%%timeit -n 10000
list(line_test(state))

7 µs ± 622 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


# Benchmark get_v

In [18]:
lines = list(line_py_gt(state))

In [19]:
def get_v_py_gt(lines):
    p1 = p2 = a1 = a2 = 0
    for line in lines:
        p_count = line.count(1)
        a_count = line.count(-1)
        if p_count == 2 and a_count == 0:
            p2 += 1
        if p_count == 1 and a_count == 0:
            p1 += 1
        if p_count == 0 and a_count == 2:
            a2 += 1
        if p_count == 0 and a_count == 1:
            a1 += 1
    return 3 * p2 + p1 - (3 * a2 + a1)

In [20]:
%%timeit -n 100000
get_v_py_gt(lines)

1.32 µs ± 90.7 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [21]:
def get_v_fast(lines):
    score = 0
    for line in lines:
        p_count = line.count(1)
        a_count = line.count(-1)
        if a_count == 0:
            if p_count == 2:
                score += 3
            elif p_count == 1:
                score += 1
        if p_count == 0:
            if a_count == 2:
                score -= 3
            elif a_count == 1:
                score -= 1
    return score

In [22]:
%%timeit -n 100000
get_v_fast(lines)

1.06 µs ± 134 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [23]:
def get_v_math(lines):
    score = 0
    for line in lines:
        p = math.prod(line)
        s = sum(line)
        abs_s = abs(s)
        if abs_s == 2:
            score += 3 * s / 2
        if p == 0 and abs_s == 1:
            score += s
    return score

In [24]:
%%timeit -n 100000
get_v_math(lines)

1.57 µs ± 158 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [25]:
def get_v_np(lines):
    lines = np.array(lines)
    sum_line = lines.sum(axis=1)
    abs_sum_line = np.abs(sum_line)
    prod_line = lines.prod(axis=1)
    return np.sum(
        + 3 * sum_line * (abs_sum_line == 2) / 2
        + 1 * sum_line * ((prod_line == 0) & (abs_sum_line == 1))
    )

In [26]:
%%timeit -n 10000
get_v_np(lines)

12 µs ± 1.3 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [27]:
def get_v_test(lines):
    a1 = a2 = p1 = p2 = 0
    for line in lines:
        unique, counts = np.unique(line, return_counts=True)
        counts_dict = dict(zip(unique, counts))
        if 1 in counts_dict and -1 not in counts_dict:
            if counts_dict[1] == 1:
                p1 += 1
            elif counts_dict[1] == 2:
                p2 += 1
        if -1 in counts_dict and 1 not in counts_dict:
            if counts_dict[-1] == 1:
                a1 += 1
            elif counts_dict[-1] == 2:
                a2 += 1

    return p1 + p2 * 3 - a1 - a2 * 3

In [28]:
%%timeit -n 10000
get_v_test(lines)

51.1 µs ± 2.26 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
