## 1. Prepare imports 

In [None]:
import os
import glob
import sys

import cv2
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from tqdm.auto import tqdm

# Setup (may take a few minutes)
# Installs CLIP and other dependencies
!gdown --id 1kwJndtv5tCd0LEzRTi4NHJ2sJbVZeDLG

!pip uninstall -y torchtext torchaudio
!pip install torch==1.7.1+cu101 torchvision==0.8.2+cu101 -f https://download.pytorch.org/whl/torch_stable.html
!pip install ftfy regex tqdm click requests pyspng ninja imageio-ffmpeg==0.4.3 ffmpeg-python wandb
!pip install git+https://github.com/openai/CLIP.git

!git clone https://github.com/thepowerfuldeez/stylemc
%cd stylemc

# id loss model
!gdown --id 1xG_YHGcbzd5LWwqQCDDQDcqsSw7OvODY -O id_loss/model_ir_se50.pth

# model for masks, not used at the moment
# !mkdir -p deeplab_model/
# !gdown --id 1oRGgrI4KNdefbWVpw0rRkEP1gbJIRokM -O deeplab_model/R-101-GN-WS.pth.tar
# !gdown --id 1w2XjDywFr2NjuUWaLQDRktH7VwIfuNlY -O deeplab_model/deeplab_model.pth

# landmarks model
!gdown --id 1Le5UdpMkKOTRr1sTp4lwkw8263sbgdSe

## 2. Prepare data

Generate 700 images, use --network=NETWORK path to FFHQ 512x512 official Nvidia's StyleGAN2 checkpoint. You can remove argument to use default.

1. Generate w values from z
2. Convert w values to S-space
3. Generate images using w values

In [None]:
!python generate_w.py --network="" --trunc=0.7 --seeds="100000-100699"
!python w_s_converter.py --network="" --out_file=out/input.npz --projected-w=encoder4editing/projected_w.npz
!python generate_fromS.py --network="" --text_prompt="" --change_power=0 --outdir=out --projected-w=encoder4editing/projected_w.npz

In [None]:
!python generate_w.py --network="" --trunc=0.7 --seeds="10000-29999" --out_file=encoder4editing/projected_w_train.npz
!python w_s_converter.py --network="" --out_file=out_train/input.npz --projected-w=encoder4editing/projected_w_train.npz
!python generate_fromS.py --network="" --text_prompt="" --change_power=0 --outdir=out_train/ --projected-w=encoder4editing/projected_w_train.npz

Process first 700 images in `out/` directory. Predict male/female label using CLIP and save only male S values

In [None]:
import torch
import clip
from PIL import Image

device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)


text = clip.tokenize(["a photo of a male person", "a photo of a female person"]).to(device)
all_probs = []
for i in tqdm(range(0, 700)):
  image = preprocess(Image.open(f"out/proj{i:02d}.png")).unsqueeze(0).to(device)

  with torch.no_grad():
      image_features = model.encode_image(image)

      logits_per_image, logits_per_text = model(image, text)
      probs = logits_per_image.softmax(dim=-1).cpu().numpy()
      all_probs.append(probs)
all_probs = np.concatenate(all_probs)


import numpy as np
styles = np.load("out/input.npz")['s']
male_idx = (all_probs[:, 0] > 0.85).nonzero()[0]
female_idx = (all_probs[:, 1] > 0.85).nonzero()[0]

np.savez("out/female_s.npz", s=styles[female_idx], idx=female_idx)
np.savez("out/male_s.npz", s=styles[male_idx], idx=male_idx)

Generate training dataset. Same filtering male/female as before (but with lower threshold) and then removing some oversampled classes, like european man, white man, middle-aged man. 

This is in order to get more balanced dataset.

In [None]:
pairs = [
         ["a photo of a asian man", "a photo of a european man"],
         ["a photo of a white man", "a photo of a black man"],
         ["a photo of a young boy", "a photo of a middle-aged man", "a photo of a old man"],
         ["a photo of a man with long hair", "a photo of a man with short hair"],
         ["a photo of a man with glasses", "a photo of a man without glasses"],
]


import torch
import clip
from PIL import Image

device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)


text = clip.tokenize(["a photo of a male person", "a photo of a female person"]).to(device)
all_probs = []
all_pair_probs = [[] for _ in range(len(pairs))]
text_pairs = [clip.tokenize(p).to(device) for p in pairs]
for i in tqdm(range(0, 20000)):
  image = preprocess(Image.open(f"out_train/proj{i:02d}.png")).unsqueeze(0).to(device)

  with torch.no_grad():
      image_features = model.encode_image(image)

      logits_per_image, logits_per_text = model(image, text)
      probs = logits_per_image.softmax(dim=-1).cpu().numpy()
      all_probs.append(probs)

      for i, t_p in enumerate(text_pairs):
        logits_per_image, logits_per_text = model(image, t_p)
        probs = logits_per_image.softmax(dim=-1).cpu().numpy()
        all_pair_probs[i].append(probs)

all_pair_probs = [np.concatenate(x) for x in all_pair_probs]
all_pair_probs = np.concatenate(all_pair_probs, axis=-1)  # [15000, K]
all_probs = np.concatenate(all_probs)


import numpy as np
styles = np.load("out_train/input.npz")['s']
male_idx = (all_probs[:, 0] > 0.75).nonzero()[0]
female_idx = (all_probs[:, 1] > 0.75).nonzero()[0]

np.savez("out_train/female_s.npz", s=styles[female_idx], idx=female_idx)
np.savez("out_train/male_s.npz", s=styles[male_idx], idx=male_idx)

In [None]:
k = 0
i = 0
while k < all_pair_probs.shape[1]:
  print(pairs[i])
  n = len(pairs[i])
  print(k, (all_pair_probs[male_idx, k:k+n] > 1.5 / n).sum(0))
  k += n
  i += 1

# indices of overpopulated classes
exclude_idx1 = (all_pair_probs[male_idx, 1] > 0.7).nonzero()[0][0:4000:10]
exclude_idx2 = (all_pair_probs[male_idx, 2] > 0.7).nonzero()[0][0:3500:10]
exclude_idx3 = (all_pair_probs[male_idx, 5] > 0.5).nonzero()[0][0:4000:10]
exclude_idx4 = (all_pair_probs[male_idx, 8] > 0.7).nonzero()[0][0:4000:10]
exclude_idx5 = (all_pair_probs[male_idx, 10] > 0.7).nonzero()[0][0:4000:10]

exclude_idx = np.union1d(np.union1d(np.union1d(np.union1d(exclude_idx1, exclude_idx2), exclude_idx3), exclude_idx4), exclude_idx5)

male_idx_balanced = male_idx[~np.isin(np.arange(len(male_idx)), exclude_idx)]
print("male_idx", len(male_idx), "male_idx balanced", len(male_idx_balanced))
np.savez("out_train/male_s_balanced.npz", s=styles[male_idx_balanced], idx=male_idx_balanced)

## 3. Training

In order to train latent mapper and global direction you must specify network (or remove to use default). Other arguments already filled in. You may try to increase batch size and it will reduce training time.

For training one network you should wait around 4 hours.

In [None]:
# first, log-in to wandb to log losses and images
import wandb
wandb.login()

In [None]:
!python train_latent_mapper.py --network="" --s_input="out_train/male_s_balanced2.npz" --text_prompt="a photo of a face of a feminine woman with no makeup" --outdir=runs/male2female_mapper_id0.3_clip1.5_l2_0.8_landmarks0.1_batch2_epoch10_power1.8_big/ --clip_loss_coef=1.5 --landmarks_loss_coef=0.1 --batch_size=2 --n_epochs=10 --negative_text_prompt="a photo of a face of a man" 
!python generate_fromS.py --network="" --s_input="out_train/male_s.npz" --text_prompt="a photo of a face of a feminine woman with no makeup" --outdir=runs/male2female_mapper_id0.3_clip1.5_l2_0.8_landmarks0.1_batch2_epoch10_power1.8_big/ --change_power=1.8 --use_mapper=1

In [None]:
!python find_direction.py --network="" --s_input="out_train/male_s_balanced2.npz" --resolution=512 --text_prompt="a photo of a face of a feminine woman with no makeup" --outdir=runs/male2female_512_id0.3_clip1.0_l2_0.8_landmarks0.1_lr5.0_batch2_epoch10_power1.8_big/ --batch_size=2 --n_epochs=10 --negative_text_prompt="a photo of a face of a man" --clip_loss_coef=1.0 --identity_loss_coef=0.3 --landmarks_loss_coef=0.1 --l2_reg_coef=0.8 --learning_rate=5.0
!python generate_fromS.py --network="" --s_input="out_train/male_s.npz" --text_prompt="a photo of a face of a feminine woman with no makeup" --outdir=runs/male2female_512_id0.3_clip1.0_l2_0.8_landmarks0.1_lr5.0_batch2_epoch10_power1.8_big/ --change_power=1.8

In [None]:
# save results with test images to archive
!tar czf runs.tgz runs/