In [1]:
import os
os.chdir('../')

import warnings
warnings.filterwarnings('ignore')

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import os

!nvidia-smi
os.environ["CUDA_VISIBLE_DEVICES"]="6"

Fri May 26 19:38:34 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 515.105.01   Driver Version: 515.105.01   CUDA Version: 11.7     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA RTX A5000    Off  | 00000000:1B:00.0 Off |                    0 |
| 30%   26C    P8    14W / 230W |   2884MiB / 23028MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  NVIDIA RTX A5000    Off  | 00000000:1C:00.0 Off |                  Off |
| 30%   51C    P2   118W / 230W |  17677MiB / 24564MiB |     17%      Default |
|       

### Hyperparams

In [3]:
n_mels = 8
n_outputs = 61
n_frames = 400
sr = 24000
fps = 30
train_csv_files = ['/Storage/speech/face/rvh_viseme2_1_iPhone_raw_30fps.csv',
                   '/Storage/speech/face/rvh_viseme2_2_iPhone_raw_30fps.csv',
                   '/Storage/speech/face/rvh_viseme2_3_iPhone_raw_30fps.csv',
                   '/Storage/speech/face/rvh_viseme2_4_iPhone_raw_30fps.csv',
                  ]
train_wav_files = ['/Storage/speech/face/rvh_viseme2_1_iPhone.wav',
                   '/Storage/speech/face/rvh_viseme2_2_iPhone.wav',
                   '/Storage/speech/face/rvh_viseme2_3_iPhone.wav',
                   '/Storage/speech/face/rvh_viseme2_4_iPhone.wav',
                  ]

test_csv_files = ['/Storage/speech/face/MySlate_6_박수철의_iPhone_30fps.csv']
test_wav_files = ['/Storage/speech/face/MySlate_6_박수철의_iPhone.wav']

### Model

In [5]:
from model.model_vqvae import Model
from utils.util import *
from tensorboardX import SummaryWriter

step = 0
device = 'cuda:0'

# Model
model = Model(in_dim=n_mels, out_dim=n_outputs, K=16, latent_dim=128)
model = model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
print('done')

done


### Load

In [6]:
save_dir = '/data/scpark/save/lips/train05.24-3/'
!mkdir -p $save_dir
!ls -lt $save_dir

writer = SummaryWriter(save_dir)

if False:
    step, model, _, optimizer = load(save_dir, 1000, model, None, optimizer)

total 0


In [7]:
import torch
from data.dataset import LipsDataset, Collate

dataset = LipsDataset(train_wav_files, train_csv_files, n_frames, n_mels=n_mels, sr=sr, fps=fps)
print('train dataset :', len(dataset))
train_loader = torch.utils.data.DataLoader(dataset, num_workers=1, shuffle=True, batch_size=8, 
                                           collate_fn=Collate(n_frames, n_mels))
print(train_loader)

dataset = LipsDataset(test_wav_files, test_csv_files, n_frames, n_mels=n_mels, sr=sr, fps=fps)
print('test dataset :', len(dataset))
test_loader = torch.utils.data.DataLoader(dataset, num_workers=1, shuffle=True, batch_size=8,
                                           collate_fn=Collate(n_frames, n_mels))
print(test_loader)

train dataset : 28308
<torch.utils.data.dataloader.DataLoader object at 0x7f87862e37f0>
test dataset : 1413
<torch.utils.data.dataloader.DataLoader object at 0x7f8716126220>


In [8]:
def preprocess(x):
    return torch.log10(x+1) * 10

In [None]:
from IPython import display
import librosa.display
import matplotlib.pyplot as plt

isnan = False
while True:
    if isnan:
        break
    for batch in train_loader:
        inputs = batch['inputs'].transpose(1, 2).to(device)
        targets = batch['outputs'].transpose(1, 2).to(device)
        targets = preprocess(targets)
        
        model.train()
        model.zero_grad()
        outputs = model(inputs, targets)
        
        print(step)
        loss = 0
        for key in outputs.keys():
            if 'loss' in key:
                loss += outputs[key]
                print(key, outputs[key].item())
        if torch.isnan(loss):
            isnan = True
            break
        loss.backward()
        optimizer.step()
        print(step, loss.item())
        
        if step % 100 == 0:
            writer.add_scalar('train_loss', loss.item(), step)
        
        if step % 100 == 0:
            display.clear_output()
            
            losses = []
            for i, batch in enumerate(test_loader):
                if i >= 10:
                    break
                    
                inputs = batch['inputs'].transpose(1, 2).to(device)
                targets = batch['outputs'].transpose(1, 2).to(device)
                targets = preprocess(targets)

                model.eval()
                with torch.no_grad():
                    outputs = model(inputs, targets)
                    
                loss = 0
                for key in outputs.keys():
                    if 'loss' in key:
                        loss += outputs[key]
                print('test :', i, loss.item())
                losses.append(loss)        
            
            test_loss = torch.stack(losses).mean().item()
            print('test_loss :', test_loss)
            writer.add_scalar('test_loss', test_loss, step)

        if step % 1000 == 0:
            save(save_dir, step, model, None, optimizer)
    
        step += 1

test : 0 5.177490711212158
test : 1 5.17063570022583
test : 2 5.269093990325928
test : 3 5.301322937011719
test : 4 5.237942218780518
test : 5 5.35686731338501
test : 6 5.158825397491455
test : 7 5.294137477874756
test : 8 5.333480358123779
test : 9 5.224009037017822
test_loss : 5.25238037109375
4101
auto_encoding_loss 0.032544348388910294
commit_loss 0.01419934444129467
zi_prediction_loss 0.47786498069763184
4101 0.5246086716651917
4102
auto_encoding_loss 0.03392971679568291
commit_loss 0.014050749130547047
zi_prediction_loss 0.5034300088882446
4102 0.5514104962348938
4103
auto_encoding_loss 0.033825017511844635
commit_loss 0.014151062816381454
zi_prediction_loss 0.504743754863739
4103 0.5527198314666748
4104
auto_encoding_loss 0.03348485380411148
commit_loss 0.0153043232858181
zi_prediction_loss 0.4886641800403595
4104 0.5374533534049988
4105
auto_encoding_loss 0.03134307265281677
commit_loss 0.014694973826408386
zi_prediction_loss 0.5100514888763428
4105 0.5560895204544067
4106
auto

4158 0.5313160419464111
4159
auto_encoding_loss 0.03471828252077103
commit_loss 0.01421065628528595
zi_prediction_loss 0.4838577210903168
4159 0.5327866673469543
4160
auto_encoding_loss 0.0321192592382431
commit_loss 0.013345185667276382
zi_prediction_loss 0.4879230856895447
4160 0.5333875417709351
4161
auto_encoding_loss 0.03057291731238365
commit_loss 0.013443058356642723
zi_prediction_loss 0.5052249431610107
4161 0.5492409467697144
4162
auto_encoding_loss 0.03250075504183769
commit_loss 0.015609064139425755
zi_prediction_loss 0.4817618131637573
4162 0.5298716425895691
4163
auto_encoding_loss 0.031765323132276535
commit_loss 0.013135994784533978
zi_prediction_loss 0.4584120810031891
4163 0.503313422203064
4164
auto_encoding_loss 0.03143935278058052
commit_loss 0.01281314343214035
zi_prediction_loss 0.44785788655281067
4164 0.49211037158966064
4165
auto_encoding_loss 0.03184980899095535
commit_loss 0.013209501281380653
zi_prediction_loss 0.4800184667110443
4165 0.525077760219574
4166
