Skip to content

Commit

Permalink
Merge pull request #1 from simonrouard/dev
Browse files Browse the repository at this point in the history
Dev
  • Loading branch information
simonrouard committed Jul 27, 2021
2 parents 7710ebf + b63538e commit 97d13e2
Show file tree
Hide file tree
Showing 10 changed files with 950 additions and 833 deletions.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
saved_weights/
saved_weights_old/
.ipynb_checkpoints/
__pycache__
img/
img/
model_classifier_v2.py
14 changes: 0 additions & 14 deletions __main__.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,3 @@
# Copyright 2020 LMNT, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import os

from torch.cuda import device_count
Expand Down
17 changes: 0 additions & 17 deletions dataset.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,6 @@
# Copyright 2020 LMNT, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from torch.utils.data.distributed import DistributedSampler
from glob import glob
import numpy as np
import os
import random
import torch
import torchaudio
torchaudio.set_audio_backend("sox")
Expand Down
47 changes: 34 additions & 13 deletions inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,31 @@
from scipy.integrate import solve_ivp


def compute_interpolation_in_latent(latent1, latent2, lambd):
'''
Implementation of Spherical Linear Interpolation: https://en.wikipedia.org/wiki/Slerp
latent1: tensor of shape (1, 21000)
latent2: tensor of shape (1, 21000)
lambd: list of floats between 0 and 1 representing the parameter t of the Slerp
'''
device = latent1.device
lambd = torch.tensor(lambd)

cos_omega = latent1[0]@latent2[0] / \
(torch.linalg.norm(latent1[0])*torch.linalg.norm(latent2[0]))
omega = torch.arccos(cos_omega).item()

a = torch.sin((1-lambd)*omega) / np.sin(omega)
b = torch.sin(lambd*omega) / np.sin(omega)
a = a.unsqueeze(1).to(device)
b = b.unsqueeze(1).to(device)
return a * latent1 + b * latent2


class SDESampling:
"""
Euler-Maruyama discretisation of the SDE as in https://arxiv.org/abs/2011.13456
This the less precise discretization
Euler-Maruyama discretization of the SDE as in https://arxiv.org/abs/2011.13456
This is the less precise discretization
"""

def __init__(self, model, sde):
Expand Down Expand Up @@ -55,7 +76,7 @@ def predict(

class SDESampling2:
"""
DDPM-like discretisation of the SDE as in https://arxiv.org/abs/2107.00630
DDPM-like discretization of the SDE as in https://arxiv.org/abs/2107.00630
This is the most precise discretization
"""

Expand Down Expand Up @@ -104,7 +125,7 @@ def predict(

class SDESampling3:
"""
DDIM-like discretisation of the SDE as in https://arxiv.org/abs/2106.07431 Alg. 6
DDIM-like discretization of the SDE as in https://arxiv.org/abs/2106.07431 Alg. 6
This is an intermediate model in terms of precision
"""

Expand Down Expand Up @@ -316,7 +337,7 @@ def __init__(self, model, sde):
self.sde = sde

def create_schedules(self, nb_steps):
t_schedule = torch.arange(0, nb_steps + 1) / nb_steps
t_schedule = torch.arange(1, nb_steps + 1) / nb_steps
t_schedule = (self.sde.t_max - self.sde.t_min) * \
t_schedule + self.sde.t_min
sigma_schedule = self.sde.sigma(t_schedule)
Expand Down Expand Up @@ -596,7 +617,7 @@ def predict(
audio = (audio - sigma[0] * self.model(audio,
sigma[0])) / self.sde.mean(t_0)

return audio
return audio.detach()


class ClassMixingSDE2:
Expand Down Expand Up @@ -652,7 +673,7 @@ def predict(
audio = (audio - sigma[0] * self.model(audio,
sigma[0])) / m[0]

return audio
return audio.detach()


class ClassMixingODE:
Expand Down Expand Up @@ -704,7 +725,7 @@ def predict(
audio = (audio - sigma[0] * self.model(audio,
sigma[0])) / self.sde.mean(t_0)

return audio
return audio.detach()


class ClassMixingDDIM:
Expand Down Expand Up @@ -754,12 +775,12 @@ def predict(
audio = (audio - sigma[0] * self.model(audio,
sigma[0])) / m[0]

return audio
return audio.detach()


class RegenerateSDESampling2:
"""
Using the DDPM-like discretisation of the SDE (like SDESampling2 class) of a drum sound noised at the noise level sigma
Using the DDPM-like discretization of the SDE (like SDESampling2 class) of a drum sound noised at the noise level sigma
"""

def __init__(self, model, sde):
Expand Down Expand Up @@ -808,10 +829,10 @@ def predict(
audio = (audio - sigma[0] * self.model(audio,
sigma[0])) / m[0]

return audio
return audio.detach()


class RandomClassConditionalDDIMSampling:
class RandomClassMixingDDIM:
"""
Adapted the DDIM to the SDE Framework.
eta = 1 corresponds to the SDESampling2 class
Expand Down Expand Up @@ -878,4 +899,4 @@ def predict(
audio = (audio - sigma[0] * self.model(audio,
sigma[0])) / m[0]

return audio
return audio.detach()
Loading

0 comments on commit 97d13e2

Please sign in to comment.