In [1]:
import os
os.chdir('../')

import warnings
warnings.filterwarnings('ignore')

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import os

!nvidia-smi
os.environ["CUDA_VISIBLE_DEVICES"]="0"

Mon Aug  7 22:39:46 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.199.02   Driver Version: 470.199.02   CUDA Version: 11.4     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA A100 80G...  Off  | 00000000:1B:00.0 Off |                    0 |
| N/A   40C    P0    66W / 300W |   1193MiB / 80994MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
|   1  NVIDIA A100 80G...  Off  | 00000000:1C:00.0 Off |                    0 |
| N/A   32C    P0    42W / 300W |     35MiB / 80994MiB |      0%      Default |
|       

### Hyperparams

In [3]:
n_mels = 768
n_outputs = 61
n_frames = 400

### Model

In [4]:
from model.model_transformer_reg import Model
from utils.util import *
from tensorboardX import SummaryWriter

step = 0
device = 'cuda:0'

# Model
model = Model(in_dim=n_mels, h_dim=512, out_dim=n_outputs)
model = model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
print('done')

done


### Load

In [5]:
save_dir = '/data/scpark/save/lips/train08.08-1/'
!mkdir -p $save_dir
!ls -lt $save_dir

writer = SummaryWriter(save_dir)

if False:
    step, model, _, optimizer = load(save_dir, 133000, model, None, optimizer)

total 521844
-rw-rw-r-- 1 scpark scpark       131  8월  7 22:37 events.out.tfevents.1691415353.GPUSVR11
-rw-rw-r-- 1 scpark scpark 534355839  8월  7 22:37 save_0
-rw-rw-r-- 1 scpark scpark       131  8월  7 22:24 events.out.tfevents.1691413257.GPUSVR11
-rw-rw-r-- 1 scpark scpark       131  8월  7 21:48 events.out.tfevents.1691412297.GPUSVR11


### Dataset

In [6]:
import os
from data.arkit_dataset import LipsDataset, CombinedDataset, CombinedCollate

root_dir = '/data/speech/digital_human/preprocessed/jeewonPark'
files = sorted([os.path.join(root_dir, file) for file in os.listdir(root_dir)])
print(len(files))

train_datasets = []
test_datasets = []

for file in files:
    print(file)
    dataset = LipsDataset(file, n_mels, n_frames, perturb=False)
    if '_1_' in file:
        test_datasets.append(dataset)
    else:
        train_datasets.append(dataset)
print(len(train_datasets), len(test_datasets))

10
/data/speech/digital_human/preprocessed/jeewonPark/MH_ARKit_003_10_iPhone_raw.npy
/data/speech/digital_human/preprocessed/jeewonPark/MH_ARKit_003_11_iPhone_raw.npy
/data/speech/digital_human/preprocessed/jeewonPark/MH_ARKit_003_12_iPhone_raw.npy
/data/speech/digital_human/preprocessed/jeewonPark/MH_ARKit_003_1_iPhone_raw.npy
/data/speech/digital_human/preprocessed/jeewonPark/MH_ARKit_003_2_iPhone_raw.npy
/data/speech/digital_human/preprocessed/jeewonPark/MH_ARKit_003_3_iPhone_raw.npy
/data/speech/digital_human/preprocessed/jeewonPark/MH_ARKit_003_4_iPhone_raw.npy
/data/speech/digital_human/preprocessed/jeewonPark/MH_ARKit_003_5_iPhone_raw.npy
/data/speech/digital_human/preprocessed/jeewonPark/MH_ARKit_003_8_iPhone_raw.npy
/data/speech/digital_human/preprocessed/jeewonPark/MH_ARKit_003_9_iPhone_raw.npy
9 1


In [7]:
train_loader = torch.utils.data.DataLoader(CombinedDataset(train_datasets), 
                                           num_workers=16, shuffle=True, batch_size=32, collate_fn=CombinedCollate())
test_loader = torch.utils.data.DataLoader(CombinedDataset(test_datasets), 
                                          num_workers=10, shuffle=True, batch_size=10, collate_fn=CombinedCollate())
print('done')

done


In [8]:
import fairseq
from torchaudio.transforms import Resample

resample = Resample(24000, 16000)

ckpt_path = "/Storage/speech/pretrained/contentvec/checkpoint_best_legacy_500.pt"
hubert, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([ckpt_path])
hubert = hubert[0]
hubert = hubert.to(device)
hubert.eval()

def get_hubert_feature(wav):
    with torch.no_grad():
        # (b, t, c)
        wav = resample(torch.Tensor(wav)).to(device)
        feature = hubert.extract_features(wav, output_layer=12)[0]
        return feature.transpose(1, 2)
print('done')

2023-08-07 22:39:58 | INFO | fairseq.tasks.hubert_pretraining | current directory is /home/scpark/projects/wav2face
2023-08-07 22:39:58 | INFO | fairseq.tasks.hubert_pretraining | HubertPretrainingTask Config {'_name': 'hubert_pretraining', 'data': 'metadata', 'fine_tuning': False, 'labels': ['km'], 'label_dir': 'label', 'label_rate': 50.0, 'sample_rate': 16000, 'normalize': False, 'enable_padding': False, 'max_keep_size': None, 'max_sample_size': 250000, 'min_sample_size': 32000, 'single_target': False, 'random_crop': True, 'pad_audio': False}
2023-08-07 22:39:58 | INFO | fairseq.models.hubert.hubert | HubertModel Config: {'_name': 'hubert', 'label_rate': 50.0, 'extractor_mode': default, 'encoder_layers': 12, 'encoder_embed_dim': 768, 'encoder_ffn_embed_dim': 3072, 'encoder_attention_heads': 12, 'activation_fn': gelu, 'layer_type': transformer, 'dropout': 0.1, 'attention_dropout': 0.1, 'activation_dropout': 0.0, 'encoder_layerdrop': 0.05, 'dropout_input': 0.1, 'dropout_features': 0.1,

done


In [None]:
from IPython import display
import librosa.display
import matplotlib.pyplot as plt

isnan = False
while True:
    if isnan:
        break
    for batch in train_loader:
        targets = torch.Tensor(batch['blend']).transpose(1, 2).to(device)
        inputs = get_hubert_feature(batch['wav'])
        inputs = F.interpolate(inputs, size=(targets.shape[2]), mode='linear')
        
        model.train()
        model.zero_grad()
        outputs = model(inputs, targets)
        
        print(step)
        loss = 0
        for key in outputs.keys():
            if 'loss' in key:
                loss += outputs[key]
                print(key, outputs[key].item())
        if torch.isnan(loss):
            isnan = True
            break
        loss.backward()
        optimizer.step()
        print(step, loss.item())
        
        if step % 1000 == 0:
            writer.add_scalar('train_loss', loss.item(), step)
        
        if step % 1000 == 0:
            display.clear_output()
            
            losses = []
            for i, batch in enumerate(test_loader):
                if i >= 10:
                    break
                    
                targets = torch.Tensor(batch['blend']).transpose(1, 2).to(device)
                inputs = get_hubert_feature(batch['wav'])
                inputs = F.interpolate(inputs, size=(targets.shape[2]), mode='linear')

                model.eval()
                with torch.no_grad():
                    outputs = model(inputs, targets)
                    
                loss = 0
                for key in outputs.keys():
                    if 'loss' in key:
                        loss += outputs[key]
                print('test :', i, loss.item())
                losses.append(loss)        
            
            test_loss = torch.stack(losses).mean().item()
            print('test_loss :', test_loss)
            writer.add_scalar('test_loss', test_loss, step)
            
#             plt.figure(figsize=[18, 4])
#             librosa.display.specshow(targets[0].data.cpu().numpy(), cmap='magma')
#             plt.show()

#             plt.figure(figsize=[18, 4])
#             librosa.display.specshow(outputs['y_pred'][0].data.cpu().numpy(), cmap='magma')
#             plt.show()
            
#             for i in [20, 37]:
#                 plt.figure(figsize=[18, 2])
#                 plt.title(str(i))
#                 plt.plot(targets[0].data.cpu().numpy()[i])
#                 plt.plot(outputs['y_pred'][0].data.cpu().numpy()[i])
#                 plt.show()

        if step % 1000 == 0:
            save(save_dir, step, model, None, optimizer)
    
        step += 1

test : 0 0.9556631445884705
test : 1 0.9548057913780212
test : 2 0.9545495510101318
test : 3 0.9615374803543091
test : 4 0.9496999382972717
test : 5 0.9566548466682434
test : 6 0.9565381407737732
test : 7 0.9510100483894348
test : 8 0.9537798762321472
test : 9 0.957465648651123
test_loss : 0.9551704525947571
saved /data/scpark/save/lips/train08.08-1/save_0
1
loss 0.9358871579170227
1 0.9358871579170227
2
loss 0.6398470401763916
2 0.6398470401763916
3
loss 0.40842926502227783
3 0.40842926502227783
4
loss 0.2681022584438324
4 0.2681022584438324
5
loss 0.20145370066165924
5 0.20145370066165924
6
loss 0.189541757106781
6 0.189541757106781
7
loss 0.19125716388225555
7 0.19125716388225555
8
loss 0.18208229541778564
8 0.18208229541778564
9
loss 0.16850534081459045
9 0.16850534081459045
10
loss 0.1509912759065628
10 0.1509912759065628
11
loss 0.13464529812335968
11 0.13464529812335968
12
loss 0.12982089817523956
12 0.12982089817523956
13
loss 0.13143368065357208
13 0.13143368065357208
14
loss 

155
loss 0.06812132894992828
155 0.06812132894992828
156
loss 0.06839161366224289
156 0.06839161366224289
157
loss 0.0678860992193222
157 0.0678860992193222
158
loss 0.06785641610622406
158 0.06785641610622406
159
loss 0.06898124516010284
159 0.06898124516010284
160
loss 0.06871748715639114
160 0.06871748715639114
161
loss 0.06932631880044937
161 0.06932631880044937
162
loss 0.06927046179771423
162 0.06927046179771423
163
loss 0.06982911378145218
163 0.06982911378145218
164
loss 0.0689605325460434
164 0.0689605325460434
165
loss 0.06777665764093399
165 0.06777665764093399
166
loss 0.06749743223190308
166 0.06749743223190308
167
loss 0.06598026305437088
167 0.06598026305437088
168
loss 0.06572447717189789
168 0.06572447717189789
169
loss 0.0660962238907814
169 0.0660962238907814
170
loss 0.06807547062635422
170 0.06807547062635422
171
loss 0.06644003838300705
171 0.06644003838300705
172
loss 0.06669612973928452
172 0.06669612973928452
173
loss 0.06480158865451813
173 0.06480158865451813

308 0.0504605770111084
309
loss 0.050204914063215256
309 0.050204914063215256
310
loss 0.05138054117560387
310 0.05138054117560387
311
loss 0.05154411122202873
311 0.05154411122202873
312
loss 0.05221162736415863
312 0.05221162736415863
313
loss 0.050761058926582336
313 0.050761058926582336
314
loss 0.04982731491327286
314 0.04982731491327286
315
loss 0.04877203330397606
315 0.04877203330397606
316
loss 0.05098896846175194
316 0.05098896846175194
317
loss 0.050457682460546494
317 0.050457682460546494
318
loss 0.04940557852387428
318 0.04940557852387428
319
loss 0.05043346807360649
319 0.05043346807360649
320
loss 0.050945933908224106
320 0.050945933908224106
321
loss 0.05025672912597656
321 0.05025672912597656
322
loss 0.050596583634614944
322 0.050596583634614944
323
loss 0.0504305474460125
323 0.0504305474460125
324
loss 0.04889893904328346
324 0.04889893904328346
325
loss 0.05020694434642792
325 0.05020694434642792
326
loss 0.049928389489650726
326 0.049928389489650726
327
loss 0.04

462
loss 0.04065677151083946
462 0.04065677151083946
463
loss 0.03970777615904808
463 0.03970777615904808
464
loss 0.041865330189466476
464 0.041865330189466476
465
loss 0.040672581642866135
465 0.040672581642866135
466
loss 0.04086286202073097
466 0.04086286202073097
467
loss 0.039712093770504
467 0.039712093770504
468
loss 0.04097637161612511
468 0.04097637161612511
469
loss 0.03964198753237724
469 0.03964198753237724
470
loss 0.040387313812971115
470 0.040387313812971115
471
loss 0.04159562289714813
471 0.04159562289714813
472
loss 0.038523901253938675
472 0.038523901253938675
473
loss 0.03985801711678505
473 0.03985801711678505
474
loss 0.039321742951869965
474 0.039321742951869965
475
loss 0.03951907902956009
475 0.03951907902956009
476
loss 0.04035477340221405
476 0.04035477340221405
477
loss 0.039769046008586884
477 0.039769046008586884
478
loss 0.04041342809796333
478 0.04041342809796333
479
loss 0.038831863552331924
479 0.038831863552331924
480
loss 0.03950481116771698
480 0.0

615 0.03582330420613289
616
loss 0.03730864077806473
616 0.03730864077806473
617
loss 0.03518914058804512
617 0.03518914058804512
618
loss 0.03532441705465317
618 0.03532441705465317
619
loss 0.03575912117958069
619 0.03575912117958069
620
loss 0.03576808422803879
620 0.03576808422803879
621
loss 0.03482901677489281
621 0.03482901677489281
622
loss 0.03575493395328522
622 0.03575493395328522
623
loss 0.03451371565461159
623 0.03451371565461159
624
loss 0.034048277884721756
624 0.034048277884721756
625
loss 0.03457092121243477
625 0.03457092121243477
626
loss 0.03506951034069061
626 0.03506951034069061
627
loss 0.035520363599061966
627 0.035520363599061966
628
loss 0.033969562500715256
628 0.033969562500715256
629
loss 0.03558104485273361
629 0.03558104485273361
630
loss 0.033978138118982315
630 0.033978138118982315


In [None]:
save(save_dir, step, model, None, optimizer)