In [1]:
%matplotlib inline
import gym
import math
import random
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from collections import namedtuple
from itertools import count
from copy import deepcopy
from PIL import Image
from torch.distributions import Categorical

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable
import torchvision.transforms as T


#GPU
use_cuda = torch.cuda.is_available()
FloatTensor = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor
LongTensor = torch.cuda.LongTensor if use_cuda else torch.LongTensor
ByteTensor = torch.cuda.ByteTensor if use_cuda else torch.ByteTensor
Tensor = FloatTensor

### 环境

In [2]:
env = gym.make('CartPole-v0').unwrapped
env.action_space.n

2

### P网络

In [3]:
def ini_net(md):
    for m in md.modules():
        if isinstance(m, nn.Linear):
            torch.nn.init.xavier_normal(m.weight.data)
            torch.nn.init.normal(m.bias.data)

class PN(nn.Module):

    def __init__(self):
        super(PN, self).__init__()
        self.MLP = nn.Sequential(
            nn.Linear(4,200),
            nn.ReLU(),
            nn.Linear(200,200),
            nn.ReLU(),
            nn.Linear(200,2),
            nn.Softmax(1)
        )
    def forward(self, x):
        x = self.MLP(x)
        return x
    
pnet = PN().cuda()
pnet.apply(ini_net)
testx = Variable(torch.randn(1,4)).cuda()
pnet(testx)

Variable containing:
 0.9707  0.0293
[torch.cuda.FloatTensor of size 1x2 (GPU 0)]

### 动作策略

In [4]:
def select_action(state): #策略,以概率选择动作
    prob = pnet(state)
    m = Categorical(prob)
    action = m.sample()
    log_prob = m.log_prob(action)
    action = action.data.cpu().numpy()[0]
    return action, log_prob

testx = Variable(torch.randn(1,4)).cuda()
print(select_action(testx))

(0, Variable containing:
1.00000e-02 *
 -2.5634
[torch.cuda.FloatTensor of size 1 (GPU 0)]
)


### Episode 回合片段

In [5]:
class Episode:
    def __init__(self):
        self.log_probs = []
        self.rewards = []
        self.R = 0
    
    def __len__(self):
        return len(self.frames)
    
    def save_log_probs(self, log_prob):
        self.log_probs.append(log_prob)
    
    def save_rewards(self, r):
        self.rewards.append(r)
        
    def _reward(self):
        n = len(self.rewards)
        ds = np.zeros(n)
        running_add = 0
        for t in reversed(range(0, n)):
            running_add = running_add * 0.99 + self.rewards[t]
            ds[t] = running_add
        ds = (ds - np.mean(ds))/np.std(ds)
        return FloatTensor(ds)
        
    
    def optimize(self):
        vt = self._reward()
        #print(vt)
        loss = 0
        t = 0
        for prob,r in zip(self.log_probs, vt):
            #print('a:{}\n prob:{}'.format(r,prob.data))
            loss += - prob * r #当前状态的loss函数
            #print("loss",loss)
        return loss

#ept = Episode(net, optimizer, gamma=1)

### 训练

In [6]:
#from tensorboardX import SummaryWriter
#writer = SummaryWriter()
GAMMA = 1
num_episodes = 5000
optimizer = optim.Adam(pnet.parameters(),lr=1e-4)

def get_state():
    state = FloatTensor(env.state).cuda().view(1,-1)
    state = Variable(state)
    return state

for i_episode in range(num_episodes):
    # 初始化环境
    env.reset()
    state = get_state()
    episode = Episode()
    for t in count(): #无限循环
        env.render()
        action, log_prob = select_action(state) #选择一个动作
        #print('a is {}'.format(action))
        _, reward, done, _ = env.step(action) #计算该动作的奖励，done                
        
        episode.save_log_probs(log_prob)
        episode.save_rewards(reward)
   
        state = get_state() #当前屏幕状态重新获取

        if done:
            #writer.add_scalar('t', t, i_episode)
            print(i_episode, t)
            break
    #print('len is {}'.format(len(episode)))
    optimizer.zero_grad()
    loss = episode.optimize()
    loss.backward()
    optimizer.step()
    print("loss:",loss.data.cpu().numpy()[0])
print('Complete')
#writer.close()
#env.render(close=True)
#env.close()

0 10
loss: 3.34932
1 9
loss: 0.0117204
2 11
loss: 3.32642
3 8
loss: 0.00790752
4 8
loss: 0.00607251
5 9
loss: 0.00299206
6 9
loss: -3.56032
7 8
loss: -0.000895217
8 9
loss: -0.00109181
9 7
loss: -0.000492826
10 10
loss: 4.63121
11 9
loss: -0.00264166
12 12
loss: 3.80385
13 12
loss: -0.0605809
14 9
loss: -0.0052804
15 12
loss: -2.11802
16 10
loss: -5.86494
17 8
loss: 0.000544339
18 9
loss: 0.00330254
19 8
loss: 0.00593133
20 8
loss: 0.00880396
21 9
loss: 0.00941009
22 8
loss: 0.00916514
23 8
loss: 0.00913035
24 12
loss: 4.41417
25 9
loss: 0.0104982
26 9
loss: -3.41678
27 8
loss: 0.010716
28 7
loss: 0.00880057
29 8
loss: 0.0114639
30 7
loss: 0.00840963
31 10
loss: 0.0549886
32 12
loss: 4.48532
33 9
loss: 2.40851
34 10
loss: 4.25967
35 10
loss: -4.3716
36 10
loss: -1.65244
37 8
loss: 0.0119274
38 13
loss: 5.82299
39 8
loss: 0.0138388
40 10
loss: 2.41215
41 10
loss: -6.3415
42 9
loss: -2.17397
43 8
loss: -1.93425
44 8
loss: 0.0216665
45 8
loss: -4.08158
46 15
loss: 6.43223
47 9
loss: 0.032

375 19
loss: 0.123737
376 36
loss: -0.676694
377 27
loss: 0.20659
378 24
loss: -0.110763
379 23
loss: -0.106606
380 29
loss: 0.298637
381 17
loss: -0.00559992
382 16
loss: -1.14071
383 16
loss: -0.367042
384 14
loss: 0.0298584
385 42
loss: 0.120656
386 25
loss: -1.1625
387 20
loss: -0.386058
388 17
loss: 0.311169
389 61
loss: -0.864615
390 25
loss: -1.59321
391 45
loss: 0.237398
392 51
loss: -1.29214
393 15
loss: -0.873669
394 36
loss: 0.105461
395 31
loss: -0.99533
396 42
loss: -0.439423
397 55
loss: -3.49064
398 15
loss: 0.224134
399 13
loss: 0.195295
400 18
loss: -0.0528025
401 31
loss: 0.570483
402 18
loss: -0.363383
403 13
loss: -0.0921918
404 9
loss: 0.372825
405 17
loss: -0.752997
406 23
loss: -0.609733
407 73
loss: -2.57664
408 24
loss: -0.463534
409 24
loss: 0.548029
410 34
loss: 0.25205
411 20
loss: 0.20724
412 11
loss: -1.21922
413 35
loss: -0.0535647
414 22
loss: 0.126202
415 22
loss: -1.53382
416 21
loss: -0.711646
417 42
loss: 0.0680338
418 39
loss: -0.220927
419 42
loss:

loss: 0.132485
744 13
loss: -1.10199
745 40
loss: 0.328551
746 55
loss: 1.186
747 17
loss: -0.328556
748 15
loss: 0.535255
749 15
loss: 0.716572
750 31
loss: -1.77313
751 32
loss: -1.24366
752 19
loss: 0.831642
753 12
loss: 1.2531
754 36
loss: 1.17663
755 42
loss: 0.490301
756 81
loss: 1.24821
757 25
loss: 0.155611
758 42
loss: 0.976464
759 11
loss: 2.03065
760 34
loss: -1.05061
761 22
loss: 0.996191
762 15
loss: 1.14211
763 73
loss: 0.398825
764 63
loss: 1.69398
765 44
loss: 0.306437
766 17
loss: -2.28092
767 45
loss: -2.24827
768 54
loss: 0.62125
769 51
loss: 0.920499
770 19
loss: 0.556341
771 19
loss: -1.59005
772 16
loss: 1.22334
773 71
loss: 0.765133
774 43
loss: -0.562337
775 22
loss: 1.41427
776 57
loss: 1.18048
777 49
loss: -0.505281
778 46
loss: 1.47848
779 27
loss: -1.79442
780 35
loss: 0.0628883
781 59
loss: 0.915944
782 19
loss: -2.67925
783 67
loss: 0.891003
784 23
loss: -0.559319
785 36
loss: -0.957521
786 29
loss: 0.817585
787 33
loss: 1.27933
788 59
loss: 0.332405
789 3

1112 158
loss: 1.35882
1113 52
loss: -0.370021
1114 56
loss: -0.613029
1115 60
loss: -0.78006
1116 83
loss: 0.533509
1117 57
loss: -1.00909
1118 26
loss: 0.978626
1119 31
loss: 0.318589
1120 15
loss: -0.304965
1121 19
loss: 0.225489
1122 44
loss: 1.44364
1123 19
loss: -0.155588
1124 17
loss: -0.528999
1125 23
loss: -0.054404
1126 14
loss: -0.398148
1127 22
loss: 0.669807
1128 41
loss: -0.462634
1129 20
loss: -0.0346994
1130 34
loss: 1.00935
1131 15
loss: -0.962135
1132 14
loss: -0.0449812
1133 100
loss: 5.19362
1134 40
loss: -1.21161
1135 10
loss: -5.87339
1136 36
loss: -0.511716
1137 15
loss: 1.01936
1138 42
loss: 0.298978
1139 20
loss: -0.187616
1140 15
loss: 1.32945
1141 17
loss: -0.22102
1142 30
loss: -0.561365
1143 29
loss: 1.19858
1144 38
loss: -0.0701108
1145 27
loss: 0.904449
1146 21
loss: -1.18103
1147 18
loss: -0.302352
1148 28
loss: -0.309958
1149 17
loss: -0.551721
1150 22
loss: -1.27179
1151 33
loss: -1.0387
1152 27
loss: -1.86114
1153 33
loss: -0.49758
1154 12
loss: 0.901

1469 103
loss: 0.325919
1470 16
loss: 1.93301
1471 139
loss: 2.59413
1472 49
loss: -2.40054
1473 65
loss: 0.568002
1474 29
loss: 0.492834
1475 74
loss: 2.50918
1476 56
loss: -2.31067
1477 48
loss: 0.483689
1478 56
loss: 0.245313
1479 34
loss: 1.90444
1480 41
loss: -5.5584
1481 49
loss: 0.329624
1482 129
loss: 2.24038
1483 107
loss: -4.36176
1484 115
loss: 2.66408
1485 81
loss: -2.02931
1486 116
loss: -0.190109
1487 41
loss: 0.889915
1488 46
loss: 1.60437
1489 179
loss: 1.57254
1490 115
loss: -3.3461
1491 103
loss: -0.215737
1492 140
loss: -0.351922
1493 69
loss: -6.70617
1494 161
loss: -1.15761
1495 67
loss: -2.45553
1496 171
loss: 0.856382
1497 23
loss: -1.76875
1498 36
loss: 0.100365
1499 98
loss: -1.67851
1500 104
loss: -1.89688
1501 120
loss: 1.46012
1502 139
loss: -2.71645
1503 141
loss: 1.41709
1504 83
loss: -0.849988
1505 90
loss: -0.324996
1506 154
loss: -1.96582
1507 127
loss: -1.52432
1508 47
loss: 0.311133
1509 107
loss: -1.63393
1510 70
loss: -3.64023
1511 87
loss: -0.02263

KeyboardInterrupt: 

In [None]:
### 主要错误还是在loss函数的计算和r的引导上
### 增加网络的泛化能力