In [0]:
import TensorFlow

In [0]:
/// A policy asking the user to provide the next move.
public class HumanPolicy: Policy {

  public let participantName: String

  public init(participantName: String) {
      self.participantName = participantName
  }

  public func nextMove(for boardState: BoardState, after previousMove: Move?) -> Move {
    let legalMoves = boardState.legalMoves
    guard !legalMoves.isEmpty else {
      return .pass
    }

    func validator(_ position: Position) throws {
      guard legalMoves.contains(position) else {
        throw HumanInputError.invalidInput(message: "The move is not legal.")
      }
    }
    guard let position = promptAndReadMove(validatingWith: validator) else {
      return .pass
    }
    return .place(position: position)
  }
}

enum HumanInputError: Error {
  case emptyInput
  case invalidInput(message: String)
}

/// Gets the next move from user via stdio.
fileprivate func promptAndReadMove(validatingWith validator: (Position) throws -> ()) -> Position? {
  while true {
    do {
      print("Your input (x: -1, y: -1) means `pass`:")
      print("x: ", terminator: "")
      let x = try readCoordinate()
      print("y: ", terminator: "")
      let y = try readCoordinate()

      if x == -1 && y == -1 {
        return nil  // User chooses `pass`.
      }

      let position = Position(x: x, y: y)
      try validator(position)
      return position
    } catch let HumanInputError.invalidInput(message) {
      print("The input is invalid: \(message)")
      print("Please try again!")
    } catch HumanInputError.emptyInput {
      print("Empty input is now allowed.")
      print("Please try again!")
    } catch {
      print("Unknown error: \(error)")
      print("Please try again!")
    }
  }
}

fileprivate func readCoordinate() throws -> Int {
  guard let line = readLine() else {
    throw HumanInputError.emptyInput
  }
  guard let coordinate = Int(line) else {
    throw HumanInputError.invalidInput(message: "Coordinate must be Int.")
  }
  return coordinate
}
/// A policy generating the next move randomly.
public class RandomPolicy: Policy {

  public let participantName: String

  public init(participantName: String) {
    self.participantName = participantName
  }

  public func nextMove(for boardState: BoardState, after previousMove: Move?) -> Move {
    let legalMoves = boardState.legalMoves
    guard !legalMoves.isEmpty else {
      return .pass
    }

    if case .pass? = previousMove {
      // If `previousMove` is nil, it means it is a new game. This does not count as opponent pass.
      //
      // If opponent passed, this random algorithrm should be smarter a little to avoid commiting
      // stupid move lowing the current score.
      return chooseMoveWithoutLoweringScore(for: boardState)
    }

    guard let randomMove = legalMoves.randomElement() else {
      fatalError("randomElement should not return nil for non-empty legal moves: \(legalMoves).")
    }
    return .place(position: randomMove)
  }

  private func chooseMoveWithoutLoweringScore(for boardState: BoardState) -> Move {
    var legalMoves = boardState.legalMoves
    precondition(!legalMoves.isEmpty)

    let currentPlayerColor = boardState.nextPlayerColor
    let currentScore = boardState.score(for: currentPlayerColor)

    // Instead of sequentially go through `legalMoves`, we sample the move each time to ensure
    // randomness.
    repeat {
      let sampleIndex = Int.random(in: 0..<legalMoves.count)
      let candidate = legalMoves[sampleIndex]

      let newBoardState = try! boardState.placingNewStone(at: candidate)
      let newScore = newBoardState.score(for: currentPlayerColor)

      if newScore > currentScore {
        return .place(position: candidate)
      }
      legalMoves.remove(at: sampleIndex)
    } while !legalMoves.isEmpty
    // If no better choice, then pass.
    return .pass
  }
}

/// The configuration for MCTS algorithm.
public struct MCTSConfiguration {
  /// The configuration of the Go game.
  let gameConfiguration: GameConfiguration

  /// The total number of simulations to run for each move.
  let simulationCountForOneMove: Int

  /// The maximum game depth to expand the tree during simulation.
  ///
  /// The maximum game depth is compared with `BoardState.playedMoveCount`. Once reached, score the
  /// board immediately. This is used to avoid infinite game plays during simulation.
  let maxGameDepth: Int

  public enum ExplorationOption {
    /// Disable exploration in MCTS algorithm when selecting a move.
    ///
    /// This should be used for real game play to generate strongest move.
    case noExploration

    /// Enable exploration in early stage of the game.
    ///
    /// To be specific, if the `BoardState.playedMoveCount` is no larger than the
    /// `maximumMoveCountToExplore`, enable the / exploration to select move. This helps improving
    /// the early stage diversity.
    ///
    /// This is recommended for self plays to generate training data.
    case exploreMovesInEarlyStage(maximumMoveCountToExplore: Int)
  }

  /// The configuration for exploration.
  let explorationOption: ExplorationOption

  // The default value, 1600, for `simulationCountForOneMove` was the number used by the AlphaGoZero
  // paper.
  //
  // If `maxGameDepth` is `nil`, it will be set to (gameConfiguration.size)^2 * 1.4 according to the
  // MiniGo reference model, i.e., 505 moves for 19x19, 113 for 9x9. The AlphaGo paper chooses 2.0
  // instead of 1.4.
  public init(
    gameConfiguration: GameConfiguration,
    simulationCountForOneMove: Int = 1600,
    maxGameDepth: Int? = nil,
    explorationOption: ExplorationOption = .noExploration
  ) {
    self.gameConfiguration = gameConfiguration

    precondition(simulationCountForOneMove > 0)
    self.simulationCountForOneMove = simulationCountForOneMove

    let maxGameDepthValue = maxGameDepth ??
      Int(Float(gameConfiguration.size * gameConfiguration.size) * 1.4)
    precondition(maxGameDepthValue > 0)
    self.maxGameDepth = maxGameDepthValue

    self.explorationOption = explorationOption
  }
}

/// A random `MCTSPredictor` predicting the next move and reward with random numbers.
///
/// This is mainly for testing and debugging purposes.
public class MCTSRandomPredictor: MCTSPredictor {
  private let boardSize: Int

  public init(boardSize: Int) {
    self.boardSize = boardSize
  }

  public func prediction(for boardState: BoardState) -> MCTSPrediction {
    let distribution = MCTSPrediction.Distribution(
      positions: ShapedArray<Float>(shape: [boardSize, boardSize], repeating: 1.0),
      pass: 1.0)

    // Randomize the reward (range: 0.0 +/- 0.05).
    let reward = 0.0 + (Float(Int.random(in: 0..<100)) - 50.0) / 1000.0
    return MCTSPrediction(rewardForNextPlayer: reward, distribution: distribution)
  }
}

public protocol InferenceModel {
  /// Predicts the model output based on input tensor.
  func prediction(input: Tensor<Float>) -> GoModelOutput
}

/// A ResNet-like model based `MCTSPredictor` predicting the next move and reward.
public class MCTSModelBasedPredictor: MCTSPredictor {
  // Maintainer Note: As of Feb 2019, `Tensor` shape related APIs require `Int32`, but
  // `ShapedArray` requires `Int` for shape. So, we have to pick one here. Consider the fact we will
  // use `Tensor` anyway in future but might change another implementation to replace `ShapedArray`,
  // `Int32` is chosen here.
  private let boardSize: Int32
  private var model: InferenceModel

  public init(boardSize: Int, model: InferenceModel) {
    guard boardSize == 19 else {
      fatalError("GoModel only supports boardSize=19 for now.")
    }
    self.boardSize = Int32(boardSize)
    self.model = model
  }

  public func prediction(for boardState: BoardState) -> MCTSPrediction {
    let modelInput = boardState.featurePlanes()
    let inference = model.prediction(input: modelInput)

    let policy = inference.policy.flattened()
    assert(policy.shape == [boardSize*boardSize + 1])

    // The first boardSize * boardSize elements are placed in `positions`.
    // The final value is for `pass`.
    let distribution = MCTSPrediction.Distribution(
      positions: policy[0..<boardSize*boardSize].reshaped(to: [boardSize, boardSize]).array,
      pass: policy[policy.scalarCount - 1].scalarized())

    assert(inference.value.shape == [1])
    var reward = inference.value.scalarized()

    // We occasionally see the model output falls out of the expected range, which should never
    // happen given the final activation funciton is `tanh`.
    //
    // To avoid crash, we log the case here and do value clipping.
    if reward > 1.0 {
      print("Reward is out of range: value \(reward). \n \(boardState)")
      reward = 1.0
    }
    if reward < -1.0 {
      print("Reward is out of range: value \(reward). \n \(boardState)")
      reward = -1.0
    }
    return MCTSPrediction(rewardForNextPlayer: reward, distribution: distribution)
  }
}

extension BoardState {

  /// Returns the feature planes as `Tensor`.
  ///
  /// The output `Tensor` has shape `[1, boardSize, boardSize, 17]`.
  ///
  /// For reference, see the AlphaGo Zero paper, Section: Method -> Neural network architecture.
  fileprivate func featurePlanes() -> Tensor<Float> {
    assert(gameConfiguration.maxHistoryCount <= 7, "Only support at most 8 board states in total.")

    let boardSize = gameConfiguration.size

    var featurePlanes = ShapedArray<Float>(shape: [17, boardSize, boardSize], repeating: 0.0)

    // The first 16 feature planes represent recent board states. Each board state needs two planes.
    //
    // First, sets the feature planes for the current board state.
    var featurePlanesForOldestBoard = self.board.binaryFeaturePlanes(
      activePlayerColor: self.nextPlayerColor)
    featurePlanes[0...1] = featurePlanesForOldestBoard

    // Then, sets the feature planes for each board state in history.
    var nextIndex = 2
    for boardInHistory in self.history {
      featurePlanesForOldestBoard = boardInHistory.binaryFeaturePlanes(
        activePlayerColor: self.nextPlayerColor)
      featurePlanes[nextIndex...nextIndex+1] = featurePlanesForOldestBoard
      nextIndex += 2
    }

    // Finally, sets the remaining feature planes as the one for the last one.
    //
    // AlphaGo sets the remaining feature planes as all zeros (see "Method" -> "Neural network
    // architecture" section). But MiniGo reference model repeats the oldest board. We followed the
    // latter here.
    assert(nextIndex == (self.history.count + 1) * 2)
    while nextIndex < 16 {
      featurePlanes[nextIndex...(nextIndex+1)] = featurePlanesForOldestBoard
      nextIndex += 2
    }

    // The final feature plane represents the color to play. 1.0 if black is to play or 0.0 if white
    // is to play.
    featurePlanes[16] = ShapedArraySlice<Float>(
      shape: [boardSize, boardSize],
      repeating: self.nextPlayerColor == .black ? 1.0 : 0.0)

    let featureTensor = Tensor(featurePlanes)

    // The Go prediction network expects the input tensor to be in `[batch, boardSize, boardSize,
    // featurePlanes]` order.
    //
    // Rotate our inputs to this order by transposing and reshape to a single-element batch.
    return featureTensor.transposed(
      withPermutations: 1, 2, 0
    ).reshaped(to: [1, Int32(boardSize), Int32(boardSize), 17])
  }
}

extension Board {
  /// Converts the board state to binary feature planes.
  ///
  /// The first plane has 1 on position where the stone's color matches `activePlayerColor`.
  /// The second plane has 1 on position where the stone's color matches opponent's color.
  ///
  /// `activePlayerColor` is not same as `BoardState.nextPlayerColor`. The active player color (the
  /// player to play) being filled depends on the point in history we are filling.
  ///
  /// Consider a board state A, white is playing and the state history is [A, B, C, D].
  ///
  /// Calls to binaryFeaturePlanes will be:
  ///
  ///     binaryFeaturePlanes(state: A, activePlayerColor: .white)
  ///     binaryFeaturePlanes(state: B, activePlayerColor: .white)
  ///     binaryFeaturePlanes(state: C, activePlayerColor: .white)
  ///     binaryFeaturePlanes(state: D, activePlayerColor: .white)
  fileprivate func binaryFeaturePlanes(activePlayerColor: Color) -> ShapedArraySlice<Float> {
    let boardSize = self.size

    let opponentColor: Color = activePlayerColor == .black ? .white : .black

    var result = ShapedArraySlice<Float>(shape: [2, boardSize, boardSize], repeating: 0.0)
    for x in 0..<boardSize {
      for y in 0..<boardSize {
        guard let stoneColor = color(at: Position(x: x, y: y)) else {
          continue
        }

        if stoneColor == activePlayerColor {
          result[0][x][y] = ShapedArraySlice(1.0)
        } else {
          assert(stoneColor == opponentColor)
          result[1][x][y] = ShapedArraySlice(1.0)
        }
      }
    }
    return result
  }
}
/// Tree node for the MCTS algorithm.
class MCTSNode {
  private let boardSize: Int

  /// Total visited count for this node during simulations.
  private var totalVisitedCount: Int = 0

  /// The corresponding board state for this node.
  let boardState: BoardState

  /// All children (nodes) for this node in the `MCTSTree`.
  var children: [Move: MCTSNode] = [:]

  private struct Action {
    let move: Move
    var prior: Float
    var qValueTotal: Float
    var visitedCount: Int
  }

  /// The `actionSpace` consists of all legal actions for current `BoardState`. The first action is
  /// `.pass`, followed by all legal positions.
  ///
  /// Note: `prior` in `actionSpace` must be normalized to form a valid probability.
  private var actionSpace: [Action]

  /// Creates a MCTS node.
  ///
  /// - Precondition: The `distribution` is not expected to be normalized. And it is allowed to
  /// have positive values for illegal positions.
  init(
    boardSize: Int,
    boardState: BoardState,
    distribution: MCTSPrediction.Distribution
  ) {
    self.boardSize = boardSize
    self.boardState = boardState

    var actions: [Move] = [.pass]  // .pass must be the first one.
    actions.reserveCapacity(boardState.legalMoves.count + 1)
    boardState.legalMoves.forEach {
      actions.append(.place(position: $0))
    }

    var priorOverActions = Array(repeating: Float(0), count: actions.count)
    var sum: Float = 0
    for (index, action) in actions.enumerated() {
      let prior: Float
      switch action {
      case .pass:
        assert(index == 0)
        prior = distribution.pass
      case .place(let position):
        assert(index > 0)
        prior = distribution.positions[position.x][position.y].scalars[0]
      }
      sum += prior
      priorOverActions[index] = prior
    }

    self.actionSpace = actions.enumerated().map {
      Action(move: $1, prior: priorOverActions[$0] / sum, qValueTotal: 0.0, visitedCount: 0)
    }
  }
}

/// Supports the node backing up.
extension MCTSNode {
  /// Backs up the reward.
  func backUp(for move: Move, withRewardForBlackPlayer rewardForBlackPlayer: Float) {
    guard let index = actionSpace.firstIndex(where: { $0.move == move }) else {
      fatalError("The action \(move) taken must be legal (all legal actions: \(actionSpace)).")
    }

    totalVisitedCount += 1
    actionSpace[index].visitedCount += 1
    actionSpace[index].qValueTotal += rewardForBlackPlayer *
      (boardState.nextPlayerColor == .black ? 1.0 : -1.0)
  }
}

/// Supports selecting the action.
extension MCTSNode {
  /// Returns the next move to take based on current learned statistic in Node.
  func nextMove(withExplorationEnabled: Bool) -> Move {
    precondition(totalVisitedCount > 0, "The node has not been visited after creation.")
    if withExplorationEnabled {
      return sampleFromPMF(actionSpace) { Float($0.visitedCount) }.move
    } else {
      return maxScoringElement(actionSpace) { Float($0.visitedCount) }.move
    }
  }

  /// Selects the action based on PUCT algorithm for simulation.
  ///
  /// PUCT stands for predictor + UCT, where UCT stands for UCB applied to trees. The
  /// action is selected based on the statistic in the search tree and has some levels of
  /// exploration. Initially, this algorithm prefers action with high prior probability and low
  /// visit count but asymptotically prefers action with high action value.
  ///
  /// See the AlphaGoZero paper and its references for details.
  var actionByPUCT: Move {
    guard totalVisitedCount > 0 else {
      // If the node has not be visited after creation, we select the move based on prior
      // probability.
      return nextMoveWithHighestPrior
    }
    return nextMoveWithHighestActionValue
  }
}

extension MCTSNode {
  private var nextMoveWithHighestPrior: Move {
    return maxScoringElement(actionSpace) { $0.prior }.move
  }

  private var nextMoveWithHighestActionValue: Move {
    return maxScoringElement(
      actionSpace,
      withScoringFunction: {
        // See the AlphaGoZero paper ("Methods" -> "Select" section) for the formula of action
        // value.
        let visitedCount = $0.visitedCount

        var actionValue = $0.prior *
          (Float(totalVisitedCount) / (1.0 + Float(visitedCount))).squareRoot()

        if visitedCount > 0 {
          actionValue += $0.qValueTotal / Float(visitedCount)
        }
        return actionValue
      }).move
  }
}

extension MCTSNode {
  /// A general algorithm to find the element with highest score. If there are multiple,
  /// breaks the tie randomly.
  private func maxScoringElement<T>(
    _ elements: [T],
    withScoringFunction scoringFunction: (T) -> Float
  ) -> T {
    precondition(elements.count > 0)
    var candidateIndexes = [0]
    var highestValue = scoringFunction(elements[0])

    for index in 1..<elements.count {
      let v = scoringFunction(elements[index])
      if v > highestValue {
        highestValue = v
        candidateIndexes.removeAll()
        candidateIndexes.append(index)
      } else if abs(v - highestValue) < .ulpOfOne {
        precondition(candidateIndexes.count > 0)
        candidateIndexes.append(index)
      }
    }

    let candidateCount = candidateIndexes.count
    assert(candidateCount > 0)

    // Breaks the tie randomly.
    let candidateIndex = candidateIndexes[Int.random(in: 0..<candidateCount)]
    return elements[candidateIndex]
  }

  /// Samples an element according to the PMF.
  private func sampleFromPMF<T>(_ elements: [T], with pmfFunction: (T) -> Float) -> T {
    precondition(elements.count > 0)
    var cdf: [Float] = []
    var currentSum: Float = 0.0
    for element in elements {
      let probability = pmfFunction(element)
      assert(probability >= 0)
      currentSum += probability
      cdf.append(currentSum)
    }

    let sampleSpace = 10000
    let sample = Int.random(in: 0..<sampleSpace)
    let threshold = Float(sample) / Float(sampleSpace) * currentSum

    for (i, element) in elements.enumerated() where threshold < cdf[i] {
      return element
    }
    return elements[elements.count - 1]
  }
}
/// The MCTS tree used for one game playing.
class MCTSTree {

  private let gameConfiguration: GameConfiguration
  private let predictor: MCTSPredictor

  var root: MCTSNode

  init(gameConfiguration: GameConfiguration, predictor: MCTSPredictor) {
    self.gameConfiguration = gameConfiguration
    self.predictor = predictor

    let emptyBoard = BoardState(gameConfiguration: gameConfiguration)
    let prediction = predictor.prediction(for: emptyBoard)

    let newNode = MCTSNode(
      boardSize: gameConfiguration.size,
      boardState: emptyBoard,
      distribution: prediction.distribution)
    root = newNode
  }

  func promoteNewRoot(after previousMove: Move?) -> MCTSNode {
    guard let action = previousMove else {
      // Game just started. Returns the root node directly.
      assert(root.boardState.playedMoveCount == 0)
      return root
    }

    // Tries to find the new root if it is already one of the children of current root.
    if let newRoot = root.children[action] {
      root = newRoot
      return newRoot
    }

    // Creates a new node and promotes it.
    let newBoardState: BoardState
    switch action {
    case .pass:
      newBoardState = root.boardState.passing()
    case .place(let position):
      do {
        newBoardState = try root.boardState.placingNewStone(at: position)
      } catch {
        fatalError("MCTS algorithm should never emit an illegal action. Got error: \(error).")
      }
    }

    // creates the new node by calling the predictor.
    let prediction = predictor.prediction(for: newBoardState)

    root = MCTSNode(
      boardSize: gameConfiguration.size,
      boardState: newBoardState,
      distribution: prediction.distribution)
    return root
  }

  enum NodeKind {
    case existingNode(node: MCTSNode)
    case newNode(node: MCTSNode, rewardForNextPlayer: Float)
  }

  /// Returns the child node for `action`.
  func child(of node: MCTSNode, for action: Move) -> NodeKind {
    if let child = node.children[action] {
      return .existingNode(node: child)
    }

    let newBoardState: BoardState
    switch action {
    case .pass:
      newBoardState = node.boardState.passing()
    case .place(let position):
      do {
        newBoardState = try node.boardState.placingNewStone(at: position)
      } catch {
        fatalError("MCTS algorithm should never emit an illegal action. Got error: \(error).")
      }
    }

    // Creates the new node by calling the predictor.
    let prediction = predictor.prediction(for: newBoardState)

    // TODO(xiejw): Implement noise injection for predictions.
    let newNode = MCTSNode(
      boardSize: gameConfiguration.size,
      boardState: newBoardState,
      distribution: prediction.distribution)
    node.children[action] = newNode
    return .newNode(node: newNode, rewardForNextPlayer: prediction.rewardForNextPlayer)
  }
}
/// The Monte Carlo tree search (MCTS) algorithrm based policy.
public class MCTSPolicy: Policy {
  public let participantName: String

  private let configuration: MCTSConfiguration
  private var tree: MCTSTree

  public init(participantName: String, predictor: MCTSPredictor, configuration: MCTSConfiguration) {
    self.participantName = participantName
    self.configuration = configuration
    self.tree = MCTSTree(gameConfiguration: configuration.gameConfiguration, predictor: predictor)
  }

  public func nextMove(for boardState: BoardState, after previousMove: Move?) -> Move {
    // Stage 1: Promote the corresponding node in the tree to become the new root. This purges any
    // old nodes which are not used anymore.
    let root = tree.promoteNewRoot(after: previousMove)
    assert(
      root.boardState == boardState,
      "Expected board \(boardState),\n got: \(root.boardState).")

    // Stage 2: Runs simulations to expand the tree.
    for _ in 0..<configuration.simulationCountForOneMove {
      runOneSimulation(previousMove: previousMove)
    }

    // Stage 3: Select a move.
    var exploreMove = false
    if case .exploreMovesInEarlyStage(let maximumMoveCount) = configuration.explorationOption,
       boardState.playedMoveCount <= maximumMoveCount {
      exploreMove = true
    }

    // Stage 4: Promotes again to future trim the tree.
    let move = root.nextMove(withExplorationEnabled: exploreMove)
    _ = tree.promoteNewRoot(after: move)
    return move
  }

  private func runOneSimulation(previousMove: Move?) {
    var consecutivePassCount = 0
    if case .pass? = previousMove {
      consecutivePassCount = 1
    }

    var rewardForBlackPlayer: Float?
    var visitedActions: [Move] = []
    var visitedMCTSNodes: [MCTSNode] = []

    var currentNodeKind: MCTSTree.NodeKind = .existingNode(node: tree.root)

    // Expands the tree until the current game is finished, a new node is seen, or the maximum game
    // depth is reached.
    explandTree: while true {
      assert(consecutivePassCount <= 1)

      switch currentNodeKind {
      case let .newNode(node, rewardForNextPlayer):
        // Reaches a new node: Expand this node and return.
        var reward = rewardForNextPlayer
        if node.boardState.nextPlayerColor == .white {
          reward *= -1.0
        }
        rewardForBlackPlayer = reward
        break explandTree

      case .existingNode(let node):
        let action = node.actionByPUCT
        visitedActions.append(action)
        visitedMCTSNodes.append(node)

        // To avoid infinite tree expansion, quit expanding once max game depth is reached.
        if node.boardState.playedMoveCount >= configuration.maxGameDepth {
          rewardForBlackPlayer = node.boardState.rewardForBlackPlayer()
          break explandTree
        }

        switch action {
        case .pass:
          consecutivePassCount += 1

          // Two consecutive passes end the game. Quit expanding at this point.
          if consecutivePassCount == 2 {
            rewardForBlackPlayer = node.boardState.rewardForBlackPlayer()
            break explandTree
          }

        case .place(_):
          consecutivePassCount = 0
        }

        currentNodeKind = tree.child(of: node, for: action)
      }
    }

    // Backup the reward to all visited notes.
    assert(visitedActions.count == visitedMCTSNodes.count)
    guard let finalRewardForBlackPlayer = rewardForBlackPlayer else {
      fatalError("The reward must be set during simulation before quiting the tree expanding.")
    }
    for (index, node) in visitedMCTSNodes.enumerated() {
      node.backUp(
        for: visitedActions[index],
        withRewardForBlackPlayer: finalRewardForBlackPlayer)
    }
  }
}

extension BoardState {
  /// Converts the score to reward for black player.
  ///
  /// Score is a numerical value according to the scoring rule, like area score. Reward takes
  /// binary values, 1 for win and -1 for lose.
  fileprivate func rewardForBlackPlayer() -> Float {
    let scoreForBlackPlayer = score(for: .black)
    switch scoreForBlackPlayer.sign {
    case .plus: return 1.0
    case .minus: return -1.0
    }
  }
}
import TensorFlow

public struct MCTSPrediction {
  /// The reward is for the next player for current board state.
  ///
  /// The reward has float value in range [-1, 1].
  let rewardForNextPlayer: Float

  struct Distribution {
    let positions: ShapedArray<Float>
    let pass: Float
  }

  /// The probability distribution over the position, including `pass`, to take for next move.
  ///
  /// The distribution is over all positions (including pass and illegal ones). It does not
  /// need to be normalized.
  let distribution: Distribution

  init(rewardForNextPlayer: Float, distribution: Distribution) {
    precondition(rewardForNextPlayer >= -1.0 && rewardForNextPlayer <= 1.0)
    self.rewardForNextPlayer = rewardForNextPlayer
    self.distribution = distribution
  }
}

/// Predicts the reward and distribution over positions for current `boardState`.
///
/// Predictor must be stateless and is expected to be used in multiple threads.
public protocol MCTSPredictor: class {
  func prediction(for boardState: BoardState) -> MCTSPrediction
}

/// Holds the next move generated by a `Policy`.
public enum Move: Equatable, Hashable {
  case pass
  case place(position: Position)
}

/// The strategy for playing a game.
///
/// Policy is not thread safe and should be used for one game only.
public protocol Policy {

  /// The name of the game participant.
  var participantName: String { get }

  /// Returns the next move for the current `BoardState` after previous move.
  /// The previous move is nil if the game just starts.
  func nextMove(for boardState: BoardState, after previousMove: Move?) -> Move
}
import TensorFlow
// implements the same architecture as https://github.com/tensorflow/minigo/blob/master/dual_net.py

public struct ModelConfiguration {
  /// size of Go board (typically 9 or 19)
  let boardSize: Int
  /// output feature count of conv layers in shared trunk
  let convWidth: Int
  /// output feature count of conv layer in policy head
  let policyConvWidth: Int
  /// output feature count of conv layer in value head
  let valueConvWidth: Int
  /// output feature count of dense layer in value head
  let valueDenseWidth: Int
  /// number of layers (typically equal to boardSize)
  let layerCount: Int

  public init(boardSize: Int) {
    self.boardSize = boardSize
    self.convWidth = boardSize == 19 ? 256 : 32
    self.policyConvWidth = 2
    self.valueConvWidth = 1
    self.valueDenseWidth = boardSize == 19 ? 256 : 64
    self.layerCount = boardSize
  }
}

struct ConvBN: Layer {
  var conv: Conv2D<Float>
  var norm: BatchNorm<Float>

  init(
    filterShape: (Int, Int, Int, Int),
    strides: (Int, Int) = (1, 1),
    padding: Padding,
    bias: Bool = true,
    affine: Bool = true) {
    // TODO(jekbradbury): thread through bias and affine boolean arguments
    // (behavior is correct for inference but this should be changed for training)
    self.conv = Conv2D(filterShape: filterShape, strides: strides, padding: padding)
    self.norm = BatchNorm(
      featureCount: filterShape.3,
      momentum: Tensor<Float>(0.95),
      epsilon: Tensor<Float>(1e-5))
  }

  @differentiable(wrt: (self, input))
  func applied(to input: Tensor<Float>, in context: Context) -> Tensor<Float> {
    return norm.applied(to: conv.applied(to: input, in: context), in: context)
  }
}

extension ConvBN: LoadableFromPythonCheckpoint {
  mutating func load(from reader: PythonCheckpointReader) {
    conv.load(from: reader)
    norm.load(from: reader)
  }
}

struct ResidualIdentityBlock: Layer {
  var layer1: ConvBN
  var layer2: ConvBN

  public init(featureCounts: (Int, Int), kernelSize: Int = 3) {
    self.layer1 = ConvBN(
      filterShape: (kernelSize, kernelSize, featureCounts.0, featureCounts.1),
      padding: .same,
      bias: false)

    self.layer2 = ConvBN(
      filterShape: (kernelSize, kernelSize, featureCounts.1, featureCounts.1),
      padding: .same,
      bias: false)
  }

  @differentiable(wrt: (self, input))
  func applied(to input: Tensor<Float>, in context: Context) -> Tensor<Float> {
    var tmp = relu(layer1.applied(to: input, in: context))
    tmp = layer2.applied(to: tmp, in: context)
    return relu(tmp + input)
  }
}

extension ResidualIdentityBlock: LoadableFromPythonCheckpoint {
  mutating func load(from reader: PythonCheckpointReader) {
    layer1.load(from: reader)
    layer2.load(from: reader)
  }
}

// This is needed because we can't conform tuples to protocols
public struct GoModelOutput: Differentiable {
  public let policy: Tensor<Float>
  public let value: Tensor<Float>
  public let logits: Tensor<Float>
}

// This might be needed when we add training to work around an AD bug for memberwise initializers
// @differentiable(wrt: (policy, value, logits), vjp: _vjpMakeGoModelOutput)
// func makeGoModelOutput(
//   policy: Tensor<Float>, value: Tensor<Float>, logits: Tensor<Float>)
//   -> GoModelOutput {
//   return GoModelOutput(policy: policy, value: value, logits: logits)
// }
// func _vjpMakeGoModelOutput(
//   policy: Tensor<Float>, value: Tensor<Float>, logits: Tensor<Float>)
//   -> (GoModelOutput, (GoModelOutput.CotangentVector)
//   -> (Tensor<Float>, Tensor<Float>, Tensor<Float>)) {
//   let result = GoModelOutput(policy: policy, value: value, logits: logits)
//   return (result, { seed in (seed.policy, seed.value, seed.logits) })
// }

public struct GoModel: Layer {
  @noDerivative let configuration: ModelConfiguration
  var initialConv: ConvBN
  // TODO(jekbradbury): support differentiation wrt residualBlocks
  // [T] where T: Differentiable doesn't (shouldn't?) conform to Differentiable,
  // so we will likely need a LayerArray<T> where T: Layer type. But this
  // itself won't work until we have better generics support, and even then
  // T can't be an existential Layer. So it's @noDerivative for now.
  @noDerivative var residualBlocks: [ResidualIdentityBlock]
  var policyConv: ConvBN
  var policyDense: Dense<Float>
  var valueConv: ConvBN
  var valueDense1: Dense<Float>
  var valueDense2: Dense<Float>

  public init(configuration: ModelConfiguration) {
    self.configuration = configuration
    initialConv = ConvBN(
      filterShape: (3, 3, 17, configuration.convWidth),
      padding: .same,
      bias: false)

    residualBlocks = (1...configuration.boardSize).map { _ in
      ResidualIdentityBlock(featureCounts: (configuration.convWidth, configuration.convWidth))
    }

    policyConv = ConvBN(
      filterShape: (1, 1, configuration.convWidth, configuration.policyConvWidth),
      padding: .same,
      bias: false,
      affine: false)
    policyDense = Dense<Float>(
      inputSize: configuration.policyConvWidth * configuration.boardSize * configuration.boardSize,
      outputSize: configuration.boardSize * configuration.boardSize + 1)

    valueConv = ConvBN(
      filterShape: (1, 1, configuration.convWidth, configuration.valueConvWidth),
      padding: .same,
      bias: false,
      affine: false)
    valueDense1 = Dense<Float>(
      inputSize: configuration.valueConvWidth * configuration.boardSize * configuration.boardSize,
      outputSize: configuration.valueDenseWidth,
      activation: relu)
    valueDense2 = Dense<Float>(
      inputSize: configuration.valueDenseWidth,
      outputSize: 1,
      activation: tanh)
  }

  @differentiable(wrt: (self, input), vjp: _vjpApplied)
  public func applied(to input: Tensor<Float>, in context: Context) -> GoModelOutput {
    let batchSize = input.shape[0]
    var output = relu(initialConv.applied(to: input, in: context))

    for i in 0..<configuration.boardSize {
      output = residualBlocks[i].applied(to: output, in: context)
    }

    let policyConvOutput = relu(policyConv.applied(to: output, in: context))
    let logits = policyDense.applied(
      to: policyConvOutput.reshaped(toShape: Tensor<Int32>(
        [batchSize,
         Int32(configuration.policyConvWidth * configuration.boardSize * configuration.boardSize),
        ])),
      in: context)
    let policyOutput = softmax(logits)

    let valueConvOutput = relu(valueConv.applied(to: output, in: context))
    let valueHidden = valueDense1.applied(
      to: valueConvOutput.reshaped(toShape: Tensor<Int32>(
        [batchSize,
         Int32(configuration.valueConvWidth * configuration.boardSize * configuration.boardSize)
        ])),
      in: context)
    let valueOutput = valueDense2.applied(to: valueHidden, in: context).reshaped(
      toShape: Tensor<Int32>([batchSize]))

    return GoModelOutput(policy: policyOutput, value: valueOutput, logits: logits)
  }

  @usableFromInline
  func _vjpApplied(to input: Tensor<Float>, in context: Context)
    -> (GoModelOutput, (GoModelOutput.CotangentVector)
      -> (GoModel.CotangentVector, Tensor<Float>)) {
    // TODO(jekbradbury): add a real VJP
    // (we're only interested in inference for now and have control flow in our applied(to:) method)
    return (applied(to: input, in: context), {
      seed in (GoModel.CotangentVector.zero, Tensor<Float>(0))
    })
  }
}

extension GoModel: InferenceModel {
  public func prediction(input: Tensor<Float>) -> GoModelOutput {
    return applied(to: input, in: Context(learningPhase: .inference))
  }
}

extension GoModel: LoadableFromPythonCheckpoint {
  public mutating func load(from reader: PythonCheckpointReader) {
    initialConv.load(from: reader)
    for i in 0..<configuration.boardSize {
      residualBlocks[i].load(from: reader)
    }

    // Special-case the two batchnorms that lack affine weights.
    policyConv.conv.load(from: reader)
    policyConv.norm.runningMean.value = reader.readTensor(
      layerName: "batch_normalization",
      weightName: "moving_mean")!
    policyConv.norm.runningVariance.value = reader.readTensor(
      layerName: "batch_normalization",
      weightName: "moving_variance")!
    reader.increment(layerName: "batch_normalization")

    policyDense.load(from: reader)

    valueConv.conv.load(from: reader)
    valueConv.norm.runningMean.value = reader.readTensor(
      layerName: "batch_normalization",
      weightName: "moving_mean")!
    valueConv.norm.runningVariance.value = reader.readTensor(
      layerName: "batch_normalization",
      weightName: "moving_variance")!
    reader.increment(layerName: "batch_normalization")

    valueDense1.load(from: reader)
    valueDense2.load(from: reader)
  }
}
import TensorFlow

public class PythonCheckpointReader {
  private let path: String
  private var layerCounts: [String: Int] = [:]

  public init(path: String) {
    self.path = path
  }

  // Currently returns Optional in order to support the case where the variable might not exist, but
  // this is not implemented (see b/124126672)
  func readTensor(layerName: String, weightName: String) -> Tensor<Float>? {
    let countSuffix = layerCounts[layerName] == nil ? "" : "_\(layerCounts[layerName]!)"
    let tensorName = layerName + countSuffix + "/" + weightName
    // TODO(jekbradbury): support variadic dtype attrs in RawOpsGenerated
    return Tensor<Float>(handle: #tfop(
      "RestoreV2",
      StringTensor(path),
      StringTensor([tensorName]),
      StringTensor([""]),
      dtypes$dtype: [Float.tensorFlowDataType]))
  }

  /// Increment a per-layer counter for variable names in the checkpoint file.
  /// As the Python model code uses low-level TensorFlow APIs, variables are namespaced only by
  /// layer name and this per-layer counter (e.g., conv2d_5/bias).
  func increment(layerName: String) {
    layerCounts[layerName, default: 0] += 1
  }
}

private func checkShapes(_ tensor1: Tensor<Float>, _ tensor2: Tensor<Float>) {
  guard tensor1.shape == tensor2.shape else {
    print("Shape mismatch: \(tensor1.shape) != \(tensor2.shape)")
    fatalError()
  }
}

protocol LoadableFromPythonCheckpoint {
  mutating func load(from reader: PythonCheckpointReader)
}

extension Dense: LoadableFromPythonCheckpoint where Scalar == Float {
  mutating func load(from reader: PythonCheckpointReader) {
    let newWeight = reader.readTensor(layerName: "dense", weightName: "kernel")!
    checkShapes(weight, newWeight)
    weight = newWeight

    if let newBias = reader.readTensor(layerName: "dense", weightName: "bias") {
      checkShapes(bias, newBias)
      bias = newBias
    }
    reader.increment(layerName: "dense")
  }
}

extension Conv2D: LoadableFromPythonCheckpoint where Scalar == Float {
  mutating func load(from reader: PythonCheckpointReader) {
    let newFilter = reader.readTensor(layerName: "conv2d", weightName: "kernel")!
    checkShapes(filter, newFilter)
    filter = newFilter

    // TODO(jekbradbury): handle layers with optional weights
    // It would be helpful to have an op to see if a checkpoint contains a particular variable
    // (see b/124126672)
    // if let newBias = loader.readTensor(layerName: "conv2d", weightName: "bias") {
    //   checkShapes(bias, newBias)
    //   bias = newBias
    // }

    reader.increment(layerName: "conv2d")
  }
}

extension BatchNorm: LoadableFromPythonCheckpoint where Scalar == Float {
  mutating func load(from reader: PythonCheckpointReader) {
    if let newOffset = reader.readTensor(layerName: "batch_normalization", weightName: "beta") {
      checkShapes(offset, newOffset)
      offset = newOffset
    }

    if let newScale = reader.readTensor(layerName: "batch_normalization", weightName: "gamma") {
      checkShapes(scale, newScale)
      scale = newScale
    }

    if let newRunningMean = reader.readTensor(
      layerName: "batch_normalization",
      weightName: "moving_mean") {
      // do not check shapes, because Swift running mean/variance are initialized to scalar tensors
      runningMean.value = newRunningMean
    }

    if let newRunningVariance = reader.readTensor(
      layerName: "batch_normalization",
      weightName: "moving_variance") {
      // do not check shapes, because Swift running mean/variance are initialized to scalar tensors
      runningVariance.value = newRunningVariance
    }

    reader.increment(layerName: "batch_normalization")
  }
}

/// Plays one game with participants. The game ends with two passes.
public func playOneGame(gameConfiguration: GameConfiguration, participants: [Policy]) throws {

  var boardState = BoardState(gameConfiguration: gameConfiguration)
  precondition(participants.count == 2, "Must provide two participants.")
  precondition(
    participants[0].participantName !=  participants[1].participantName,
    "Participants' names should not be same.")

  // TODO(xiejw): Choose a random participant to play black.
  let blackPlayer = participants[0]
  let whitePlayer = participants[1]

  var previousMove: Move?
  var consecutivePassCount = 0

  // Loops until we get a winner or tie.
  while true {
    print(boardState)

    if gameConfiguration.isVerboseDebuggingEnabled {
      print("Legal moves: \(boardState.legalMoves.count)")
      print("Stones on board: \(boardState.stoneCount)")
      if let ko = boardState.ko {
        print("Found ko: \(ko).")
      } else {
        print("No ko.")
      }
    }

    // Check whether the game ends with two passes.
    if consecutivePassCount >= 2 {
      print("End of Game. Score for black player: \(boardState.score(for: .black)).")
      break
    }

    let policy: Policy
    switch boardState.nextPlayerColor {
    case .black:
      policy = blackPlayer
      print("-> Black")
    case .white:
      policy = whitePlayer
      print("-> White")
    }

    let move = policy.nextMove(for: boardState, after: previousMove)
    previousMove = move

    switch move {
    case .pass:
      consecutivePassCount += 1
      print("- Pass")
      boardState = boardState.passing()
    case .place(let position):
      consecutivePassCount = 0
      print("- Placing stone at: \(position)")
      boardState = try boardState.placingNewStone(at: position)
    }
  }
}
import TensorFlow

/// Holds the current board stones.
///
/// This struct allows caller to arbitrarily mutate the board information but
/// does not handle validation check for placing new stone. `BoardState` is
/// designed for that.
struct Board: Hashable {
  // Holds the stone `Color`s  for each position.
  private var stones: ShapedArray<Color?>
  let size: Int

  init(size: Int) {
    self.stones = ShapedArray<Color?>(shape: [size, size], repeating: nil)
    self.size = size
  }

  func color(at position: Position) -> Color? {
    assert(0..<size ~= position.x && 0..<size ~= position.y)
    return stones[position.x][position.y].scalars[0]
  }

  mutating func placeStone(at position: Position, withColor color: Color) {
    assert(0..<size ~= position.x && 0..<size ~= position.y)
    stones[position.x][position.y] = ShapedArraySlice(color)
  }

  mutating func removeStone(at position: Position) {
    assert(0..<size ~= position.x && 0..<size ~= position.y)
    stones[position.x][position.y] = ShapedArraySlice(nil)
  }
}

extension Board: CustomStringConvertible {

  var description: String {
    var output = ""

    // First, generates the head line, which looks like
    //
    //   x/y  0  1  2  3  4  5  6  7  8
    //
    // for a 9x9 board.
    output.append("\nx/y")

    // For board size <10, all numbers in head line are single digit. So, we only need one empty
    // space between them.
    //
    //   x/y 0 1 2 3 4 5 6 7 8
    //
    // For board size >=11, we need to print a space between two-digit numbers. So, spaces between
    // single-digit numbers are larger.
    //
    //   x/y  0  1  2  3  4  5  6  7  8  9 10 11
    for y in 0..<size {
      if size >= 11 {
        output.append(" ")
      }
      // As we cannot use Foundation, String(format:) method is not avaiable to use.
      if y < 10 {
        output.append(" \(y)")
      } else {
        output.append("\(y)")
      }
    }
    output.append("\n")

    // Similarly, we need two spaces between stones for size >= 11, but one space for small board.
    let gapBetweenStones = size <= 10 ? " " : "  "
    for x in 0..<size {
      // Prints row index.
      if x < 10 {
        output.append("  \(x)")  // Two leading spaces.
      } else {
        output.append(" \(x)")  // One leading space.
      }

      // Prints the color of stone at each position.
      for y in 0..<size {
        output.append(gapBetweenStones)
        guard let color = self.color(at: Position(x: x, y: y)) else {
          output.append(".")  // Empty position.
          continue
        }
        output.append(color == .black ? "X" : "O")
      }
      output.append("\n")
    }
    return output
  }
}

/// Represents an (immutable) configuration of a Go game.
public struct GameConfiguration {
  /// The board size of the game.
  let size: Int

  /// The points added to the score of the player with the white stones as compensation for playing
  /// second.
  let komi: Float

  /// The maximum number of board states to track.
  ///
  /// This does not include the the current board.
  let maxHistoryCount: Int

  /// If true, enables debugging information.
  let isVerboseDebuggingEnabled: Bool

  public init(
    size: Int,
    komi: Float,
    maxHistoryCount: Int = 7,
    isVerboseDebuggingEnabled: Bool = false
  ) {
    self.size = size
    self.komi = komi
    self.maxHistoryCount = maxHistoryCount
    self.isVerboseDebuggingEnabled = isVerboseDebuggingEnabled
  }
}

/// A group whose stones are connected and share the same liberty.
struct LibertyGroup {
  // A numerical unique ID for the group.
  var id: Int

  var color: Color

  // The stones belonging to this group.
  var stones: Set<Position>

  // The liberties for this group.
  var liberties: Set<Position>
}

/// Tracks the liberty of all stones on board.
///
/// `LibertyTracker` is designed to be a struct as it trackes the liberty
/// information of current board snapshot. So not expected to be changed. After
/// placing a new stone, we make a copy, update it and then attach it to new
/// board snapshot to track state.
struct LibertyTracker {

  private let gameConfiguration: GameConfiguration

  // Tracks the liberty groups. For a position (stone) having no group,
  // groupIndex[stone] should be nil. Otherwise, the group ID should be
  // groupIndex[stone] and its group is groups[groupIndex[stone]].
  // The invariance check can be done via checkLibertyGroupsInvariance helper
  // method.
  private var nextGroupIDToAssign = 0
  private var groupIndex: [[Int?]]
  private var groups: [Int: LibertyGroup] = [:]

  init(gameConfiguration: GameConfiguration) {
    self.gameConfiguration = gameConfiguration

    let size = gameConfiguration.size
    groupIndex = Array(repeating: Array(repeating: nil, count: size), count: size)
  }

  /// Returns the liberty group at the position.
  func group(at position: Position) -> LibertyGroup? {
    guard let groupID = groupIndex(for: position) else {
      return nil
    }
    guard let group = groups[groupID] else {
      fatalErrorForGroupsInvariance(groupID: groupID)
    }
    return group
  }
}

/// Extend `LibertyTracker` to have a mutating method by placing a new stone.
extension LibertyTracker {

  /// Adds a new stone to the board and returns all captured stones.
  mutating func addStone(at position: Position, withColor color: Color) throws -> Set<Position> {
    precondition(groupIndex(for: position) == nil)

    printDebugInfo(message: "Before adding stone.")

    var capturedStones = Set<Position>()

    // Records neighbors information.
    var emptyNeighbors = Set<Position>()
    var opponentNeighboringGroupIDs = Set<Int>()
    var friendlyNeighboringGroupIDs = Set<Int>()

    for neighbor in position.neighbors(boardSize: gameConfiguration.size) {

      // First, handle the case neighbor has no group.
      guard let neighborGroupID = groupIndex(for: neighbor) else {
        emptyNeighbors.insert(neighbor)
        continue
      }

      guard let neighborGroup = groups[neighborGroupID] else {
        fatalErrorForGroupsInvariance(groupID: neighborGroupID)
      }

      if neighborGroup.color == color {
        friendlyNeighboringGroupIDs.insert(neighborGroupID)
      } else {
        opponentNeighboringGroupIDs.insert(neighborGroupID)
      }
    }

    if gameConfiguration.isVerboseDebuggingEnabled {
      print("empty: \(emptyNeighbors)")
      print("friends: \(friendlyNeighboringGroupIDs)")
      print("opponents: \(opponentNeighboringGroupIDs)")
    }

    // Creates new group and sets its liberty as the empty neighbors at first.
    var newGroupID = makeGroup(
      color: color,
      stone: position,
      liberties: emptyNeighbors
    ).id

    // Merging all friend groups.
    for friendGroupID in friendlyNeighboringGroupIDs {
      newGroupID = mergeGroups(newGroupID, friendGroupID)
    }

    // Calculates captured stones.
    for opponentGroupID in opponentNeighboringGroupIDs {
      guard var opponentGroup = groups[opponentGroupID] else {
        fatalErrorForGroupsInvariance(groupID: opponentGroupID)
      }

      guard opponentGroup.liberties.count > 1 else {
        // The only liberty will be taken by the stone just placed. Delete it.
        capturedStones.formUnion(captureGroup(opponentGroupID))
        continue
      }

      // Updates the liberty taken by the stone just placed.
      opponentGroup.liberties.remove(position)
      // As group is struct, we need to explicitly write it back.
      groups[opponentGroupID] = opponentGroup
      assert(checkLibertyGroupsInvariance())
    }

    if gameConfiguration.isVerboseDebuggingEnabled {
      print("captured stones: \(capturedStones)")
    }

    // Update liberties for existing stones
    updateLibertiesAfterRemovingCapturedStones(capturedStones)

    printDebugInfo(message: "After adding stone.")

    // Suicide is illegal.
    guard let newGroup = groups[newGroupID] else {
      fatalErrorForGroupsInvariance(groupID: newGroupID)
    }

    guard newGroup.liberties.count > 0 else {
      throw IllegalMove.suicide
    }

    return capturedStones
  }

  private func checkLibertyGroupsInvariance() -> Bool {
    var groupIDsInGroupIndex = Set<Int>()
    let size = gameConfiguration.size
    for x in 0..<size {
      for y in 0..<size {
        guard let groupID = groupIndex[x][y] else {
          continue
        }
        groupIDsInGroupIndex.insert(groupID)
      }
    }
    return Set(groups.keys) == groupIDsInGroupIndex
  }

  private func fatalErrorForGroupsInvariance(groupID: Int) -> Never {
    print("The group ID \(groupID) should exist.")
    print("Current groups are \(groups).")
    fatalError()
  }

  /// Returns the group index of the stone.
  private func groupIndex(for position: Position) -> Int? {
    return groupIndex[position.x][position.y]
  }

  /// Assigns a new unique group ID.
  mutating private func assignNewGroupID() -> Int {
    let newID = nextGroupIDToAssign
    precondition(groups[newID] == nil)

    nextGroupIDToAssign += 1
    return newID
  }

  /// Creates a new group for the single stone with liberties.
  mutating private func makeGroup(
    color: Color,
    stone: Position,
    liberties: Set<Position>
  ) -> LibertyGroup {
    let newID = assignNewGroupID()
    let newGroup = LibertyGroup(id: newID, color: color, stones: [stone], liberties: liberties)

    precondition(groups[newID] == nil)
    groups[newID] = newGroup
    groupIndex[stone.x][stone.y] = newID
    assert(checkLibertyGroupsInvariance())
    return newGroup
  }

  /// Returns a new group (id) by merging the groups identified by the IDs.
  mutating private func mergeGroups(_ groupID1: Int, _ groupID2: Int) -> Int {
    guard let group1 = groups.removeValue(forKey: groupID1) else {
      fatalErrorForGroupsInvariance(groupID: groupID1)
    }
    guard let group2 = groups.removeValue(forKey: groupID2) else {
      fatalErrorForGroupsInvariance(groupID: groupID2)
    }
    precondition(group1.color == group2.color)

    let newID = assignNewGroupID()

    let unionedStones = group1.stones.union(group2.stones)
    var newLiberties = group1.liberties.union(group2.liberties)
    newLiberties.subtract(group1.stones)
    newLiberties.subtract(group2.stones)

    let newGroup = LibertyGroup(
      id: newID,
      color: group1.color,
      stones: unionedStones,
      liberties: newLiberties
    )

    groups[newID] = newGroup

    // Updates groups IDs for future lookups.
    for stone in unionedStones {
      groupIndex[stone.x][stone.y] = newID
    }
    assert(checkLibertyGroupsInvariance())
    return newID
  }

  /// Captures the whole group and returns all stones in it.
  mutating private func captureGroup(_ groupID: Int) -> Set<Position> {
    let deadGroup = groups.removeValue(forKey: groupID)!
    for stone in deadGroup.stones {
      groupIndex[stone.x][stone.y] = nil
    }
    return deadGroup.stones
  }

  /// Updates all neighboring groups' liberties.
  mutating private func updateLibertiesAfterRemovingCapturedStones(_ capturedStones: Set<Position>) {
    let size = gameConfiguration.size
    for capturedStone in capturedStones {
      for neighbor in capturedStone.neighbors(boardSize: size) {
        if let neighborGroupdID = groupIndex(for: neighbor) {
          guard groups.keys.contains(neighborGroupdID) else {
            fatalErrorForGroupsInvariance(groupID:neighborGroupdID)
          }
          // This force unwrap is safe as we checked the key existence above. As
          // the value in the groups is struct. We need the force unwrap to do
          // mutation in place.
          groups[neighborGroupdID]!.liberties.insert(capturedStone)
        }
      }
    }
    assert(checkLibertyGroupsInvariance())
  }

  /// Prints the debug info for liberty tracked so far.
  private func printDebugInfo(message: String) {
    guard gameConfiguration.isVerboseDebuggingEnabled else {
      return
    }

    print(message)

    /// Prints the group index for the board.
    let size = gameConfiguration.size
    for x in 0..<size {
      for y in 0..<size {
        switch groupIndex[x][y] {
        case .none:
          print("  .", terminator: "")
        case .some(let id) where id < 10:
          print("  \(id)", terminator: "")
        case .some(let id):
          print(" \(id)", terminator: "")
        }
      }
      print("")
    }

    for (id, group) in groups {
      print(" id: \(id) -> liberty: \(group.liberties)")
    }
  }
}


/// Represents a position in a Go game.
public struct Position: Hashable, Equatable {
  var x: Int
  var y: Int
}

/// Returns all valid neighbors for the given position on board.
extension Position {

  func neighbors(boardSize size: Int) -> [Position] {
    let neighbors = [
      Position(x: x+1, y: y),
      Position(x: x-1, y: y),
      Position(x: x, y: y+1),
      Position(x: x, y: y-1),
    ]

    return neighbors.filter { 0..<size ~= $0.x && 0..<size ~= $0.y }
  }
}
enum IllegalMove: Error {
  case suicide
  case occupied

  /// A `ko` fight is a tactical and strategic phase that can arise in the game
  /// of go.
  ///
  /// The existence of ko fights is implied by the rule of ko, a special rule of
  /// the game that prevents immediate repetition of position, by a short 'loop'
  /// in which a single stone is captured, and another single stone immediately
  /// taken back.
  ///
  /// See https://en.wikipedia.org/wiki/Ko_fight for details.
  case ko
}

private enum PositionStatus: Equatable {
  case legal
  case illegal(reason: IllegalMove)
}

/// Represents an immutable snapshot of the current board state.
///
/// `BoardState` checks whether a new placed stone is legal or not. If so,
/// creates a new snapshot.
public struct BoardState {

  let gameConfiguration: GameConfiguration
  let nextPlayerColor: Color

  /// The position of potential `ko`. See `IllegalMove.ko` for details.
  let ko: Position?

  /// All legal position to be considered as next move given the current board state.
  let legalMoves: [Position]

  /// All stones on the current board.
  let board: Board

  /// History of the previous board states (does not include current one).
  ///
  /// The most recent one is placed at index 0. The history count is truncated by
  /// `GameConfiguration.maxHistoryCount`.
  ///
  /// TODO(xiejw): Improve the efficient of history.
  let history: [Board]

  // General statistic.
  let playedMoveCount: Int
  let stoneCount: Int

  // Internal maintained states.
  private let libertyTracker: LibertyTracker

  /// Constructs an empty board state.
  init(gameConfiguration: GameConfiguration) {
    self.init(
      gameConfiguration: gameConfiguration,
      nextPlayerColor: .black,  // First player is always black.
      playedMoveCount: 0,
      stoneCount: 0,
      ko: nil,
      history: [],
      board: Board(size: gameConfiguration.size),
      libertyTracker: LibertyTracker(gameConfiguration: gameConfiguration)
    )
  }

  private init(
    gameConfiguration: GameConfiguration,
    nextPlayerColor: Color,
    playedMoveCount: Int,
    stoneCount: Int,
    ko: Position?,
    history: [Board],
    board: Board,
    libertyTracker: LibertyTracker
  ) {
    self.gameConfiguration = gameConfiguration
    self.nextPlayerColor = nextPlayerColor
    self.playedMoveCount = playedMoveCount
    self.stoneCount = stoneCount
    self.ko = ko

    assert(history.count <= gameConfiguration.maxHistoryCount)
    self.history = history

    self.libertyTracker = libertyTracker
    self.board = board
    precondition(board.size == gameConfiguration.size)

    if stoneCount == gameConfiguration.size * gameConfiguration.size {
      // Full board.
      self.legalMoves = []
    } else {
      self.legalMoves = board.allLegalMoves(
        ko: ko,
        libertyTracker: libertyTracker,
        nextPlayerColor: nextPlayerColor
      )
    }
  }

  /// Returns a new `BoardState` after current player passed.
  func passing() -> BoardState {
    var newHistory = self.history
    newHistory.insert(self.board, at: 0)
    if newHistory.count > gameConfiguration.maxHistoryCount {
      _ = newHistory.popLast()
    }
    return BoardState(
      gameConfiguration: self.gameConfiguration,
      nextPlayerColor: self.nextPlayerColor.opponentColor,
      playedMoveCount: self.playedMoveCount + 1,
      stoneCount: self.stoneCount,
      ko: nil,  // Reset ko.
      history: newHistory,
      board: self.board,
      libertyTracker: self.libertyTracker
    )
  }

  /// Returns a new `BoardState` after placing a new stone at `position`.
  func placingNewStone(at position: Position) throws -> BoardState {
    // Sanity Check first.
    if case .illegal(let reason) = board.positionStatus(
      at: position,
      ko: self.ko,
      libertyTracker: self.libertyTracker,
      nextPlayerColor: self.nextPlayerColor
    ) {
      throw reason
    }

    // Gets copies of libertyTracker and board. Updates both by placing new stone.
    let currentStoneColor = self.nextPlayerColor
    var newLibertyTracker = self.libertyTracker
    var newBoard = self.board

    // Makes attempt to guess the possible ko.
    let isPotentialKo = newBoard.isKoish(at: position, withNewStoneColor: currentStoneColor)

    // Updates libertyTracker and board by placing a new stone.
    let capturedStones = try newLibertyTracker.addStone(at: position, withColor: currentStoneColor)
    newBoard.placeStone(at: position, withColor: currentStoneColor)

    // Removes capturedStones
    for capturedStone in capturedStones {
      newBoard.removeStone(at: capturedStone)
    }

    // Updates stone count on board.
    let newStoneCount = self.stoneCount + 1 - capturedStones.count

    var newKo: Position?
    if let stone = capturedStones.first, capturedStones.count == 1, isPotentialKo {
      newKo = stone
    }

    var newHistory = self.history
    newHistory.insert(self.board, at: 0)
    if newHistory.count > gameConfiguration.maxHistoryCount {
      _ = newHistory.popLast()
    }

    return BoardState(
      gameConfiguration: self.gameConfiguration,
      nextPlayerColor: currentStoneColor == .black ? .white : .black,
      playedMoveCount: self.playedMoveCount + 1,
      stoneCount: newStoneCount,
      ko: newKo,
      history: newHistory,
      board: newBoard,
      libertyTracker: newLibertyTracker
    )
  }

  /// Returns the score of the player.
  func score(for playerColor: Color) -> Float {
    let scoreForBlackPlayer = self.board.scoreForBlackPlayer(komi: self.gameConfiguration.komi)
    switch playerColor {
    case .black:
      return scoreForBlackPlayer
    case .white:
      return -scoreForBlackPlayer
    }
  }
}

extension BoardState: CustomStringConvertible {
  public var description: String {
    return board.description
  }
}

extension BoardState: Equatable {
  public static func == (lhs: BoardState, rhs: BoardState) -> Bool {
    // The following line is the sufficient and necessary condition for "equal".
    return lhs.board == rhs.board &&
      lhs.nextPlayerColor == rhs.nextPlayerColor &&
      lhs.ko == rhs.ko &&
      lhs.history == rhs.history
  }
}

extension Board {

  /// Calculates all legal moves on board.
  fileprivate func allLegalMoves(
    ko: Position?,
    libertyTracker: LibertyTracker,
    nextPlayerColor: Color
  ) -> [Position] {
    var legalMoves = Array<Position>()
    for x in 0..<self.size {
      for y in 0..<self.size {
        let position = Position(x: x, y: y)
        guard .legal == positionStatus(
          at: position,
          ko: ko,
          libertyTracker: libertyTracker,
          nextPlayerColor: nextPlayerColor
        ) else {
          continue
        }

        legalMoves.append(position)
      }
    }
    return legalMoves
  }

  /// Checks whether a move is legal. If isLegal is false, reason will be set.
  fileprivate func positionStatus(
    at position: Position,
    ko: Position?,
    libertyTracker: LibertyTracker,
    nextPlayerColor: Color
  ) -> PositionStatus {
    guard self.color(at: position) == nil else { return .illegal(reason: .occupied) }
    guard position != ko else { return .illegal(reason: .ko) }

    guard !isSuicidal(
      at: position,
      libertyTracker: libertyTracker,
      nextPlayerColor: nextPlayerColor
    ) else {
      return .illegal(reason: .suicide)
    }
    return .legal
  }

  /// A fast algorithm to check a possible suicidal move.
  ///
  /// This method assume the move is not `ko`.
  fileprivate func isSuicidal(
    at position: Position,
    libertyTracker: LibertyTracker,
    nextPlayerColor: Color
  ) -> Bool {
    var possibleLiberties = Set<Position>()

    for neighbor in position.neighbors(boardSize: self.size) {
      guard let group = libertyTracker.group(at: neighbor) else {
        // If the neighbor is not occupied, no liberty group, the position is
        // OK.
        return false
      }
      if group.color == nextPlayerColor {
        possibleLiberties.formUnion(group.liberties)
      } else if group.liberties.count == 1 {
        // This move is capturing opponent's group. So, always legal.
        return false
      }
    }

    // After removing the new postion from liberties, if there is no liberties
    // left, this move is suicide.
    possibleLiberties.remove(position)
    return possibleLiberties.isEmpty
  }

  /// Checks whether the position is a potential ko, i.e., whether the position is surrounded by all
  /// sides belonging to the opponent.
  ///
  /// This is an approximated algorithm to find `ko`. See https://en.wikipedia.org/wiki/Ko_fight
  /// for details.
  fileprivate func isKoish(at position: Position, withNewStoneColor stoneColor: Color) -> Bool {
    precondition(self.color(at: position) == nil)
    let opponentColor = stoneColor.opponentColor
    let neighbors = position.neighbors(boardSize: self.size)
    return neighbors.allSatisfy { self.color(at: $0) == opponentColor }
  }
}

// Extends the Color (for player) to generate opponent's Color.
extension Color {
  fileprivate var opponentColor: Color {
    return self == .black ? .white : .black
  }
}

extension Board {
  /// Returns the score for black player.
  ///
  /// `komi` is the points added to the score of the player with the white stones as compensation
  /// for playing second.
  fileprivate func scoreForBlackPlayer(komi: Float) -> Float {

    // Makes a copy as we will modify it over time.
    var scoreBoard = self

    // First pass: Finds all empty positions on board.
    var emptyPositions = Set<Position>()
    for x in 0..<size {
      for y in 0..<size {
        let position = Position(x: x, y: y)
        if scoreBoard.color(at: position) == nil {
          emptyPositions.insert(position)
        }
      }
    }

    // Second pass: Calculates the territory and borders for each empty position, if there is any.
    // If territory is surrounded by the stones in same color, fills that color in territory.
    while !emptyPositions.isEmpty {
      let emptyPosition = emptyPositions.removeFirst()

      let (territory, borders) = territoryAndBorders(startingFrom: emptyPosition)
      guard !borders.isEmpty else {
        continue
      }

      // Fills the territory with black (or white) if the borders are all in black (or white).
      for color: Color in [.black, .white] {
        if borders.allSatisfy({ scoreBoard.color(at: $0) == color }) {
          territory.forEach {
            scoreBoard.placeStone(at: $0, withColor: color)
            emptyPositions.remove($0)
          }
        }
      }
    }

    // TODO(xiejw): Print out the modified board in debug mode.

    // Third pass: Counts stones now for scoring.
    var blackStoneCount = 0
    var whiteStoneCount = 0
    for x in 0..<size {
      for y in 0..<size {
        guard let color = scoreBoard.color(at: Position(x: x, y: y))  else {
          // This board position does not belong to either player. Could be seki or dame.
          // See https://en.wikipedia.org/wiki/Go_(game)#Seki_(mutual_life).
          continue
        }
        switch color {
        case .black:
          blackStoneCount += 1
        case .white:
          whiteStoneCount += 1
        }
      }
    }
    return Float(blackStoneCount - whiteStoneCount) - komi
  }

  /// Finds the `territory`, all connected empty positions starting from `position`, and the
  /// `borders`, either black or white stones, surrounding the `territory`.
  ///
  /// The `position` must be an empty position. The returned `territory` contains empty positions
  /// only. The returned `borders` contains positions for placed stones. If the board is empty,
  /// `borders` will be empty.
  fileprivate func territoryAndBorders(
    startingFrom position: Position
  ) -> (territory: Set<Position>, borders: Set<Position>) {
    precondition(self.color(at: position) == nil)

    var territory = Set<Position>()
    var borders = Set<Position>()

    // Stores all candidates for the territory.
    var candidates: Set = [position]
    repeat {
      let currentPosition = candidates.removeFirst()
      territory.insert(currentPosition)

      for neighbor in currentPosition.neighbors(boardSize: self.size) {
        if self.color(at: neighbor) == nil {
          if !territory.contains(neighbor) {
            // We have not explored this (empty) position, so queue it up for processing.
            candidates.insert(neighbor)
          }
        } else {
          // Insert the stone (either black or white) into borders.
          borders.insert(neighbor)
        }
      }
    } while !candidates.isEmpty

    precondition(
      territory.allSatisfy { self.color(at: $0) == nil },
      "territory must be all empty (no stones).")
    precondition(
      borders.allSatisfy { self.color(at: $0) != nil },
      "borders cannot have empty positions.")
    return (territory, borders)
  }
}
/// The color of a player or a stone.
enum Color: Int {
  case black = 1
  case white = -1
}

extension Color: CustomStringConvertible {
  var description: String {
    switch self {
    case .black: return "X"
    case .white: return "O"
    }
  }
}



In [0]:
let boardSize = 9
let simulationCountForOneMove = 1000

let gameConfiguration = GameConfiguration(
  size: boardSize,
  komi: 0.5,
  isVerboseDebuggingEnabled: false)

// Creates the GoModel and loads the checkpoint.
// print("Loading checkpoint into GoModel. Might take a while.")
// let modelConfig = ModelConfiguration(boardSize: boardSize)
// var model = GoModel(configuration: modelConfig)
// let reader = PythonCheckpointReader(path: "./MiniGo/000939-heron")
// model.load(from: reader)

// let predictor = MCTSModelBasedPredictor(boardSize: boardSize, model: model)

// Pick up policies to play. The first policy in `participants` plays black.
// Current available policies are:
//   - RandomPolicy
//   - HumanPolicy
//   - MCTSPolicy
let mctsConfiguration = MCTSConfiguration(
  gameConfiguration: gameConfiguration,
  simulationCountForOneMove: simulationCountForOneMove)

try playOneGame(
  gameConfiguration: gameConfiguration,
  participants: [
    MCTSPolicy(participantName: "black", predictor: MCTSRandomPredictor(boardSize: boardSize), configuration: mctsConfiguration),
    MCTSPolicy(participantName: "white", predictor: MCTSRandomPredictor(boardSize: boardSize), configuration: mctsConfiguration),
  ])



x/y 0 1 2 3 4 5 6 7 8
  0 . . . . . . . . .
  1 . . . . . . . . .
  2 . . . . . . . . .
  3 . . . . . . . . .
  4 . . . . . . . . .
  5 . . . . . . . . .
  6 . . . . . . . . .
  7 . . . . . . . . .
  8 . . . . . . . . .

-> Black
- Placing stone at: Position(x: 4, y: 3)

x/y 0 1 2 3 4 5 6 7 8
  0 . . . . . . . . .
  1 . . . . . . . . .
  2 . . . . . . . . .
  3 . . . . . . . . .
  4 . . . X . . . . .
  5 . . . . . . . . .
  6 . . . . . . . . .
  7 . . . . . . . . .
  8 . . . . . . . . .

-> White
- Placing stone at: Position(x: 8, y: 2)

x/y 0 1 2 3 4 5 6 7 8
  0 . . . . . . . . .
  1 . . . . . . . . .
  2 . . . . . . . . .
  3 . . . . . . . . .
  4 . . . X . . . . .
  5 . . . . . . . . .
  6 . . . . . . . . .
  7 . . . . . . . . .
  8 . . O . . . . . .

-> Black
- Placing stone at: Position(x: 5, y: 6)

x/y 0 1 2 3 4 5 6 7 8
  0 . . . . . . . . .
  1 . . . . . . . . .
  2 . . . . . . . . .
  3 . . . . . . . . .
  4 . . . X . . . . .
  5 . . . . . . X . .
  6 . . . . . . .

: ignored