-
Notifications
You must be signed in to change notification settings - Fork 0
/
runner.py
57 lines (44 loc) · 1.88 KB
/
runner.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import numpy as np
import torch as tr
from utils import node2obs, softmax, get_theta
import torch.nn.functional as F
import copy
import pdb
device = tr.device("cuda:0" if tr.cuda.is_available() else "cpu")
class Runner(object):
def __init__(self, T_dist, A):
self.T_dist = T_dist
self.A = A
self.nS = 24
self.pi0 = 0.5 * tr.ones([self.nS, 2], requires_grad=False)
def beta(self, k):
return np.exp(-k/10.)
def run(self, M_tasks, args, task_list=None):
return_hists, pi0_hist = [], []
T_dist = self.T_dist
pi0 = self.pi0
if args['reg'] not in ['TV', 'log-barrier', 'maxent']:
pi0 = get_theta([15, 1])
for m in range(1, M_tasks + 1):
# draw a task
T_m = T_dist.sample() if task_list is None else task_list[m-1]
args['decisions'] = T_m.decisions
# randomly init. parameters
n_s = len(args['decisions'])
args['pi0'] = pi0
policy, value_fn = get_theta([n_s, 1]), get_theta([n_s, 1])
# train policy
policy, pi0, returns = self.A(policy, value_fn, T_m, args, task_id=m, display_iter=0)
# update pi0
if args['reg'] == 'TV':
for s in T_m.decisions:
state = node2obs(T_m, s) # make the states into a single tensor
pi = tr.dot(softmax(policy).flatten(), state)
pi_star = F.one_hot(tr.argmax(tr.tensor([1 - pi, pi])), num_classes=2)
pi0[s] += (pi_star - pi0[s]) / m
pi0[s] = softmax(pi0[s] / self.beta(m))
# save returns
return_hists.append(returns)
pi0_hist.append(copy.copy(pi0))
self.pi = pi0
return return_hists, pi0_hist