In [1]:
import os
os.chdir('../')

import warnings
warnings.filterwarnings('ignore')

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import os

!nvidia-smi
os.environ["CUDA_VISIBLE_DEVICES"]="4"

Mon Aug  7 22:43:48 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.199.02   Driver Version: 470.199.02   CUDA Version: 11.4     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA A100 80G...  Off  | 00000000:1B:00.0 Off |                    0 |
| N/A   60C    P0   342W / 300W |  16890MiB / 80994MiB |     24%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
|   1  NVIDIA A100 80G...  Off  | 00000000:1C:00.0 Off |                    0 |
| N/A   32C    P0    52W / 300W |     35MiB / 80994MiB |      0%      Default |
|       

### Hyperparams

In [3]:
n_mels = 768
n_outputs = 61
n_frames = 400

### Model

In [4]:
from model.model_transformer_reg import Model
from utils.util import *
from tensorboardX import SummaryWriter

step = 0
device = 'cuda:0'

# Model
model = Model(in_dim=n_mels, h_dim=512, out_dim=n_outputs)
model = model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
print('done')

done


### Load

In [5]:
save_dir = '/data/scpark/save/lips/train08.08-3/'
!mkdir -p $save_dir
!ls -lt $save_dir

writer = SummaryWriter(save_dir)

if False:
    step, model, _, optimizer = load(save_dir, 133000, model, None, optimizer)

total 4
-rw-rw-r-- 1 scpark scpark 86  8ì›”  7 22:25 events.out.tfevents.1691413283.GPUSVR11


### Dataset

In [6]:
import os
from data.arkit_dataset import LipsDataset, CombinedDataset, CombinedCollate

root_dir = '/data/speech/digital_human/preprocessed/kyuchulLee'
files = sorted([os.path.join(root_dir, file) for file in os.listdir(root_dir)])
print(len(files))

train_datasets = []
test_datasets = []

for file in files:
    print(file)
    dataset = LipsDataset(file, n_mels, n_frames, perturb=False)
    if '_1_' in file:
        test_datasets.append(dataset)
    else:
        train_datasets.append(dataset)
print(len(train_datasets), len(test_datasets))

10
/data/speech/digital_human/preprocessed/kyuchulLee/MH_ARKit_006_10_iPhone_raw.npy
/data/speech/digital_human/preprocessed/kyuchulLee/MH_ARKit_006_1_iPhone_raw.npy
/data/speech/digital_human/preprocessed/kyuchulLee/MH_ARKit_006_2_iPhone_raw.npy
/data/speech/digital_human/preprocessed/kyuchulLee/MH_ARKit_006_3_iPhone_raw.npy
/data/speech/digital_human/preprocessed/kyuchulLee/MH_ARKit_006_4_iPhone_raw.npy
/data/speech/digital_human/preprocessed/kyuchulLee/MH_ARKit_006_5_iPhone_raw.npy
/data/speech/digital_human/preprocessed/kyuchulLee/MH_ARKit_006_6_iPhone_raw.npy
/data/speech/digital_human/preprocessed/kyuchulLee/MH_ARKit_006_7_iPhone_raw.npy
/data/speech/digital_human/preprocessed/kyuchulLee/MH_ARKit_006_8_iPhone_raw.npy
/data/speech/digital_human/preprocessed/kyuchulLee/MH_ARKit_006_9_iPhone_raw.npy
9 1


In [7]:
train_loader = torch.utils.data.DataLoader(CombinedDataset(train_datasets), 
                                           num_workers=16, shuffle=True, batch_size=32, collate_fn=CombinedCollate())
test_loader = torch.utils.data.DataLoader(CombinedDataset(test_datasets), 
                                          num_workers=10, shuffle=True, batch_size=10, collate_fn=CombinedCollate())
print('done')

done


In [8]:
import fairseq
from torchaudio.transforms import Resample

resample = Resample(24000, 16000)

ckpt_path = "/Storage/speech/pretrained/contentvec/checkpoint_best_legacy_500.pt"
hubert, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([ckpt_path])
hubert = hubert[0]
hubert = hubert.to(device)
hubert.eval()

def get_hubert_feature(wav):
    with torch.no_grad():
        # (b, t, c)
        wav = resample(torch.Tensor(wav)).to(device)
        feature = hubert.extract_features(wav, output_layer=12)[0]
        return feature.transpose(1, 2)
print('done')

2023-08-07 22:44:03 | INFO | fairseq.tasks.hubert_pretraining | current directory is /home/scpark/projects/wav2face
2023-08-07 22:44:03 | INFO | fairseq.tasks.hubert_pretraining | HubertPretrainingTask Config {'_name': 'hubert_pretraining', 'data': 'metadata', 'fine_tuning': False, 'labels': ['km'], 'label_dir': 'label', 'label_rate': 50.0, 'sample_rate': 16000, 'normalize': False, 'enable_padding': False, 'max_keep_size': None, 'max_sample_size': 250000, 'min_sample_size': 32000, 'single_target': False, 'random_crop': True, 'pad_audio': False}
2023-08-07 22:44:03 | INFO | fairseq.models.hubert.hubert | HubertModel Config: {'_name': 'hubert', 'label_rate': 50.0, 'extractor_mode': default, 'encoder_layers': 12, 'encoder_embed_dim': 768, 'encoder_ffn_embed_dim': 3072, 'encoder_attention_heads': 12, 'activation_fn': gelu, 'layer_type': transformer, 'dropout': 0.1, 'attention_dropout': 0.1, 'activation_dropout': 0.0, 'encoder_layerdrop': 0.05, 'dropout_input': 0.1, 'dropout_features': 0.1,

done


In [None]:
from IPython import display
import librosa.display
import matplotlib.pyplot as plt

isnan = False
while True:
    if isnan:
        break
    for batch in train_loader:
        targets = torch.Tensor(batch['blend']).transpose(1, 2).to(device)
        inputs = get_hubert_feature(batch['wav'])
        inputs = F.interpolate(inputs, size=(targets.shape[2]), mode='linear')
        
        model.train()
        model.zero_grad()
        outputs = model(inputs, targets)
        
        print(step)
        loss = 0
        for key in outputs.keys():
            if 'loss' in key:
                loss += outputs[key]
                print(key, outputs[key].item())
        if torch.isnan(loss):
            isnan = True
            break
        loss.backward()
        optimizer.step()
        print(step, loss.item())
        
        if step % 1000 == 0:
            writer.add_scalar('train_loss', loss.item(), step)
        
        if step % 1000 == 0:
            display.clear_output()
            
            losses = []
            for i, batch in enumerate(test_loader):
                if i >= 10:
                    break
                    
                targets = torch.Tensor(batch['blend']).transpose(1, 2).to(device)
                inputs = get_hubert_feature(batch['wav'])
                inputs = F.interpolate(inputs, size=(targets.shape[2]), mode='linear')

                model.eval()
                with torch.no_grad():
                    outputs = model(inputs, targets)
                    
                loss = 0
                for key in outputs.keys():
                    if 'loss' in key:
                        loss += outputs[key]
                print('test :', i, loss.item())
                losses.append(loss)        
            
            test_loss = torch.stack(losses).mean().item()
            print('test_loss :', test_loss)
            writer.add_scalar('test_loss', test_loss, step)
            
#             plt.figure(figsize=[18, 4])
#             librosa.display.specshow(targets[0].data.cpu().numpy(), cmap='magma')
#             plt.show()

#             plt.figure(figsize=[18, 4])
#             librosa.display.specshow(outputs['y_pred'][0].data.cpu().numpy(), cmap='magma')
#             plt.show()
            
#             for i in [20, 37]:
#                 plt.figure(figsize=[18, 2])
#                 plt.title(str(i))
#                 plt.plot(targets[0].data.cpu().numpy()[i])
#                 plt.plot(outputs['y_pred'][0].data.cpu().numpy()[i])
#                 plt.show()

        if step % 1000 == 0:
            save(save_dir, step, model, None, optimizer)
    
        step += 1

test : 0 0.896209180355072
test : 1 0.8976234197616577
test : 2 0.8956179618835449
test : 3 0.8974289894104004
test : 4 0.8961043357849121
test : 5 0.897742509841919
test : 6 0.8973168134689331
test : 7 0.8974056839942932
test : 8 0.8989085555076599
test : 9 0.8987313508987427
test_loss : 0.8973089456558228
saved /data/scpark/save/lips/train08.08-3/save_0
1
loss 0.8751948475837708
1 0.8751948475837708
2
loss 0.5989212393760681
2 0.5989212393760681
3
loss 0.40325459837913513
3 0.40325459837913513
4
loss 0.26713234186172485
4 0.26713234186172485
5
loss 0.18325191736221313
5 0.18325191736221313
6
loss 0.1789175420999527
6 0.1789175420999527
7
loss 0.19478484988212585
7 0.19478484988212585
8
loss 0.18748913705348969
8 0.18748913705348969
9
loss 0.16869178414344788
9 0.16869178414344788
10
loss 0.14829805493354797
10 0.14829805493354797
11
loss 0.13345582783222198
11 0.13345582783222198
12
loss 0.13088661432266235
12 0.13088661432266235
13
loss 0.13414417207241058
13 0.13414417207241058
14


154
loss 0.07002505660057068
154 0.07002505660057068
155
loss 0.06991288810968399
155 0.06991288810968399
156
loss 0.06920268386602402
156 0.06920268386602402
157
loss 0.06944148987531662
157 0.06944148987531662
158
loss 0.06946910172700882
158 0.06946910172700882
159
loss 0.06994036585092545
159 0.06994036585092545
160
loss 0.07092101871967316
160 0.07092101871967316
161
loss 0.0690484493970871
161 0.0690484493970871
162
loss 0.06984995305538177
162 0.06984995305538177
163
loss 0.06946182250976562
163 0.06946182250976562
164
loss 0.07003691792488098
164 0.07003691792488098
165
loss 0.06926370412111282
165 0.06926370412111282
166
loss 0.06880224496126175
166 0.06880224496126175
167
loss 0.06725077331066132
167 0.06725077331066132
168
loss 0.0676376074552536
168 0.0676376074552536
169
loss 0.06717788428068161
169 0.06717788428068161
170
loss 0.0678446963429451
170 0.0678446963429451
171
loss 0.06702306866645813
171 0.06702306866645813
172
loss 0.0674283355474472
172 0.0674283355474472
1

308
loss 0.05353950709104538
308 0.05353950709104538
309
loss 0.05421829596161842
309 0.05421829596161842
310
loss 0.051462385803461075
310 0.051462385803461075
311
loss 0.052693162113428116
311 0.052693162113428116
312
loss 0.05152690038084984
312 0.05152690038084984
313
loss 0.051605794578790665
313 0.051605794578790665
314
loss 0.051923058927059174
314 0.051923058927059174
315
loss 0.05171351879835129
315 0.05171351879835129
316
loss 0.053044434636831284
316 0.053044434636831284
317
loss 0.0519656278192997
317 0.0519656278192997
318
loss 0.050730034708976746
318 0.050730034708976746
319
loss 0.05126674473285675
319 0.05126674473285675
320
loss 0.05030999705195427
320 0.05030999705195427
321
loss 0.05055698752403259
321 0.05055698752403259
322
loss 0.05012890323996544
322 0.05012890323996544
323
loss 0.0510915070772171
323 0.0510915070772171
324
loss 0.05245048552751541
324 0.05245048552751541
325
loss 0.05109250918030739
325 0.05109250918030739
326
loss 0.05160282179713249
326 0.051

In [None]:
save(save_dir, step, model, None, optimizer)