/
test-agent.js
70 lines (60 loc) · 1.95 KB
/
test-agent.js
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
'use strict';
/* eslint-disable */
const _ = require('underscore');
const Encoder = require('../core/encoder');
const BaseAgent = require('../core/base-agent');
const tools = require('../../utils/tools');
const pokemon = [
'tapukoko', 'landorus', 'landorustherian', 'toxapex', 'serperior', 'celesteela', 'medicham',
'medichammega', 'ninetalesalola', 'magearna', 'zygarde', 'mawile', 'mawilemega', 'manaphy',
];
const moves = [
'fakeout', 'highjumpkick', 'zenheadbutt', 'icepunch',
'heavyslam', 'protect', 'earthquake', 'leechseed',
'thunderbolt', 'uturn', 'hiddenpowerice', 'taunt',
'scald', 'toxicspikes', 'recover', 'toxic',
'earthquake', 'uturn', 'stealthrock', 'hiddenpowerice',
'leafstorm', 'leechseed', 'substitute', 'hiddenpowerfire',
];
const vocab = {
'pokemon': pokemon.sort(),
'move': moves.sort(),
};
/**
* A test agent.
*/
class TestAgent extends BaseAgent {
/**
* @param {string} id
*/
constructor(id) {
super(id);
this.encoder = new Encoder(vocab, 'gen7ou');
}
/**
* @override
*/
act(actionSpace, observation, reward, done) {
super.updateState(observation);
return _.sample(actionSpace);
}
/**
* @override
*/
get state() {
const state = super.state;
const myActive = state.mySide.active[0].id;
const yourActive = state.yourSide.active[0].id;
const team = state.request.side.pokemon.map((pokemon) => tools.toId(pokemon.ident.substring(3)));
const moves = state.request.active ? state.request.active[0].moves.map((m) => tools.toId(m.move)) : [null, null, null, null];
const features = [
['bias', 'number', 1],
['my_active', 'pokemon', myActive],
['your_active', 'pokemon', yourActive],
['team', 'pokemon', team],
['moves', 'move', moves],
];
return this.encoder.encode(features);
}
}
module.exports = TestAgent;