In [None]:
#@title Setup

%load_ext autoreload
%autoreload 2

import os
import sys

import subprocess
CUDA_version = [s for s in subprocess.check_output(["nvcc", "--version"]).decode("UTF-8").split(", ") if s.startswith("release")][0].split(" ")[-1]
print("CUDA version:", CUDA_version)

if CUDA_version == "10.0":
    torch_version_suffix = "+cu100"
elif CUDA_version == "10.1":
    torch_version_suffix = "+cu101"
elif CUDA_version == "10.2":
    torch_version_suffix = ""
else:
    torch_version_suffix = "+cu110"

!pip install torch==1.7.1{torch_version_suffix} torchvision==0.8.2{torch_version_suffix} -f https://download.pytorch.org/whl/torch_stable.html ftfy regex pytorch-ignite visdom
!pip install git+https://github.com/nielsrolf/siren-pytorch
!pip install git+https://github.com/pollinations/AudioCLIP
!pip install git+https://github.com/pollinations/CLIPTranslate

!wget https://github.com/AndreyGuzhov/AudioCLIP/releases/download/v0.1/bpe_simple_vocab_16e6.txt.gz -O 'AudioCLIP/assets/bpe_simple_vocab_16e6.txt.gz'
!wget https://github.com/AndreyGuzhov/AudioCLIP/releases/download/v0.1/AudioCLIP-Partial-Training.pt

!wget https://raw.githubusercontent.com/pollinations/CLIPTranslate/main/notebooks/data/cat.jpg
!wget https://raw.githubusercontent.com/pollinations/CLIPTranslate/main/notebooks/data/gt_bach.wav
!wget https://raw.githubusercontent.com/pollinations/CLIPTranslate/main/notebooks/data/hearbeat.jpg


from clip_translate.utils import load_img, imshow, load_audio, play
sample_img = load_img("/content/cat.jpg")
imshow(sample_img)

sample_audio = load_audio("/content/gt_bach.wav")
play(sample_audio)


from siren_pytorch import SirenNet, SirenWrapperNDim
import time
from IPython.display import clear_output
from clip_translate import AudioImagine
import torch

DEVICE = torch.device('cuda:0')

CUDA version: 11.0
Looking in links: https://download.pytorch.org/whl/torch_stable.html
Collecting torch==1.7.1+cu110
[?25l  Downloading https://download.pytorch.org/whl/cu110/torch-1.7.1%2Bcu110-cp37-cp37m-linux_x86_64.whl (1156.8MB)
[K     |███████████████████████         | 834.1MB 1.3MB/s eta 0:04:10tcmalloc: large alloc 1147494400 bytes == 0x555b1d924000 @  0x7f7440cdc615 0x555ae4c3fcdc 0x555ae4d1f52a 0x555ae4c42afd 0x555ae4d33fed 0x555ae4cb6988 0x555ae4cb14ae 0x555ae4c443ea 0x555ae4cb67f0 0x555ae4cb14ae 0x555ae4c443ea 0x555ae4cb332a 0x555ae4d34e36 0x555ae4cb2853 0x555ae4d34e36 0x555ae4cb2853 0x555ae4d34e36 0x555ae4cb2853 0x555ae4d34e36 0x555ae4db73e1 0x555ae4d176a9 0x555ae4c82cc4 0x555ae4c43559 0x555ae4cb74f8 0x555ae4c4430a 0x555ae4cb23b5 0x555ae4cb17ad 0x555ae4c443ea 0x555ae4cb23b5 0x555ae4c4430a 0x555ae4cb23b5
[K     |█████████████████████████████▏  | 1055.7MB 1.2MB/s eta 0:01:27tcmalloc: large alloc 1434370048 bytes == 0x555b61f7a000 @  0x7f7440cdc615 0x555ae4c3fcdc 0x555ae4

# Fit an audio to a text or image prompt

In [None]:
from clip_translate import AudioImagine, get_siren_decoder, fit_siren

imagine = AudioImagine(
    perceptor="AudioCLIP-Partial-Training.pt", 
    text="A cat", 
    image=sample_img)

siren = get_siren_decoder(sample_audio.shape, latent_dim=None)
fit_siren(imagine, siren,  steps=1000)

# Fitting a hypernetwork

In [None]:
import torch.nn.functional as F
from torch import nn
from clip_translate import AudioImagine, get_siren_decoder, fit_siren
from matplotlib import pyplot as plt

class Autoencoder(nn.Module):
  def __init__(self, encoder, decoder, loss):
    super().__init__()
    self.encoder = encoder
    self.decoder = decoder
    self.loss = loss
  
  def forward(self, audio, target=None):
    latent = self.encoder(audio)
    reconstructed = self.decoder(latent=latent)
    if target is not None:
      return self.loss(target, reconstructed)
    return reconstructed


latent = nn.Parameter(torch.zeros(1024).normal_(0, 1)).to('cuda')
def constant_encoder(audio):
  return latent


def clip_encoder(audio):
  audio = audio.reshape(1, -1)
  with torch.no_grad():
    latent = imagine.encode_audio(audio.detach())
  return latent.reshape(1024)


def get_siren_decoder(output_shape, latent_dim=1024):
    net = SirenNet(
        dim_in=1,
        dim_hidden=256,
        dim_out=1,
        num_layers=3,
        w0=30.,
        w0_initial=10000.,
        use_bias=True,
        final_activation=None)

    decoder = SirenWrapperNDim(
        net,
        latent_dim=latent_dim,
        output_shape=output_shape
    )
    decoder.cuda()

    return decoder



def train_on_single_sample(ae, lr=1e-4, steps=2000):
  optim = torch.optim.Adam(lr=lr, params=ae.parameters())
  steps_till_summary = 1000
  for step in range(steps):
    loss = ae(sample_audio, sample_audio)
    optim.zero_grad()
    loss.backward()
    optim.step()
    if step % steps_till_summary == 0:
      print(loss.cpu().detach())
      pred_audio = ae(sample_audio)
      play(pred_audio)
      plt.plot(pred_audio.cpu().detach().numpy().squeeze())
      plt.show()

# decoder = get_siren_decoder(sample_audio.shape, 1024)
# ae = Autoencoder(encoder=clip_encoder, decoder=decoder, loss=F.mse_loss)
# train_on_single_sample(ae)

In [None]:
class GON(Autoencoder):
    def __init__(self, decoder):
        super().__init__(encoder=self.encode, decoder=decoder, loss=F.mse_loss)
    
    def encode(self, audio):
        latent = nn.Parameter(torch.zeros(1024)).to('cuda')
        inner_loss = self.loss(audio, self.decoder(latent=latent))
        z = -torch.autograd.grad(inner_loss, [latent], create_graph=True, retain_graph=True)[0]
        return z

decoder = get_siren_decoder(sample_audio.shape, 1024)
ae = GON(decoder=decoder)
train_on_single_sample(ae)

In [None]:
# Train GON on multiple audio examples

from glob import glob


audio_files = glob("/content/drive/MyDrive/ddsp/samples/*/*.wav")
audio_files

In [None]:
audios = [load_audio(i) for i in audio_files]



In [None]:
audios[0].shape

In [None]:
import numpy as np

def random_crop(audio, seconds=4):
    audio_shape = audio.shape
    frames = int(seconds * rate)
    cutoff = audio.shape[0] - frames
    cutoff_start = np.random.randint(0, cutoff)
    cutoff_end = cutoff - cutoff_start
    audio = audio[cutoff_start:-cutoff_end]
    return audio

def get_sample():
    audio = audios[np.random.randint(len(audios))]
    return random_crop(audio, seconds=4)


for _ in range(4):
    play(get_sample())

In [None]:
steps = 10000
lr = 1e-3





decoder = get_siren_decoder(get_sample().shape, 1024)
ae = GON(decoder=decoder)

optim = torch.optim.Adam(lr=lr, params=ae.parameters())
steps_till_summary = 1000
for step in range(steps):
    x = get_sample()
    loss = ae(x, x)
    optim.zero_grad()
    loss.backward()
    optim.step()
    if step % steps_till_summary == 0:
        print(loss.cpu().detach())
        pred_audio = ae(x)
        play(pred_audio)
        plt.plot(pred_audio.cpu().detach().numpy().squeeze())
        plt.show()