In [1]:
%matplotlib inline
from matplotlib import pyplot as plt
import numpy as np
import tensorflow as tf
import glob, random, time, os, zlib
import model

FEATURE_COUNT = 6 + 6 + 1
CROSS_VAL_SIZE = 3000
IN_SAMPLE_SIZE = 1000
MINIBATCH_SIZE = 256
DATA_ROOT = "build2/"
TOTAL_CHUNK_COUNT = 12

def to_hms(x):
    x = int(x)
    seconds = x % 60
    minutes = (x // 60) % 60
    hours   = x // 60 // 60
    return "%2i:%02i:%02i" % (hours, minutes, seconds)

# For some reason some Python versions basically explode on .decode("zlib") for large strings.
# We can bypass by just decoding it in blocks ourself and assembling them.
def stream_decompress(s):
    decomp = zlib.decompressobj()
    block_size = 2**23
    i = 0
    results = []
    while i < len(s):
        block = s[i:i+block_size]
        results.append(decomp.decompress(block))
        i += block_size
    results.append(decomp.flush())
    return "".join(results)

def load_chunk(features, moves):
    def load_flat_array(path, shape):
        with open(path) as f:
            data = f.read()
        data = stream_decompress(data)
        return np.fromstring(data, dtype=np.int8).reshape(shape)
    features = load_flat_array(features, (-1, 8, 8, FEATURE_COUNT))
    moves    = load_flat_array(moves, (-1, 8, 8, 2))
    # Move each sample to be of shape (2, 8, 8) so we can use tf.nn.softmax_cross_entropy_with_logits_v2.
    moves    = np.moveaxis(moves, -1, 1)
    assert len(features) == len(moves)
    return {"features": features, "moves": moves}

# Views into the extremely large dataset.
next_chunk_index = 0
chunk = None
in_sample_test = None

def load_next_chunk():
    global next_chunk_index, chunk, in_sample_test
    print "    >>> Loading chunk:", next_chunk_index
    # Free the memory from the previous chunk FIRST, if we have one loaded.
    # This is necessary to avoid running out of memory.
    if chunk is not None:
        del chunk
        del in_sample_test
    start = time.time()
    chunk = load_chunk(
        os.path.join(DATA_ROOT, "features_%03i.z" % next_chunk_index),
        os.path.join(DATA_ROOT, "moves_%03i.z" % next_chunk_index),
    )
    next_chunk_index = (next_chunk_index + 1) % TOTAL_CHUNK_COUNT
    in_sample_test = {
        "features": chunk["features"][:IN_SAMPLE_SIZE],
        "moves":    chunk["moves"][:IN_SAMPLE_SIZE],
    }
    stop = time.time()
    print "    >>> (In %f) Samples: %i" % (stop - start, len(chunk["features"]))

def get_random_subset(samples, n):
    indices = random.sample(xrange(len(samples["features"])), n)
    return {
        "features": [samples["features"][i] for i in indices],
        "moves": [samples["moves"][i] for i in indices],
    }

In [2]:
load_next_chunk()
cross_val = load_chunk(
    os.path.join(DATA_ROOT, "test_features.z"),
    os.path.join(DATA_ROOT, "test_moves.z"),
)
cross_val = get_random_subset(cross_val, CROSS_VAL_SIZE)

    >>> Loading chunk: 0
    >>> (In 12.563726) Samples: 6764463


In [3]:
net = model.ChessNet()
print "Total network parameters:", net.total_parameters
sess = tf.InteractiveSession()
sess.run(tf.initialize_all_variables())
total_training_steps = 0
loss_plot = []
in_sample_loss_plot = []

Total network parameters: 2797312
Instructions for updating:
Use `tf.global_variables_initializer` instead.


In [4]:
total_work = 0.0
start_time = time.time()
best_loss = float("inf")
lr_schedule = lambda step: 0.01 * 0.5**(step / 8e4)

for overall_step in range(10000):
    lr = lr_schedule(total_training_steps)
    elapsed = time.time() - start_time
    in_sample_loss = net.get_loss(in_sample_test)
    loss = net.get_loss(cross_val)
    color_pair = "", ""
    if loss < best_loss:
        color_pair = "\x1b[31m", "\x1b[0m"
    message = "%s%6i [%s - %s] Loss: %.6f  In-sample loss: %.6f  Accuracy: %.3f  lr = %f%s" % (
        color_pair[0],
        total_training_steps,
        to_hms(elapsed),
        to_hms(total_work),
        loss,
        in_sample_loss,
        net.get_accuracy(cross_val) * 100,
        lr,
        color_pair[1]
    )
    print(message)
    with open("/home/snp/chess_training_log", "a+") as f:
        print >>f, message
    loss_plot.append((total_training_steps, loss))
    in_sample_loss_plot.append((total_training_steps, in_sample_loss))
    best_loss = min(best_loss, loss)

    for _ in range(500):
        minibatch = get_random_subset(chunk, MINIBATCH_SIZE)
        working = time.time()
        net.train(minibatch, lr)
        total_work += time.time() - working
        # Try really hard to not keep any views around!
        del minibatch
        total_training_steps += 1

    # Periodically swap out the data for fresh training data.
    if (overall_step + 1) % 5 == 0:
        load_next_chunk()
#    if (overall_step + 1) % 20 == 0:
#        save_model()

[31m     0 [ 0:00:00 -  0:00:00] Loss: 4.245204  In-sample loss: 4.241312  Accuracy: 0.000  lr = 0.010000[0m
[31m   500 [ 0:00:25 -  0:00:23] Loss: 3.501557  In-sample loss: 3.587129  Accuracy: 6.467  lr = 0.009957[0m
[31m  1000 [ 0:00:49 -  0:00:46] Loss: 3.212693  In-sample loss: 3.320307  Accuracy: 6.800  lr = 0.009914[0m
[31m  1500 [ 0:01:12 -  0:01:08] Loss: 3.013221  In-sample loss: 3.118905  Accuracy: 7.233  lr = 0.009871[0m
[31m  2000 [ 0:01:36 -  0:01:31] Loss: 2.872252  In-sample loss: 2.964665  Accuracy: 9.667  lr = 0.009828[0m
    >>> Loading chunk: 1
    >>> (In 12.591915) Samples: 6736364
[31m  2500 [ 0:02:12 -  0:01:54] Loss: 2.797369  In-sample loss: 2.764048  Accuracy: 10.333  lr = 0.009786[0m
[31m  3000 [ 0:02:35 -  0:02:17] Loss: 2.781773  In-sample loss: 2.717456  Accuracy: 10.300  lr = 0.009743[0m
[31m  3500 [ 0:02:58 -  0:02:39] Loss: 2.727601  In-sample loss: 2.628811  Accuracy: 10.300  lr = 0.009701[0m
[31m  4000 [ 0:03:21 -  0:03:02] Loss: 2.67

[31m 34500 [ 0:29:24 -  0:26:03] Loss: 2.215088  In-sample loss: 2.162228  Accuracy: 18.167  lr = 0.007416[0m
    >>> Loading chunk: 2
    >>> (In 12.348594) Samples: 6739494
 35000 [ 0:30:00 -  0:26:26] Loss: 2.249152  In-sample loss: 2.159724  Accuracy: 18.167  lr = 0.007384
[31m 35500 [ 0:30:23 -  0:26:48] Loss: 2.206133  In-sample loss: 2.131828  Accuracy: 19.800  lr = 0.007352[0m
 36000 [ 0:30:46 -  0:27:11] Loss: 2.224047  In-sample loss: 2.110447  Accuracy: 18.767  lr = 0.007320
 36500 [ 0:31:09 -  0:27:34] Loss: 2.222814  In-sample loss: 2.105618  Accuracy: 18.767  lr = 0.007289
 37000 [ 0:31:32 -  0:27:56] Loss: 2.215075  In-sample loss: 2.122285  Accuracy: 19.267  lr = 0.007257
    >>> Loading chunk: 3
    >>> (In 12.439680) Samples: 6737983
 37500 [ 0:32:08 -  0:28:19] Loss: 2.288257  In-sample loss: 2.324234  Accuracy: 18.633  lr = 0.007226
[31m 38000 [ 0:32:31 -  0:28:41] Loss: 2.201880  In-sample loss: 2.235601  Accuracy: 19.900  lr = 0.007195[0m
 38500 [ 0:32:54 - 

KeyboardInterrupt: 