In [15]:
import gym
import numpy as np
from PPO import PPO
from SAC import SAC
import torch
from tqdm import tqdm
from buffer import ReplayBuffer
from Discriminator import Discriminator

In [16]:
device='cuda'
BUFFER_SIZE=8192
hidden_dim=32

In [17]:
class CustomEnv:
    def __init__(self,env_id):
        self.env=gym.make(env_id)
        self.replay_buffer=ReplayBuffer(8192)
        
    def reset(self):
        return self.env.reset()
    
    def step(self,action):
        return self.env.step(action)
    
    def get_state_dim(self):
        return self.env.observation_space.shape
    
    def get_action_dim(self):
        return self.env.action_space.shape

In [22]:
class StudentAgent:
    def __init__(self,state_dim,action_dim,max_action,env:CustomEnv):
        self.env=env
        self.replay_buffer=self.env.replay_buffer
        # self.model=PPO(state_dim,action_dim,self.replay_buffer)
        self.model=SAC(state_dim,action_dim,hidden_dim,max_action,self.replay_buffer)
        
    def generate_trajectory(self,step:int):
        pb=tqdm(range(step))
        num=0
        for i in pb:
            s=self.env.reset()
            d=False
            while not d:
                a,action,l,v=self.model.select_action(s)
                s_,r,d,_=self.env.step(a)
                # print(type(r))
                self.model.buffer.store(s,a,l,s_,r,v,d)
                num+=1
            pb.update()
        print('生成',num,'条轨迹')   
        
    def train(self,total_timestep,batch_size):
        pb=tqdm(range(total_timestep))
        for i in pb:
            self.model.update(batch_size)

根据环境设置ta.state_dim=env.state_dim+env.action_dim

更改取样函数，正常来说是取样state,action,reward,nextstate,done

| Teacher Parameters | Student Parameters |
| - | - |
| state | state,action |
| action | reward |
| reward | D(s,a) |
| state_ | state_ |
| done | done |

In [5]:
# class CustomSAC(SAC):
#     def sample(self):
#         return super().sample()
class TeacherAgent:
    def __init__(self,state_dim,action_dim,max_action,replay_buffer:ReplayBuffer):
        self.replay_buffer=replay_buffer
        self.model=SAC(np.prod(state_dim)+action_dim,1,hidden_dim,max_action,self.replay_buffer,spec=True)
        self.discriminator=Discriminator(np.prod(state_dim)+action_dim,hidden_dim,64,replay_buffer,replay_buffer)
        
    def ComputeReward(self):
        pb=tqdm(range(min(self.replay_buffer.index,self.replay_buffer.buffer_size)))
        _state=[]
        for i in pb:
            state=torch.cat((self.replay_buffer.state[i],self.replay_buffer.action[i]))
            _state.append(state)
            pb.update()
        _state=torch.stack(_state,dim=0)
        self.replay_buffer.reward=torch.FloatTensor(self.model.select_action(_state))
        self.replay_buffer.islatest=True
        
    def compute_teacher_reward(self):
        pb=tqdm(range(min(self.replay_buffer.index,self.replay_buffer.buffer_size)))
        for i in pb:
            sap=torch.cat((self.replay_buffer.state[i],self.replay_buffer.action[i]),-1)
            self.replay_buffer.teacher_reward[i]=self.discriminator.model(sap)
            pb.update()
    
    #interface for changing state and action to state-action pairs
    def teacher_sample(self):
        state,action,reward,next_state,done=self.model.sample()
        state=torch.cat((state,action),1)
        if self.replay_buffer.islatest:
            action=reward
        else:
            self.ComputeReward()
        reward=self.compute_teacher_reward(state)
        return state,action,reward,next_state,done
    
    def train(self,total_timestep:int):
        pb=tqdm(range(total_timestep))
        for _ in pb:
            self.model.update()
            pb.update()
        self.replay_buffer.islatest=False
    

In [19]:
custom_env=CustomEnv('Pendulum-v0')

In [23]:
sa=StudentAgent(custom_env.get_state_dim()[0],custom_env.get_action_dim()[0],1,custom_env)

In [24]:
sa.generate_trajectory(10)


  0%|          | 0/10 [00:00<?, ?it/s]


ValueError: not enough values to unpack (expected 4, got 1)

In [9]:
sa.train(10,2048)

100%|██████████| 10/10 [00:03<00:00,  2.66it/s]


In [10]:
ta=TeacherAgent(custom_env.get_state_dim(),custom_env.get_action_dim()[0],1,custom_env.replay_buffer)

In [11]:
ta.ComputeReward()
ta.compute_teacher_reward()

100%|██████████| 2000/2000 [00:00<00:00, 536493.22it/s]
100%|██████████| 2000/2000 [00:00<00:00, 30961.93it/s]


In [12]:
s,_,_,_,_=ta.model.spec_sample()

In [13]:
s.shape

torch.Size([64, 4])

In [14]:
ta.train(10)

  0%|          | 0/10 [00:00<?, ?it/s]


spec


RuntimeError: mat1 and mat2 shapes cannot be multiplied (64x3 and 4x32)

In [None]:
ta.replay_buffer.reward.shape

torch.Size([8192, 1])