<a href="https://colab.research.google.com/github/sdelta/ImageGen/blob/main/stylegan2_clip_finetune.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!pip install click requests tqdm pyspng ninja imageio-ffmpeg==0.4.3 open_clip_torch

In [None]:
!wget https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/transfer-learning-source-nets/ffhq-res256-mirror-paper256-noaug.pkl

In [None]:
!cp drive/MyDrive/datasets/ffhq_256/ffhq.zip ./

In [None]:
!ls

In [None]:
from PIL import Image
import requests

url = "http://images.cocodataset.org/val2017/000000039769.jpg"

image = Image.open(requests.get(url, stream=True).raw)

In [None]:
import torch
import open_clip
from PIL import Image
import numpy as np

device = "cuda" if torch.cuda.is_available() else "cpu"
assert device == "cuda"

model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32-quickgelu', pretrained='laion400m_e32')
tokenizer = open_clip.get_tokenizer('ViT-B-32-quickgelu')

model = model.to(device)

In [None]:
src_images = [image]
src_texts = ["cat", "dog"]
images = torch.tensor(np.stack([preprocess(img) for img in src_images])).to(device)
texts = tokenizer(src_texts).to(device)

In [None]:
texts_features = model.encode_text(texts)
texts_features /= texts_features.norm(dim=-1, keepdim=True)

In [None]:
image_features = model.encode_image(images)
image_features /= image_features.norm(dim=-1, keepdim=True)

In [None]:
sim = torch.matmul(texts_features, image_features.permute(1, 0))

In [None]:
images.shape

In [None]:
sim

In [None]:
from torch.nn import functional as tfn
from torchvision import transforms


def normalize(x, mean, std):
    mean = mean.unsqueeze(1).unsqueeze(2)
    std = std.unsqueeze(1).unsqueeze(2)
    return (x - mean) / std

start = transforms.ToTensor()(image).unsqueeze(0).to(device)
sized = tfn.interpolate(start, size=224, mode='bicubic')
normed = normalize(
    sized,
    torch.tensor(open_clip.OPENAI_DATASET_MEAN).to(device),
    torch.tensor(open_clip.OPENAI_DATASET_STD).to(device)
)

In [None]:
torch.equal(images, normed)

In [None]:
import matplotlib.pyplot as plt

plt.imshow(images[0].cpu().permute(1, 2, 0))

In [None]:
plt.imshow(normed[0].cpu().permute(1, 2, 0))

In [None]:
tr_lst = [
    transforms.Resize(224, interpolation=transforms.InterpolationMode.BICUBIC),
    transforms.CenterCrop(224)
]

my_preprocess = transforms.Compose(tr_lst)

In [None]:
plt.imshow(tr_lst[0](start[0]).cpu().permute(1, 2, 0))

In [None]:
image_features.shape

In [None]:
input = start.detach()
input.requires_grad_(True)

In [None]:
class CLIPSubloss(object):
    def __init__(self, device, clip_phrase):
        self.device = device
        self.model = model
        self.model = self.model.to(device)
        tokenizer = open_clip.get_tokenizer('ViT-B-32-quickgelu')
        with torch.no_grad():
            self.texts_features = self.model.encode_text(tokenizer([clip_phrase]).to(device))
            self.texts_features /= self.texts_features.norm(dim=-1, keepdim=True)

    def _preprocess_images(self, images):
        resized = torch.nn.functional.interpolate(images, size=224, mode='bicubic')
        mean = torch.tensor(open_clip.OPENAI_DATASET_MEAN).to(self.device).unsqueeze(1).unsqueeze(2)
        std = torch.tensor(open_clip.OPENAI_DATASET_STD).to(self.device).unsqueeze(1).unsqueeze(2)
        return (resized - mean) / std
        
    def get_similarities(self, images):
        images_features = self.model.encode_image(self._preprocess_images(images))
        
        images_norm = images_features.norm(dim=-1, keepdim=True) + 1e-5
        print(images_norm.cpu())
        #return (images_features / images_norm).permute(1, 0)
        return torch.matmul(self.texts_features, (images_features / images_norm).permute(1, 0))

clip_subloss = CLIPSubloss(device, "glasses")

with torch.autograd.set_detect_anomaly(True):
    gen_clip = clip_subloss.get_similarities(input)
    gen_clip.mean().mul(4).backward()


In [None]:
preprocess

In [None]:
! rm -fR stylegan2-ada-pytorch

In [None]:
!git clone https://github.com/sdelta/stylegan2-ada-pytorch.git

In [None]:
!python stylegan2-ada-pytorch/train.py --outdir=drive/MyDrive/stylegan_finetuning --data=ffhq.zip \
    --mirror=1 --gpus=1 --resume=ffhq-res256-mirror-paper256-noaug.pkl --kimg=1500 --cfg=paper256 \
    --freezed=10 --freezed_mapping=True \
    --clip_phrase='glasses' --clip_reg_interval=4