In [4]:
from torch import nn
import torch.nn.functional as F
import torch
from train_attack_model_env import attack_env
from attack_train_func import rewardFunction, normalizeData
import random

from PPO import PPO



class PolicyNetwork(nn.Module):
    def __init__(self, state_size, action_size):
        super(PolicyNetwork, self).__init__()
        self.fc1 = nn.Linear(state_size, 128)
        self.fc3 = nn.Linear(128, action_size)

        

    def forward(self, x):
        
        x = F.relu(self.fc1(x))
        
        logits = self.fc3(x)
        return F.log_softmax(logits, dim=-1)

        

        
class ValueNetwork(nn.Module):
    def __init__(self, state_size):
        super(ValueNetwork, self).__init__()
        self.fc1 = nn.Linear(state_size, 128)
        self.fc3 = nn.Linear(128, 1)  # 僅一個輸出，表示狀態的價值
        
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.fc3(x)
        return x
    

class MyPPO(PPO):
    def __init__(self, env, policyNetwork, valueNetwork):
        super().__init__(policyNetwork, valueNetwork)
        self.env = env

    def show(self):
        device = torch.device("cpu")
        self.PolicyNetwork.to(device, dtype=torch.float32)
        self.ValueNetwork.to(device, dtype=torch.float32)
        data = self.env.reset()
        while (True):    
            self.env.render()                
            updateData = {}
            actions = {}
            
            for j in range(1, 7):
                playerID = str(j) + "P"

                state, _, _ = data[playerID]
                action = self.getAction(normalizeData(state))    
                updateData[playerID] = [self.env.actionSpace[action]]
                actions[playerID] = action
            data = self.env.update(updateData)  


            if (not self.env.not_done()):
                break

    def learn(self, timeStep=10000, dataNum = 4096, lr=0.003, episode=0.2, epoch=10, batchSize=256):
        print("start learning")
        for i in range(timeStep):
            playtime_count = 0
            
            
            
            while (len(self.ExperienceHistory['oldstate']) < dataNum):
                data = self.env.reset()
                die = [False, False, False, False, False, False, False]
                agentRewards = [0, 0, 0, 0, 0, 0, 0]
                while (True):                
                    updateData = {}
                    oldStates = {}
                    actions = {}
                    
                    for j in range(1, 7):
                        playerID = str(j) + "P"
                        if (self.env.lives[j] == 0):
                            die[j] = True
                            continue
                        state, _, _ = data[playerID]
                        oldStates[playerID] = state
                        action = self.getAction(normalizeData(state))
                       

                        
                        updateData[playerID] = [self.env.actionSpace[action]]
                        actions[playerID] = action
                    data = self.env.update(updateData)  

                    for j in range(1, 7):
                        if (not die[j]):
                            playerID = str(j) + "P"
                            old_state = oldStates[playerID]
                            new_state, liveLoss, scoreUp = data[playerID]

                            reward = rewardFunction(old_state, self.env.actionSpace[actions[playerID]], scoreUp, liveLoss)
                            
                            done = int(self.env.lives[j] == 0)
                            
                            self.ExperienceHistory['oldstate'].append(normalizeData(state))
                            self.ExperienceHistory['state'].append(normalizeData(new_state))
                            self.ExperienceHistory['action'].append(action)
                            self.ExperienceHistory['reward'].append(reward)
                            self.ExperienceHistory['done'].append(int(done))

                            agentRewards[j] += reward

                    if (not self.env.not_done()):
                        if (i % 10 == 0):
                            print(f"time step:{i + 1}", end=" ")
                            for j in range(1, 7):
                                
                                print(f"player{j} reward: {agentRewards[j]}", end=",")
                            print()
                        playtime_count += 1
                        break
            self.train(epochs=epoch, lr=lr, episode=episode, batch_size=batchSize)





In [5]:
env = attack_env(FPS=300)
policyNetwork = PolicyNetwork(18, 7)
valueNetwork = ValueNetwork(18)
agent = MyPPO(env, policyNetwork=policyNetwork, valueNetwork=valueNetwork)


In [3]:
agent.learn(timeStep=10000, lr=0.01, dataNum=4096)

start learning
time step:1 player1 reward: 17.454000000000086,player2 reward: -0.11999999999999965,player3 reward: 30.041000000000007,player4 reward: 58.02199999999999,player5 reward: 16.2810000000001,player6 reward: 10.550999999999966,


  with torch.autograd.detect_anomaly():


KeyboardInterrupt: 

In [None]:
print(agent.PolicyNetwork.fc1.weight)
print(agent.PolicyNetwork.fc2.weight)
print(agent.PolicyNetwork.fc3.weight)

In [6]:
import pygame
agent.show()
pygame.quit()

tensor([0.1513, 0.1245, 0.1532, 0.1520, 0.1726, 0.1172, 0.1291],
       grad_fn=<ExpBackward0>)
tensor([0.1340, 0.1340, 0.1731, 0.1576, 0.1728, 0.0978, 0.1308],
       grad_fn=<ExpBackward0>)
tensor([0.1539, 0.1351, 0.1566, 0.1454, 0.1846, 0.0962, 0.1282],
       grad_fn=<ExpBackward0>)
tensor([0.1344, 0.1493, 0.1722, 0.1286, 0.1305, 0.1491, 0.1358],
       grad_fn=<ExpBackward0>)
tensor([0.1402, 0.1539, 0.1730, 0.1353, 0.1404, 0.1267, 0.1305],
       grad_fn=<ExpBackward0>)
tensor([0.1320, 0.1521, 0.1673, 0.1415, 0.1480, 0.1356, 0.1236],
       grad_fn=<ExpBackward0>)
tensor([0.1508, 0.1244, 0.1536, 0.1519, 0.1730, 0.1172, 0.1290],
       grad_fn=<ExpBackward0>)
tensor([0.1340, 0.1340, 0.1731, 0.1576, 0.1728, 0.0978, 0.1308],
       grad_fn=<ExpBackward0>)
tensor([0.1605, 0.1272, 0.1611, 0.1362, 0.1885, 0.1028, 0.1236],
       grad_fn=<ExpBackward0>)
tensor([0.1364, 0.1481, 0.1728, 0.1289, 0.1300, 0.1490, 0.1349],
       grad_fn=<ExpBackward0>)
tensor([0.1398, 0.1538, 0.1721, 0.1356, 