In [1]:
import torch
import torch.nn as nn
import numpy as np
from tqdm import tqdm
from buffer import ReplayBuffer
from PPO import PPO
import gym
from Discriminator import Discriminator
from reward_env import RewardEnv
import network_sim
import contextlib

In [2]:
device='cuda'
BUFFER_SIZE=8192
hidden_dim=32

In [3]:
class CustomEnv:
    def __init__(self,env_id):
        self.env=gym.make(env_id)
        self.student_buffer=ReplayBuffer(8192)
        self.teacher_buffer=ReplayBuffer(8192)
        
    def reset(self):
        return self.env.reset()
    
    def step(self,action):
        with contextlib.redirect_stdout(None):
            return self.env.step(action)
    
    def get_state_dim(self):
        return self.env.observation_space.shape
    
    def get_action_dim(self):
        return self.env.action_space.shape

In [4]:
class StudentAgent:
    def __init__(self,state_dim,action_dim,env:CustomEnv):
        self.env=env
        self.replay_buffer=self.env.student_buffer
        self.model=PPO(state_dim,action_dim,self.replay_buffer)
        
    def generate_trajectory(self,step:int):
        pb=tqdm(range(step))
        num=0
        for i in pb:
            s=self.env.reset()
            d=False
            while not d:
                a,_,l,v=self.model.select_action(s)
                s_,r,d,_=self.env.step(a)
                # print(type(r))
                self.model.buffer.store(s,a,l,s_,r,v,d)
                num+=1
            #pb.update()
        print('生成',num,'条轨迹')   
        
    def train(self,total_timestep,batch_size):
        pb=tqdm(range(total_timestep))
        for i in pb:
            self.model.update(batch_size)

In [5]:
env=CustomEnv('PccNs-v0')

History length: 10
Features: ['sent latency inflation', 'latency ratio', 'send ratio']
Getting min obs for ['sent latency inflation', 'latency ratio', 'send ratio']




In [6]:
sa=StudentAgent(np.prod(env.get_state_dim()),env.get_action_dim()[0],env)

In [7]:
sa.generate_trajectory(100)

  0%|          | 0/100 [00:00<?, ?it/s]

Reward: 0.00, Ewma Reward: 0.00


  1%|          | 1/100 [00:03<05:33,  3.37s/it]

Reward: -1651.17, Ewma Reward: -16.51


  2%|▏         | 2/100 [00:07<05:51,  3.59s/it]

Reward: -519.85, Ewma Reward: -21.55


  3%|▎         | 3/100 [00:07<03:29,  2.16s/it]

Reward: -173.21, Ewma Reward: -23.06


  4%|▍         | 4/100 [00:08<02:31,  1.57s/it]

Reward: -295.88, Ewma Reward: -25.79


  5%|▌         | 5/100 [00:08<01:53,  1.19s/it]

Reward: -249.29, Ewma Reward: -28.02


  6%|▌         | 6/100 [00:15<04:37,  2.95s/it]

Reward: -3951.49, Ewma Reward: -67.26


  7%|▋         | 7/100 [00:15<03:25,  2.21s/it]

Reward: -548.00, Ewma Reward: -72.07


  8%|▊         | 8/100 [00:18<03:51,  2.52s/it]

Reward: -943.13, Ewma Reward: -80.78


  9%|▉         | 9/100 [00:20<03:25,  2.25s/it]

Reward: -1321.44, Ewma Reward: -93.18


 10%|█         | 10/100 [00:23<03:45,  2.51s/it]

Reward: -1332.76, Ewma Reward: -105.58


 11%|█         | 11/100 [00:24<03:02,  2.05s/it]

Reward: -236.18, Ewma Reward: -106.89


 12%|█▏        | 12/100 [00:25<02:29,  1.70s/it]

Reward: 665.92, Ewma Reward: -99.16


 13%|█▎        | 13/100 [00:28<02:45,  1.91s/it]

Reward: -85.99, Ewma Reward: -99.03


 14%|█▍        | 14/100 [00:28<02:10,  1.52s/it]

Reward: 320.88, Ewma Reward: -94.83


 15%|█▌        | 15/100 [00:29<01:44,  1.23s/it]

Reward: 1091.06, Ewma Reward: -82.97


 16%|█▌        | 16/100 [00:29<01:28,  1.06s/it]

Reward: 898.47, Ewma Reward: -73.15


 17%|█▋        | 17/100 [00:30<01:12,  1.15it/s]

Reward: 1088.07, Ewma Reward: -61.54


 18%|█▊        | 18/100 [00:31<01:21,  1.01it/s]

Reward: -108.43, Ewma Reward: -62.01


 19%|█▉        | 19/100 [00:32<01:16,  1.06it/s]

Reward: 404.94, Ewma Reward: -57.34


 20%|██        | 20/100 [00:32<01:05,  1.22it/s]

Reward: 1307.58, Ewma Reward: -43.69


 21%|██        | 21/100 [00:33<01:01,  1.28it/s]

Reward: 1220.92, Ewma Reward: -31.05


 22%|██▏       | 22/100 [00:34<00:55,  1.42it/s]

Reward: 979.35, Ewma Reward: -20.94


 23%|██▎       | 23/100 [00:34<00:51,  1.49it/s]

Reward: 35.22, Ewma Reward: -20.38


 24%|██▍       | 24/100 [00:35<00:53,  1.43it/s]

Reward: 1335.64, Ewma Reward: -6.82


 25%|██▌       | 25/100 [00:37<01:18,  1.04s/it]

Reward: -1267.77, Ewma Reward: -19.43


 26%|██▌       | 26/100 [00:38<01:14,  1.01s/it]

Reward: 133.17, Ewma Reward: -17.90


 27%|██▋       | 27/100 [00:43<02:36,  2.15s/it]

Reward: -1990.36, Ewma Reward: -37.63


 28%|██▊       | 28/100 [00:43<01:56,  1.62s/it]

Reward: 442.29, Ewma Reward: -32.83


 29%|██▉       | 29/100 [00:45<01:53,  1.60s/it]

Reward: 40.15, Ewma Reward: -32.10


 30%|███       | 30/100 [00:46<01:50,  1.58s/it]

Reward: -752.27, Ewma Reward: -39.30


 31%|███       | 31/100 [00:49<02:18,  2.01s/it]

Reward: -1547.77, Ewma Reward: -54.39


 32%|███▏      | 32/100 [00:49<01:42,  1.50s/it]

Reward: 504.16, Ewma Reward: -48.80


 33%|███▎      | 33/100 [00:50<01:28,  1.32s/it]

Reward: 680.39, Ewma Reward: -41.51


 34%|███▍      | 34/100 [00:51<01:13,  1.12s/it]

Reward: 160.89, Ewma Reward: -39.48


 35%|███▌      | 35/100 [00:53<01:40,  1.54s/it]

Reward: -613.92, Ewma Reward: -45.23


 36%|███▌      | 36/100 [00:54<01:15,  1.19s/it]

Reward: -147.61, Ewma Reward: -46.25


 37%|███▋      | 37/100 [00:54<01:00,  1.04it/s]

Reward: 1109.88, Ewma Reward: -34.69


 38%|███▊      | 38/100 [00:55<00:56,  1.09it/s]

Reward: 1101.32, Ewma Reward: -23.33


 39%|███▉      | 39/100 [01:00<02:04,  2.04s/it]

Reward: -726.98, Ewma Reward: -30.37


 40%|████      | 40/100 [01:04<02:40,  2.68s/it]

Reward: -1621.94, Ewma Reward: -46.28


 41%|████      | 41/100 [01:04<01:57,  2.00s/it]

Reward: 323.82, Ewma Reward: -42.58


 42%|████▏     | 42/100 [01:13<04:00,  4.15s/it]

Reward: -7298.37, Ewma Reward: -115.14


 43%|████▎     | 43/100 [01:14<02:56,  3.09s/it]

Reward: 235.86, Ewma Reward: -111.63


 44%|████▍     | 44/100 [01:15<02:14,  2.40s/it]

Reward: -437.34, Ewma Reward: -114.89


 45%|████▌     | 45/100 [01:17<02:11,  2.39s/it]

Reward: -166.53, Ewma Reward: -115.40


 46%|████▌     | 46/100 [01:18<01:41,  1.88s/it]

Reward: 1331.81, Ewma Reward: -100.93


 47%|████▋     | 47/100 [01:19<01:19,  1.49s/it]

Reward: 979.30, Ewma Reward: -90.13


 48%|████▊     | 48/100 [01:19<01:04,  1.25s/it]

Reward: -20.91, Ewma Reward: -89.44


 49%|████▉     | 49/100 [01:20<00:54,  1.07s/it]

Reward: 293.37, Ewma Reward: -85.61


 50%|█████     | 50/100 [01:21<00:50,  1.01s/it]

Reward: 477.00, Ewma Reward: -79.98


 51%|█████     | 51/100 [01:21<00:41,  1.17it/s]

Reward: -309.98, Ewma Reward: -82.28


 52%|█████▏    | 52/100 [01:22<00:35,  1.35it/s]

Reward: 1292.57, Ewma Reward: -68.53


 53%|█████▎    | 53/100 [01:22<00:33,  1.40it/s]

Reward: 251.46, Ewma Reward: -65.33


 54%|█████▍    | 54/100 [01:23<00:33,  1.36it/s]

Reward: 1309.61, Ewma Reward: -51.58


 55%|█████▌    | 55/100 [01:24<00:29,  1.53it/s]

Reward: 272.03, Ewma Reward: -48.35


 56%|█████▌    | 56/100 [01:25<00:31,  1.38it/s]

Reward: 1218.18, Ewma Reward: -35.68


 57%|█████▋    | 57/100 [01:25<00:29,  1.46it/s]

Reward: 728.46, Ewma Reward: -28.04


 58%|█████▊    | 58/100 [01:25<00:25,  1.67it/s]

Reward: 818.06, Ewma Reward: -19.58


 59%|█████▉    | 59/100 [01:28<00:53,  1.30s/it]

Reward: -266.50, Ewma Reward: -22.05


 60%|██████    | 60/100 [01:29<00:46,  1.17s/it]

Reward: 632.69, Ewma Reward: -15.50


 61%|██████    | 61/100 [01:30<00:39,  1.01s/it]

Reward: 1127.15, Ewma Reward: -4.08


 62%|██████▏   | 62/100 [01:30<00:32,  1.16it/s]

Reward: 314.28, Ewma Reward: -0.89


 63%|██████▎   | 63/100 [01:34<00:56,  1.52s/it]

Reward: -3325.32, Ewma Reward: -34.14


 64%|██████▍   | 64/100 [01:35<00:59,  1.65s/it]

Reward: -1610.57, Ewma Reward: -49.90


 65%|██████▌   | 65/100 [01:43<01:59,  3.42s/it]

Reward: -4059.21, Ewma Reward: -89.99


 66%|██████▌   | 66/100 [01:44<01:29,  2.64s/it]

Reward: 1144.57, Ewma Reward: -77.65


 67%|██████▋   | 67/100 [01:46<01:26,  2.62s/it]

Reward: -67.10, Ewma Reward: -77.54


 68%|██████▊   | 68/100 [01:47<01:04,  2.02s/it]

Reward: 930.10, Ewma Reward: -67.47


 69%|██████▉   | 69/100 [01:48<00:51,  1.66s/it]

Reward: 658.76, Ewma Reward: -60.20


 70%|███████   | 70/100 [01:49<00:41,  1.38s/it]

Reward: 225.01, Ewma Reward: -57.35


 71%|███████   | 71/100 [01:49<00:32,  1.11s/it]

Reward: 78.09, Ewma Reward: -56.00


 72%|███████▏  | 72/100 [01:50<00:28,  1.03s/it]

Reward: -607.21, Ewma Reward: -61.51


 73%|███████▎  | 73/100 [01:50<00:22,  1.21it/s]

Reward: 941.80, Ewma Reward: -51.48


 74%|███████▍  | 74/100 [01:55<00:53,  2.05s/it]

Reward: -3122.59, Ewma Reward: -82.19


 75%|███████▌  | 75/100 [01:56<00:41,  1.68s/it]

Reward: 1289.33, Ewma Reward: -68.47


 76%|███████▌  | 76/100 [01:56<00:31,  1.32s/it]

Reward: 791.13, Ewma Reward: -59.88


 77%|███████▋  | 77/100 [01:57<00:26,  1.15s/it]

Reward: 919.30, Ewma Reward: -50.08


 78%|███████▊  | 78/100 [02:00<00:33,  1.52s/it]

Reward: -1551.30, Ewma Reward: -65.10


 79%|███████▉  | 79/100 [02:02<00:39,  1.87s/it]

Reward: -1003.48, Ewma Reward: -74.48


 80%|████████  | 80/100 [02:04<00:38,  1.94s/it]

Reward: -1432.94, Ewma Reward: -88.07


 81%|████████  | 81/100 [02:05<00:30,  1.59s/it]

Reward: 1264.71, Ewma Reward: -74.54


 82%|████████▏ | 82/100 [02:06<00:23,  1.31s/it]

Reward: 160.21, Ewma Reward: -72.19


 83%|████████▎ | 83/100 [02:08<00:28,  1.67s/it]

Reward: -261.45, Ewma Reward: -74.08


 84%|████████▍ | 84/100 [02:09<00:20,  1.29s/it]

Reward: 513.31, Ewma Reward: -68.21


 85%|████████▌ | 85/100 [02:10<00:17,  1.20s/it]

Reward: 754.75, Ewma Reward: -59.98


 86%|████████▌ | 86/100 [02:11<00:17,  1.24s/it]

Reward: 1013.21, Ewma Reward: -49.25


 87%|████████▋ | 87/100 [02:12<00:13,  1.07s/it]

Reward: 1033.25, Ewma Reward: -38.42


 88%|████████▊ | 88/100 [02:12<00:10,  1.12it/s]

Reward: 102.99, Ewma Reward: -37.01


 89%|████████▉ | 89/100 [02:13<00:08,  1.26it/s]

Reward: 1152.96, Ewma Reward: -25.11


 90%|█████████ | 90/100 [02:13<00:07,  1.40it/s]

Reward: 293.80, Ewma Reward: -21.92


 91%|█████████ | 91/100 [02:16<00:10,  1.22s/it]

Reward: 12.93, Ewma Reward: -21.57


 92%|█████████▏| 92/100 [02:17<00:11,  1.40s/it]

Reward: 232.24, Ewma Reward: -19.03


 93%|█████████▎| 93/100 [02:18<00:07,  1.09s/it]

Reward: 377.88, Ewma Reward: -15.06


 94%|█████████▍| 94/100 [02:18<00:05,  1.13it/s]

Reward: 143.72, Ewma Reward: -13.48


 95%|█████████▌| 95/100 [02:19<00:04,  1.21it/s]

Reward: 591.09, Ewma Reward: -7.43


 96%|█████████▌| 96/100 [02:20<00:03,  1.32it/s]

Reward: -7.10, Ewma Reward: -7.43


 97%|█████████▋| 97/100 [02:21<00:03,  1.11s/it]

Reward: -1527.25, Ewma Reward: -22.63


 98%|█████████▊| 98/100 [02:22<00:01,  1.03it/s]

Reward: 1278.86, Ewma Reward: -9.61


 99%|█████████▉| 99/100 [02:23<00:00,  1.24it/s]

Reward: 1327.15, Ewma Reward: 3.76


100%|██████████| 100/100 [02:26<00:00,  1.46s/it]

生成 40000 条轨迹





In [8]:
sa.train(10,2048)

100%|██████████| 10/10 [00:03<00:00,  2.70it/s]


In [9]:
class TeacherAgent():
    def __init__(self,state_dim,action_dim,env:CustomEnv):
        self.trajectory_buffer=env.student_buffer
        self.replay_buffer=env.teacher_buffer
        self.model=PPO(state_dim+action_dim,1,self.replay_buffer)
        self.discriminator=Discriminator(state_dim+action_dim,hidden_dim,64,self.trajectory_buffer,self.trajectory_buffer)
        
    def ComputeReward(self):
        pb=tqdm(range(min(self.trajectory_buffer.index,self.trajectory_buffer.buffer_size)))
        for i in pb:
            sa_pair=torch.cat((self.trajectory_buffer.state[i],self.trajectory_buffer.action[i]),-1)
            reward,_,l,v=self.model.select_action(sa_pair)
            self.replay_buffer.store(sa_pair,
                                     reward,
                                     l,
                                     self.trajectory_buffer.next_state[i],
                                     self.discriminator.model(sa_pair).detach().cpu().numpy(),
                                     v,
                                     self.trajectory_buffer.done[i],
                                     )
            self.trajectory_buffer.reward[i]=reward
            pb.update()
        self.discriminator.collect_expert()
            
    def trainPPO(self,total_timestep:int):
        for i in range(total_timestep):
            self.model.update(1024)
            
    def trainDiscriminator(self,total_timestep:int):
        self.discriminator.update(total_timestep,False)
        
    def train(self,total_timestep:int,PPO_timestep:int,D_timestep:int):
        pb=tqdm(range(total_timestep))
        for i in pb:
            self.trainDiscriminator(D_timestep)
            pb.update()
            self.trainPPO(PPO_timestep)
            pb.update()
        

In [10]:
ta=TeacherAgent(np.prod(env.get_state_dim()),env.get_action_dim()[0],env)

In [11]:
ta.ComputeReward()

100%|██████████| 8192/8192 [00:05<00:00, 1581.68it/s]


In [12]:
ta.trainDiscriminator(100)

In [13]:
ta.trainPPO(100)

In [14]:
for i in range(10):
    sa.train(10,1024)
    ta.train(10,5,5)

  rewards = torch.tensor(rewards, dtype=torch.float32).to(device)
  return F.mse_loss(input, target, reduction=self.reduction)
100%|██████████| 10/10 [00:01<00:00,  6.92it/s]
100%|██████████| 10/10 [00:07<00:00,  1.39it/s]
100%|██████████| 10/10 [00:01<00:00,  7.02it/s]
100%|██████████| 10/10 [00:07<00:00,  1.40it/s]
100%|██████████| 10/10 [00:01<00:00,  6.97it/s]
100%|██████████| 10/10 [00:07<00:00,  1.40it/s]
100%|██████████| 10/10 [00:01<00:00,  6.98it/s]
100%|██████████| 10/10 [00:07<00:00,  1.39it/s]
100%|██████████| 10/10 [00:01<00:00,  7.01it/s]
100%|██████████| 10/10 [00:07<00:00,  1.39it/s]
100%|██████████| 10/10 [00:01<00:00,  6.95it/s]
100%|██████████| 10/10 [00:07<00:00,  1.39it/s]
100%|██████████| 10/10 [00:01<00:00,  7.03it/s]
100%|██████████| 10/10 [00:07<00:00,  1.39it/s]
100%|██████████| 10/10 [00:01<00:00,  7.03it/s]
100%|██████████| 10/10 [00:07<00:00,  1.39it/s]
100%|██████████| 10/10 [00:01<00:00,  7.01it/s]
100%|██████████| 10/10 [00:07<00:00,  1.38it/s]
100%|████

In [15]:
rewards=[]
for i in range(100):
    s=env.reset()
    d=False
    reward=0
    while not d:
        a,_,_,_=sa.model.select_action(s)
        s_,r,d,_=env.step(a)
        reward+=r
    rewards.append(reward)

Reward: -1616.16, Ewma Reward: -12.44
Reward: -1475.92, Ewma Reward: -27.08
Reward: 523.15, Ewma Reward: -21.57
Reward: 57.40, Ewma Reward: -20.78
Reward: -104.62, Ewma Reward: -21.62
Reward: -317.19, Ewma Reward: -24.58
Reward: 987.47, Ewma Reward: -14.46
Reward: -722.73, Ewma Reward: -21.54
Reward: -1385.88, Ewma Reward: -35.18
Reward: 535.77, Ewma Reward: -29.48
Reward: 562.63, Ewma Reward: -23.55
Reward: -618.77, Ewma Reward: -29.51
Reward: -863.98, Ewma Reward: -37.85
Reward: 236.24, Ewma Reward: -35.11
Reward: 1087.96, Ewma Reward: -23.88
Reward: 195.06, Ewma Reward: -21.69
Reward: -261.17, Ewma Reward: -24.08
Reward: 174.48, Ewma Reward: -22.10
Reward: -418.33, Ewma Reward: -26.06
Reward: -267.37, Ewma Reward: -28.47
Reward: -333.26, Ewma Reward: -31.52
Reward: 601.25, Ewma Reward: -25.19
Reward: -640.03, Ewma Reward: -31.34
Reward: 408.71, Ewma Reward: -26.94
Reward: -976.63, Ewma Reward: -36.44
Reward: 917.98, Ewma Reward: -26.90
Reward: -233.88, Ewma Reward: -28.96
Reward: -9

In [16]:
np.mean(reward)

-306.07449684345556

In [20]:
env.env.reset()

Reward: 0.00, Ewma Reward: -50.22


array([0., 1., 1., 0., 1., 1., 0., 1., 1., 0., 1., 1., 0., 1., 1., 0., 1.,
       1., 0., 1., 1., 0., 1., 1., 0., 1., 1., 0., 1., 1.])