In [1]:
import os
os.chdir('../')

import warnings
warnings.filterwarnings('ignore')

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import os

!nvidia-smi
os.environ["CUDA_VISIBLE_DEVICES"]="5"

Mon Aug  7 22:43:38 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.199.02   Driver Version: 470.199.02   CUDA Version: 11.4     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA A100 80G...  Off  | 00000000:1B:00.0 Off |                    0 |
| N/A   55C    P0    78W / 300W |  16890MiB / 80994MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
|   1  NVIDIA A100 80G...  Off  | 00000000:1C:00.0 Off |                    0 |
| N/A   31C    P0    42W / 300W |     35MiB / 80994MiB |      0%      Default |
|       

### Hyperparams

In [3]:
n_mels = 768
n_outputs = 61
n_frames = 400

### Model

In [4]:
from model.model_transformer_reg import Model
from utils.util import *
from tensorboardX import SummaryWriter

step = 0
device = 'cuda:0'

# Model
model = Model(in_dim=n_mels, h_dim=512, out_dim=n_outputs)
model = model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
print('done')

done


### Load

In [5]:
save_dir = '/data/scpark/save/lips/train08.08-6/'
!mkdir -p $save_dir
!ls -lt $save_dir

writer = SummaryWriter(save_dir)

if False:
    step, model, _, optimizer = load(save_dir, 133000, model, None, optimizer)

total 0
-rw-rw-r-- 1 scpark scpark 0  8ì›”  7 22:02 events.out.tfevents.1691413331.GPUSVR11


### Dataset

In [6]:
import os
from data.arkit_dataset import LipsDataset, CombinedDataset, CombinedCollate

root_dir = '/data/speech/digital_human/preprocessed/soochulPark'
files = sorted([os.path.join(root_dir, file) for file in os.listdir(root_dir)])
print(len(files))

train_datasets = []
test_datasets = []

for file in files:
    print(file)
    dataset = LipsDataset(file, n_mels, n_frames, perturb=False)
    if '_1_' in file:
        test_datasets.append(dataset)
    else:
        train_datasets.append(dataset)
print(len(train_datasets), len(test_datasets))

10
/data/speech/digital_human/preprocessed/soochulPark/MH_ARKit_004_10_iPhone_raw.npy
/data/speech/digital_human/preprocessed/soochulPark/MH_ARKit_004_1_iPhone_raw.npy
/data/speech/digital_human/preprocessed/soochulPark/MH_ARKit_004_2_iPhone_raw.npy
/data/speech/digital_human/preprocessed/soochulPark/MH_ARKit_004_3_iPhone_raw.npy
/data/speech/digital_human/preprocessed/soochulPark/MH_ARKit_004_4_iPhone_raw.npy
/data/speech/digital_human/preprocessed/soochulPark/MH_ARKit_004_5_iPhone_raw.npy
/data/speech/digital_human/preprocessed/soochulPark/MH_ARKit_004_6_iPhone_raw.npy
/data/speech/digital_human/preprocessed/soochulPark/MH_ARKit_004_7_iPhone_raw.npy
/data/speech/digital_human/preprocessed/soochulPark/MH_ARKit_004_8_iPhone_raw.npy
/data/speech/digital_human/preprocessed/soochulPark/MH_ARKit_004_9_iPhone_raw.npy
9 1


In [7]:
train_loader = torch.utils.data.DataLoader(CombinedDataset(train_datasets), 
                                           num_workers=16, shuffle=True, batch_size=32, collate_fn=CombinedCollate())
test_loader = torch.utils.data.DataLoader(CombinedDataset(test_datasets), 
                                          num_workers=10, shuffle=True, batch_size=10, collate_fn=CombinedCollate())
print('done')

done


In [8]:
import fairseq
from torchaudio.transforms import Resample

resample = Resample(24000, 16000)

ckpt_path = "/Storage/speech/pretrained/contentvec/checkpoint_best_legacy_500.pt"
hubert, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([ckpt_path])
hubert = hubert[0]
hubert = hubert.to(device)
hubert.eval()

def get_hubert_feature(wav):
    with torch.no_grad():
        # (b, t, c)
        wav = resample(torch.Tensor(wav)).to(device)
        feature = hubert.extract_features(wav, output_layer=12)[0]
        return feature.transpose(1, 2)
print('done')

2023-08-07 22:43:53 | INFO | fairseq.tasks.hubert_pretraining | current directory is /home/scpark/projects/wav2face
2023-08-07 22:43:53 | INFO | fairseq.tasks.hubert_pretraining | HubertPretrainingTask Config {'_name': 'hubert_pretraining', 'data': 'metadata', 'fine_tuning': False, 'labels': ['km'], 'label_dir': 'label', 'label_rate': 50.0, 'sample_rate': 16000, 'normalize': False, 'enable_padding': False, 'max_keep_size': None, 'max_sample_size': 250000, 'min_sample_size': 32000, 'single_target': False, 'random_crop': True, 'pad_audio': False}
2023-08-07 22:43:53 | INFO | fairseq.models.hubert.hubert | HubertModel Config: {'_name': 'hubert', 'label_rate': 50.0, 'extractor_mode': default, 'encoder_layers': 12, 'encoder_embed_dim': 768, 'encoder_ffn_embed_dim': 3072, 'encoder_attention_heads': 12, 'activation_fn': gelu, 'layer_type': transformer, 'dropout': 0.1, 'attention_dropout': 0.1, 'activation_dropout': 0.0, 'encoder_layerdrop': 0.05, 'dropout_input': 0.1, 'dropout_features': 0.1,

done


In [None]:
from IPython import display
import librosa.display
import matplotlib.pyplot as plt

isnan = False
while True:
    if isnan:
        break
    for batch in train_loader:
        targets = torch.Tensor(batch['blend']).transpose(1, 2).to(device)
        inputs = get_hubert_feature(batch['wav'])
        inputs = F.interpolate(inputs, size=(targets.shape[2]), mode='linear')
        
        model.train()
        model.zero_grad()
        outputs = model(inputs, targets)
        
        print(step)
        loss = 0
        for key in outputs.keys():
            if 'loss' in key:
                loss += outputs[key]
                print(key, outputs[key].item())
        if torch.isnan(loss):
            isnan = True
            break
        loss.backward()
        optimizer.step()
        print(step, loss.item())
        
        if step % 1000 == 0:
            writer.add_scalar('train_loss', loss.item(), step)
        
        if step % 1000 == 0:
            display.clear_output()
            
            losses = []
            for i, batch in enumerate(test_loader):
                if i >= 10:
                    break
                    
                targets = torch.Tensor(batch['blend']).transpose(1, 2).to(device)
                inputs = get_hubert_feature(batch['wav'])
                inputs = F.interpolate(inputs, size=(targets.shape[2]), mode='linear')

                model.eval()
                with torch.no_grad():
                    outputs = model(inputs, targets)
                    
                loss = 0
                for key in outputs.keys():
                    if 'loss' in key:
                        loss += outputs[key]
                print('test :', i, loss.item())
                losses.append(loss)        
            
            test_loss = torch.stack(losses).mean().item()
            print('test_loss :', test_loss)
            writer.add_scalar('test_loss', test_loss, step)
            
#             plt.figure(figsize=[18, 4])
#             librosa.display.specshow(targets[0].data.cpu().numpy(), cmap='magma')
#             plt.show()

#             plt.figure(figsize=[18, 4])
#             librosa.display.specshow(outputs['y_pred'][0].data.cpu().numpy(), cmap='magma')
#             plt.show()
            
#             for i in [20, 37]:
#                 plt.figure(figsize=[18, 2])
#                 plt.title(str(i))
#                 plt.plot(targets[0].data.cpu().numpy()[i])
#                 plt.plot(outputs['y_pred'][0].data.cpu().numpy()[i])
#                 plt.show()

        if step % 1000 == 0:
            save(save_dir, step, model, None, optimizer)
    
        step += 1

test : 0 0.8768720030784607
test : 1 0.8770356774330139
test : 2 0.8755492568016052
test : 3 0.8760419487953186
test : 4 0.8720812797546387
test : 5 0.8741520643234253
test : 6 0.8762869238853455
test : 7 0.8726571202278137
test : 8 0.8758551478385925
test : 9 0.8771385550498962
test_loss : 0.8753669857978821
saved /data/scpark/save/lips/train08.08-6/save_0
1
loss 0.8561713695526123
1 0.8561713695526123
2
loss 0.5445883870124817
2 0.5445883870124817
3
loss 0.35501930117607117
3 0.35501930117607117
4
loss 0.23418588936328888
4 0.23418588936328888
5
loss 0.1981951743364334
5 0.1981951743364334
6
loss 0.18681427836418152
6 0.18681427836418152
7
loss 0.1761457771062851
7 0.1761457771062851
8
loss 0.16852396726608276
8 0.16852396726608276
9
loss 0.15741026401519775
9 0.15741026401519775
10
loss 0.14633271098136902
10 0.14633271098136902
11
loss 0.1347547322511673
11 0.1347547322511673
12
loss 0.1309252381324768
12 0.1309252381324768
13
loss 0.12992645800113678
13 0.12992645800113678
14
loss

154
loss 0.07062923163175583
154 0.07062923163175583
155
loss 0.07112514972686768
155 0.07112514972686768
156
loss 0.07156626135110855
156 0.07156626135110855
157
loss 0.07094894349575043
157 0.07094894349575043
158
loss 0.06992191821336746
158 0.06992191821336746
159
loss 0.07124248892068863
159 0.07124248892068863
160
loss 0.07129284739494324
160 0.07129284739494324
161
loss 0.07041992992162704
161 0.07041992992162704
162
loss 0.07142740488052368
162 0.07142740488052368
163
loss 0.07069578766822815
163 0.07069578766822815
164
loss 0.0704142302274704
164 0.0704142302274704
165
loss 0.06988407671451569
165 0.06988407671451569
166
loss 0.06943239271640778
166 0.06943239271640778
167
loss 0.06988085061311722
167 0.06988085061311722
168
loss 0.07024980336427689
168 0.07024980336427689
169
loss 0.07150250673294067
169 0.07150250673294067
170
loss 0.07164429873228073
170 0.07164429873228073
171
loss 0.06833618879318237
171 0.06833618879318237
172
loss 0.06892484426498413
172 0.0689248442649

309
loss 0.05542251840233803
309 0.05542251840233803
310
loss 0.05611395090818405
310 0.05611395090818405
311
loss 0.056653670966625214
311 0.056653670966625214
312
loss 0.05620202794671059
312 0.05620202794671059
313
loss 0.05676119402050972
313 0.05676119402050972
314
loss 0.05539197474718094
314 0.05539197474718094
315
loss 0.05480325594544411
315 0.05480325594544411
316
loss 0.054009780287742615
316 0.054009780287742615
317
loss 0.0548454150557518
317 0.0548454150557518
318
loss 0.05444837361574173
318 0.05444837361574173
319
loss 0.05495728552341461
319 0.05495728552341461
320
loss 0.05651646479964256
320 0.05651646479964256
321
loss 0.05478201434016228
321 0.05478201434016228
322
loss 0.055815089493989944
322 0.055815089493989944
323
loss 0.05590566247701645
323 0.05590566247701645
324
loss 0.05613316223025322
324 0.05613316223025322
325
loss 0.05433952063322067
325 0.05433952063322067
326
loss 0.054495058953762054
326 0.054495058953762054
327
loss 0.05350011587142944
327 0.05350

In [None]:
save(save_dir, step, model, None, optimizer)