In [None]:
import torch
import numpy as np
import pickle
import os
from PIL import Image
import dnnlib, legacy
import clip
import matplotlib.pyplot as plt

In [None]:
class Generator:
    def __init__(self, device, path):
        self.name = 'generator'
        self.model = self.load_model(device, path)
        self.device = device
        self.force_32 = False
        
    def load_model(self, device, path):
        with dnnlib.util.open_url(path) as f:
            network= legacy.load_network_pkl(f)
            self.G_ema = network['G_ema'].to(device)
            self.D = network['D'].to(device)
#                 self.G = network['G'].to(device)
            return self.G_ema
        
    def generate(self, z, c, fts, noise_mode='const', return_styles=True):
        return self.model(z, c, fts=fts, noise_mode=noise_mode, return_styles=return_styles, force_fp32=self.force_32)
    
    def generate_from_style(self, style, noise_mode='const'):
        ws = torch.randn(1, self.model.num_ws, 512)
        return self.model.synthesis(ws, fts=None, styles=style, noise_mode=noise_mode, force_fp32=self.force_32)
    
    def tensor_to_img(self, tensor):
        img = torch.clamp((tensor + 1.) * 127.5, 0., 255.)
        img_list = img.permute(0, 2, 3, 1)
        img_list = [img for img in img_list]
        return Image.fromarray(torch.cat(img_list, dim=-2).detach().cpu().numpy().astype(np.uint8))

In [None]:
device = 'cuda:0'

text_mean_path = ...

with open(text_mean_path, 'rb') as f:
    txt_mean = pickle.load(f).to(device)

def get_ground_truth_image(img_path):
    if os.path.exists(img_path):
        img = Image.open(img_path)
        img = np.array(img)
    else:
        img = np.ones((224, 224, 3))

    return img


def run_generation(model_paths, img_path, sentence):
    with torch.no_grad():
        clip_model, _ = clip.load("ViT-B/32", device=device)
        clip_model = clip_model.eval()

        tokenized_text = clip.tokenize([sentence]).to(device)
        txt_fts = clip_model.encode_text(tokenized_text)
        txt_fts = txt_fts/txt_fts.norm(dim=-1, keepdim=True)
        
        z = torch.randn((1, 512)).to(device)
        c = torch.randn((1, 1)).to(device) # label is actually not used

        images = {}

        print(f"Image path: {img_path}")
        print(f"Sentence: {sentence}")
        for model_type, path in model_paths.items():
            if model_type == 'ground_truth':
                continue
            elif model_type in ['c21', 'c3']:
                txt_fts -= txt_mean
                txt_fts = txt_fts/txt_fts.norm(dim=-1, keepdim=True)
            
            generator = Generator(device=device, path=path)

            img, _ = generator.generate(z=z, c=c, fts=txt_fts)
            to_show_img = generator.tensor_to_img(img)
            images[model_type] = to_show_img
        
        images['ground_truth'] = get_ground_truth_image(img_path)

        plt.figure(figsize=(10 * len(model_paths), 40 * len(model_paths))) 

        for i, model_type in enumerate(model_paths):
            plt.subplot(1, len(model_paths), i+1) 
            plt.axis('off')
            # plt.title(model_type)
            plt.imshow(images[model_type])

        plt.show()

In [None]:
model_paths = {
    'ground_truth': None,
    'lafite_reprod': ...,
    'c1': ...,
    'c21': ...,
    'c22': ...,
    'c3': ...,
}


sentences = {
    "path/to/image": "caption"
}

for img_path, sentence in sentences.items():
    torch.manual_seed(1234)
    run_generation(model_paths, img_path, sentence)