Skip to content

Commit

Permalink
Persist trainExamples during sessions (#35)
Browse files Browse the repository at this point in the history
* Clear resources during long training

* Coach.py: recreate NNet after each iteration to clear resources
(especially in Keras+TF environment)
* Coach.py: free mcts variables after they are not used
* NeuralNet.py: abstract method for recreating models
* Keras implementation of Othello: reference method implementations for
recreating models

* Improvemetns

* Pythonic indentation in TicToePlayers
* Nice printing of the board

* sync from upstream

* Persist trainExamples during sessions

* Coach.py holds history of trainExamples from last N iterations as it
stated in AlphaGo paper
* Coach.py is able to save/load examples to/from file

* remove recreate()

* remove destroy()

* remove recreate()

* empty line

* empty line

* Unused arguments
  • Loading branch information
evg-tyurin authored and suragnair committed Feb 12, 2018
1 parent 2a37807 commit 1fb1255
Show file tree
Hide file tree
Showing 7 changed files with 110 additions and 37 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Expand Up @@ -2,3 +2,5 @@
.DS_Store

/temp/
/.project
/.pydevproject
100 changes: 75 additions & 25 deletions Coach.py
Expand Up @@ -3,7 +3,10 @@
from MCTS import MCTS
import numpy as np
from pytorch_classification.utils import Bar, AverageMeter
import time
import time, os, sys
from pickle import Pickler, Unpickler
from random import shuffle


class Coach():
"""
Expand All @@ -12,11 +15,12 @@ class Coach():
"""
def __init__(self, game, nnet, args):
self.game = game
self.board = game.getInitBoard()
self.nnet = nnet
self.pnet = self.nnet.__class__(self.game) # the competitor network
self.args = args
self.mcts = MCTS(self.game, self.nnet, self.args)
self.trainExamplesHistory = [] # history of examples from args.numItersForTrainExamplesHistory latest iterations
self.skipFirstSelfPlay = False # can be overriden in loadTrainExamples()

def executeEpisode(self):
"""
Expand All @@ -35,13 +39,13 @@ def executeEpisode(self):
the player eventually won the game, else -1.
"""
trainExamples = []
self.board = self.game.getInitBoard()
board = self.game.getInitBoard()
self.curPlayer = 1
episodeStep = 0

while True:
episodeStep += 1
canonicalBoard = self.game.getCanonicalForm(self.board,self.curPlayer)
canonicalBoard = self.game.getCanonicalForm(board,self.curPlayer)
temp = int(episodeStep < self.args.tempThreshold)

pi = self.mcts.getActionProb(canonicalBoard, temp=temp)
Expand All @@ -50,9 +54,9 @@ def executeEpisode(self):
trainExamples.append([b, self.curPlayer, p, None])

action = np.random.choice(len(pi), p=pi)
self.board, self.curPlayer = self.game.getNextState(self.board, self.curPlayer, action)
board, self.curPlayer = self.game.getNextState(board, self.curPlayer, action)

r = self.game.getGameEnded(self.board, self.curPlayer)
r = self.game.getGameEnded(board, self.curPlayer)

if r!=0:
return [(x[0],x[2],r*((-1)**(x[1]!=self.curPlayer))) for x in trainExamples]
Expand All @@ -66,25 +70,44 @@ def learn(self):
only if it wins >= updateThreshold fraction of games.
"""

trainExamples = deque([], maxlen=self.args.maxlenOfQueue)
for i in range(self.args.numIters):
for i in range(1, self.args.numIters+1):
# bookkeeping
print('------ITER ' + str(i+1) + '------')
eps_time = AverageMeter()
bar = Bar('Self Play', max=self.args.numEps)
end = time.time()

for eps in range(self.args.numEps):
self.mcts = MCTS(self.game, self.nnet, self.args) # reset search tree
trainExamples += self.executeEpisode()

# bookkeeping + plot progress
eps_time.update(time.time() - end)
print('------ITER ' + str(i) + '------')
# examples of the iteration
if not self.skipFirstSelfPlay or i>1:
iterationTrainExamples = deque([], maxlen=self.args.maxlenOfQueue)

eps_time = AverageMeter()
bar = Bar('Self Play', max=self.args.numEps)
end = time.time()
bar.suffix = '({eps}/{maxeps}) Eps Time: {et:.3f}s | Total: {total:} | ETA: {eta:}'.format(eps=eps+1, maxeps=self.args.numEps, et=eps_time.avg,
total=bar.elapsed_td, eta=bar.eta_td)
bar.next()
bar.finish()

for eps in range(self.args.numEps):
self.mcts = MCTS(self.game, self.nnet, self.args) # reset search tree
iterationTrainExamples += self.executeEpisode()

# bookkeeping + plot progress
eps_time.update(time.time() - end)
end = time.time()
bar.suffix = '({eps}/{maxeps}) Eps Time: {et:.3f}s | Total: {total:} | ETA: {eta:}'.format(eps=eps+1, maxeps=self.args.numEps, et=eps_time.avg,
total=bar.elapsed_td, eta=bar.eta_td)
bar.next()
bar.finish()

# save the iteration examples to the history
self.trainExamplesHistory.append(iterationTrainExamples)

if len(self.trainExamplesHistory) > self.args.numItersForTrainExamplesHistory:
print("len(trainExamplesHistory) =", len(self.trainExamplesHistory), " => remove the oldest trainExamples")
self.trainExamplesHistory.pop(0)
# backup history to a file
# NB! the examples were collected using the model from the previous iteration, so (i-1)
self.saveTrainExamples(i-1)

# shuffle examlpes before training
trainExamples = []
for e in self.trainExamplesHistory:
trainExamples.extend(e)
shuffle(trainExamples)

# training new network, keeping a copy of the old one
self.nnet.save_checkpoint(folder=self.args.checkpoint, filename='temp.pth.tar')
Expand All @@ -103,8 +126,35 @@ def learn(self):
if pwins+nwins > 0 and float(nwins)/(pwins+nwins) < self.args.updateThreshold:
print('REJECTING NEW MODEL')
self.nnet.load_checkpoint(folder=self.args.checkpoint, filename='temp.pth.tar')

else:
print('ACCEPTING NEW MODEL')
self.nnet.save_checkpoint(folder=self.args.checkpoint, filename='checkpoint_' + str(i) + '.pth.tar')
self.nnet.save_checkpoint(folder=self.args.checkpoint, filename=self.getCheckpointFile(i))
self.nnet.save_checkpoint(folder=self.args.checkpoint, filename='best.pth.tar')

def getCheckpointFile(self, iteration):
return 'checkpoint_' + str(iteration) + '.pth.tar'

def saveTrainExamples(self, iteration):
folder = self.args.checkpoint
if not os.path.exists(folder):
os.makedirs(folder)
filename = os.path.join(folder, self.getCheckpointFile(iteration)+".examples")
with open(filename, "wb+") as f:
Pickler(f).dump(self.trainExamplesHistory)
f.closed

def loadTrainExamples(self):
modelFile = os.path.join(self.args.load_folder_file[0], self.args.load_folder_file[1])
examplesFile = modelFile+".examples"
if not os.path.isfile(examplesFile):
print(examplesFile)
r = input("File with trainExamples not found. Continue? [y|n]")
if r != "y":
sys.exit()
else:
print("File with trainExamples found. Read it.")
with open(examplesFile, "rb") as f:
self.trainExamplesHistory = Unpickler(f).load()
f.closed
# examples based on the model were already collected (loaded)
self.skipFirstSelfPlay = True
3 changes: 3 additions & 0 deletions README.md
Expand Up @@ -33,3 +33,6 @@ While the current code is fairly functional, we could benefit from the following
* [MBoss](https://github.com/1424667164) contributed rules and a model for GoBang.

Thanks to [pytorch-classification](https://github.com/bearpaw/pytorch-classification) and [progress](https://github.com/verigak/progress).

### AlphaGo / AlphaZero Events
* February 8, 2018 - [Solving Alpha Go Zero + TensorFlow, Kubernetes-based Serverless AI Models on GPU](https://www.meetup.com/Advanced-Spark-and-TensorFlow-Meetup/events/245308722/)
5 changes: 5 additions & 0 deletions main.py
Expand Up @@ -16,6 +16,8 @@
'checkpoint': './temp/',
'load_model': False,
'load_folder_file': ('/dev/models/8x100x50','best.pth.tar'),
'numItersForTrainExamplesHistory': 20,

})

if __name__=="__main__":
Expand All @@ -26,4 +28,7 @@
nnet.load_checkpoint(args.load_folder_file[0], args.load_folder_file[1])

c = Coach(g, nnet, args)
if args.load_model:
print("Load trainExamples from file")
c.loadTrainExamples()
c.learn()
3 changes: 3 additions & 0 deletions tictactoe/README.md
Expand Up @@ -19,3 +19,6 @@ I trained a Keras model for 3x3 TicTacToe (3 iterations, 25 episodes, 10 epochs
* [Evgeny Tyurin](https://github.com/evg-tyurin)

The implementation is based on the game of Othello (https://github.com/suragnair/alpha-zero-general/tree/master/othello).

### AlphaGo / AlphaZero Events
* February 8, 2018 - [Solving Alpha Go Zero + TensorFlow, Kubernetes-based Serverless AI Models on GPU](https://www.meetup.com/Advanced-Spark-and-TensorFlow-Meetup/events/245308722/)
13 changes: 10 additions & 3 deletions tictactoe/TicTacToeGame.py
Expand Up @@ -97,10 +97,14 @@ def stringRepresentation(self, board):
def display(board):
n = board.shape[0]

print(" ", end="")
for y in range(n):
print (y,"|",end="")
print (y,"", end="")
print("")
print(" -----------------------")
print(" ", end="")
for _ in range(n):
print ("-", end="-")
print("--")
for y in range(n):
print(y, "|",end="") # print the row #
for x in range(n):
Expand All @@ -114,4 +118,7 @@ def display(board):
print("- ",end="")
print("|")

print(" -----------------------")
print(" ", end="")
for _ in range(n):
print ("-", end="-")
print("--")
21 changes: 12 additions & 9 deletions tictactoe/TicTacToePlayers.py
Expand Up @@ -31,14 +31,17 @@ def play(self, board):
for i in range(len(valid)):
if valid[i]:
print(int(i/self.game.n), int(i%self.game.n))
while True:
a = input()

x,y = [int(x) for x in a.split(' ')]
a = self.game.n * x + y if x!= -1 else self.game.n ** 2
if valid[a]:
break
else:
print('Invalid')
while True:
# Python 3.x
a = input()
# Python 2.x
# a = raw_input()

x,y = [int(x) for x in a.split(' ')]
a = self.game.n * x + y if x!= -1 else self.game.n ** 2
if valid[a]:
break
else:
print('Invalid')

return a

0 comments on commit 1fb1255

Please sign in to comment.