In [1]:
import os
os.chdir('../')

import warnings
warnings.filterwarnings('ignore')

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import os

!nvidia-smi
os.environ["CUDA_VISIBLE_DEVICES"]="6"

Mon Aug  7 22:46:15 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.199.02   Driver Version: 470.199.02   CUDA Version: 11.4     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA A100 80G...  Off  | 00000000:1B:00.0 Off |                    0 |
| N/A   62C    P0   301W / 300W |  32593MiB / 80994MiB |    100%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
|   1  NVIDIA A100 80G...  Off  | 00000000:1C:00.0 Off |                    0 |
| N/A   31C    P0    42W / 300W |     35MiB / 80994MiB |      0%      Default |
|       

### Hyperparams

In [3]:
n_mels = 768
n_outputs = 61
n_frames = 400

### Model

In [4]:
from model.model_transformer_reg import Model
from utils.util import *
from tensorboardX import SummaryWriter

step = 0
device = 'cuda:0'

# Model
model = Model(in_dim=n_mels, h_dim=512, out_dim=n_outputs)
model = model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
print('done')

done


### Load

In [5]:
save_dir = '/data/scpark/save/lips/train08.08-7/'
!mkdir -p $save_dir
!ls -lt $save_dir

writer = SummaryWriter(save_dir)

if False:
    step, model, _, optimizer = load(save_dir, 133000, model, None, optimizer)

total 1043672
-rw-rw-r-- 1 scpark scpark       131  8월  7 22:45 events.out.tfevents.1691415822.GPUSVR11
-rw-rw-r-- 1 scpark scpark 534355839  8월  7 22:44 save_5
-rw-rw-r-- 1 scpark scpark 534355839  8월  7 22:44 save_0
-rw-rw-r-- 1 scpark scpark         0  8월  7 22:02 events.out.tfevents.1691413338.GPUSVR11


### Dataset

In [6]:
import os
from data.arkit_dataset import LipsDataset, CombinedDataset, CombinedCollate

root_dir = '/data/speech/digital_human/preprocessed/yehunHwang'
files = sorted([os.path.join(root_dir, file) for file in os.listdir(root_dir) if 'ARKit' in file])
print(len(files))

train_datasets = []
test_datasets = []

for file in files:
    print(file)
    dataset = LipsDataset(file, n_mels, n_frames, perturb=False)
    if '_1_' in file:
        test_datasets.append(dataset)
    else:
        train_datasets.append(dataset)
print(len(train_datasets), len(test_datasets))

17
/data/speech/digital_human/preprocessed/yehunHwang/MH_ARKit_001_1_iPhone_raw.npy
/data/speech/digital_human/preprocessed/yehunHwang/MH_ARKit_001_20_iPhone_raw.npy
/data/speech/digital_human/preprocessed/yehunHwang/MH_ARKit_001_21_iPhone_raw.npy
/data/speech/digital_human/preprocessed/yehunHwang/MH_ARKit_001_2_iPhone_raw.npy
/data/speech/digital_human/preprocessed/yehunHwang/MH_ARKit_001_30_iPhone_raw.npy
/data/speech/digital_human/preprocessed/yehunHwang/MH_ARKit_001_31_iPhone_raw.npy
/data/speech/digital_human/preprocessed/yehunHwang/MH_ARKit_001_3_iPhone_raw.npy
/data/speech/digital_human/preprocessed/yehunHwang/MH_ARKit_001_40_iPhone_raw.npy
/data/speech/digital_human/preprocessed/yehunHwang/MH_ARKit_001_4_iPhone_raw.npy
/data/speech/digital_human/preprocessed/yehunHwang/MH_ARKit_001_50_iPhone_raw.npy
/data/speech/digital_human/preprocessed/yehunHwang/MH_ARKit_001_51_iPhone_raw.npy
/data/speech/digital_human/preprocessed/yehunHwang/MH_ARKit_001_52_iPhone_raw.npy
/data/speech/digi

In [7]:
train_loader = torch.utils.data.DataLoader(CombinedDataset(train_datasets), 
                                           num_workers=16, shuffle=True, batch_size=32, collate_fn=CombinedCollate())
test_loader = torch.utils.data.DataLoader(CombinedDataset(test_datasets), 
                                          num_workers=10, shuffle=True, batch_size=10, collate_fn=CombinedCollate())
print('done')

done


In [8]:
import fairseq
from torchaudio.transforms import Resample

resample = Resample(24000, 16000)

ckpt_path = "/Storage/speech/pretrained/contentvec/checkpoint_best_legacy_500.pt"
hubert, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([ckpt_path])
hubert = hubert[0]
hubert = hubert.to(device)
hubert.eval()

def get_hubert_feature(wav):
    with torch.no_grad():
        # (b, t, c)
        wav = resample(torch.Tensor(wav)).to(device)
        feature = hubert.extract_features(wav, output_layer=12)[0]
        return feature.transpose(1, 2)
print('done')

2023-08-07 22:46:31 | INFO | fairseq.tasks.hubert_pretraining | current directory is /home/scpark/projects/wav2face
2023-08-07 22:46:31 | INFO | fairseq.tasks.hubert_pretraining | HubertPretrainingTask Config {'_name': 'hubert_pretraining', 'data': 'metadata', 'fine_tuning': False, 'labels': ['km'], 'label_dir': 'label', 'label_rate': 50.0, 'sample_rate': 16000, 'normalize': False, 'enable_padding': False, 'max_keep_size': None, 'max_sample_size': 250000, 'min_sample_size': 32000, 'single_target': False, 'random_crop': True, 'pad_audio': False}
2023-08-07 22:46:31 | INFO | fairseq.models.hubert.hubert | HubertModel Config: {'_name': 'hubert', 'label_rate': 50.0, 'extractor_mode': default, 'encoder_layers': 12, 'encoder_embed_dim': 768, 'encoder_ffn_embed_dim': 3072, 'encoder_attention_heads': 12, 'activation_fn': gelu, 'layer_type': transformer, 'dropout': 0.1, 'attention_dropout': 0.1, 'activation_dropout': 0.0, 'encoder_layerdrop': 0.05, 'dropout_input': 0.1, 'dropout_features': 0.1,

done


In [None]:
from IPython import display
import librosa.display
import matplotlib.pyplot as plt

isnan = False
while True:
    if isnan:
        break
    for batch in train_loader:
        targets = torch.Tensor(batch['blend']).transpose(1, 2).to(device)
        inputs = get_hubert_feature(batch['wav'])
        inputs = F.interpolate(inputs, size=(targets.shape[2]), mode='linear')
        
        model.train()
        model.zero_grad()
        outputs = model(inputs, targets)
        
        print(step)
        loss = 0
        for key in outputs.keys():
            if 'loss' in key:
                loss += outputs[key]
                print(key, outputs[key].item())
        if torch.isnan(loss):
            isnan = True
            break
        loss.backward()
        optimizer.step()
        print(step, loss.item())
        
        if step % 1000 == 0:
            writer.add_scalar('train_loss', loss.item(), step)
        
        if step % 1000 == 0:
            display.clear_output()
            
            losses = []
            for i, batch in enumerate(test_loader):
                if i >= 10:
                    break
                    
                targets = torch.Tensor(batch['blend']).transpose(1, 2).to(device)
                inputs = get_hubert_feature(batch['wav'])
                inputs = F.interpolate(inputs, size=(targets.shape[2]), mode='linear')

                model.eval()
                with torch.no_grad():
                    outputs = model(inputs, targets)
                    
                loss = 0
                for key in outputs.keys():
                    if 'loss' in key:
                        loss += outputs[key]
                print('test :', i, loss.item())
                losses.append(loss)        
            
            test_loss = torch.stack(losses).mean().item()
            print('test_loss :', test_loss)
            writer.add_scalar('test_loss', test_loss, step)
            
#             plt.figure(figsize=[18, 4])
#             librosa.display.specshow(targets[0].data.cpu().numpy(), cmap='magma')
#             plt.show()

#             plt.figure(figsize=[18, 4])
#             librosa.display.specshow(outputs['y_pred'][0].data.cpu().numpy(), cmap='magma')
#             plt.show()
            
#             for i in [20, 37]:
#                 plt.figure(figsize=[18, 2])
#                 plt.title(str(i))
#                 plt.plot(targets[0].data.cpu().numpy()[i])
#                 plt.plot(outputs['y_pred'][0].data.cpu().numpy()[i])
#                 plt.show()

        if step % 1000 == 0:
            save(save_dir, step, model, None, optimizer)
    
        step += 1

test : 0 0.8581734895706177
test : 1 0.855551540851593
test : 2 0.8561782836914062
test : 3 0.8568066954612732
test : 4 0.8555034399032593
test : 5 0.8561365604400635
test : 6 0.856967031955719
test : 7 0.8577532172203064
test : 8 0.856499969959259
test : 9 0.856006383895874
test_loss : 0.8565576672554016
saved /data/scpark/save/lips/train08.08-7/save_0
1
loss 0.8370347023010254
1 0.8370347023010254
2
loss 0.5336255431175232
2 0.5336255431175232
3
loss 0.37246325612068176
3 0.37246325612068176
4
loss 0.2531784176826477
4 0.2531784176826477
5
loss 0.18763545155525208
5 0.18763545155525208
6
loss 0.1846725046634674
6 0.1846725046634674
7
loss 0.19174574315547943
7 0.19174574315547943
8
loss 0.1748339831829071
8 0.1748339831829071
9
loss 0.1574152112007141
9 0.1574152112007141
10
loss 0.1478520631790161
10 0.1478520631790161
11
loss 0.14165367186069489
11 0.14165367186069489
12
loss 0.13705889880657196
12 0.13705889880657196
13
loss 0.13575588166713715
13 0.13575588166713715
14
loss 0.134

155
loss 0.08308904618024826
155 0.08308904618024826
156
loss 0.08173616230487823
156 0.08173616230487823
157
loss 0.08922941982746124
157 0.08922941982746124
158
loss 0.07504663616418839
158 0.07504663616418839
159
loss 0.07734651118516922
159 0.07734651118516922
160
loss 0.08114420622587204
160 0.08114420622587204
161
loss 0.0781087577342987
161 0.0781087577342987
162
loss 0.07478395849466324
162 0.07478395849466324
163
loss 0.07423744350671768
163 0.07423744350671768
164
loss 0.0788407027721405
164 0.0788407027721405
165
loss 0.08124508708715439
165 0.08124508708715439
166
loss 0.07724631577730179
166 0.07724631577730179
167
loss 0.0765928402543068
167 0.0765928402543068
168
loss 0.07926450669765472
168 0.07926450669765472
169
loss 0.08260143548250198
169 0.08260143548250198
170
loss 0.07869573682546616
170 0.07869573682546616
171
loss 0.07328296452760696
171 0.07328296452760696
172
loss 0.08077511936426163
172 0.08077511936426163
173
loss 0.07781713455915451
173 0.07781713455915451

311
loss 0.06406255811452866
311 0.06406255811452866
312
loss 0.06568987667560577
312 0.06568987667560577
313
loss 0.06374513357877731
313 0.06374513357877731
314
loss 0.06344214081764221
314 0.06344214081764221
315
loss 0.05657150596380234
315 0.05657150596380234
316
loss 0.06053803116083145
316 0.06053803116083145
317
loss 0.0657505989074707
317 0.0657505989074707
318
loss 0.06096088886260986
318 0.06096088886260986
319
loss 0.06359031051397324
319 0.06359031051397324
320
loss 0.06592909246683121
320 0.06592909246683121
321
loss 0.06396003067493439
321 0.06396003067493439
322
loss 0.05330086499452591
322 0.05330086499452591
323
loss 0.0592854768037796
323 0.0592854768037796
324
loss 0.06062263995409012
324 0.06062263995409012
325
loss 0.06786029040813446
325 0.06786029040813446
326
loss 0.061075419187545776
326 0.061075419187545776
327
loss 0.06012643873691559
327 0.06012643873691559
328
loss 0.06304168701171875
328 0.06304168701171875
329
loss 0.06125587224960327
329 0.0612558722496

In [None]:
save(save_dir, step, model, None, optimizer)