In [1]:
# %pip install -qq -U diffusers transformers ftfy pyarrow==9.0.0
%cd ..
%load_ext autoreload
%autoreload 2

/home/smehta/Projects/OverFlow


In [2]:
%env CUDA_VISIBLE_DEVICES=0

env: CUDA_VISIBLE_DEVICES=0


In [3]:
from src.hparams import create_hparams
from src.utilities.data import TextMelLoader, TextMelMotionCollate
from torch.utils.data import DataLoader

In [4]:
hparams = create_hparams()

In [5]:
trainset = TextMelLoader(
            hparams.training_files, hparams, [hparams.mel_normaliser], [hparams.motion_normaliser]
)

train_dataloader = DataLoader(
            trainset,
            batch_size=32,
            collate_fn=TextMelMotionCollate(hparams.n_frames_per_step),
            num_workers=hparams.num_workers,
            pin_memory=True,
        )

Data cache found at : data/filelists/cormac_train.txt.cleaned! Loading cache...


  0%|          | 0/4812 [00:00<?, ?it/s]

Done caching mels! New mels cached: 0


In [6]:
from diffusers import DDIMScheduler, DDPMScheduler
from src.model.diffusion import GradLogPEstimator2d
from src.model.wavegrad import WaveGrad

In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm.auto import tqdm
from src.utilities.functions import fix_len_compatibility, get_mask_from_len
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
import matplotlib.pyplot as plt

In [8]:
noise_scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule='squaredcos_cap_v2')

net = GradLogPEstimator2d(
    64,
    n_feats=80,
    n_spks=1
).to(device)

losses = []


net = WaveGrad(
    motion_in_channels=45, 
    latent_in_channels=80
).to(device)


opt = torch.optim.Adam(net.parameters(), lr=1e-3)


In [None]:
loss_fn = nn.L1Loss()

# The optimizer

# Keeping a record of the losses for later viewing

# The training loop
for epoch in range(40):
    pbar = tqdm(train_dataloader)
    pbar.set_description(f'Epoch: {epoch}')
    for batch in pbar:
        text_padded, input_lengths, mel_padded, motion_padded, output_lengths = batch
        # Get some data and prepare the corrupted version
        
        
        x = motion_padded.to(device)
        condition = mel_padded.to(device)
        
        output_lengths = output_lengths.to(device)
        # output_lengths_max = fix_len_compatibility(int(output_lengths.max()))
        output_lengths_mask = get_mask_from_len(output_lengths, output_lengths.max(), device=x.device, dtype=x.dtype).unsqueeze(1)
        # original_length = x.shape[-1]
        # x = F.pad(x, (0, output_lengths_max - x.shape[-1]))
        
  
        noise = torch.randn_like(x)
        timesteps = torch.randint(0, 999, (x.shape[0],)).long().to(device)
        noisy_x = noise_scheduler.add_noise(x, noise, timesteps)
        
        # Get the model prediction
        pred = net(noisy_x, output_lengths_mask, condition, timesteps) 

        # Calculate the loss
        loss = loss_fn(pred * output_lengths_mask, noise * output_lengths_mask) # How close is the output to the noise
        
        pbar.set_postfix({"loss" : loss.item()})
        
        if loss > 15:
            import pdb; pdb.set_trace()
        

        # Backprop and update the params:
        opt.zero_grad()
        loss.backward()
        opt.step()

        # Store the loss for later
        losses.append(loss.item())

    # Print our the average of the last 100 loss values to get an idea of progress:
    avg_loss = sum(losses[-100:])/100
    print(f'Finished epoch {epoch}. Average of the last 100 loss values: {avg_loss:05f}')

# View the loss curve

  0%|          | 0/151 [00:01<?, ?it/s]

Finished epoch 0. Average of the last 100 loss values: 0.480170


  0%|          | 0/151 [00:03<?, ?it/s]

Finished epoch 1. Average of the last 100 loss values: 0.262889


  0%|          | 0/151 [00:02<?, ?it/s]

Finished epoch 2. Average of the last 100 loss values: 0.200206


  0%|          | 0/151 [00:02<?, ?it/s]

Finished epoch 3. Average of the last 100 loss values: 0.173795


  0%|          | 0/151 [00:03<?, ?it/s]

Finished epoch 4. Average of the last 100 loss values: 0.148546


  0%|          | 0/151 [00:03<?, ?it/s]

Finished epoch 5. Average of the last 100 loss values: 0.140949


  0%|          | 0/151 [00:03<?, ?it/s]

Finished epoch 6. Average of the last 100 loss values: 0.132708


  0%|          | 0/151 [00:03<?, ?it/s]

Finished epoch 7. Average of the last 100 loss values: 0.125844


  0%|          | 0/151 [00:03<?, ?it/s]

Finished epoch 8. Average of the last 100 loss values: 0.122321


  0%|          | 0/151 [00:03<?, ?it/s]

Finished epoch 9. Average of the last 100 loss values: 0.119541


  0%|          | 0/151 [00:03<?, ?it/s]

Finished epoch 10. Average of the last 100 loss values: 0.113181


  0%|          | 0/151 [00:03<?, ?it/s]

Finished epoch 11. Average of the last 100 loss values: 0.108509


  0%|          | 0/151 [00:03<?, ?it/s]

Finished epoch 12. Average of the last 100 loss values: 0.114314


  0%|          | 0/151 [00:03<?, ?it/s]

Finished epoch 13. Average of the last 100 loss values: 0.106735


  0%|          | 0/151 [00:04<?, ?it/s]

Finished epoch 14. Average of the last 100 loss values: 0.105727


  0%|          | 0/151 [00:04<?, ?it/s]

Finished epoch 15. Average of the last 100 loss values: 0.105370


  0%|          | 0/151 [00:04<?, ?it/s]

Finished epoch 16. Average of the last 100 loss values: 0.102220


  0%|          | 0/151 [00:05<?, ?it/s]

Finished epoch 17. Average of the last 100 loss values: 0.102118


  0%|          | 0/151 [00:05<?, ?it/s]

Finished epoch 18. Average of the last 100 loss values: 0.096308


  0%|          | 0/151 [00:05<?, ?it/s]

Finished epoch 19. Average of the last 100 loss values: 0.097329


  0%|          | 0/151 [00:05<?, ?it/s]

IOStream.flush timed out
IOStream.flush timed out


Finished epoch 20. Average of the last 100 loss values: 0.099135


  0%|          | 0/151 [00:07<?, ?it/s]

Finished epoch 21. Average of the last 100 loss values: 0.101331


  0%|          | 0/151 [00:08<?, ?it/s]

Finished epoch 22. Average of the last 100 loss values: 0.103214


  0%|          | 0/151 [00:10<?, ?it/s]

In [None]:
plt.plot(losses)

In [None]:
## DDPM
len_ = 929
x = torch.randn(1, 45, len_).to(device)
mask = get_mask_from_len(torch.tensor([len_], device=device, dtype=torch.long), len_, device=x.device, dtype=x.dtype).unsqueeze(1)
noise_scheduler.set_timesteps(num_inference_steps=len_)
# Sampling loop
for i, t in tqdm(enumerate(noise_scheduler.timesteps)):
    
    x = noise_scheduler.scale_model_input(x, t)
    # Get model pred
    t = t.unsqueeze(0).to(device)
    with torch.no_grad():
        residual = net(x, mask, mel_padded[0:1].to(device), t)  # Again, note that we pass in our labels y

    # Update sample with step
    x = noise_scheduler.step(residual, t, x, temperature=0.334).prev_sample

In [None]:
# ### DDIM
# len_ = 1000
# x = torch.randn(1, 80, len_).to(device)
# mask = get_mask_from_len(torch.tensor([len_], device=device, dtype=torch.long), len_, device=x.device, dtype=x.dtype).unsqueeze(1)
# noise_scheduler.set_timesteps(num_inference_steps=40)
# # Sampling loop
# for i, t in tqdm(enumerate(noise_scheduler.timesteps)):
    
#     x = noise_scheduler.scale_model_input(x, t)
#     # Get model pred
#     t = t.unsqueeze(0).to(device)
#     with torch.no_grad():
#         residual = net(x, mask, x, t)  # Again, note that we pass in our labels y

#     # Update sample with step
#     x = noise_scheduler.step(residual, t, x, temperature=0.334).prev_sample * mask

In [None]:
mel_spectrogram = x.squeeze(0).cpu().detach().numpy()


In [None]:
fig, ax = plt.subplots(figsize=(12, 3))
im = ax.imshow(mel_spectrogram, aspect="auto", origin="lower", interpolation="none")
plt.colorbar(im, ax=ax)
plt.xlabel("Frames")
plt.ylabel("Channels")
plt.tight_layout()
plt.show()

In [None]:
%%capture
from pymo.preprocessing import MocapParameterizer
from pymo.viz_tools import render_mp4
hparams.motion_visualizer

bvh_values = hparams.motion_visualizer.inverse_transform([mel_spectrogram.T])

    # To stickfigure
X_pos = MocapParameterizer("position").fit_transform(bvh_values)

render_mp4(X_pos[0], "temp.mp4", axis_scale=200)