In [1]:
import os
os.chdir('../')

import warnings
warnings.filterwarnings('ignore')

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import os

!nvidia-smi
os.environ["CUDA_VISIBLE_DEVICES"]="0"

Mon Aug  7 22:43:52 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.199.02   Driver Version: 470.199.02   CUDA Version: 11.4     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA A100 80G...  Off  | 00000000:1B:00.0 Off |                    0 |
| N/A   56C    P0    79W / 300W |  16890MiB / 80994MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
|   1  NVIDIA A100 80G...  Off  | 00000000:1C:00.0 Off |                    0 |
| N/A   31C    P0    42W / 300W |     35MiB / 80994MiB |      0%      Default |
|       

### Hyperparams

In [3]:
n_mels = 768
n_outputs = 61
n_frames = 400

### Model

In [4]:
from model.model_transformer_reg import Model
from utils.util import *
from tensorboardX import SummaryWriter

step = 0
device = 'cuda:0'

# Model
model = Model(in_dim=n_mels, h_dim=512, out_dim=n_outputs)
model = model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
print('done')

done


### Load

In [5]:
save_dir = '/data/scpark/save/lips/train08.08-2/'
!mkdir -p $save_dir
!ls -lt $save_dir

writer = SummaryWriter(save_dir)

if False:
    step, model, _, optimizer = load(save_dir, 133000, model, None, optimizer)

total 517332
-rw-rw-r-- 1 scpark scpark        40  8월  7 22:28 events.out.tfevents.1691414835.GPUSVR11
-rw-rw-r-- 1 scpark scpark 529735487  8월  7 22:25 save_0
-rw-rw-r-- 1 scpark scpark        86  8월  7 22:25 events.out.tfevents.1691413273.GPUSVR11


### Dataset

In [6]:
import os
from data.arkit_dataset import LipsDataset, CombinedDataset, CombinedCollate

root_dir = '/data/speech/digital_human/preprocessed/jinwooOh'
files = sorted([os.path.join(root_dir, file) for file in os.listdir(root_dir)])
print(len(files))

train_datasets = []
test_datasets = []

for file in files:
    print(file)
    dataset = LipsDataset(file, n_mels, n_frames, perturb=False)
    if '_1_' in file:
        test_datasets.append(dataset)
    else:
        train_datasets.append(dataset)
print(len(train_datasets), len(test_datasets))

10
/data/speech/digital_human/preprocessed/jinwooOh/MH_ARKit_005_10_iPhone_raw.npy
/data/speech/digital_human/preprocessed/jinwooOh/MH_ARKit_005_1_iPhone_raw.npy
/data/speech/digital_human/preprocessed/jinwooOh/MH_ARKit_005_2_iPhone_raw.npy
/data/speech/digital_human/preprocessed/jinwooOh/MH_ARKit_005_3_iPhone_raw.npy
/data/speech/digital_human/preprocessed/jinwooOh/MH_ARKit_005_4_iPhone_raw.npy
/data/speech/digital_human/preprocessed/jinwooOh/MH_ARKit_005_5_iPhone_raw.npy
/data/speech/digital_human/preprocessed/jinwooOh/MH_ARKit_005_6_iPhone_raw.npy
/data/speech/digital_human/preprocessed/jinwooOh/MH_ARKit_005_7_iPhone_raw.npy
/data/speech/digital_human/preprocessed/jinwooOh/MH_ARKit_005_8_iPhone_raw.npy
/data/speech/digital_human/preprocessed/jinwooOh/MH_ARKit_005_9_iPhone_raw.npy
9 1


In [7]:
train_loader = torch.utils.data.DataLoader(CombinedDataset(train_datasets), 
                                           num_workers=16, shuffle=True, batch_size=32, collate_fn=CombinedCollate())
test_loader = torch.utils.data.DataLoader(CombinedDataset(test_datasets), 
                                          num_workers=10, shuffle=True, batch_size=10, collate_fn=CombinedCollate())
print('done')

done


In [8]:
import fairseq
from torchaudio.transforms import Resample

resample = Resample(24000, 16000)

ckpt_path = "/Storage/speech/pretrained/contentvec/checkpoint_best_legacy_500.pt"
hubert, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([ckpt_path])
hubert = hubert[0]
hubert = hubert.to(device)
hubert.eval()

def get_hubert_feature(wav):
    with torch.no_grad():
        # (b, t, c)
        wav = resample(torch.Tensor(wav)).to(device)
        feature = hubert.extract_features(wav, output_layer=12)[0]
        return feature.transpose(1, 2)
print('done')

2023-08-07 22:44:07 | INFO | fairseq.tasks.hubert_pretraining | current directory is /home/scpark/projects/wav2face
2023-08-07 22:44:07 | INFO | fairseq.tasks.hubert_pretraining | HubertPretrainingTask Config {'_name': 'hubert_pretraining', 'data': 'metadata', 'fine_tuning': False, 'labels': ['km'], 'label_dir': 'label', 'label_rate': 50.0, 'sample_rate': 16000, 'normalize': False, 'enable_padding': False, 'max_keep_size': None, 'max_sample_size': 250000, 'min_sample_size': 32000, 'single_target': False, 'random_crop': True, 'pad_audio': False}
2023-08-07 22:44:07 | INFO | fairseq.models.hubert.hubert | HubertModel Config: {'_name': 'hubert', 'label_rate': 50.0, 'extractor_mode': default, 'encoder_layers': 12, 'encoder_embed_dim': 768, 'encoder_ffn_embed_dim': 3072, 'encoder_attention_heads': 12, 'activation_fn': gelu, 'layer_type': transformer, 'dropout': 0.1, 'attention_dropout': 0.1, 'activation_dropout': 0.0, 'encoder_layerdrop': 0.05, 'dropout_input': 0.1, 'dropout_features': 0.1,

done


In [None]:
from IPython import display
import librosa.display
import matplotlib.pyplot as plt

isnan = False
while True:
    if isnan:
        break
    for batch in train_loader:
        targets = torch.Tensor(batch['blend']).transpose(1, 2).to(device)
        inputs = get_hubert_feature(batch['wav'])
        inputs = F.interpolate(inputs, size=(targets.shape[2]), mode='linear')
        
        model.train()
        model.zero_grad()
        outputs = model(inputs, targets)
        
        print(step)
        loss = 0
        for key in outputs.keys():
            if 'loss' in key:
                loss += outputs[key]
                print(key, outputs[key].item())
        if torch.isnan(loss):
            isnan = True
            break
        loss.backward()
        optimizer.step()
        print(step, loss.item())
        
        if step % 1000 == 0:
            writer.add_scalar('train_loss', loss.item(), step)
        
        if step % 1000 == 0:
            display.clear_output()
            
            losses = []
            for i, batch in enumerate(test_loader):
                if i >= 10:
                    break
                    
                targets = torch.Tensor(batch['blend']).transpose(1, 2).to(device)
                inputs = get_hubert_feature(batch['wav'])
                inputs = F.interpolate(inputs, size=(targets.shape[2]), mode='linear')

                model.eval()
                with torch.no_grad():
                    outputs = model(inputs, targets)
                    
                loss = 0
                for key in outputs.keys():
                    if 'loss' in key:
                        loss += outputs[key]
                print('test :', i, loss.item())
                losses.append(loss)        
            
            test_loss = torch.stack(losses).mean().item()
            print('test_loss :', test_loss)
            writer.add_scalar('test_loss', test_loss, step)
            
#             plt.figure(figsize=[18, 4])
#             librosa.display.specshow(targets[0].data.cpu().numpy(), cmap='magma')
#             plt.show()

#             plt.figure(figsize=[18, 4])
#             librosa.display.specshow(outputs['y_pred'][0].data.cpu().numpy(), cmap='magma')
#             plt.show()
            
#             for i in [20, 37]:
#                 plt.figure(figsize=[18, 2])
#                 plt.title(str(i))
#                 plt.plot(targets[0].data.cpu().numpy()[i])
#                 plt.plot(outputs['y_pred'][0].data.cpu().numpy()[i])
#                 plt.show()

        if step % 1000 == 0:
            save(save_dir, step, model, None, optimizer)
    
        step += 1

test : 0 0.8649572730064392
test : 1 0.865056037902832
test : 2 0.8663179874420166
test : 3 0.8672628998756409
test : 4 0.8671029210090637
test : 5 0.866256594657898
test : 6 0.8668138980865479
test : 7 0.8649154305458069
test : 8 0.8678506016731262
test : 9 0.8670303225517273
test_loss : 0.866356372833252
saved /data/scpark/save/lips/train08.08-2/save_0
1
loss 0.8480110168457031
1 0.8480110168457031
2
loss 0.5695661306381226
2 0.5695661306381226
3
loss 0.34079110622406006
3 0.34079110622406006
4
loss 0.22818419337272644
4 0.22818419337272644
5
loss 0.1948510706424713
5 0.1948510706424713
6
loss 0.18334434926509857
6 0.18334434926509857
7
loss 0.16679030656814575
7 0.16679030656814575
8
loss 0.16373194754123688
8 0.16373194754123688
9
loss 0.1551969200372696
9 0.1551969200372696
10
loss 0.142522931098938
10 0.142522931098938
11
loss 0.12875749170780182
11 0.12875749170780182
12
loss 0.124574676156044
12 0.124574676156044
13
loss 0.12259794771671295
13 0.12259794771671295
14
loss 0.1201

154
loss 0.06317368894815445
154 0.06317368894815445
155
loss 0.06332185864448547
155 0.06332185864448547
156
loss 0.06294697523117065
156 0.06294697523117065
157
loss 0.06307604908943176
157 0.06307604908943176
158
loss 0.0609685443341732
158 0.0609685443341732
159
loss 0.06150931119918823
159 0.06150931119918823
160
loss 0.06312831491231918
160 0.06312831491231918
161
loss 0.06184443086385727
161 0.06184443086385727
162
loss 0.06107553094625473
162 0.06107553094625473
163
loss 0.061129260808229446
163 0.061129260808229446
164
loss 0.060775887221097946
164 0.060775887221097946
165
loss 0.06062214821577072
165 0.06062214821577072
166
loss 0.06114644929766655
166 0.06114644929766655
167
loss 0.06144142895936966
167 0.06144142895936966
168
loss 0.06105242669582367
168 0.06105242669582367
169
loss 0.061369989067316055
169 0.061369989067316055
170
loss 0.06205473467707634
170 0.06205473467707634
171
loss 0.06201155483722687
171 0.06201155483722687
172
loss 0.0618608221411705
172 0.06186082

307 0.045009031891822815
308
loss 0.04541768878698349
308 0.04541768878698349
309
loss 0.04585617408156395
309 0.04585617408156395
310
loss 0.045611973851919174
310 0.045611973851919174
311
loss 0.04486765339970589
311 0.04486765339970589
312
loss 0.044441189616918564
312 0.044441189616918564
313
loss 0.04422158747911453
313 0.04422158747911453
314
loss 0.04391438513994217
314 0.04391438513994217
315
loss 0.04360947757959366
315 0.04360947757959366
316
loss 0.04384268447756767
316 0.04384268447756767
317
loss 0.0442146360874176
317 0.0442146360874176
318
loss 0.04388372600078583
318 0.04388372600078583
319
loss 0.04365884140133858
319 0.04365884140133858
320
loss 0.042857300490140915
320 0.042857300490140915
321
loss 0.04357248172163963
321 0.04357248172163963
322
loss 0.04435290768742561
322 0.04435290768742561
323
loss 0.04380432516336441
323 0.04380432516336441
324
loss 0.04366764426231384
324 0.04366764426231384
325
loss 0.042889900505542755
325 0.042889900505542755
326
loss 0.0430

In [None]:
save(save_dir, step, model, None, optimizer)