In [None]:
import yaml
import time
import random
import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter
import time
from envs.photo_env import PhotoEnhancementEnv
from envs.photo_env import PhotoEnhancementEnvTest
from sac.sac_algorithm import SAC
import multiprocessing as mp
try:
    mp.set_start_method('spawn', force=True)
except RuntimeError:
    pass  


In [None]:
env = PhotoEnhancementEnv()
test_env = PhotoEnhancementEnvTest()

In [None]:
with open("configs/hyperparameters.yaml") as f:
    config_dict =yaml.load(f, Loader=yaml.FullLoader)
    
class Config(object):
    def __init__(self, dictionary):
        self.__dict__.update(dictionary)
sac_config = Config(config_dict)

SEED = sac_config.seed
DEVICE= 'CUDA'

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = sac_config.torch_deterministic
torch.autograd.set_detect_anomaly(True)

In [None]:
run_name = f"{sac_config.exp_name}__{sac_config.seed}__{int(time.time())}"
writer = SummaryWriter(f"runs/{run_name}")
writer.add_text(
    "hyperparameters",
    "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(sac_config).items()])),
)
agent = SAC(env,sac_config,writer)

In [5]:
agent.start_time = time.time()
for i in range(sac_config.total_timesteps):
    episode_count = 0
    
    agent.reset_env()
    envs_mean_rewards =[]
    while True:     
        episode_count+=1
        agent.global_step+=1
        rewards,batch_dones = agent.train()
        envs_mean_rewards.append(rewards.mean().item())
        if(batch_dones==True).any():
            # print('one done')
            # print(agent.state.shape,agent.env.sub_env_running.shape)
            num_env_done = int(batch_dones.sum().item())
            agent.writer.add_scalar("charts/num_env_done", num_env_done , agent.global_step)
        if agent.global_step % 100 == 0:
            ens_mean_episodic_return = sum(envs_mean_rewards)
            agent.writer.add_scalar("charts/mean_episodic_return", ens_mean_episodic_return, agent.global_step)

        if (batch_dones==True).all()==True or episode_count==sac_config.max_episode_timesteps:
            episode_count=0           
            break 

    if agent.global_step%200==0:
        agent.backbone.eval()
        with torch.no_grad():
            n_images = 5
            obs = test_env.reset() 
            actions = agent.actor.get_action(obs.to(sac_config.device))
            _,rewards,dones = test_env.step(actions[0].cpu())
            agent.writer.add_scalar("charts/test_mean_episodic_return", rewards.mean().item(), agent.global_step)
            agent.writer.add_images("test_images",test_env.state['source_image'][:n_images],0)
            agent.writer.add_images("test_images",test_env.state['enhanced_image'][:n_images],1)
            agent.writer.add_images("test_images",test_env.state['target_image'][:n_images],2)
        agent.backbone.train()
            

tensor([False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False])
tensor([])
tensor([-0.0855, -0.2336, -0.1189, -0.1057, -0.0802, -0.1202, -0.0797, -0.2552,
        -0.0652, -0.1022, -0.0857, -0.0501, -0.2029, -0.2656, -0.1410, -0.0995,
        -0.0713, -0.1752, -0.0602, -0.0384, -0.3845, -0.2845, -0.2289, -0.0764,
        -0.2044, -0.0637, -0.0520, -0.1045, -0.0801, -0.1407, -0.1775, -0.1949])
tensor([False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False])
tensor([])
tensor([-0.1070, -0.2030, -0.1092, -0.1119, -0.1349, -0.1120, -0.1036, -0.0789,
        -0.0921, -0.0836, -0.0633, -0.1977, -0.1084, -0.0573, -

In [None]:
import torch

def sample_near_values_batch(tensor, batch_size, std_dev=0.05, clip_min=0.0, clip_max=1.0):
    """
    Generate a batch of sampled values near the given tensor.
    
    Args:
    tensor (torch.Tensor): The input tensor to sample near.
    batch_size (int): The number of samples to generate.
    std_dev (float): Standard deviation for the normal distribution.
    clip_min (float): Minimum value to clip the result.
    clip_max (float): Maximum value to clip the result.
    
    Returns:
    torch.Tensor: A batch of tensors with sampled values.
    """
    # Expand the input tensor to the desired batch size
    batched_tensor = tensor.unsqueeze(0).expand(batch_size, -1)
    
    # Create a noise tensor with the same shape as the batched tensor
    noise = torch.randn_like(batched_tensor) * std_dev
    
    # Add the noise to the batched tensor
    sampled = batched_tensor + noise
    
    # Clip the values to ensure they're within the specified range
    sampled = torch.clamp(sampled, clip_min, clip_max)
    
    return sampled

# Your original tensor
original_tensor = torch.tensor([0.125, 0.125, 0.375, 0.125, 0., 0.0625, 0.9375, 0.375, 0.0625, 0., 0.125, 0.125])

# Set the desired batch size
batch_size = 5

# Generate a batch of sampled values
sampled_batch = sample_near_values_batch(original_tensor, batch_size)
sampled_batch

In [None]:
import cv2
import torch
import matplotlib.pyplot as plt
source_image = cv2.imread("sample_images/a0001-jmac_DSC1459.jpg")
target_image = cv2.imread("sample_images/a0676-kme_609_C.jpg")
source_image = cv2.cvtColor(source_image, cv2.COLOR_BGR2RGB) 
target_image = cv2.cvtColor(target_image, cv2.COLOR_BGR2RGB) 
source_image = cv2.resize(source_image, (64, 64)) / 255.0
target_image = cv2.resize(target_image, (64, 64)) / 255.0

input = torch.Tensor(source_image).permute(2,0,1).unsqueeze(0)

enhanced_image = input.clone()

In [None]:
input.shape

In [None]:
plt.imshow(target_image)

In [None]:
plt.imshow(source_image)

In [None]:
from envs.features_extractor import ResnetEncoder
from envs.new_edit_photo import PhotoEditor
import matplotlib.pyplot as plt
photo_editor = PhotoEditor()
image_encoder = ResnetEncoder()

In [None]:
agent.global_step

In [None]:
encoded_source = image_encoder.encode(input)
original_64 = input.permute(0,2,3,1)
original_image = torch.Tensor(source_image).unsqueeze(0)

In [None]:
param = torch.tensor([0.125, 0.125, 0.375, 0.125, 0., 0.0625, 0.9375, 0.375, 0.0625, 0., 0.125, 0.125]).unsqueeze(0)

In [None]:
enhanced_image=enhanced_image.permute(0,2,3,1)

In [None]:
for i in range(1):
    with torch.no_grad():
        encoded_enhanced = image_encoder.encode(enhanced_image.permute(0,3,1,2))
        encoded_input = torch.cat([encoded_source,encoded_enhanced],dim=1)
        parameters = agent.actor.get_action(encoded_input)
        enhanced_image = photo_editor(original_64.cpu(),parameters[0].cpu())
enhanced_image_512 = photo_editor(original_image.cpu(),parameters[0].cpu())
parameters

In [None]:
enhanced_image.shape

In [None]:
enhanced =torch.flatten(enhanced_image_512.clone(),start_dim=1, end_dim=-1)
target = torch.flatten(original_image.clone(),start_dim=1, end_dim=-1)

rmse = enhanced-target
rmse = torch.pow(rmse,2).mean(1)

In [None]:
rmse

In [None]:
plt.imshow(enhanced_image_512[0])