# Audio2Blendshape test notebook

In [1]:
import os
import numpy as np
import torch
import torch.nn as nn
from scripts.Dataset import a2bsDataset

  from .autonotebook import tqdm as notebook_tqdm


## Initialize train-eval-test split
Set *build_cache=True* to buid cache if it doesn't already exist.

In [2]:
train_data = a2bsDataset(loader_type='train', build_cache=True)
train_loader = torch.utils.data.DataLoader(
            train_data, 
            batch_size=16,  
            shuffle=True,  
            num_workers=0,
            drop_last=True,
        )

In [3]:
len(train_data)

6460

In [4]:
eval_data = a2bsDataset(loader_type='eval', build_cache=False)
eval_loader = torch.utils.data.DataLoader(
            eval_data, 
            batch_size=16,  
            shuffle=True,  
            num_workers=0,
            drop_last=True,
        )

In [5]:
len(eval_data)

1447

In [4]:
class BasicBlock(nn.Module):
    '''
    from timm
    '''
    def __init__(self, inplanes, planes, ker_size, stride=1, downsample=None, cardinality=1, base_width=64,
                 reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.LeakyReLU,   norm_layer=nn.BatchNorm1d, attn_layer=None, aa_layer=None, drop_block=None, drop_path=None):
        super(BasicBlock, self).__init__()

        self.conv1 = nn.Conv1d(
            inplanes, planes, kernel_size=ker_size, stride=stride, padding=first_dilation,
            dilation=dilation, bias=True)
        self.bn1 = norm_layer(planes)
        self.act1 = act_layer(inplace=True)
        #self.aa = aa_layer(channels=first_planes, stride=stride) if use_aa else None

        self.conv2 = nn.Conv1d(
            planes, planes, kernel_size=ker_size, padding=ker_size//2, dilation=dilation, bias=True)
        self.bn2 = norm_layer(planes)

        #self.se = create_attn(attn_layer, outplanes)

        self.act2 = act_layer(inplace=True)
        if downsample is not None:
            self.downsample = nn.Sequential(
                nn.Conv1d(inplanes, planes,  stride=stride, kernel_size=ker_size, padding=first_dilation, dilation=dilation, bias=True),
                norm_layer(planes), 
            )
        else: self.downsample=None
        self.stride = stride
        self.dilation = dilation
        self.drop_block = drop_block
        self.drop_path = drop_path

    def zero_init_last_bn(self):
        nn.init.zeros_(self.bn2.weight)

    def forward(self, x):
        #print("x after 0", x.shape)
        shortcut = x

        x = self.conv1(x)
        x = self.bn1(x)
        x = self.act1(x)
        #print("x after 1", x.shape)
        x = self.conv2(x)
        x = self.bn2(x)
        #print("x after 2", x.shape)
        if self.downsample is not None:
            shortcut = self.downsample(shortcut)
        x += shortcut
        x = self.act2(x)
        #print("x after 3", x.shape)
        return x

In [5]:
class WavEncoder(nn.Module):
    def __init__(self, out_dim):
        super().__init__() #128*1*140844 
        self.out_dim = out_dim
        self.feat_extractor = nn.Sequential( #b = (a+3200)/5 a 
                BasicBlock(1, 32, 15, 5, first_dilation=1600, downsample=True),
                BasicBlock(32, 32, 15, 6, first_dilation=0, downsample=True),
                BasicBlock(32, 32, 15, 1, first_dilation=7, ),
                BasicBlock(32, 64, 15, 6, first_dilation=0, downsample=True),
                BasicBlock(64, 64, 15, 1, first_dilation=7),
                BasicBlock(64, 128, 15, 6,  first_dilation=0,downsample=True),     
            )
        
    def forward(self, wav_data):
        wav_data = wav_data.unsqueeze(1)  # add channel dim
        out = self.feat_extractor(wav_data)
        return out.transpose(1, 2)  # to (batch x seq x dim)

class FaceGenerator(nn.Module):
    def __init__(self, facial_dims = 51, audio_f = 128, hidden_size = 256, n_layer = 4, dropout_prob = 0.3):
        super().__init__()
        #self.pre_length = args.pre_frames #4
        #self.gen_length = args.facial_length - args.pre_frames #30
        self.facial_dims = facial_dims
        #self.speaker_f = args.speaker_f
        self.audio_f = audio_f
        #self.facial_in = int(args.facial_rep[-2:])
        
        self.in_size = self.audio_f + self.facial_dims + 1
        self.audio_encoder = WavEncoder(self.audio_f)
        
        self.hidden_size = hidden_size
        self.n_layer = n_layer
        self.dropout_prob = dropout_prob

        # if self.facial_f is not 0:  
        #     self.facial_encoder = nn.Sequential( #b = (a+3200)/5 a 
        #         BasicBlock(self.facial_in, self.facial_f//2, 7, 1, first_dilation=3,  downsample=True),
        #         BasicBlock(self.facial_f//2, self.facial_f//2, 3, 1, first_dilation=1,  downsample=True),
        #         BasicBlock(self.facial_f//2, self.facial_f//2, 3, 1, first_dilation=1, ),
        #         BasicBlock(self.facial_f//2, self.facial_f, 3, 1, first_dilation=1,  downsample=True),   
        #     )
        # else:
        #     self.facial_encoder = None

        
        self.gru = nn.GRU(self.in_size, hidden_size=self.hidden_size, num_layers=self.n_layer, batch_first=True,
                          bidirectional=True, dropout=self.dropout_prob)
        self.out = nn.Sequential(
            nn.Linear(self.hidden_size, self.hidden_size//2),
            nn.LeakyReLU(True),
            nn.Linear(self.hidden_size//2, self.facial_dims)
        )
        
        self.do_flatten_parameters = False
        if torch.cuda.device_count() > 1:
            self.do_flatten_parameters = True
            

    def forward(self, pre_seq, in_audio, is_test=False): #pre_seq in this case is in_face
        decoder_hidden = decoder_hidden_hands = None
        
        if self.do_flatten_parameters:
            self.gru.flatten_parameters()

        audio_feat_seq = None
        audio_feat_seq = self.audio_encoder(in_audio)  # output (bs, n_frames, feat_size)
        
        # if self.facial_f is not 0:
        #     # facial
        #     # print(in_facial.shape)
        #     face_feat_seq = self.facial_encoder(in_facial.permute([0, 2, 1]))
        #     face_feat_seq = face_feat_seq.permute([0, 2, 1])
    
        # if is_test:
        #     if self.facial_f is not 0:
        #         min_length = min(pre_seq.shape[1], audio_feat_seq.shape[1], face_feat_seq.shape[1])
        #         pre_seq = pre_seq[:,: min_length]
        #         audio_feat_seq = audio_feat_seq[:, : min_length]
        #         face_feat_seq = face_feat_seq[:, : min_length]
        #     else:
        #         min_length = min(pre_seq.shape[1], audio_feat_seq.shape[1])
        #         pre_seq = pre_seq[:,: min_length]
        #         audio_feat_seq = audio_feat_seq[:, : min_length]
            #print(pre_seq.shape)
        # if self.audio_f is not 0 and self.facial_f is 0:
        #     in_data = torch.cat((pre_seq, audio_feat_seq), dim=2)
        # elif self.audio_f is not 0 and self.facial_f is not 0:
        #     in_data = torch.cat((pre_seq, audio_feat_seq, face_feat_seq), dim=2)
        # else: pass
        
        in_data = torch.cat((pre_seq, audio_feat_seq), dim=2)
        
        # if speaker_feat_seq is not None:
        #     #if print(z_context.shape)
        #     repeated_s = speaker_feat_seq
        #     #print(repeated_s.shape)
        #     if len(repeated_s.shape) == 2:
        #         repeated_s = repeated_s.reshape(1, repeated_s.shape[1], repeated_s.shape[0])
        #         #print(repeated_s.shape)
        #     repeated_s = repeated_s.repeat(1, in_data.shape[1], 1)
        #     #print(repeated_s.shape)
        #     #print(repeated_s.shape)
        #     in_data = torch.cat((in_data, repeated_s), dim=2)
        
        
        output, decoder_hidden = self.gru(in_data, decoder_hidden)
        output = output[:, :, :self.hidden_size] + output[:, :, self.hidden_size:]  # sum bidirectional outputs
        output = self.out(output.reshape(-1, output.shape[2]))
        decoder_outputs = output.reshape(in_data.shape[0], in_data.shape[1], -1)


        return decoder_outputs

In [31]:
in_audio, facial, in_id = next(iter(train_loader))

In [32]:
print(in_audio.shape, facial.shape, in_id.shape)

torch.Size([16, 36266]) torch.Size([16, 34, 51]) torch.Size([16, 1])


In [64]:
net = FaceGenerator().cuda()
optimizer = torch.optim.Adam( net.parameters(), lr=1e-3)
class RMSLELoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.mse = nn.MSELoss()
        
    def forward(self, pred, actual):
        return torch.sqrt(self.mse(torch.log(pred + 1), torch.log(actual + 1)))
loss_function = RMSLELoss()

In [66]:
net.train()
num_epochs = 5
log_period = 100
for epoch in range(num_epochs):
    for it, (in_audio, facial, in_id) in enumerate(train_loader):
        in_audio = in_audio.cuda()
        facial = facial.cuda()
        pre_frames = 4
        in_pre_face = facial.new_zeros((facial.shape[0], facial.shape[1], facial.shape[2] + 1)).cuda()
        in_pre_face[:, 0:pre_frames, :-1] = facial[:, 0:pre_frames]
        in_pre_face[:, 0:pre_frames, -1] = 1 
        
        optimizer.zero_grad()
        out_face = net(in_pre_face,in_audio)
        loss = loss_function(facial, out_face)
        loss.backward()
        optimizer.step()
        
        #logging
        if it % log_period == 0:
            print(f'[{epoch}][{it}/{len(train_loader)}] loss: {loss.item()}')

[0][0/535] loss: 0.05445299297571182
[0][100/535] loss: 0.05884157121181488
[0][200/535] loss: 0.05473871901631355
[0][300/535] loss: 0.059451572597026825
[0][400/535] loss: 0.05230134725570679
[0][500/535] loss: 0.049990464001894
[1][0/535] loss: 0.06423146277666092
[1][100/535] loss: 0.062476251274347305
[1][200/535] loss: 0.053438518196344376
[1][300/535] loss: 0.054198652505874634
[1][400/535] loss: 0.062087081372737885
[1][500/535] loss: 0.05568043887615204
[2][0/535] loss: 0.05325791984796524
[2][100/535] loss: 0.05018600821495056
[2][200/535] loss: 0.05615130811929703
[2][300/535] loss: 0.05180227756500244
[2][400/535] loss: 0.05200745910406113
[2][500/535] loss: 0.06145203486084938
[3][0/535] loss: 0.05457715317606926
[3][100/535] loss: 0.04981430992484093
[3][200/535] loss: 0.04949994385242462
[3][300/535] loss: 0.049588557332754135
[3][400/535] loss: 0.06405346840620041
[3][500/535] loss: 0.05206301808357239
[4][0/535] loss: 0.05488245189189911
[4][100/535] loss: 0.0568476989

In [98]:
net.eval()
in_audio, facial, in_id = next(iter(eval_loader))
in_audio = in_audio.cuda()
facial = facial.cuda()
pre_frames = 4
in_pre_face = facial.new_zeros((facial.shape[0], facial.shape[1], facial.shape[2] + 1)).cuda()
in_pre_face[:, 0:pre_frames, :-1] = facial[:, 0:pre_frames]
in_pre_face[:, 0:pre_frames, -1] = 1 
out_face = net(in_pre_face, in_audio)

In [99]:
out_face.shape, facial.shape

(torch.Size([1, 34, 51]), torch.Size([1, 34, 51]))

In [100]:
out_face[0,:,7]

tensor([0.2166, 0.2283, 0.2397, 0.2315, 0.2428, 0.2493, 0.2504, 0.2541, 0.2573,
        0.2609, 0.2718, 0.2830, 0.2924, 0.2983, 0.2963, 0.2858, 0.2360, 0.2100,
        0.1918, 0.2012, 0.2473, 0.2794, 0.2949, 0.2902, 0.2707, 0.2000, 0.2018,
        0.2292, 0.2660, 0.3031, 0.3274, 0.3392, 0.3476, 0.3398],
       device='cuda:0', grad_fn=<SelectBackward0>)

In [101]:
facial[0,:,7]

tensor([0.2018, 0.1870, 0.1821, 0.1815, 0.1819, 0.1816, 0.1878, 0.1939, 0.1885,
        0.1680, 0.1425, 0.1257, 0.1093, 0.0941, 0.0808, 0.0839, 0.1211, 0.1340,
        0.1208, 0.1187, 0.1301, 0.1508, 0.1523, 0.1383, 0.1254, 0.1011, 0.0942,
        0.1182, 0.1292, 0.1447, 0.1581, 0.1300, 0.0912, 0.0982],
       device='cuda:0')

## Test set

In [11]:
test_data = a2bsDataset(loader_type='test', build_cache=False)
test_loader = torch.utils.data.DataLoader(
            test_data, 
            batch_size=16,  
            shuffle=True,  
            num_workers=0,
            drop_last=True,
        )

In [12]:
len(test_data)

660