In [7]:
import os 
os.chdir('../src')

In [9]:
from envs.env_dataloader import create_dataloaders
import torchvision.transforms as transforms
from torchvision.transforms import v2
from torchmetrics.image import StructuralSimilarityIndexMeasure
from envs.new_edit_photo import PhotoEditor
from sac.sac_inference import InferenceAgent
import yaml
from envs.photo_env import PhotoEnhancementEnvTest
import numpy as np
import argparse
import logging
import os
from pathlib import Path
from tqdm import tqdm

In [38]:
class Config(object):
    def __init__(self, dictionary):
        self.__dict__.update(dictionary)

with open(os.path.join("configs/inference_config.yaml")) as f:
    inf_config_dict =yaml.load(f, Loader=yaml.FullLoader)
inference_config = Config(inf_config_dict)

In [39]:
experiments_path = '../experiments/runs/'
models_path = [os.path.join(experiments_path,model_path) for model_path in os.listdir(experiments_path)]

In [40]:
models_path.remove('../experiments/runs/ResNetEncoder__1__1720274282')
models_path.remove('../experiments/runs/ResNetEncoder__1__1719841165')
models_path

['../experiments/runs/ResNetEncoder__solos__2024-07-08_14-39-07',
 '../experiments/runs/ResNetEncoder__fivesliders__2024-07-10_15-57-44',
 '../experiments/runs/ResNetEncoder__fiveslidersaug__2024-07-12_14-59-37',
 '../experiments/runs/ResNetEncoder__sixsliders__2024-07-11_21-02-42',
 '../experiments/runs/ResNetEncoder__allsliders__2024-07-09_16-04-37',
 '../experiments/runs/ResNetEncoder__ninesliders__2024-07-11_22-58-05']

In [41]:

models_names = [name.split('/')[-1] for name in models_path]

In [42]:
models_names

['ResNetEncoder__solos__2024-07-08_14-39-07',
 'ResNetEncoder__fivesliders__2024-07-10_15-57-44',
 'ResNetEncoder__fiveslidersaug__2024-07-12_14-59-37',
 'ResNetEncoder__sixsliders__2024-07-11_21-02-42',
 'ResNetEncoder__allsliders__2024-07-09_16-04-37',
 'ResNetEncoder__ninesliders__2024-07-11_22-58-05']

In [43]:
all_psnrs = []
all_ssims = []
for model_path in models_path:
    with open(os.path.join(model_path,"configs/sac_config.yaml")) as f:
        sac_config_dict =yaml.load(f, Loader=yaml.FullLoader)
    with open(os.path.join(model_path,"configs/env_config.yaml")) as f:
        env_config_dict =yaml.load(f, Loader=yaml.FullLoader)

    sac_config = Config(sac_config_dict)
    env_config = Config(env_config_dict)

    photo_editor = PhotoEditor(env_config.sliders_to_use)

    inference_env = PhotoEnhancementEnvTest(
                        batch_size=inference_config.batch_size,
                        imsize=inference_config.imsize,
                        training_mode=False,
                        done_threshold=inference_config.threshold_psnr,
                        pre_encode=False,
                        edit_sliders=env_config.sliders_to_use,
                        features_size=inference_config.features_size,
                        discretize=env_config.discretize,
                        discretize_step= env_config.discretize_step,
                        logger=None)

    inf_agent =InferenceAgent(inference_env, inference_config)
    inf_agent.load_backbone(os.path.join(model_path,'models','backbone.pth'))
    inf_agent.load_actor_weights(os.path.join(model_path,'models','actor_head.pth'))
    inf_agent.load_critics_weights( os.path.join(model_path,'models','qf1_head.pth'), os.path.join(model_path,'models','qf2_head.pth'))

    ssim_metric = StructuralSimilarityIndexMeasure()
    test_512 = create_dataloaders(batch_size=1,image_size=64,train=False,pre_encode= False,shuffle=False,resize=False)
    transform = transforms.Compose([
                v2.Resize(size = (64,64), interpolation= transforms.InterpolationMode.BICUBIC),
            ])
    PSNRS = []
    SSIM = []
    for i,t in tqdm(test_512, position=0, leave=True):
        source = i/255.0
        target = t/255.0 
        parameters = inf_agent.act(obs=transform(source),deterministic=True)
        enhanced_image = photo_editor((source.permute(0,2,3,1)).cpu(),parameters[2].cpu())
        psnr = inference_env.compute_rewards(enhanced_image.permute(0,3,1,2),target).item()+50
        ssim = ssim_metric(enhanced_image.permute(0,3,1,2),target).item()
        PSNRS.append(psnr)
        SSIM.append(ssim)
    mean_PSNRS = round(np.mean(PSNRS),2)
    mean_SSIM = round(np.mean(SSIM),3)
    all_psnrs.append(mean_PSNRS)
    all_ssims.append(mean_SSIM)

 68%|██████▊   | 338/500 [01:09<00:36,  4.40it/s]

In [None]:
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 10))

# PSNR plot
ax1.bar(model_names, all_psnrs)
ax1.set_title('Mean PSNR for Different Models')
ax1.set_xlabel('Models')
ax1.set_ylabel('PSNR')
ax1.set_xticklabels(model_names, rotation=45, ha='right')

# SSIM plot
ax2.bar(model_names, all_ssims)
ax2.set_title('Mean SSIM for Different Models')
ax2.set_xlabel('Models')
ax2.set_ylabel('SSIM')
ax2.set_xticklabels(model_names, rotation=45, ha='right')

plt.tight_layout()
plt.show()