<a href="https://colab.research.google.com/github/pollinations/hive/blob/main/notebooks/3%20Audio-To-Video/2%20StyleGAN-3%20Dance.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

<img src="https://ipfs.pollinations.ai/ipfs/QmWx5jVzhS63C2Zn26bVazCthcMeM6dSXnvsaFEA5AMZfE" />

*An eye > A planet >The moon > Brush strokes > The globe > Kandinsky > Salvador Dali*

Each prompt you write generates an image, then the images are mixed according to the rhythm of the music.

---

*Note: the Wikiart and Flickr faces models are high resolution and may fail using the free version of Google Colab*

Credits: [nielsrolf](github.com/nielsrolf), inspired by [lucid-sonic-dreams](https://github.com/mikaelalafriz/lucid-sonic-dreams).


[UPD 13.12.2021] Added more visual feedback to generation process and smoothing

In [None]:
# Topic that the video starts with
story = "An eye->A planet->The moon->brush strokes->the globe->kandinsky|salvador dali" #@param {type: "string"}

style_suffix = "psychedelic painting" #@param {type: "string"}

# Music upload
audio_file = "" #@param {type: "string"}

# experimental setting - prevents moving and focuses on transformation
fix_camera = True #@param {type:"boolean"}
# Model type to use. Wiki Art and Flickr Faces are high-resolution and may only work well with Colab Pro.
model = 'Wiki Art'  #@param ['Painted Faces', 'Animal Faces', 'Flickr Faces', 'Wiki Art', 'Landscapes']

text_prompt_bass = 'Mysterious and deep, violet' #@param {type: "string"}
text_prompt_treble = "dreamy and gold full of holograms" #@param {type: "string"}
text_prompt_mids = "gold and holographic dreams about futuristic nature" #@param {type: "string"}

# style_suffix = "painting by Hieronymus Bosch:0.5|A confusing image:-0.5" #@param {type: "string"}

# How much of a puslating effect the bass creates. A bass sound moves to the text_prompt_bass and moves back once it is released
bass_pulse_impact = 0 #@param {type: "number"}
# How much of a puslating effect the mids create. A mid sound moves to the text_prompt_mids and moves back once it is released
mids_pulse_impact = 0 #@param {type: "number"}
# How much of a puslating effect a high pitched sound creates. A treble sound moves to the text_prompt_trebles and moves back once it is released
trebles_pulse_impact = 0 #@param {type: "number"}

# How much the bass pushes the story forward
bass_story_speed = 1 #@param {type: "number"}
# How much the mids push the story forward
mids_story_speed = 1 #@param {type: "number"}
# How much the trebles push the story forward
trebles_story_speed = 0 #@param {type: "number"}

# How many seconds should a sound related movement in fade in and out?
smoothing = 0.1 #@param {type: "number"}


# It can take quite long to generate a 6min video, use these inputs to make the video shorter
start_second =  0 #@param {type: "number"}
end_second = 30 #@param {type: "number"}


# Speed at which to try approximating the text. Too fast seems to give strange results. Maximum is 100.
speed = 20  #@param {type: "number"}

# Change the seed to generate variations of the same prompt 
seed = 336 #@param {type: "number"}

# Batch size. Keep this low (e.g. 2) if you are using the free Google Colab. Colab Pro users can safely set this to 8 or more. Increases image quality and generation speed.
batch_size = 2 #@param {type: "number"}

model_map = {
    'Painted Faces': 'stylegan3-r-metfacesu-1024x1024.pkl', 
    'Animal Faces': 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-r-afhqv2-512x512.pkl', 
    'Flickr Faces': 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-r-ffhqu-1024x1024.pkl',
    'Wiki Art': 'https://ipfs.pollinations.ai/ipfs/QmZkrYwEUnykVQJfJw3opTj1HfdNUCm87amsR3LHp1QnuV/wikiart-1024-stylegan3-t-17.2Mimg.pkl',
    'Landscapes': 'https://ipfs.pollinations.ai/ipfs/QmZkrYwEUnykVQJfJw3opTj1HfdNUCm87amsR3LHp1QnuV/lhq-256-stylegan3-t-25Mimg.pkl'
}


output_path = "/content/output"

steps = 150

#@markdown ---




In [None]:
!nvidia-smi
!mkdir {output_path}

In [None]:
prompts = story.replace("\n", "->").split("->")
# prompts = [f"{prompt}|{style_suffix}" if prompt!="" else "" for prompt in prompts]

if len(prompts) == 1:
    prompts = prompts + prompts

prompts

In [None]:
#@title Processing Audio
!pip install python-slugify
import librosa
import librosa.display
from matplotlib import pyplot as plt
import numpy as np
from slugify import slugify
from IPython.display import display

# alternative link to youtube or soundcloud
youtube_dl_link = "" 

if youtube_dl_link.startswith("http"):
  print(f"Downloading from {youtube_dl_link}...")
  !pip install -q youtube-dl
  !youtube-dl --rm-cache-dir
  !youtube-dl --extract-audio --audio-format wav {youtube_dl_link} --output /tmp/audio_file.wav
  audio_file = "/tmp/audio_file.wav"
  from glob import glob
  print(glob("/tmp/*.wav"))


specno=0
def specshow(spec):
  global specno
  fig, ax = plt.subplots()
  img = librosa.display.specshow(spec, x_axis='time',
                         y_axis='mel', sr=sr,
                         fmax=8000, ax=ax)
  plt.show()
  plt.figure(figsize=(15, 7))
  ld = spec.sum(0)
  plt.plot(ld)
  plt.savefig(f"{output_path}/a_spec_{specno}.png")
  plt.show()
  specno += 1



image_counts = {}
def save_image(img, text_prompt="progress", save_every=20):
  text_prompt = slugify(text_prompt)
  global image_counts
  if text_prompt not in image_counts:
      image_counts[text_prompt] = 0
  image_count = image_counts[text_prompt]
  if image_count % save_every == 0:
    img = TF.to_pil_image(img)
    display(img)
    img.save(f'{output_path}/{text_prompt}_{image_count:04d}.jpg')
  image_counts[text_prompt] = image_count + 1

audio, sr = librosa.load(audio_file)
audio = audio[start_second*sr:end_second*sr]
# Add a second of silence in the end
audio = np.concatenate([audio, np.zeros(sr)], axis=0)
spec = librosa.feature.melspectrogram(y=audio, sr=sr)[:,::2]

spec_s = librosa.amplitude_to_db(spec)
spec_s = spec_s - spec_s.min()
specshow(spec_s)

mids = spec_s[12:-35]

bass = spec_s[:12]
bass = bass - bass.mean()
bass[bass<0] = 0

treble = spec_s[-35:]
treble = treble - treble.mean() / 2
treble[treble<0] = 0



seconds = len(audio) / sr
frame_rate = spec_s.shape[1] / seconds


def get_spec_slice(spec, i):
  tsteps = spec.shape[1]
  chapters = len(prompts) - 1
  chapter_len = tsteps / chapters
  start = int(i * chapter_len)
  end = int((i+1) * chapter_len)
  return spec[:,start:end]
#specshow(bass)
#specshow(mids)
#specshow(treble)


# Preprocessing other controls

model_url = model_map[model]

summed_speed = bass_story_speed + mids_story_speed + trebles_story_speed
bass_story_speed /= summed_speed
mids_story_speed /= summed_speed
trebles_story_speed /= summed_speed



In [None]:
# Check GPU and CUDA
!nvidia-smi
!nvcc --version

In [None]:
#@title Installing packages

!pip install --upgrade torch==1.9.1+cu111 torchvision==0.10.1+cu111 -f https://download.pytorch.org/whl/torch_stable.html
!pip install --upgrade git+https://github.com/pollinations/pytorch_clip_guided_loss.git
#!pip install --upgrade https://download.pytorch.org/whl/nightly/cu111/torch-1.11.0.dev20211012%2Bcu111-cp37-cp37m-linux_x86_64.whl https://download.pytorch.org/whl/nightly/cu111/torchvision-0.12.0.dev20211012%2Bcu111-cp37-cp37m-linux_x86_64.whl
!git clone https://github.com/NVlabs/stylegan3
!pip install kornia
# !pip install -e ./CLIP
!pip install einops ninja

In [None]:
#@title Initializing models

# Importing tensorflow first seems to resolve some weird dependency errors (it is included later by one of the dependencies anyway)
import tensorflow

from pytorch_clip_guided_loss import get_clip_guided_loss
import sys
sys.path.append('./stylegan3')

import io
import os, time
import pickle
import shutil
import numpy as np
from PIL import Image
import torch
from torch import nn
import torch.nn.functional as F
import requests
import torchvision.transforms as transforms
import torchvision.transforms.functional as TF
from tqdm.notebook import tqdm
from torchvision.transforms import Compose, Resize, ToTensor, Normalize
from IPython.display import display
from einops import rearrange
from google.colab import files
from functools import lru_cache
import kornia.augmentation as K


device = torch.device('cuda:0')
print('Using device:', device, file=sys.stderr)
torch.manual_seed(seed)


def fetch(url_or_path):
    if str(url_or_path).startswith('http://') or str(url_or_path).startswith('https://'):
        r = requests.get(url_or_path)
        r.raise_for_status()
        fd = io.BytesIO()
        fd.write(r.content)
        fd.seek(0)
        return fd
    return open(url_or_path, 'rb')

def fetch_model(url_or_path):
    basename = os.path.basename(url_or_path)
    if os.path.exists(basename):
        return basename
    else:
        !wget -N '{url_or_path}'
        return basename

with open(fetch_model(model_url), 'rb') as fp:
    G = pickle.load(fp)['G_ema'].to(device)

# Fix the coordinate grid to w_avg
if fix_camera:
    shift = G.synthesis.input.affine(G.mapping.w_avg.unsqueeze(0))
    G.synthesis.input.affine.bias.data.add_(shift.squeeze(0))
    G.synthesis.input.affine.weight.data.zero_()

zs = torch.randn([10000, G.mapping.z_dim], device=device)
w_stds = G.mapping(zs, None).std(0)

clip_guided_loss = get_clip_guided_loss("clip", input_range = (0, 1))
clip_guided_loss.to(device)

In [None]:
#@title Define functions to search the latent space

from kornia.filters import gaussian_blur2d

os.makedirs(output_path, exist_ok=True)


class MakeCutouts(torch.nn.Module):
    def __init__(self, cut_size, cutn, cut_pow=1.):
        super().__init__()
        self.cut_size = cut_size
        self.cutn = cutn
        self.cut_pow = cut_pow

    def forward(self, input):
        sideY, sideX = input.shape[2:4]
        max_size = min(sideX, sideY)
        min_size = min(sideX, sideY, self.cut_size)
        cutouts = []
        for _ in range(self.cutn):
            size = int(torch.rand([])**self.cut_pow * (max_size - min_size) + min_size)
            offsetx = torch.randint(0, sideX - size + 1, ())
            offsety = torch.randint(0, sideY - size + 1, ())
            cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size]
            cutouts.append(F.adaptive_avg_pool2d(cutout, self.cut_size))
        return torch.cat(cutouts)

make_cutouts = MakeCutouts(224, 32, 0.5)




def save_latent_image(latent, text_prompt):
    img = G.synthesis(latent, noise_mode='const').add(1).div(2).clamp(0,1).cpu().detach()
    save_image(img[0], text_prompt, 1)


#def show_image(img_batch, text_prompt=None):
#    for i in range(len(img_batch)):
#        print(i)
#
#         img = TF.to_pil_image(img_batch[i])
#        display(img)
#    if text_prompt is not None:
#        img.save(f'{output_path}/{text_prompt}.jpg')



@torch.no_grad()
def get_initial_q(clip_guided_loss, samples):
    qs = []
    losses = []
    for _ in range(8):
        q = (G.mapping(torch.randn([1,G.mapping.z_dim], device=device), None, truncation_psi=0.7) - G.mapping.w_avg) / w_stds
        images = G.synthesis(q * w_stds + G.mapping.w_avg).add(1).div(2)
        loss = clip_guided_loss(image = images)["loss"]
        print(q.shape, loss)
        qs.append(q)
        losses.append(loss)
    qs = torch.stack(qs)
    losses = torch.stack(losses)
    print(losses)
    print(losses.shape, qs.shape)
    i = torch.argmin(losses)
    return qs[i]

def load_pickled_latent(url):
  return torch.load(fetch(url))

prompt_count = 0
@lru_cache(maxsize=None)
def get_latents_for(text_prompt, steps=steps, learning_rate=0.035, num_augmentations=batch_size, initial_latent=None):
    try:
      return load_pickled_latent(text_prompt)
    except Exception:
      print("Could not load latent from url, searching now")
    global prompt_count
    # parse prompt - todo move to clip guided loss
    clip_guided_loss.clear_prompts()
    for prompt in text_prompt.split("|") + style_suffix.split("|"):
        if ":" in prompt:
            prompt, weight = prompt.split(":")
            weight = float(weight)
        else:
            weight = 1
        print(prompt, weight)
        clip_guided_loss.add_prompt(text=prompt, weight=weight)
    # initialize image
    if initial_latent is None:
        q = get_initial_q(clip_guided_loss, samples=8)
    else:
        latent = initial_latent
        q = (latent - G.mapping.w_avg) / w_stds
    q = q.requires_grad_()

    # initialize augmentations - todo move to clip guided loss
    # augs = nn.Sequential(
    #     K.RandomAffine(degrees=15, translate=0.1, shear=5, p=0.6, padding_mode='zeros', keepdim=True),
    #     K.RandomPerspective(distortion_scale=0.7, p=0.6),
    #     K.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1, p=0.6),
    #     K.RandomErasing(scale=(.1, .4), ratio=(.3, 1/.3), same_on_batch=True, p=0.5)
    # )
    augs = make_cutouts
    # initialize optimizer
    # opt = torch.optim.AdamW([z], lr=learning_rate)
    opt = torch.optim.AdamW([q], lr=learning_rate, betas=(0.0,0.999))

    # start optimization
    best_loss = 10000000000
    best_q = None
    iterator = tqdm(range(steps))
    for i in iterator:
        opt.zero_grad()
        w = q * w_stds
        latent = w + G.mapping.w_avg
        img = G.synthesis(latent, noise_mode='const').add(1).div(2).clamp(0, 1)
        # if i % 2 == 0 and i < steps // 2:
        #     blur_size = int(x.shape[-1] * (steps - i) / steps  / 10) // 2 * 2 + 1
        #     x = gaussian_blur2d(x, (blur_size, blur_size), sigma=(5, 5))
        x = img.repeat_interleave(num_augmentations, dim=0)
        x = augs(x)
        loss = clip_guided_loss(image = x)["loss"]
        save_image(img[0], f"b_{prompt_count}_{text_prompt}")
        if i % 11 == 0:
            print("Loss:", loss, "Std:",torch.std(q) )
        if loss < best_loss:
            best_loss = loss.detach()
            best_q = q.detach()
        loss.backward()
        opt.step()
        iterator.set_description(f"loss: {loss.item()}")
    w = best_q * w_stds
    latent = w + G.mapping.w_avg
    # save image
    #save_latent_image(latent, "b_text_prompt)
    prompt_count += 1
    latent = latent.detach()
    torch.save(latent, f"{output_path}/{slugify(text_prompt)}.pt")
    return latent


In [None]:
# latent = get_latents_for('An elephant:0.5|This is an elephant:0.5|elephant trunk:0.3|elephant ears:0.3|abstract:-0.2|confusing:-0.2|landscape:-0.2',
#                          steps=5, learning_rate=0.04)

In [None]:
#@title Running text to image for all prompts - this can take a while

latent_story = [get_latents_for(prompt) for prompt in prompts]

# The pulses are less important, we optimize fewer steps
if trebles_pulse_impact > 0:
  fewer_steps = 10 if trebles_pulse_impact < 0.2 else steps
  latent_treble = get_latents_for(text_prompt_treble, steps=fewer_steps)
if mids_pulse_impact > 0:
  fewer_steps = 10 if mids_pulse_impact < 0.2 else steps
  latent_middle = get_latents_for(text_prompt_mids, steps=fewer_steps)
if bass_pulse_impact > 0:
  fewer_steps = 10 if bass_pulse_impact < 0.2 else steps
  latent_bass = get_latents_for(text_prompt_bass, steps=fewer_steps)

In [None]:
#@title Interpolate in the latent space

# del clip_guided_loss


@torch.no_grad()
def slerp(val, low, high):
  """Batched spherical interpolation

  Arguments:
    val: 1d tensor (n_frames)
    low: n-d tensor: (n_frames or 1, *frame_shape)
    high: n-d tensor: (n_frames or 1, *frame_shape)
  Returns:
    interpolated: (n_frames, *frame_shape)
  """
  val = val[:,None]
  shape = low.shape
  low = low.reshape([low.shape[0], -1])
  high = high.reshape([high.shape[0], -1])

  low_ = low / torch.norm(low, dim=1, keepdim=True)
  high_ = high / torch.norm(high, dim=1, keepdim=True)
  omega = torch.arccos(torch.clip(torch.sum(low_*high_, axis=1, keepdim=True), -1, 1))
  so = torch.sin(omega)
  print("low:", low.shape, "high:", high.shape, "so:", so.shape, "val:", val.shape)
  so = 0*low + 0*val + so # broadcast
  slerped = (1.0-val) * low + val * high
  slerped_1 = torch.sin((1.0-val)*omega) / so * low + torch.sin(val*omega) / so * high
  slerped[so!=0] = slerped_1[so!=0]
  return slerped.reshape([slerped.shape[0]] + list(shape[1:]))



latent_chapters = []
for chapter in range(len(prompts)-1):
  latent_start = latent_story[chapter]
  latent_end = latent_story[chapter + 1][None]
  story_speed = get_spec_slice(bass, chapter).sum(0) * bass_story_speed \
              + get_spec_slice(mids, chapter).sum(0) * mids_story_speed \
              + get_spec_slice(treble, chapter).sum(0) * trebles_story_speed
  print(story_speed.shape)
  story_speed = story_speed / story_speed.sum()
  progress = torch.tensor(story_speed.cumsum(0)).to(device)
  latent_chapters.append(slerp(progress, latent_start, latent_end))

latent_vid = torch.cat(latent_chapters, dim=0)

if trebles_pulse_impact > 0:
  bright = treble.sum(0)
  bright = (bright - bright.min()) / (bright.max() - bright.min())
  bright = bright * trebles_pulse_impact
  bright = torch.tensor(bright).to(device)
  latent_vid = slerp(bright, latent_vid, latent_treble)
if mids_pulse_impact > 0:
  middle = mids.sum(0)
  middle = (middle - middle.min()) / (middle.max() - middle.min())
  middle = middle * mids_pulse_impact
  middle = torch.tensor(middle).to(device)
  latent_vid = slerp(middle, latent_vid, latent_middle)
if bass_pulse_impact > 0:
  deep = bass.sum(0)
  deep = (deep - deep.min()) / (deep.max() - deep.min())
  deep = deep * bass_pulse_impact
  deep = torch.tensor(deep).to(device)
  latents = slerp(deep, latent_vid, latent_bass)
else:
  latents = latent_vid

In [None]:
# #@title Smoothing in the latent space

from torch.nn.functional import conv1d

# Bild filter of constant values to convolve the latent frames with
if smoothing > 0:
    smoothing_frames = int(frame_rate * smoothing) * 2 + 1
    latents_flat = latents.view(1, latents.shape[0], -1)
    smoothing_filter = torch.ones([latents_flat.shape[-1], 1, smoothing_frames], dtype=latents.dtype).to(device) / smoothing_frames
    latents_smoothed = conv1d(latents_flat.transpose(1, 2), smoothing_filter,
                              bias=None, stride=1, padding=int(frame_rate * smoothing),
                              dilation=1, groups=latents_flat.shape[-1]).transpose(2, 1)[0].view(latents.shape)
    latents = latents_smoothed

In [None]:
#@title Generate frames - this can take a while

!rm -rf parts

os.makedirs("parts", exist_ok=True)
def write_frames(frames, start_id):
  for offset in range(len(frames)):
    frame_id = start_id + offset
    TF.to_pil_image(frames[offset]).save(f'parts/output_{frame_id:08}.jpg')

batch = batch_size
start_frame = 0
end_frame = batch
part_id = 0
while end_frame < len(latents):
  frames = G.synthesis(latents[start_frame:end_frame], noise_mode='const').add(1).div(2).clamp(0,1).cpu().detach()
  write_frames(frames, start_frame)
  save_image(frames[0], "c_frame")
  start_frame = end_frame
  end_frame += batch

In [None]:
#@title Create video from frames and audio
import soundfile as sf
sf.write("audio_cut.wav", audio, sr)
!mkdir {output_path}
!ffmpeg  -r {frame_rate} -i parts/%*.jpg -y -c:v libx264 vid_no_audio.mp4
!ffmpeg -y -i audio_cut.wav -i vid_no_audio.mp4 final_video.mp4
!ffmpeg -i final_video.mp4 {output_path}/d_compressed.mp4
!sleep 20

import os.path
if not os.path.exists(output_path + '/d_compressed.mp4'):
  raise Exception("Expected output file does not exist.")

In [None]:
#@title Licensed under the MIT License { display-mode: "form" }

# Copyright (c) 2021 nshepperd; Katherine Crowson

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:

# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.

In [None]:
# from google.colab import drive
# drive.mount('/content/drive')
# !mkdir drive/MyDrive/weihnachten2021
# !cp -r parts/ drive/MyDrive/weihnachten2021/frames
# !cp -r final_video.mp4 drive/MyDrive/weihnachten2021