In [1]:
import os
os.chdir('../')

import warnings
warnings.filterwarnings('ignore')

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import os

!nvidia-smi
os.environ["CUDA_VISIBLE_DEVICES"]="4"

Fri Jul 14 13:14:07 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.125.06   Driver Version: 525.125.06   CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA RTX A5000    Off  | 00000000:1B:00.0 Off |                    0 |
| 30%   27C    P8    13W / 230W |   1706MiB / 23028MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  NVIDIA RTX A5000    Off  | 00000000:1C:00.0 Off |                  Off |
| 30%   26C    P8    14W / 230W |   2866MiB / 24564MiB |      0%      Default |
|       

### Hyperparams

In [4]:
n_mels = 80
n_outputs = 61
n_frames = 400

### Model

In [5]:
from model.model_transformer_reg import Model
from utils.util import *
from tensorboardX import SummaryWriter

step = 0
device = 'cuda:0'

# Model
model = Model(in_dim=n_mels, h_dim=1024, out_dim=n_outputs, n_layers=6, window_size=8)
model = model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
print('done')

done


In [6]:
# # warm start
# checkpoint = torch.load('/data/scpark/save/lips/train06.23-1/save_120000', map_location=torch.device('cpu'))
# model_state_dict = model.state_dict()

# for key in checkpoint['model_state_dict']:
#     if key in model_state_dict.keys():
#         if checkpoint['model_state_dict'][key].shape == model_state_dict[key].shape:
#             model_state_dict[key] = checkpoint['model_state_dict'][key]
#             print(key)
# model.load_state_dict(model_state_dict, strict=True)
# print('warm start')

### Load

In [7]:
save_dir = '/data/scpark/save/lips/train07.14-1/'
!mkdir -p $save_dir
!ls -lt $save_dir

writer = SummaryWriter(save_dir)

if False:
    step, model, _, optimizer = load(save_dir, 56389, model, None, optimizer)

total 0


### Dataset

In [8]:
import os
from data.arkit_dataset import LipsDataset, CombinedDataset, CombinedCollate

root_dir = '/data/speech/digital_human/preprocessed/'
files = sorted([os.path.join(root_dir, file) for file in os.listdir(root_dir)])
print(len(files))

train_datasets = []
test_datasets = []

for file in files:
    print(file)
    dataset = LipsDataset(file, n_mels, n_frames)
    if '_10_' in file:
        test_datasets.append(dataset)
    else:
        train_datasets.append(dataset)
print(len(train_datasets), len(test_datasets))

39
/data/speech/digital_human/preprocessed/MH_ARKit_001_1_iPhone_raw.npy
/data/speech/digital_human/preprocessed/MH_ARKit_001_2_iPhone_raw.npy
/data/speech/digital_human/preprocessed/MH_ARKit_001_3_iPhone_raw.npy
/data/speech/digital_human/preprocessed/MH_ARKit_001_4_iPhone_raw.npy
/data/speech/digital_human/preprocessed/MH_ARKit_001_5_iPhone_raw.npy
/data/speech/digital_human/preprocessed/MH_ARKit_001_6_iPhone_raw.npy
/data/speech/digital_human/preprocessed/MH_ARKit_001_7_iPhone_raw.npy
/data/speech/digital_human/preprocessed/MH_ARKit_001_8_iPhone_raw.npy
/data/speech/digital_human/preprocessed/MH_ARKit_001_9_iPhone_raw.npy
/data/speech/digital_human/preprocessed/MH_ARKit_002_10_iPhone_raw.npy
/data/speech/digital_human/preprocessed/MH_ARKit_002_1_iPhone_raw.npy
/data/speech/digital_human/preprocessed/MH_ARKit_002_2_iPhone_raw.npy
/data/speech/digital_human/preprocessed/MH_ARKit_002_3_iPhone_raw.npy
/data/speech/digital_human/preprocessed/MH_ARKit_002_4_iPhone_raw.npy
/data/speech/dig

In [9]:
train_loader = torch.utils.data.DataLoader(CombinedDataset(train_datasets), 
                                           num_workers=16, shuffle=True, batch_size=32, collate_fn=CombinedCollate())
test_loader = torch.utils.data.DataLoader(CombinedDataset(test_datasets), 
                                          num_workers=10, shuffle=True, batch_size=10, collate_fn=CombinedCollate())
print('done')

done


In [None]:
from IPython import display
import librosa.display
import matplotlib.pyplot as plt

isnan = False
while True:
    if isnan:
        break
    for batch in train_loader:
        inputs = torch.Tensor(batch['mel']).transpose(1, 2).to(device)
        targets = torch.Tensor(batch['blend']).transpose(1, 2).to(device)
        
        model.train()
        model.zero_grad()
        outputs = model(inputs, targets)
        
        print(step)
        loss = 0
        for key in outputs.keys():
            if 'loss' in key:
                loss += outputs[key]
                print(key, outputs[key].item())
        if torch.isnan(loss):
            isnan = True
            break
        loss.backward()
        optimizer.step()
        print(step, loss.item())
        
        if step % 1000 == 0:
            writer.add_scalar('train_loss', loss.item(), step)
        
        if step % 1000 == 0:
            display.clear_output()
            
            losses = []
            for i, batch in enumerate(test_loader):
                if i >= 10:
                    break
                    
                inputs = torch.Tensor(batch['mel']).transpose(1, 2).to(device)
                targets = torch.Tensor(batch['blend']).transpose(1, 2).to(device)

                model.eval()
                with torch.no_grad():
                    outputs = model(inputs, targets)
                    
                loss = 0
                for key in outputs.keys():
                    if 'loss' in key:
                        loss += outputs[key]
                print('test :', i, loss.item())
                losses.append(loss)        
            
            test_loss = torch.stack(losses).mean().item()
            print('test_loss :', test_loss)
            writer.add_scalar('test_loss', test_loss, step)
            
#             plt.figure(figsize=[18, 4])
#             librosa.display.specshow(targets[0].data.cpu().numpy(), cmap='magma')
#             plt.show()

#             plt.figure(figsize=[18, 4])
#             librosa.display.specshow(outputs['y_pred'][0].data.cpu().numpy(), cmap='magma')
#             plt.show()
            
#             for i in [20, 37]:
#                 plt.figure(figsize=[18, 2])
#                 plt.title(str(i))
#                 plt.plot(targets[0].data.cpu().numpy()[i])
#                 plt.plot(outputs['y_pred'][0].data.cpu().numpy()[i])
#                 plt.show()

        if step % 1000 == 0:
            save(save_dir, step, model, None, optimizer)
    
        step += 1

test : 0 1.5759835243225098
test : 1 1.5748674869537354
test : 2 1.5759663581848145
test : 3 1.5770506858825684
test : 4 1.5749685764312744
test : 5 1.575554609298706
test : 6 1.575116515159607
test : 7 1.575774908065796
test : 8 1.573891282081604
test : 9 1.574268102645874
test_loss : 1.575344204902649
saved /data/scpark/save/lips/train07.14-1/save_0
1
loss 1.5433765649795532
1 1.5433765649795532
2
loss 1.0331677198410034
2 1.0331677198410034
3
loss 0.5754046440124512
3 0.5754046440124512
4
loss 0.35595980286598206
4 0.35595980286598206
5
loss 0.31548383831977844
5 0.31548383831977844
6
loss 0.30122923851013184
6 0.30122923851013184
7
loss 0.2689697742462158
7 0.2689697742462158
8
loss 0.21715256571769714
8 0.21715256571769714
9
loss 0.19307632744312286
9 0.19307632744312286
10
loss 0.18477487564086914
10 0.18477487564086914
11
loss 0.1798853576183319
11 0.1798853576183319
12
loss 0.1693449467420578
12 0.1693449467420578
13
loss 0.1568654179573059
13 0.1568654179573059
14
loss 0.15437

154 0.091191366314888
155
loss 0.08785400539636612
155 0.08785400539636612
156
loss 0.09714507311582565
156 0.09714507311582565
157
loss 0.09180588275194168
157 0.09180588275194168
158
loss 0.09261419624090195
158 0.09261419624090195
159
loss 0.09102822095155716
159 0.09102822095155716
160
loss 0.0888974741101265
160 0.0888974741101265
161
loss 0.09424292296171188
161 0.09424292296171188
162
loss 0.08910277485847473
162 0.08910277485847473
163
loss 0.08785797655582428
163 0.08785797655582428
164
loss 0.09203680604696274
164 0.09203680604696274
165
loss 0.09362271428108215
165 0.09362271428108215
166
loss 0.08867571502923965
166 0.08867571502923965
167
loss 0.09004274755716324
167 0.09004274755716324
168
loss 0.08928389102220535
168 0.08928389102220535
169
loss 0.08885063976049423
169 0.08885063976049423
170
loss 0.09508737921714783
170 0.09508737921714783
171
loss 0.08917352557182312
171 0.08917352557182312
172
loss 0.09548790007829666
172 0.09548790007829666
173
loss 0.089227803051471

In [None]:
save(save_dir, step, model, None, optimizer)

In [None]:
print('done')