In [1]:
import os
os.chdir('../')

import warnings
warnings.filterwarnings('ignore')

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import os

!nvidia-smi
os.environ["CUDA_VISIBLE_DEVICES"]="5"

Mon Aug  7 22:43:42 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.199.02   Driver Version: 470.199.02   CUDA Version: 11.4     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA A100 80G...  Off  | 00000000:1B:00.0 Off |                    0 |
| N/A   57C    P0    85W / 300W |  16890MiB / 80994MiB |     99%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
|   1  NVIDIA A100 80G...  Off  | 00000000:1C:00.0 Off |                    0 |
| N/A   31C    P0    42W / 300W |     35MiB / 80994MiB |      0%      Default |
|       

### Hyperparams

In [3]:
n_mels = 768
n_outputs = 61
n_frames = 400

### Model

In [4]:
from model.model_transformer_reg import Model
from utils.util import *
from tensorboardX import SummaryWriter

step = 0
device = 'cuda:0'

# Model
model = Model(in_dim=n_mels, h_dim=512, out_dim=n_outputs)
model = model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
print('done')

done


### Load

In [5]:
save_dir = '/data/scpark/save/lips/train08.08-5/'
!mkdir -p $save_dir
!ls -lt $save_dir

writer = SummaryWriter(save_dir)

if False:
    step, model, _, optimizer = load(save_dir, 133000, model, None, optimizer)

total 0
-rw-rw-r-- 1 scpark scpark 0  8월  7 22:01 events.out.tfevents.1691413315.GPUSVR11


### Dataset

In [6]:
import os
from data.arkit_dataset import LipsDataset, CombinedDataset, CombinedCollate

root_dir = '/data/speech/digital_human/preprocessed/nohsikPark'
files = sorted([os.path.join(root_dir, file) for file in os.listdir(root_dir)])
print(len(files))

train_datasets = []
test_datasets = []

for file in files:
    print(file)
    dataset = LipsDataset(file, n_mels, n_frames, perturb=False)
    if '_1_' in file:
        test_datasets.append(dataset)
    else:
        train_datasets.append(dataset)
print(len(train_datasets), len(test_datasets))

9
/data/speech/digital_human/preprocessed/nohsikPark/MH_ARKit_010_0_iPhone_raw.npy
/data/speech/digital_human/preprocessed/nohsikPark/MH_ARKit_010_1_iPhone_raw.npy
/data/speech/digital_human/preprocessed/nohsikPark/MH_ARKit_010_2_iPhone_raw.npy
/data/speech/digital_human/preprocessed/nohsikPark/MH_ARKit_010_3_iPhone_raw.npy
/data/speech/digital_human/preprocessed/nohsikPark/MH_ARKit_010_4_iPhone_raw.npy
/data/speech/digital_human/preprocessed/nohsikPark/MH_ARKit_010_5_iPhone_raw.npy
/data/speech/digital_human/preprocessed/nohsikPark/MH_ARKit_010_6_iPhone_raw.npy
/data/speech/digital_human/preprocessed/nohsikPark/MH_ARKit_010_7_iPhone_raw.npy
/data/speech/digital_human/preprocessed/nohsikPark/MH_ARKit_010_8_iPhone_raw.npy
8 1


In [7]:
train_loader = torch.utils.data.DataLoader(CombinedDataset(train_datasets), 
                                           num_workers=16, shuffle=True, batch_size=32, collate_fn=CombinedCollate())
test_loader = torch.utils.data.DataLoader(CombinedDataset(test_datasets), 
                                          num_workers=10, shuffle=True, batch_size=10, collate_fn=CombinedCollate())
print('done')

done


In [8]:
import fairseq
from torchaudio.transforms import Resample

resample = Resample(24000, 16000)

ckpt_path = "/Storage/speech/pretrained/contentvec/checkpoint_best_legacy_500.pt"
hubert, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([ckpt_path])
hubert = hubert[0]
hubert = hubert.to(device)
hubert.eval()

def get_hubert_feature(wav):
    with torch.no_grad():
        # (b, t, c)
        wav = resample(torch.Tensor(wav)).to(device)
        feature = hubert.extract_features(wav, output_layer=12)[0]
        return feature.transpose(1, 2)
print('done')

2023-08-07 22:44:01 | INFO | fairseq.tasks.hubert_pretraining | current directory is /home/scpark/projects/wav2face
2023-08-07 22:44:01 | INFO | fairseq.tasks.hubert_pretraining | HubertPretrainingTask Config {'_name': 'hubert_pretraining', 'data': 'metadata', 'fine_tuning': False, 'labels': ['km'], 'label_dir': 'label', 'label_rate': 50.0, 'sample_rate': 16000, 'normalize': False, 'enable_padding': False, 'max_keep_size': None, 'max_sample_size': 250000, 'min_sample_size': 32000, 'single_target': False, 'random_crop': True, 'pad_audio': False}
2023-08-07 22:44:01 | INFO | fairseq.models.hubert.hubert | HubertModel Config: {'_name': 'hubert', 'label_rate': 50.0, 'extractor_mode': default, 'encoder_layers': 12, 'encoder_embed_dim': 768, 'encoder_ffn_embed_dim': 3072, 'encoder_attention_heads': 12, 'activation_fn': gelu, 'layer_type': transformer, 'dropout': 0.1, 'attention_dropout': 0.1, 'activation_dropout': 0.0, 'encoder_layerdrop': 0.05, 'dropout_input': 0.1, 'dropout_features': 0.1,

done


In [None]:
from IPython import display
import librosa.display
import matplotlib.pyplot as plt

isnan = False
while True:
    if isnan:
        break
    for batch in train_loader:
        targets = torch.Tensor(batch['blend']).transpose(1, 2).to(device)
        inputs = get_hubert_feature(batch['wav'])
        inputs = F.interpolate(inputs, size=(targets.shape[2]), mode='linear')
        
        model.train()
        model.zero_grad()
        outputs = model(inputs, targets)
        
        print(step)
        loss = 0
        for key in outputs.keys():
            if 'loss' in key:
                loss += outputs[key]
                print(key, outputs[key].item())
        if torch.isnan(loss):
            isnan = True
            break
        loss.backward()
        optimizer.step()
        print(step, loss.item())
        
        if step % 1000 == 0:
            writer.add_scalar('train_loss', loss.item(), step)
        
        if step % 1000 == 0:
            display.clear_output()
            
            losses = []
            for i, batch in enumerate(test_loader):
                if i >= 10:
                    break
                    
                targets = torch.Tensor(batch['blend']).transpose(1, 2).to(device)
                inputs = get_hubert_feature(batch['wav'])
                inputs = F.interpolate(inputs, size=(targets.shape[2]), mode='linear')

                model.eval()
                with torch.no_grad():
                    outputs = model(inputs, targets)
                    
                loss = 0
                for key in outputs.keys():
                    if 'loss' in key:
                        loss += outputs[key]
                print('test :', i, loss.item())
                losses.append(loss)        
            
            test_loss = torch.stack(losses).mean().item()
            print('test_loss :', test_loss)
            writer.add_scalar('test_loss', test_loss, step)
            
#             plt.figure(figsize=[18, 4])
#             librosa.display.specshow(targets[0].data.cpu().numpy(), cmap='magma')
#             plt.show()

#             plt.figure(figsize=[18, 4])
#             librosa.display.specshow(outputs['y_pred'][0].data.cpu().numpy(), cmap='magma')
#             plt.show()
            
#             for i in [20, 37]:
#                 plt.figure(figsize=[18, 2])
#                 plt.title(str(i))
#                 plt.plot(targets[0].data.cpu().numpy()[i])
#                 plt.plot(outputs['y_pred'][0].data.cpu().numpy()[i])
#                 plt.show()

        if step % 1000 == 0:
            save(save_dir, step, model, None, optimizer)
    
        step += 1

test : 0 0.9149622917175293
test : 1 0.9156110882759094
test : 2 0.9170243144035339
test : 3 0.9165120124816895
test : 4 0.9138364791870117
test : 5 0.9092640280723572
test : 6 0.9128372073173523
test : 7 0.9144492149353027
test : 8 0.9183736443519592
test : 9 0.9189295768737793
test_loss : 0.9151800274848938
saved /data/scpark/save/lips/train08.08-5/save_0
1
loss 0.8941007852554321
1 0.8941007852554321
2
loss 0.5358812808990479
2 0.5358812808990479
3
loss 0.3441655933856964
3 0.3441655933856964
4
loss 0.23247218132019043
4 0.23247218132019043
5
loss 0.19305472075939178
5 0.19305472075939178
6
loss 0.20285972952842712
6 0.20285972952842712
7
loss 0.19993285834789276
7 0.19993285834789276
8
loss 0.177408829331398
8 0.177408829331398
9
loss 0.1515730321407318
9 0.1515730321407318
10
loss 0.14223960041999817
10 0.14223960041999817
11
loss 0.14766639471054077
11 0.14766639471054077
12
loss 0.14899305999279022
12 0.14899305999279022
13
loss 0.14370273053646088
13 0.14370273053646088
14
loss

155
loss 0.08061681687831879
155 0.08061681687831879
156
loss 0.08192145824432373
156 0.08192145824432373
157
loss 0.08052922785282135
157 0.08052922785282135
158
loss 0.08382955938577652
158 0.08382955938577652
159
loss 0.0815148577094078
159 0.0815148577094078
160
loss 0.08089287579059601
160 0.08089287579059601
161
loss 0.08104439079761505
161 0.08104439079761505
162
loss 0.0833008885383606
162 0.0833008885383606
163
loss 0.08062469959259033
163 0.08062469959259033
164
loss 0.08041752129793167
164 0.08041752129793167
165
loss 0.08054567128419876
165 0.08054567128419876
166
loss 0.08231647312641144
166 0.08231647312641144
167
loss 0.07993926852941513
167 0.07993926852941513
168
loss 0.08388256281614304
168 0.08388256281614304
169
loss 0.08420916646718979
169 0.08420916646718979
170
loss 0.08206815272569656
170 0.08206815272569656
171
loss 0.0835239589214325
171 0.0835239589214325
172
loss 0.08306045830249786
172 0.08306045830249786
173
loss 0.08222246170043945
173 0.08222246170043945

310 0.06543402373790741
311
loss 0.06852234899997711
311 0.06852234899997711
312
loss 0.06534479558467865
312 0.06534479558467865
313
loss 0.06552668660879135
313 0.06552668660879135
314
loss 0.06223458796739578
314 0.06223458796739578
315
loss 0.06464899331331253
315 0.06464899331331253
316
loss 0.0629190132021904
316 0.0629190132021904
317
loss 0.06780614703893661
317 0.06780614703893661
318
loss 0.06313353031873703
318 0.06313353031873703
319
loss 0.06600811332464218
319 0.06600811332464218
320
loss 0.0650218203663826
320 0.0650218203663826
321
loss 0.06291189044713974
321 0.06291189044713974
322
loss 0.06328864395618439
322 0.06328864395618439
323
loss 0.06206401064991951
323 0.06206401064991951
324
loss 0.06227393448352814
324 0.06227393448352814
325
loss 0.06088859215378761
325 0.06088859215378761
326
loss 0.06793217360973358
326 0.06793217360973358
327
loss 0.06429751962423325
327 0.06429751962423325
328
loss 0.06351840496063232
328 0.06351840496063232
329
loss 0.063564494252204

In [None]:
save(save_dir, step, model, None, optimizer)