In [1]:
import torch
from torch import distributions, nn
import pfrl
from pfrl import experiments, replay_buffers, utils
from pfrl.nn.lmbda import Lambda
from pfrl.nn import BoundByTanh, ConcatObsAndAction

In [2]:
train_options = {
    'gamma':0.9,
    'minibatch_size':64,
    'tau':1e-2,
    'replay_memory_size':50000

}

In [3]:
def create_agent_sac(state_size,action_size,train_options,gpu=0):
    def squashed_diagonal_gaussian_head(x):
        assert x.shape[-1] == action_size * 2
        mean, log_scale = torch.chunk(x, 2, dim=1)
        log_scale = torch.clamp(log_scale, -20.0, 2.0)
        var = torch.exp(log_scale * 2)
        base_distribution = distributions.Independent(
            distributions.Normal(loc=mean, scale=torch.sqrt(var)), 1
        )
        # cache_size=1 is required for numerical stability
        return distributions.transformed_distribution.TransformedDistribution(
            base_distribution, [distributions.transforms.TanhTransform(cache_size=1)]
        )


    #critic_networks
    q_func_1 = nn.Sequential(
        ConcatObsAndAction(),
        nn.Linear(state_size + action_size, 16*state_size),
        nn.ReLU(),
        nn.Linear(16*state_size, 16*action_size),
        nn.ReLU(),
        nn.Linear(16*action_size, 1),
    )
    q_func_2 = nn.Sequential(
        ConcatObsAndAction(),
        nn.Linear(state_size + action_size, 16*state_size),
        nn.ReLU(),
        nn.Linear(16*state_size, 16*action_size),
        nn.ReLU(),
        nn.Linear(16*action_size, 1),
    )
    
    #actor_network
    policy = nn.Sequential(
        nn.Linear(state_size, 16*state_size),
        nn.ReLU(),
        nn.Linear(16*state_size, 16*action_size),
        nn.ReLU(),
        nn.Linear(16*action_size, action_size* 2),
        pfrl.nn.lmbda.Lambda(squashed_diagonal_gaussian_head),
    )


    opt_a = torch.optim.Adam(policy.parameters(), lr=1e-3, weight_decay=1e-3)
    opt_c_1 = torch.optim.Adam(q_func_1.parameters(), lr=1e-3, weight_decay=1e-2)
    opt_c_2 = torch.optim.Adam(q_func_2.parameters(), lr=1e-3, weight_decay=1e-2)
    
    rbuf = replay_buffers.ReplayBuffer(train_options['replay_memory_size'])

    agent = pfrl.agents.SoftActorCritic(
        policy,
        q_func_1,
        q_func_2,
        opt_a,
        opt_c_1,
        opt_c_2,
        rbuf,
        gamma=train_options['gamma'],
        update_interval=1,
        replay_start_size=100,
        gpu=gpu,
        soft_update_tau= train_options['tau'],
        minibatch_size = train_options['minibatch_size'],
        entropy_target=-action_size,
        temperature_optimizer_lr=1e-3,
    )
    return agent

In [4]:
from envs.photo_env import PhotoEnhancementEnv
env = PhotoEnhancementEnv(64)

Encoding testing data ...


  0%|          | 0/4 [00:00<?, ?it/s]

finished...
Encoding training data ...


  0%|          | 0/36 [00:00<?, ?it/s]

finished...


In [5]:
EPISODES = 1000
agent = create_agent_sac(512*2,env.num_parameters,train_options)
for i in range(EPISODES):
    obs = env.reset()
    obs=[obs]
    episode_count = 0
    while True:
        encoded_source_image = obs[0]['encoded_source']
        encoded_enhanced_image = obs[0]['encoded_enhanced_image']
        batch_observation = torch.cat((encoded_source_image,encoded_enhanced_image),dim=1)
        batch_actions = torch.tensor(agent.batch_act(batch_observation))


        obs = env.step(batch_actions)
        encoded_source_image = obs[0]['encoded_source']
        encoded_enhanced_image = obs[0]['encoded_enhanced_image']
        batch_observation = torch.cat((encoded_source_image,encoded_enhanced_image),dim=1)
        batch_rewards =  obs[1]
        batch_done = obs[2]
        batch_reset = [False for i in range(len(batch_done))]
        agent.batch_observe(batch_observation ,batch_rewards,batch_done,batch_reset)
        episode_count+=1
        if (batch_done==True).all()==True or episode_count==10:
            print(batch_rewards)
            break      

  "action": torch.as_tensor(


tensor([-0.5264, -0.2358, -0.4945, -0.4050, -0.4357, -0.1342, -0.4530, -0.3219,
        -0.4113, -0.4120, -0.5664, -0.4006, -0.3757, -0.3983, -0.4537, -0.3185,
        -0.6678, -0.5228, -0.4673, -0.4227, -0.3658, -0.3754, -0.4243, -0.3456,
        -0.6177, -0.3061, -0.3853, -0.2711, -0.4708, -0.4922, -0.2825, -0.5136,
        -0.6022, -0.3146, -0.3994, -0.2765, -0.5079, -0.3880, -0.4522, -0.4903,
        -0.3782, -0.4390, -0.4811, -0.4507, -0.7097, -0.5293, -0.2938, -0.4309,
        -0.4425, -0.5188, -0.5621, -0.3814, -0.4158, -0.4829, -0.4698, -0.4283,
        -0.4882, -0.2653, -0.4311, -0.3662, -0.3546, -0.5339, -0.4477, -0.2685])
tensor([-0.4683, -0.5954, -0.2804, -0.5853, -0.4061, -0.6482, -0.2889, -0.5263,
        -0.4454, -0.3507, -0.3673, -0.6227, -0.4594, -0.3944, -0.3305, -0.4275,
        -0.5740, -0.2600, -0.3022, -0.4856, -0.2186, -0.5512, -0.2910, -0.5667,
        -0.3125, -0.2349, -0.3296, -0.6392, -0.2509, -0.5585, -0.4085, -0.6118,
        -0.5535, -0.4067, -0.4935, -0.4

KeyboardInterrupt: 