<a href="https://colab.research.google.com/github/sizhky/stylegan2-pytorch/blob/master/stylegan_2_style_transfer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
%%time
import os

if not os.path.exists('stylegan2-pytorch'):
    !git clone https://github.com/sizhky/stylegan2-pytorch
    !wget --quiet https://github.com/ninja-build/ninja/releases/download/v1.8.2/ninja-linux.zip
    !sudo unzip -q ninja-linux.zip -d /usr/local/bin/
    !sudo update-alternatives --install /usr/bin/ninja ninja /usr/local/bin/ninja 1 --force
    !rm ninja-linux.zip
    !pip install -U -q PyDrive torch_snippets
    from pydrive.auth import GoogleAuth
    from pydrive.drive import GoogleDrive
    from google.colab import auth
    from oauth2client.client import GoogleCredentials
    auth.authenticate_user()
    gauth = GoogleAuth()
    gauth.credentials = GoogleCredentials.get_application_default()
    drive = GoogleDrive(gauth) 
    downloaded = drive.CreateFile({'id':"1PQutd-JboOCOZqmd95XWxWrO8gGEvRcO"})   # replace the id with id of file you want to access
    downloaded.GetContentFile('5500000.pth')        # replace the file name with your file
%cd stylegan2-pytorch

/content/stylegan2-pytorch
CPU times: user 852 µs, sys: 0 ns, total: 852 µs
Wall time: 859 µs


### Playing with Model

In [2]:
%load_ext autoreload
%autoreload 2

from torch_snippets import *
from generate import Generator
from torchvision.utils import save_image

device = 'cuda'
generator = Generator(size=256, style_dim=512, 
                      n_mlp=8, channel_multiplier=2)
generator.load_state_dict(torch.load('../5500000.pth')['g_ema'], strict=False)
generator.eval().to(device);

  "Distutils was imported before Setuptools. This usage is discouraged "


In [4]:
def interpolate_two_points(p1, p2, n_steps=8):
    ratios = torch.linspace(0, 1, steps=n_steps)
    vectors = []
    for ratio in ratios:
        v = (1.0 - ratio) * p1 + ratio * p2
        vectors.append(v)
    return torch.stack(vectors)
def interpolate_four_points(p1,p2,p3,p4,n_steps=8):
    z1 = interpolate_two_points(p1, p2, n_steps)
    z2 = interpolate_two_points(p3, p4, n_steps)
    zs = []
    for _z1,_z2 in zip(z1, z2):
        zs.append(interpolate_two_points(_z1,_z2, n_steps))
    zs = torch.cat(zs)
    return zs
device = 'cuda'
with torch.no_grad():
    steps = 8
    # corner_zs = torch.randn(4, 14, 512, device=device)
    # zs = interpolate_four_points(*corner_zs, steps)
    # dumpdill(zs, 'noise.vectors')

    zs = loaddill('noise.vectors')
    zs = generator.get_latent(zs)
    sample, _ = generator([zs])
    save_image(sample, 'sample/interpolations.png', nrow=steps, normalize=True, range=(-1,1)) 

In [8]:
styles = {
    'frontal-black-hair-female': zs[56],
    'kid-worried': zs[39],
    'bearded-man': zs[63]
}

# with torch.no_grad():
#     latents = [style_vec[None] for style,style_vec in styles.items()]
#     latents = torch.cat(latents)

##############################  WORKED  ##############################
def transfer_coarse_latent(source_styles, target_style):
    originals, _ = generator(source_styles)
    style_transferred, _ = generator(source_styles, 
                                  coarse_latents=target_style[:,:4].repeat(len(source_styles),1,1))
    samples = torch.cat([originals, style_transferred], 0)
    save_image(samples, f'sample/coarse_transfer_{name}.png', nrow=len(samples)//2, normalize=True, range=(-1,1))

def transfer_middle_latent(source_styles, target_style):
    originals, _ = generator(source_styles)
    style_transferred, _ = generator(source_styles, 
                                  middle_latents=target_style[:,4:10].repeat(len(source_styles),1,1))
    samples = torch.cat([originals, style_transferred], 0)
    save_image(samples, f'sample/middle_transfer_{name}.png', nrow=len(samples)//2, normalize=True, range=(-1,1))


def transfer_fine_latent(source_styles, target_style):
    originals, _ = generator(source_styles)
    style_transferred, _ = generator(source_styles, 
                                  fine_latents=target_style[:,10:].repeat(len(source_styles),1,1))
    samples = torch.cat([originals, style_transferred], 0)
    save_image(samples, f'sample/fine_transfer_{name}.png', nrow=len(samples)//2, normalize=True, range=(-1,1))

for name, latent in styles.items():
    z = torch.randn(5, 14, 512, device=device)
    z = generator.get_latent(z)
    z = torch.cat([latent[None], z])

    transfer_coarse_latent([z.clone()], latent[None].clone())
    transfer_middle_latent([z.clone()], latent[None].clone())
    transfer_fine_latent([z.clone()], latent[None].clone())


In [9]:
import zipfile
lista_files = Glob('sample')
with zipfile.ZipFile('samples.zip', 'w') as zipMe:
    for file in lista_files:
        zipMe.write(file, compress_type=zipfile.ZIP_DEFLATED)


2020-08-01 20:42:31.888 | INFO     | torch_snippets.loader:Glob:150 - 15 files found at sample
