In [1]:
from torch import nn
import torch.nn.functional as F
import torch
from train_attack_model_env import attack_env
from attack_train_func import rewardFunction, normalizeData
import random

from PPO import PPO



class PolicyNetwork(nn.Module):
    def __init__(self, state_size, action_size):
        super(PolicyNetwork, self).__init__()
        self.fc1 = nn.Linear(state_size, 128)
        self.fc3 = nn.Linear(128, action_size)

        

    def forward(self, x):
        
        x = F.relu(self.fc1(x))
        
        logits = self.fc3(x)
        return F.log_softmax(logits, dim=-1)

        

        
class ValueNetwork(nn.Module):
    def __init__(self, state_size):
        super(ValueNetwork, self).__init__()
        self.fc1 = nn.Linear(state_size, 128)
        self.fc3 = nn.Linear(128, 1)  # 僅一個輸出，表示狀態的價值
        
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.fc3(x)
        return x
    

class MyPPO(PPO):
    def __init__(self, env, policyNetwork, valueNetwork):
        super().__init__(policyNetwork, valueNetwork)
        self.env = env

    def show(self):
        device = torch.device("cpu")
        self.PolicyNetwork.to(device, dtype=torch.float32)
        self.ValueNetwork.to(device, dtype=torch.float32)
        data = self.env.reset()
        while (True):    
            self.env.render()                
            updateData = {}
            
            for j in range(1, 7):
                playerID = str(j) + "P"

                state, _, _ = data[playerID]
                action = self.getAction(normalizeData(state))    
                updateData[playerID] = [self.env.actionSpace[action]]
            data = self.env.update(updateData)  


            if (not self.env.not_done()):
                break

    def learn(self, timeStep=10000, dataNum = 4096, lr=0.003, episode=0.2, epoch=10, batchSize=256):
        print("start learning")
        for i in range(timeStep):
            playtime_count = 0
            
            
            
            while (len(self.ExperienceHistory['oldstate']) < dataNum):
                data = self.env.reset()
                die = [False, False, False, False, False, False, False]
                agentRewards = [0, 0, 0, 0, 0, 0, 0]
                while (True):                
                    updateData = {}
                    oldStates = {}
                    actions = {}
                    
                    for j in range(1, 7):
                        playerID = str(j) + "P"
                        if (self.env.lives[j] == 0):
                            die[j] = True
                            continue
                        state, _, _ = data[playerID]
                        oldStates[playerID] = state
                        action = self.getAction(normalizeData(state))
                       

                        
                        updateData[playerID] = [self.env.actionSpace[action]]
                        actions[playerID] = action
                    data = self.env.update(updateData)  

                    for j in range(1, 7):
                        if (not die[j]):
                            playerID = str(j) + "P"
                            old_state = oldStates[playerID]
                            new_state, liveLoss, scoreUp = data[playerID]

                            reward = rewardFunction(old_state, self.env.actionSpace[actions[playerID]], scoreUp, liveLoss)
                            
                            done = int(self.env.lives[j] == 0)
                            
                            self.ExperienceHistory['oldstate'].append(normalizeData(state))
                            self.ExperienceHistory['state'].append(normalizeData(new_state))
                            self.ExperienceHistory['action'].append(action)
                            self.ExperienceHistory['reward'].append(reward)
                            self.ExperienceHistory['done'].append(int(done))

                            agentRewards[j] += reward

                    if (not self.env.not_done()):
                        if (i % 10 == 0):
                            print(f"time step:{i + 1}", end=" ")
                            for j in range(1, 7):
                                
                                print(f"player{j} reward: {agentRewards[j]}", end=",")
                            print()
                        playtime_count += 1
                        break
            self.train(epochs=epoch, lr=lr, episode=episode, batch_size=batchSize)





pygame 2.5.2 (SDL 2.28.3, Python 3.9.16)
Hello from the pygame community. https://www.pygame.org/contribute.html


In [2]:
env = attack_env(FPS=300)
policyNetwork = PolicyNetwork(18, 7)
valueNetwork = ValueNetwork(18)
agent = MyPPO(env, policyNetwork=policyNetwork, valueNetwork=valueNetwork)


In [5]:
agent.learn(timeStep=10000, lr=0.003, dataNum=4096)

start learning
time step:11 player1 reward: 0.0,player2 reward: 0.1340000000000001,player3 reward: 0.0,player4 reward: 20.261000000000006,player5 reward: -0.9099999999999999,player6 reward: 30.03,
time step:21 player1 reward: 164.06100000000023,player2 reward: 0.5000000000000003,player3 reward: 85.63599999999974,player4 reward: 44.779999999999895,player5 reward: 150.91100000000034,player6 reward: 0.3060000000000002,
time step:31 player1 reward: 0.4840000000000004,player2 reward: 0.0,player3 reward: 0.004,player4 reward: 0.03200000000000002,player5 reward: 0.7420000000000002,player6 reward: 0.014000000000000005,
time step:41 player1 reward: 0.0,player2 reward: 0.0,player3 reward: 0.5000000000000003,player4 reward: 0.5200000000000004,player5 reward: 0.060000000000000005,player6 reward: 0.4440000000000001,
time step:51 player1 reward: 187.0,player2 reward: 187.5,player3 reward: 207.5,player4 reward: 144.20200000000003,player5 reward: -108.685,player6 reward: 244.1,
time step:61 player1 re

KeyboardInterrupt: 

In [11]:
import pygame

for i in range(10):
    agent.show()
    pygame.quit()