In [1]:
import os
os.chdir('../')

import warnings
warnings.filterwarnings('ignore')

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import os

!nvidia-smi
os.environ["CUDA_VISIBLE_DEVICES"]="4"

Mon Aug  7 22:43:45 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.199.02   Driver Version: 470.199.02   CUDA Version: 11.4     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA A100 80G...  Off  | 00000000:1B:00.0 Off |                    0 |
| N/A   55C    P0    79W / 300W |  16890MiB / 80994MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
|   1  NVIDIA A100 80G...  Off  | 00000000:1C:00.0 Off |                    0 |
| N/A   31C    P0    42W / 300W |     35MiB / 80994MiB |      0%      Default |
|       

### Hyperparams

In [3]:
n_mels = 768
n_outputs = 61
n_frames = 400

### Model

In [4]:
from model.model_transformer_reg import Model
from utils.util import *
from tensorboardX import SummaryWriter

step = 0
device = 'cuda:0'

# Model
model = Model(in_dim=n_mels, h_dim=512, out_dim=n_outputs)
model = model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
print('done')

done


### Load

In [5]:
save_dir = '/data/scpark/save/lips/train08.08-4/'
!mkdir -p $save_dir
!ls -lt $save_dir

writer = SummaryWriter(save_dir)

if False:
    step, model, _, optimizer = load(save_dir, 133000, model, None, optimizer)

total 4
-rw-rw-r-- 1 scpark scpark 86  8ì›”  7 22:25 events.out.tfevents.1691413299.GPUSVR11


### Dataset

In [6]:
import os
from data.arkit_dataset import LipsDataset, CombinedDataset, CombinedCollate

root_dir = '/data/speech/digital_human/preprocessed/kyuseokKim'
files = sorted([os.path.join(root_dir, file) for file in os.listdir(root_dir)])
print(len(files))

train_datasets = []
test_datasets = []

for file in files:
    print(file)
    dataset = LipsDataset(file, n_mels, n_frames, perturb=False)
    if '_1_' in file:
        test_datasets.append(dataset)
    else:
        train_datasets.append(dataset)
print(len(train_datasets), len(test_datasets))

10
/data/speech/digital_human/preprocessed/kyuseokKim/MH_ARKit_002_10_iPhone_raw.npy
/data/speech/digital_human/preprocessed/kyuseokKim/MH_ARKit_002_1_iPhone_raw.npy
/data/speech/digital_human/preprocessed/kyuseokKim/MH_ARKit_002_2_iPhone_raw.npy
/data/speech/digital_human/preprocessed/kyuseokKim/MH_ARKit_002_3_iPhone_raw.npy
/data/speech/digital_human/preprocessed/kyuseokKim/MH_ARKit_002_4_iPhone_raw.npy
/data/speech/digital_human/preprocessed/kyuseokKim/MH_ARKit_002_5_iPhone_raw.npy
/data/speech/digital_human/preprocessed/kyuseokKim/MH_ARKit_002_6_iPhone_raw.npy
/data/speech/digital_human/preprocessed/kyuseokKim/MH_ARKit_002_7_iPhone_raw.npy
/data/speech/digital_human/preprocessed/kyuseokKim/MH_ARKit_002_8_iPhone_raw.npy
/data/speech/digital_human/preprocessed/kyuseokKim/MH_ARKit_002_9_iPhone_raw.npy
9 1


In [7]:
train_loader = torch.utils.data.DataLoader(CombinedDataset(train_datasets), 
                                           num_workers=16, shuffle=True, batch_size=32, collate_fn=CombinedCollate())
test_loader = torch.utils.data.DataLoader(CombinedDataset(test_datasets), 
                                          num_workers=10, shuffle=True, batch_size=10, collate_fn=CombinedCollate())
print('done')

done


In [8]:
import fairseq
from torchaudio.transforms import Resample

resample = Resample(24000, 16000)

ckpt_path = "/Storage/speech/pretrained/contentvec/checkpoint_best_legacy_500.pt"
hubert, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([ckpt_path])
hubert = hubert[0]
hubert = hubert.to(device)
hubert.eval()

def get_hubert_feature(wav):
    with torch.no_grad():
        # (b, t, c)
        wav = resample(torch.Tensor(wav)).to(device)
        feature = hubert.extract_features(wav, output_layer=12)[0]
        return feature.transpose(1, 2)
print('done')

2023-08-07 22:44:00 | INFO | fairseq.tasks.hubert_pretraining | current directory is /home/scpark/projects/wav2face
2023-08-07 22:44:00 | INFO | fairseq.tasks.hubert_pretraining | HubertPretrainingTask Config {'_name': 'hubert_pretraining', 'data': 'metadata', 'fine_tuning': False, 'labels': ['km'], 'label_dir': 'label', 'label_rate': 50.0, 'sample_rate': 16000, 'normalize': False, 'enable_padding': False, 'max_keep_size': None, 'max_sample_size': 250000, 'min_sample_size': 32000, 'single_target': False, 'random_crop': True, 'pad_audio': False}
2023-08-07 22:44:00 | INFO | fairseq.models.hubert.hubert | HubertModel Config: {'_name': 'hubert', 'label_rate': 50.0, 'extractor_mode': default, 'encoder_layers': 12, 'encoder_embed_dim': 768, 'encoder_ffn_embed_dim': 3072, 'encoder_attention_heads': 12, 'activation_fn': gelu, 'layer_type': transformer, 'dropout': 0.1, 'attention_dropout': 0.1, 'activation_dropout': 0.0, 'encoder_layerdrop': 0.05, 'dropout_input': 0.1, 'dropout_features': 0.1,

done


In [None]:
from IPython import display
import librosa.display
import matplotlib.pyplot as plt

isnan = False
while True:
    if isnan:
        break
    for batch in train_loader:
        targets = torch.Tensor(batch['blend']).transpose(1, 2).to(device)
        inputs = get_hubert_feature(batch['wav'])
        inputs = F.interpolate(inputs, size=(targets.shape[2]), mode='linear')
        
        model.train()
        model.zero_grad()
        outputs = model(inputs, targets)
        
        print(step)
        loss = 0
        for key in outputs.keys():
            if 'loss' in key:
                loss += outputs[key]
                print(key, outputs[key].item())
        if torch.isnan(loss):
            isnan = True
            break
        loss.backward()
        optimizer.step()
        print(step, loss.item())
        
        if step % 1000 == 0:
            writer.add_scalar('train_loss', loss.item(), step)
        
        if step % 1000 == 0:
            display.clear_output()
            
            losses = []
            for i, batch in enumerate(test_loader):
                if i >= 10:
                    break
                    
                targets = torch.Tensor(batch['blend']).transpose(1, 2).to(device)
                inputs = get_hubert_feature(batch['wav'])
                inputs = F.interpolate(inputs, size=(targets.shape[2]), mode='linear')

                model.eval()
                with torch.no_grad():
                    outputs = model(inputs, targets)
                    
                loss = 0
                for key in outputs.keys():
                    if 'loss' in key:
                        loss += outputs[key]
                print('test :', i, loss.item())
                losses.append(loss)        
            
            test_loss = torch.stack(losses).mean().item()
            print('test_loss :', test_loss)
            writer.add_scalar('test_loss', test_loss, step)
            
#             plt.figure(figsize=[18, 4])
#             librosa.display.specshow(targets[0].data.cpu().numpy(), cmap='magma')
#             plt.show()

#             plt.figure(figsize=[18, 4])
#             librosa.display.specshow(outputs['y_pred'][0].data.cpu().numpy(), cmap='magma')
#             plt.show()
            
#             for i in [20, 37]:
#                 plt.figure(figsize=[18, 2])
#                 plt.title(str(i))
#                 plt.plot(targets[0].data.cpu().numpy()[i])
#                 plt.plot(outputs['y_pred'][0].data.cpu().numpy()[i])
#                 plt.show()

        if step % 1000 == 0:
            save(save_dir, step, model, None, optimizer)
    
        step += 1

test : 0 0.7893012166023254
test : 1 0.7913994789123535
test : 2 0.791231632232666
test : 3 0.7888653874397278
test : 4 0.7907230257987976
test : 5 0.790205180644989
test : 6 0.7908065915107727
test : 7 0.7891135811805725
test : 8 0.7904638648033142
test : 9 0.7911481261253357
test_loss : 0.7903258204460144
saved /data/scpark/save/lips/train08.08-4/save_0
1
loss 0.7805330157279968
1 0.7805330157279968
2
loss 0.5398089289665222
2 0.5398089289665222
3
loss 0.366061806678772
3 0.366061806678772
4
loss 0.23858694732189178
4 0.23858694732189178
5
loss 0.1873086839914322
5 0.1873086839914322
6
loss 0.17176775634288788
6 0.17176775634288788
7
loss 0.17642976343631744
7 0.17642976343631744
8
loss 0.16957271099090576
8 0.16957271099090576
9
loss 0.15811890363693237
9 0.15811890363693237
10
loss 0.14759767055511475
10 0.14759767055511475
11
loss 0.13748449087142944
11 0.13748449087142944
12
loss 0.12771090865135193
12 0.12771090865135193
13
loss 0.11926567554473877
13 0.11926567554473877
14
loss

154 0.0640484094619751
155
loss 0.06365308910608292
155 0.06365308910608292
156
loss 0.0649920329451561
156 0.0649920329451561
157
loss 0.06607849895954132
157 0.06607849895954132
158
loss 0.06421656161546707
158 0.06421656161546707
159
loss 0.06442838907241821
159 0.06442838907241821
160
loss 0.06435778737068176
160 0.06435778737068176
161
loss 0.06441661715507507
161 0.06441661715507507
162
loss 0.06379473954439163
162 0.06379473954439163
163
loss 0.0645180270075798
163 0.0645180270075798
164
loss 0.06217589229345322
164 0.06217589229345322
165
loss 0.06194538623094559
165 0.06194538623094559
166
loss 0.06323637068271637
166 0.06323637068271637
167
loss 0.06217409670352936
167 0.06217409670352936
168
loss 0.06109097972512245
168 0.06109097972512245
169
loss 0.06321562081575394
169 0.06321562081575394
170
loss 0.061188504099845886
170 0.061188504099845886
171
loss 0.061927150934934616
171 0.061927150934934616
172
loss 0.06363283097743988
172 0.06363283097743988
173
loss 0.061140753328

308
loss 0.04957477003335953
308 0.04957477003335953
309
loss 0.047177545726299286
309 0.047177545726299286
310
loss 0.04705219343304634
310 0.04705219343304634
311
loss 0.04669496789574623
311 0.04669496789574623
312
loss 0.04598230496048927
312 0.04598230496048927
313
loss 0.04826594144105911
313 0.04826594144105911
314
loss 0.04784676805138588
314 0.04784676805138588
315
loss 0.04596222564578056
315 0.04596222564578056
316
loss 0.04575224220752716
316 0.04575224220752716
317
loss 0.0457565039396286
317 0.0457565039396286
318
loss 0.045292872935533524
318 0.045292872935533524
319
loss 0.04402950033545494
319 0.04402950033545494
320
loss 0.04502849653363228
320 0.04502849653363228
321
loss 0.04451245069503784
321 0.04451245069503784
322
loss 0.04610089957714081
322 0.04610089957714081
323
loss 0.04600531980395317
323 0.04600531980395317
324
loss 0.04498664289712906
324 0.04498664289712906
325
loss 0.04707645624876022
325 0.04707645624876022
326
loss 0.0465426929295063
326 0.0465426929

In [None]:
save(save_dir, step, model, None, optimizer)