In [None]:
import torch
from torch import distributions, nn
import pfrl
from pfrl import experiments, replay_buffers, utils
from pfrl.nn.lmbda import Lambda
from pfrl.nn import BoundByTanh, ConcatObsAndAction
import yaml
import time
import random 
import numpy as np

In [None]:
with open("configs/hyperparameters.yaml") as f:
    config_dict =yaml.load(f, Loader=yaml.FullLoader)
    
class Config(object):
    def __init__(self, dictionary):
        self.__dict__.update(dictionary)
sac_config = Config(config_dict)

SEED = sac_config.seed
DEVICE= 'CUDA'

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = sac_config.torch_deterministic
torch.autograd.set_detect_anomaly(True)

In [None]:
train_options = {
    'gamma':0.9,
    'minibatch_size':124,
    'tau':0.005,
    'replay_memory_size':40000

}

In [None]:
def create_agent_sac(state_size,action_size,config=sac_config,gpu=0):
    def squashed_diagonal_gaussian_head(x):
        assert x.shape[-1] == action_size * 2
        mean, log_scale = torch.chunk(x, 2, dim=1)
        log_scale = torch.clamp(log_scale, -20.0, 2.0)
        var = torch.exp(log_scale * 2)
        base_distribution = distributions.Independent(
            distributions.Normal(loc=mean, scale=torch.sqrt(var)), 1
        )
        # cache_size=1 is required for numerical stability
        return distributions.transformed_distribution.TransformedDistribution(
            base_distribution, [distributions.transforms.TanhTransform(cache_size=1)]
        )


    #critic_networks
    q_func_1 = nn.Sequential(
        ConcatObsAndAction(),
        nn.Linear(state_size + action_size, 256),
        nn.ReLU(),
        nn.Linear(256, 256),
        nn.ReLU(),
        nn.Linear(256, 1),
    )
    q_func_2 = nn.Sequential(
        ConcatObsAndAction(),
        nn.Linear(state_size + action_size, 256),
        nn.ReLU(),
        nn.Linear(256, 256),
        nn.ReLU(),
        nn.Linear(256, 1),
    )
    
    #actor_network
    policy = nn.Sequential(
        nn.Linear(state_size, 256),
        nn.ReLU(),
        nn.Linear(256, 256),
        nn.ReLU(),
        nn.Linear(256, action_size* 2),
        pfrl.nn.lmbda.Lambda(squashed_diagonal_gaussian_head),
    )


    opt_a = torch.optim.Adam(policy.parameters(), lr=config.policy_lr)
    opt_c_1 = torch.optim.Adam(q_func_1.parameters(), lr=config.q_lr)
    opt_c_2 = torch.optim.Adam(q_func_2.parameters(), lr=config.q_lr)
    
    rbuf = replay_buffers.ReplayBuffer(sac_config.buffer_size)

    agent = pfrl.agents.SoftActorCritic(
        policy,
        q_func_1,
        q_func_2,
        opt_a,
        opt_c_1,
        opt_c_2,
        rbuf,
        gamma= config.buffer_size,
        update_interval=1,
        replay_start_size=config.learning_starts*64,
        gpu=gpu,
        soft_update_tau= config.tau,
        minibatch_size = config.batch_size,
        entropy_target=-action_size,
        temperature_optimizer_lr=config.q_lr,
    )
    return agent

In [None]:
from envs.photo_env import PhotoEnhancementEnv
from envs.photo_env import PhotoEnhancementEnvTest
env = PhotoEnhancementEnv()
test_env = PhotoEnhancementEnvTest()

In [None]:

from torch.utils.tensorboard import SummaryWriter
run_name = f"{sac_config.exp_name}__{sac_config.seed}__{int(time.time())}"
writer = SummaryWriter(f"runs/{run_name}")
writer.add_text(
    "hyperparameters",
    "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(sac_config).items()])),
)

agent = create_agent_sac(512,env.num_parameters,train_options)

In [7]:
global_step = 0
start_time = time.time()
for i in range(sac_config.total_timesteps):
    episode_count = 0
    agent.training=True
    obs = env.reset()
    envs_mean_rewards =[]
    while True:     
        episode_count+=1
        global_step+=1
        batch_actions = torch.tensor(agent.batch_act(obs))
        next_obs,rewards,batch_dones = env.step(batch_actions)
        envs_mean_rewards.append(rewards)
        batch_reset = [False for i in range(len(batch_dones))]
        agent.batch_observe(next_obs ,rewards,batch_dones,batch_reset)
        obs=next_obs
        
        if(batch_dones==True).any():
            print('one done')
            print(env.sub_env_running)

        if  global_step % 100 == 0:
            writer.add_scalar("charts/mean_episodic_return", ens_mean_episodic_return, global_step)
            print("SPS:", int(global_step / (time.time() - start_time)))
            stats = agent.get_statistics()
            for stat in stats:
                writer.add_scalar(f"charts/{stat[0]}", stat[1], global_step)   
        if (batch_dones==True).all()==True or episode_count==sac_config.max_episode_timesteps:
            ens_mean_episodic_return = np.mean(envs_mean_rewards)
            envs_mean_rewards =[]  
            break 

    if global_step%200==0:
        agent.training=False
        with torch.no_grad():
            n_images = 10
            obs = test_env.reset() 
            batch_actions = torch.tensor(agent.batch_act(obs))
            _,rewards,_ = test_env.step(batch_actions)
            writer.add_scalar("charts/test_mean_episodic_return", rewards.mean().item(), global_step)
            writer.add_images("test_images",test_env.state['source_image'][:n_images],0)
            writer.add_images("test_images",test_env.state['enhanced_image'][:n_images],1)
            writer.add_images("test_images",test_env.state['target_image'][:n_images],2)

In [None]:

# for i in range(EPISODES):
#     obs = env.reset()
#     obs=[obs]
#     episode_count = 0
#     while True:
#         encoded_source_image = obs[0]['encoded_source']
#         encoded_enhanced_image = obs[0]['encoded_enhanced_image']
#         batch_observation = torch.cat((encoded_source_image,encoded_enhanced_image),dim=1)
#         batch_actions = torch.tensor(agent.batch_act(batch_observation))


#         obs = env.step(batch_actions)
#         encoded_source_image = obs[0]['encoded_source']
#         encoded_enhanced_image = obs[0]['encoded_enhanced_image']
#         batch_observation = torch.cat((encoded_source_image,encoded_enhanced_image),dim=1)
#         batch_rewards =  obs[1]
#         batch_done = obs[2]
#         batch_reset = [False for i in range(len(batch_done))]
#         agent.batch_observe(batch_observation ,batch_rewards,batch_done,batch_reset)
#         episode_count+=1
#         if (batch_done==True).all()==True or episode_count==10:
#             print(batch_rewards)
#             break      