# Quantitative evaluation

Use ArFace metric and Emotion metric to compare the performance of : 
    - e4e initialization vs hybrid initialization
    - og loss vs custom loss

In [8]:
from argparse import Namespace
import time
import sys
import numpy as np
from PIL import Image
import torch
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from torch.amp import autocast
from tqdm import tqdm
import clip
import torch.nn.functional as F

from utils.common import tensor2im
from utils.alignment import run_alignment
from models.psp import pSp  # we use the pSp framework to load the e4e encoder.
from criteria.clip_loss import CLIPLoss
from criteria.id_loss import IDLoss

%load_ext autoreload
%autoreload 2
%matplotlib inline

device = 'cuda' if torch.cuda.is_available() else 'cpu'
RESIZE_DIMS = (256, 256)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
# Define the arguments for the e4e encoder

pretrained_pSp_path = 'pretrained_models/e4e_ffhq_encode.pt'

e4e_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]
)

In [3]:
# Load the pSp model

ckpt = torch.load(pretrained_pSp_path, map_location='cpu')
opts = ckpt['opts']
opts['checkpoint_path'] = pretrained_pSp_path
opts['ir_se50_weights'] = 'pretrained_models/model_ir_se50.pth'
opts= Namespace(**opts)
latent_avg = ckpt["latent_avg"]

psp_model = pSp(opts)
psp_model.eval()
psp_model.cuda()

print('Model successfully loaded!')

  ckpt = torch.load(pretrained_pSp_path, map_location='cpu')


Loading e4e over the pSp framework from checkpoint: pretrained_models/e4e_ffhq_encode.pt


  ckpt = torch.load(self.opts.checkpoint_path, map_location='cpu')


Model successfully loaded!


In [None]:
# Load the emotion model

from emotionmmodel.networks.DDAM import DDAMNet
emotion_model = DDAMNet(num_class=7, num_head=2, pretrained=False)
path = "pretrained_models/affecnet7_epoch19_acc0.671.pth"
checkpoint = torch.load(path, map_location=device)
emotion_model.load_state_dict(checkpoint['model_state_dict'])
emotion_model.to(device)
emotion_model.eval()    
val_transform = transforms.Compose([
        transforms.Resize((112, 112)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])]) 

  checkpoint = torch.load(path, map_location=device)


In [24]:
import os

input_images_path = "faces_dataset_very_small"
N_images = len(os.listdir(input_images_path))
images = np.array(os.listdir(input_images_path))[np.random.randint(0, len(os.listdir(input_images_path)), N_images)]


In [26]:
# Parameters for inversion
optimization_steps_inv = 50
lambda_percept_inv = 1
lambda_L2_inv = 0.5
lr_inv = 0.01 
final_lr_inv = 0.001

# Parameters for editing
optimization_steps_edit = 50
lambda_L2_edit = 0.01
lambda_ID_edit = 0.01
lambda_E_edit = 0.01
lr_edit = 0.01
final_lr_edit = 0.001
truncation_edit = 0.9

id_loss = IDLoss(opts)
L2_loss = torch.nn.MSELoss().to(device)
clip_loss = CLIPLoss(opts)
emotion_loss = torch.nn.CrossEntropyLoss()

G = psp_model.decoder
prompt = "A person smiling"
# class7_names = ['Neutral', 'Happy', 'Sad', 'Surprise', 'Fear', 'Disgust', 'Angry']
goal_distribution = torch.tensor([[0., 1., 0., 0., 0., 0., 0.]], device=device)

Loading ResNet ArcFace


In [27]:
sims_e4e = []
sims_hybrid = []
sims_e4e_edit = []
sims_hybrid_edit = []
emo_scores_e4e = []
emo_scores_hybrid = []

for i_image, image_path in enumerate(images):
    print(f"*** Processing image {i_image + 1}/{N_images} ***")
    image = run_alignment(os.path.join(input_images_path, image_path)).resize(RESIZE_DIMS)
    input_image = e4e_transform(image)
    reference_image = transforms.Resize((1024, 1024))(input_image).unsqueeze(0).cuda()


    with torch.no_grad():
        e4e_inverted_latent = psp_model.encoder(input_image.unsqueeze(0).to(device))
    if psp_model.opts.start_from_latent_avg:
                    if e4e_inverted_latent.ndim == 2:
                        e4e_inverted_latent = e4e_inverted_latent + psp_model.latent_avg.repeat(e4e_inverted_latent.shape[0], 1, 1)[:, 0, :]
                    else:
                        e4e_inverted_latent = e4e_inverted_latent + psp_model.latent_avg.repeat(e4e_inverted_latent.shape[0], 1, 1)


    w_n = e4e_inverted_latent.clone().detach().requires_grad_(True).to(device) # initialization with e4e output

    optimizer = torch.optim.Adam([w_n], lr=lr_inv)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=optimization_steps_inv, eta_min=final_lr_inv)

 

    pbar = tqdm(range(optimization_steps_inv))

    imgs = []
    losses = []


    for i in pbar:

        with autocast("cuda"):
            
            optimizer.zero_grad()

            img_gen, _ = G([w_n], input_is_latent=True, randomize_noise=False)

            i_loss = id_loss(img_gen, reference_image)[0]
            l2_loss = L2_loss(img_gen, reference_image)

            loss = i_loss * lambda_percept_inv + l2_loss * lambda_L2_inv 
            losses.append(loss.item())
            pbar.set_description(f"Loss: {loss.item():.4f}")

            loss.backward()
            optimizer.step()
            scheduler.step()

    inverted_latent = w_n.detach()

    # Compute ID metric between inverted images and original image

    with torch.no_grad():
        e4e_inverted_image, _ = G([e4e_inverted_latent], input_is_latent=True, randomize_noise=False)
        inverted_image = img_gen.detach()

    sim_e4e = id_loss(e4e_inverted_image, reference_image)[0]
    sim_hybrid = id_loss(inverted_image, reference_image)[0]

    sims_e4e.append(sim_e4e.item())
    sims_hybrid.append(sim_hybrid.item())

    ##### Editing evaluation #####

    text_inputs = torch.cat([clip.tokenize(prompt)]).cuda()

    for latent, sims, emo_scores in zip([e4e_inverted_latent, inverted_latent], [sims_e4e_edit, sims_hybrid_edit], [emo_scores_e4e, emo_scores_hybrid]):

    # Initialization
        w_start = latent

        w_n = w_start.clone().detach().requires_grad_(True).to(device)
        optimizer = torch.optim.Adam([w_n], lr=lr_edit)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=optimization_steps_edit, eta_min=final_lr_edit)


        with torch.no_grad():
            img_orig, _ = G([w_n], input_is_latent=True)

        pbar = tqdm(range(optimization_steps_edit))

        imgs = []
        losses = []

        # Optimization loop

        for i in pbar:

            with autocast("cuda"):
                
                optimizer.zero_grad()

                img_gen, _ = G([w_n], input_is_latent=True, randomize_noise=False)

                c_loss = clip_loss(img_gen, text_inputs)
                i_loss = id_loss(img_gen, img_orig)[0] # original
                l2_loss = ((w_start - w_n) ** 2).sum()

                img_for_emotion = F.interpolate(img_gen, size=(112, 112), mode='bilinear', align_corners=False)
                prediction, feat, heads = emotion_model(img_for_emotion)
                e_loss = emotion_loss(prediction, goal_distribution)

                loss = c_loss + lambda_L2_edit * l2_loss + lambda_ID_edit * i_loss + lambda_E_edit * e_loss
                losses.append(loss.item())
                pbar.set_description(f"Loss: {loss.item():.4f}")

                loss.backward()
                optimizer.step()
                scheduler.step()

        final_image = img_gen.detach()

        # Compute ID metric between edited images and original image

        sim = id_loss(final_image, reference_image)[0]
        emo_score = e_loss # TODO: compute emotion score

        sims.append(sim.item())
        emo_scores.append(emo_score.item())

*** Processing image 1/65 ***


Loss: 0.0310: 100%|██████████| 50/50 [00:08<00:00,  6.11it/s]
Loss: 0.7192: 100%|██████████| 50/50 [00:16<00:00,  3.02it/s]
Loss: 0.7134: 100%|██████████| 50/50 [00:16<00:00,  3.04it/s]


*** Processing image 2/65 ***


Loss: 0.0587: 100%|██████████| 50/50 [00:08<00:00,  6.09it/s]
Loss: 0.7222: 100%|██████████| 50/50 [00:16<00:00,  3.02it/s]
Loss: 0.7236: 100%|██████████| 50/50 [00:16<00:00,  3.02it/s]


*** Processing image 3/65 ***


Loss: 0.0232: 100%|██████████| 50/50 [00:08<00:00,  6.06it/s]
Loss: 0.7461: 100%|██████████| 50/50 [00:16<00:00,  3.02it/s]
Loss: 0.7422: 100%|██████████| 50/50 [00:16<00:00,  3.02it/s]


*** Processing image 4/65 ***


Loss: 0.0211: 100%|██████████| 50/50 [00:08<00:00,  6.05it/s]
Loss: 0.7354: 100%|██████████| 50/50 [00:16<00:00,  3.02it/s]
Loss: 0.7354: 100%|██████████| 50/50 [00:16<00:00,  3.02it/s]


*** Processing image 5/65 ***


Loss: 0.0516: 100%|██████████| 50/50 [00:08<00:00,  6.02it/s]
Loss: 0.7192: 100%|██████████| 50/50 [00:16<00:00,  3.02it/s]
Loss: 0.7183: 100%|██████████| 50/50 [00:16<00:00,  3.02it/s]


*** Processing image 6/65 ***


Loss: 0.0283: 100%|██████████| 50/50 [00:08<00:00,  6.02it/s]
Loss: 0.7549: 100%|██████████| 50/50 [00:16<00:00,  3.02it/s]
Loss: 0.7261: 100%|██████████| 50/50 [00:16<00:00,  3.02it/s]


*** Processing image 7/65 ***


Loss: 0.0328: 100%|██████████| 50/50 [00:08<00:00,  6.02it/s]
Loss: 0.7319: 100%|██████████| 50/50 [00:16<00:00,  3.02it/s]
Loss: 0.7319: 100%|██████████| 50/50 [00:16<00:00,  3.02it/s]


*** Processing image 8/65 ***


Loss: 0.0377: 100%|██████████| 50/50 [00:08<00:00,  6.01it/s]
Loss: 0.7407: 100%|██████████| 50/50 [00:16<00:00,  3.00it/s]
Loss: 0.7480: 100%|██████████| 50/50 [00:16<00:00,  3.02it/s]


*** Processing image 9/65 ***


Loss: 0.0532: 100%|██████████| 50/50 [00:08<00:00,  6.00it/s]
Loss: 0.7119: 100%|██████████| 50/50 [00:16<00:00,  3.03it/s]
Loss: 0.7080: 100%|██████████| 50/50 [00:16<00:00,  3.03it/s]


*** Processing image 10/65 ***


Loss: 0.0196: 100%|██████████| 50/50 [00:08<00:00,  6.01it/s]
Loss: 0.7446: 100%|██████████| 50/50 [00:16<00:00,  3.03it/s]
Loss: 0.7529: 100%|██████████| 50/50 [00:16<00:00,  3.03it/s]


*** Processing image 11/65 ***


Loss: 0.0390: 100%|██████████| 50/50 [00:08<00:00,  6.00it/s]
Loss: 0.7227: 100%|██████████| 50/50 [00:16<00:00,  3.03it/s]
Loss: 0.7207: 100%|██████████| 50/50 [00:16<00:00,  3.03it/s]


*** Processing image 12/65 ***


Loss: 0.0231: 100%|██████████| 50/50 [00:08<00:00,  6.01it/s]
Loss: 0.7456: 100%|██████████| 50/50 [00:16<00:00,  3.03it/s]
Loss: 0.7422: 100%|██████████| 50/50 [00:16<00:00,  3.03it/s]


*** Processing image 13/65 ***


Loss: 0.0393: 100%|██████████| 50/50 [00:08<00:00,  6.00it/s]
Loss: 0.7397: 100%|██████████| 50/50 [00:16<00:00,  3.03it/s]
Loss: 0.7329: 100%|██████████| 50/50 [00:16<00:00,  3.03it/s]


*** Processing image 14/65 ***


Loss: 0.0456: 100%|██████████| 50/50 [00:08<00:00,  6.00it/s]
Loss: 0.7095: 100%|██████████| 50/50 [00:16<00:00,  3.03it/s]
Loss: 0.7070: 100%|██████████| 50/50 [00:16<00:00,  3.02it/s]


*** Processing image 15/65 ***


Loss: 0.0287: 100%|██████████| 50/50 [00:08<00:00,  6.00it/s]
Loss: 0.7212: 100%|██████████| 50/50 [00:16<00:00,  3.03it/s]
Loss: 0.7114: 100%|██████████| 50/50 [00:16<00:00,  3.00it/s]


*** Processing image 16/65 ***


Loss: 0.0379: 100%|██████████| 50/50 [00:08<00:00,  6.00it/s]
Loss: 0.7217: 100%|██████████| 50/50 [00:16<00:00,  3.02it/s]
Loss: 0.7192: 100%|██████████| 50/50 [00:16<00:00,  3.02it/s]


*** Processing image 17/65 ***


Loss: 0.0387: 100%|██████████| 50/50 [00:08<00:00,  6.00it/s]
Loss: 0.7061: 100%|██████████| 50/50 [00:16<00:00,  3.05it/s]
Loss: 0.7056: 100%|██████████| 50/50 [00:16<00:00,  3.03it/s]


*** Processing image 18/65 ***


Loss: 0.0326: 100%|██████████| 50/50 [00:08<00:00,  6.00it/s]
Loss: 0.7153: 100%|██████████| 50/50 [00:16<00:00,  3.03it/s]
Loss: 0.7329: 100%|██████████| 50/50 [00:16<00:00,  3.03it/s]


*** Processing image 19/65 ***


Loss: 0.0485: 100%|██████████| 50/50 [00:08<00:00,  6.00it/s]
Loss: 0.7456: 100%|██████████| 50/50 [00:16<00:00,  3.03it/s]
Loss: 0.7363: 100%|██████████| 50/50 [00:16<00:00,  3.03it/s]


*** Processing image 20/65 ***


Loss: 0.0487: 100%|██████████| 50/50 [00:08<00:00,  6.00it/s]
Loss: 0.7451: 100%|██████████| 50/50 [00:16<00:00,  3.03it/s]
Loss: 0.7363: 100%|██████████| 50/50 [00:16<00:00,  3.03it/s]


*** Processing image 21/65 ***


Loss: 0.0377: 100%|██████████| 50/50 [00:08<00:00,  6.00it/s]
Loss: 0.7090: 100%|██████████| 50/50 [00:16<00:00,  3.03it/s]
Loss: 0.7100: 100%|██████████| 50/50 [00:16<00:00,  3.03it/s]


*** Processing image 22/65 ***


Loss: 0.0413: 100%|██████████| 50/50 [00:08<00:00,  6.01it/s]
Loss: 0.7578: 100%|██████████| 50/50 [00:16<00:00,  3.03it/s]
Loss: 0.7554: 100%|██████████| 50/50 [00:16<00:00,  3.00it/s]


*** Processing image 23/65 ***


Loss: 0.0373: 100%|██████████| 50/50 [00:08<00:00,  6.00it/s]
Loss: 0.7412: 100%|██████████| 50/50 [00:16<00:00,  3.02it/s]
Loss: 0.7456: 100%|██████████| 50/50 [00:16<00:00,  3.02it/s]


*** Processing image 24/65 ***


Loss: 0.0373: 100%|██████████| 50/50 [00:08<00:00,  6.00it/s]
Loss: 0.7412: 100%|██████████| 50/50 [00:16<00:00,  3.02it/s]
Loss: 0.7456: 100%|██████████| 50/50 [00:16<00:00,  3.02it/s]


*** Processing image 25/65 ***


Loss: 0.0257: 100%|██████████| 50/50 [00:08<00:00,  6.01it/s]
Loss: 0.7310: 100%|██████████| 50/50 [00:16<00:00,  3.02it/s]
Loss: 0.7227: 100%|██████████| 50/50 [00:16<00:00,  3.02it/s]


*** Processing image 26/65 ***


Loss: 0.0273: 100%|██████████| 50/50 [00:08<00:00,  6.00it/s]
Loss: 0.7168: 100%|██████████| 50/50 [00:16<00:00,  3.01it/s]
Loss: 0.7041: 100%|██████████| 50/50 [00:16<00:00,  3.01it/s]


*** Processing image 27/65 ***


Loss: 0.0254: 100%|██████████| 50/50 [00:08<00:00,  6.00it/s]
Loss: 0.7305: 100%|██████████| 50/50 [00:16<00:00,  3.01it/s]
Loss: 0.7407: 100%|██████████| 50/50 [00:16<00:00,  3.01it/s]


*** Processing image 28/65 ***


Loss: 0.0624: 100%|██████████| 50/50 [00:08<00:00,  5.99it/s]
Loss: 0.7676: 100%|██████████| 50/50 [00:16<00:00,  3.01it/s]
Loss: 0.7803: 100%|██████████| 50/50 [00:16<00:00,  3.01it/s]


*** Processing image 29/65 ***


Loss: 0.0326: 100%|██████████| 50/50 [00:08<00:00,  5.99it/s]
Loss: 0.7236: 100%|██████████| 50/50 [00:16<00:00,  3.01it/s]
Loss: 0.7178: 100%|██████████| 50/50 [00:16<00:00,  2.98it/s]


*** Processing image 30/65 ***


Loss: 0.0414: 100%|██████████| 50/50 [00:08<00:00,  6.00it/s]
Loss: 0.7920: 100%|██████████| 50/50 [00:16<00:00,  3.01it/s]
Loss: 0.7598: 100%|██████████| 50/50 [00:16<00:00,  3.01it/s]


*** Processing image 31/65 ***


Loss: 0.0623: 100%|██████████| 50/50 [00:08<00:00,  5.99it/s]
Loss: 0.7656: 100%|██████████| 50/50 [00:16<00:00,  3.01it/s]
Loss: 0.7798: 100%|██████████| 50/50 [00:16<00:00,  3.01it/s]


*** Processing image 32/65 ***


Loss: 0.0294: 100%|██████████| 50/50 [00:08<00:00,  6.00it/s]
Loss: 0.7305: 100%|██████████| 50/50 [00:16<00:00,  3.01it/s]
Loss: 0.7275: 100%|██████████| 50/50 [00:16<00:00,  3.01it/s]


*** Processing image 33/65 ***


Loss: 0.0455: 100%|██████████| 50/50 [00:08<00:00,  5.99it/s]
Loss: 0.7090: 100%|██████████| 50/50 [00:16<00:00,  3.01it/s]
Loss: 0.7070: 100%|██████████| 50/50 [00:16<00:00,  3.01it/s]


*** Processing image 34/65 ***


Loss: 0.0339: 100%|██████████| 50/50 [00:08<00:00,  5.99it/s]
Loss: 0.7158: 100%|██████████| 50/50 [00:16<00:00,  3.01it/s]
Loss: 0.7212: 100%|██████████| 50/50 [00:16<00:00,  3.01it/s]


*** Processing image 35/65 ***


Loss: 0.0299: 100%|██████████| 50/50 [00:08<00:00,  5.99it/s]
Loss: 0.7314: 100%|██████████| 50/50 [00:16<00:00,  3.02it/s]
Loss: 0.7275: 100%|██████████| 50/50 [00:16<00:00,  3.02it/s]


*** Processing image 36/65 ***


Loss: 0.0427: 100%|██████████| 50/50 [00:08<00:00,  5.99it/s]
Loss: 0.7163: 100%|██████████| 50/50 [00:16<00:00,  3.03it/s]
Loss: 0.7168: 100%|██████████| 50/50 [00:16<00:00,  2.99it/s]


*** Processing image 37/65 ***


Loss: 0.0620: 100%|██████████| 50/50 [00:08<00:00,  5.99it/s]
Loss: 0.7681: 100%|██████████| 50/50 [00:16<00:00,  3.01it/s]
Loss: 0.7798: 100%|██████████| 50/50 [00:16<00:00,  3.01it/s]


*** Processing image 38/65 ***


Loss: 0.0252: 100%|██████████| 50/50 [00:08<00:00,  5.99it/s]
Loss: 0.7310: 100%|██████████| 50/50 [00:16<00:00,  3.01it/s]
Loss: 0.7417: 100%|██████████| 50/50 [00:16<00:00,  3.01it/s]


*** Processing image 39/65 ***


Loss: 0.0328: 100%|██████████| 50/50 [00:08<00:00,  5.99it/s]
Loss: 0.7329: 100%|██████████| 50/50 [00:16<00:00,  3.01it/s]
Loss: 0.7314: 100%|██████████| 50/50 [00:16<00:00,  3.01it/s]


*** Processing image 40/65 ***


Loss: 0.0263: 100%|██████████| 50/50 [00:08<00:00,  5.99it/s]
Loss: 0.7432: 100%|██████████| 50/50 [00:16<00:00,  3.01it/s]
Loss: 0.7417: 100%|██████████| 50/50 [00:16<00:00,  3.01it/s]


*** Processing image 41/65 ***


Loss: 0.0326: 100%|██████████| 50/50 [00:08<00:00,  5.99it/s]
Loss: 0.7236: 100%|██████████| 50/50 [00:16<00:00,  3.01it/s]
Loss: 0.7178: 100%|██████████| 50/50 [00:16<00:00,  3.01it/s]


*** Processing image 42/65 ***


Loss: 0.0572: 100%|██████████| 50/50 [00:08<00:00,  5.99it/s]
Loss: 0.7222: 100%|██████████| 50/50 [00:16<00:00,  3.02it/s]
Loss: 0.7222: 100%|██████████| 50/50 [00:16<00:00,  3.03it/s]


*** Processing image 43/65 ***


Loss: 0.0348: 100%|██████████| 50/50 [00:08<00:00,  6.00it/s]
Loss: 0.7344: 100%|██████████| 50/50 [00:16<00:00,  3.03it/s]
Loss: 0.7319: 100%|██████████| 50/50 [00:16<00:00,  3.00it/s]


*** Processing image 44/65 ***


Loss: 0.0390: 100%|██████████| 50/50 [00:08<00:00,  5.99it/s]
Loss: 0.7231: 100%|██████████| 50/50 [00:16<00:00,  3.03it/s]
Loss: 0.7197: 100%|██████████| 50/50 [00:16<00:00,  3.02it/s]


*** Processing image 45/65 ***


Loss: 0.0298: 100%|██████████| 50/50 [00:08<00:00,  5.98it/s]
Loss: 0.7119: 100%|██████████| 50/50 [00:16<00:00,  3.00it/s]
Loss: 0.7007: 100%|██████████| 50/50 [00:16<00:00,  3.00it/s]


*** Processing image 46/65 ***


Loss: 0.0453: 100%|██████████| 50/50 [00:08<00:00,  5.97it/s]
Loss: 0.7280: 100%|██████████| 50/50 [00:16<00:00,  3.00it/s]
Loss: 0.7188: 100%|██████████| 50/50 [00:16<00:00,  3.00it/s]


*** Processing image 47/65 ***


Loss: 0.0379: 100%|██████████| 50/50 [00:08<00:00,  5.97it/s]
Loss: 0.7227: 100%|██████████| 50/50 [00:16<00:00,  3.00it/s]
Loss: 0.7192: 100%|██████████| 50/50 [00:16<00:00,  3.01it/s]


*** Processing image 48/65 ***


Loss: 0.0298: 100%|██████████| 50/50 [00:08<00:00,  5.98it/s]
Loss: 0.7129: 100%|██████████| 50/50 [00:16<00:00,  3.02it/s]
Loss: 0.7002: 100%|██████████| 50/50 [00:16<00:00,  3.02it/s]


*** Processing image 49/65 ***


Loss: 0.0267: 100%|██████████| 50/50 [00:08<00:00,  5.98it/s]
Loss: 0.7734: 100%|██████████| 50/50 [00:16<00:00,  3.02it/s]
Loss: 0.7715: 100%|██████████| 50/50 [00:16<00:00,  3.02it/s]


*** Processing image 50/65 ***


Loss: 0.0233: 100%|██████████| 50/50 [00:08<00:00,  5.97it/s]
Loss: 0.7456: 100%|██████████| 50/50 [00:16<00:00,  3.02it/s]
Loss: 0.7339: 100%|██████████| 50/50 [00:16<00:00,  2.99it/s]


*** Processing image 51/65 ***


Loss: 0.0325: 100%|██████████| 50/50 [00:08<00:00,  5.97it/s]
Loss: 0.7246: 100%|██████████| 50/50 [00:16<00:00,  3.02it/s]
Loss: 0.7178: 100%|██████████| 50/50 [00:16<00:00,  3.02it/s]


*** Processing image 52/65 ***


Loss: 0.0326: 100%|██████████| 50/50 [00:08<00:00,  5.97it/s]
Loss: 0.7246: 100%|██████████| 50/50 [00:16<00:00,  3.02it/s]
Loss: 0.7178: 100%|██████████| 50/50 [00:16<00:00,  3.02it/s]


*** Processing image 53/65 ***


Loss: 0.0263: 100%|██████████| 50/50 [00:08<00:00,  5.97it/s]
Loss: 0.7432: 100%|██████████| 50/50 [00:16<00:00,  3.02it/s]
Loss: 0.7422: 100%|██████████| 50/50 [00:16<00:00,  3.02it/s]


*** Processing image 54/65 ***


Loss: 0.0333: 100%|██████████| 50/50 [00:08<00:00,  5.98it/s]
Loss: 0.7217: 100%|██████████| 50/50 [00:16<00:00,  3.02it/s]
Loss: 0.7129: 100%|██████████| 50/50 [00:16<00:00,  3.02it/s]


*** Processing image 55/65 ***


Loss: 0.0287: 100%|██████████| 50/50 [00:08<00:00,  5.97it/s]
Loss: 0.7554: 100%|██████████| 50/50 [00:16<00:00,  3.01it/s]
Loss: 0.7261: 100%|██████████| 50/50 [00:16<00:00,  3.02it/s]


*** Processing image 56/65 ***


Loss: 0.0391: 100%|██████████| 50/50 [00:08<00:00,  5.98it/s]
Loss: 0.7217: 100%|██████████| 50/50 [00:16<00:00,  3.02it/s]
Loss: 0.7212: 100%|██████████| 50/50 [00:16<00:00,  3.02it/s]


*** Processing image 57/65 ***


Loss: 0.0396: 100%|██████████| 50/50 [00:08<00:00,  5.97it/s]
Loss: 0.7397: 100%|██████████| 50/50 [00:16<00:00,  3.02it/s]
Loss: 0.7314: 100%|██████████| 50/50 [00:16<00:00,  3.02it/s]


*** Processing image 58/65 ***


Loss: 0.0516: 100%|██████████| 50/50 [00:08<00:00,  5.97it/s]
Loss: 0.7188: 100%|██████████| 50/50 [00:16<00:00,  3.00it/s]
Loss: 0.7188: 100%|██████████| 50/50 [00:16<00:00,  3.02it/s]


*** Processing image 59/65 ***


Loss: 0.0516: 100%|██████████| 50/50 [00:08<00:00,  5.98it/s]
Loss: 0.7192: 100%|██████████| 50/50 [00:16<00:00,  3.02it/s]
Loss: 0.7183: 100%|██████████| 50/50 [00:16<00:00,  3.02it/s]


*** Processing image 60/65 ***


Loss: 0.0347: 100%|██████████| 50/50 [00:08<00:00,  5.97it/s]
Loss: 0.6924: 100%|██████████| 50/50 [00:16<00:00,  3.01it/s]
Loss: 0.6875: 100%|██████████| 50/50 [00:16<00:00,  3.01it/s]


*** Processing image 61/65 ***


Loss: 0.0341: 100%|██████████| 50/50 [00:08<00:00,  5.97it/s]
Loss: 0.7212: 100%|██████████| 50/50 [00:16<00:00,  3.01it/s]
Loss: 0.7373: 100%|██████████| 50/50 [00:16<00:00,  3.01it/s]


*** Processing image 62/65 ***


Loss: 0.0288: 100%|██████████| 50/50 [00:08<00:00,  5.97it/s]
Loss: 0.7212: 100%|██████████| 50/50 [00:16<00:00,  3.01it/s]
Loss: 0.7119: 100%|██████████| 50/50 [00:16<00:00,  3.01it/s]


*** Processing image 63/65 ***


Loss: 0.0342: 100%|██████████| 50/50 [00:08<00:00,  5.98it/s]
Loss: 0.7539: 100%|██████████| 50/50 [00:16<00:00,  3.01it/s]
Loss: 0.7485: 100%|██████████| 50/50 [00:16<00:00,  3.02it/s]


*** Processing image 64/65 ***


Loss: 0.0433: 100%|██████████| 50/50 [00:08<00:00,  5.97it/s]
Loss: 0.7227: 100%|██████████| 50/50 [00:16<00:00,  3.03it/s]
Loss: 0.7202: 100%|██████████| 50/50 [00:16<00:00,  3.05it/s]


*** Processing image 65/65 ***


Loss: 0.0334: 100%|██████████| 50/50 [00:08<00:00,  6.02it/s]
Loss: 0.7217: 100%|██████████| 50/50 [00:16<00:00,  3.05it/s]
Loss: 0.7124: 100%|██████████| 50/50 [00:16<00:00,  3.03it/s]


In [28]:
# Save results

save_path = "results"
os.makedirs(save_path, exist_ok=True)

SAVE_RESULTS = True

if SAVE_RESULTS:
    np.save(os.path.join(save_path, "sims_e4e.npy"), sims_e4e)
    np.save(os.path.join(save_path, "sims_hybrid.npy"), sims_hybrid)
    np.save(os.path.join(save_path, "sims_e4e_edit.npy"), sims_e4e_edit)
    np.save(os.path.join(save_path, "sims_hybrid_edit.npy"), sims_hybrid_edit)
    np.save(os.path.join(save_path, "emo_scores_e4e.npy"), emo_scores_e4e)

In [None]:
# Load results

LOAD_RESULTS = True

if LOAD_RESULTS:

    sims_e4e = np.load(os.path.join(save_path, "sims_e4e.npy"))
    sims_hybrid = np.load(os.path.join(save_path, "sims_hybrid.npy"))
    sims_e4e_edit = np.load(os.path.join(save_path, "sims_e4e_edit.npy"))
    sims_hybrid_edit = np.load(os.path.join(save_path, "sims_hybrid_edit.npy"))
    emo_scores_e4e = np.load(os.path.join(save_path, "emo_scores_e4e.npy"))

In [32]:
import plotly.express as px
import pandas as pd
import numpy as np

data = {
    "inversion type": ["e4e"] * len(sims_e4e) * 2 + ["hybrid"] * len(sims_hybrid) * 2,
    "operation": ["inversion"] * len(sims_e4e) + ["editing"] * len(sims_e4e) + ["inversion"] * len(sims_hybrid) + ["editing"] * len(sims_hybrid),
    "dissimilarity": (
        sims_e4e +  # Distribution A
        sims_e4e_edit +  # Distribution A
        sims_hybrid + # Distribution B
        sims_hybrid_edit  # Distribution B
    )
}

df = pd.DataFrame(data)

fig = px.box(df, x="inversion type", y="dissimilarity", color="operation")
fig.update_traces(quartilemethod="exclusive") # or "inclusive", or "linear" by default
fig.update_layout(title="Impact of inversion type on editing performance")

fig.show()



In [33]:
import plotly.express as px
import pandas as pd
import numpy as np

data = {
    "inversion type": ["e4e"] * len(sims_e4e) + ["hybrid"] * len(sims_hybrid),
    "emotion loss": (
        emo_scores_e4e + 
        emo_scores_hybrid
    )
}

df = pd.DataFrame(data)

fig = px.box(df, x="inversion type", y="emotion loss", color="inversion type")
fig.update_traces(quartilemethod="exclusive") # or "inclusive", or "linear" by default
fig.update_layout(title="Impact of inversion type on editing performance")

fig.show()