-
Notifications
You must be signed in to change notification settings - Fork 1
/
main.go
133 lines (115 loc) · 3.45 KB
/
main.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
package main
import (
"bufio"
"flag"
"fmt"
"math/rand"
"net/http"
_ "net/http/pprof"
"os"
"strconv"
"strings"
"github.com/golang/glog"
"github.com/timpalpant/go-cfr"
"github.com/timpalpant/go-cfr/mcts"
"github.com/timpalpant/go-cfr/sampling"
"github.com/timpalpant/alphacats"
"github.com/timpalpant/alphacats/cards"
"github.com/timpalpant/alphacats/gamestate"
"github.com/timpalpant/alphacats/model"
)
var stdin = bufio.NewReader(os.Stdin)
func main() {
model := flag.String("model", "models/player_0.model", "Model to play against")
recommenderModel := flag.String("recommender_model", "models/player_1.model", "Model to recommend moves")
seed := flag.Int64("sampling.seed", 123, "Random seed")
flag.Parse()
rand.Seed(*seed)
go http.ListenAndServe("localhost:4123", nil)
deck := cards.CoreDeck.AsSlice()
cardsPerPlayer := 4
opponent := loadPolicy(*model)
recommender := loadPolicy(*recommenderModel)
for i := 0; ; i++ {
opponentPolicy := opponent.SamplePolicy()
recommenderPolicy := recommender.SamplePolicy()
deal := alphacats.NewRandomDeal(deck, cardsPerPlayer)
playGame(opponentPolicy, recommenderPolicy, deal)
}
}
func loadPolicy(modelPath string) *model.MCTSPSRO {
f, err := os.Open(modelPath)
if err != nil {
glog.Fatalf("Unable to load policy: %v", err)
}
defer f.Close()
r := bufio.NewReader(f)
policy, err := model.LoadMCTSPSRO(r)
if err != nil {
glog.Fatalf("Unable to load policy: %v", err)
}
return policy
}
func playGame(opponent, recommender mcts.Policy, deal alphacats.Deal) {
var game cfr.GameTreeNode = alphacats.NewGame(deal.DrawPile, deal.P0Deal, deal.P1Deal)
for game.Type() != cfr.TerminalNodeType {
if game.Type() == cfr.ChanceNodeType {
var p float64
game, p = game.SampleChild()
glog.Infof("[chance] Sampled child node with probability %v", p)
} else if game.Player() == 1 {
is := game.InfoSet(game.Player()).(*alphacats.AbstractedInfoSet)
glog.Infof("[player] Your turn. %d cards remaining in draw pile.",
game.(*alphacats.GameNode).GetDrawPile().Len())
p := recommender.GetPolicy(game)
glog.Infof("[player] Hand: %v, Choices:", is.Hand)
for i, action := range is.AvailableActions {
glog.Infof("%d: %v (recommender P: %.3f)", i, action, p[i])
}
selected := prompt("Which action? ")
game = game.GetChild(selected)
lastAction := game.(*alphacats.GameNode).LastAction()
glog.Infof("[player] Chose to %v", lastAction)
} else {
p := opponent.GetPolicy(game)
selected := sampling.SampleOne(p, rand.Float32())
game = game.GetChild(selected)
lastAction := game.(*alphacats.GameNode).LastAction()
glog.Infof("[strategy] Chose to %v with probability %v: %v",
hidePrivateInfo(lastAction), p[selected], p)
glog.V(4).Infof("[strategy] Action result was: %v", lastAction)
}
}
glog.Info("GAME OVER")
if game.Player() == 1 {
glog.Info("You win!")
} else {
glog.Info("Computer wins!")
}
glog.Info("Game history:")
h := game.(*alphacats.GameNode).GetHistory()
for i, action := range h.AsSlice() {
glog.Infof("%d: %v", i, action)
}
}
func prompt(msg string) int {
for {
fmt.Print(msg)
result, err := stdin.ReadString('\n')
if err != nil {
panic(err)
}
result = strings.TrimRight(result, "\n")
i, err := strconv.Atoi(result)
if err != nil {
glog.Errorf("Invalid selection: %v", result)
continue
}
return i
}
}
func hidePrivateInfo(a gamestate.Action) gamestate.Action {
a.PositionInDrawPile = 0
a.CardsSeen = [3]cards.Card{}
return a
}