In [5]:
import os
import sys
import glob

import numpy as np
import mxnet as mx
from mxnet import gluon, autograd ,nd
from mxnet.gluon import nn, rnn,utils
import mxnet.gluon.data.dataset as dataset
from mxnet.gluon.data.vision import datasets
from mxnet.gluon.data import DataLoader
import mxnet.ndarray as F
from mxnet.gluon.data.vision import transforms

from tqdm import tqdm, trange

from mxnet.gluon.data.vision.datasets import image

In [6]:
from utils.common import *

In [7]:
from utils.align import Align

In [8]:
class LipsDataset(dataset.Dataset):
    def __init__(self, root, align_root, flag=1, transform=None):
        self._root = os.path.expanduser(root)
        self._align_root = align_root
        self._flag = flag
        self._transform = transform
        self._exts = ['.jpg', '.jpeg', '.png']
        self._list_images(self._root)

    def _list_images(self, root):
        self.labels = []
        self.items = []
        
        folder_path = glob.glob(os.path.join(root, "*","*"))
    
        for folder in folder_path:
            label_index = os.path.split(folder)[-1]
            filename = glob.glob(os.path.join(folder, "*"))
            filename.sort()
            label = os.path.split(folder)[-1]
            self.items.append((filename, label))
            
    def align_generation(self,file_nm,padding=75):
        align = Align(self._align_root+file_nm+'.align')
        return nd.array(align.sentence(padding))
    
    def __getitem__(self, idx):
        img = list()
        for image_name in self.items[idx][0]:
            tmp_img = image.imread(image_name, self._flag)
            if self._transform is not None:
                tmp_img =  self._transform(tmp_img)
            img.append(tmp_img)
        img = nd.stack(*img)
        #print(self.items[idx][0][0])
        label = self.align_generation(self.items[idx][1])
        return img, label

    def __len__(self):
        return len(self.items)
    
ctx = mx.cpu()

In [9]:
class LipNet(nn.Block):
    def __init__(self,dr_rate, **kwargs):
        super(LipNet, self).__init__(**kwargs)
        
        with self.name_scope():
            self.conv1 = nn.Conv3D(32,kernel_size=(3,5,5),strides=(1,2,2),padding=(1,2,2))
            self.bn1 = nn.BatchNorm()
            #self.bn1 = nn.InstanceNorm(in_channels=32)
            self.dr1 = nn.Dropout(dr_rate)
            self.pool1 = nn.MaxPool3D((1,2,2),(1,2,2))
            
            self.conv2 = nn.Conv3D(64,kernel_size=(3,5,5),strides=(1,1,1),padding=(1,2,2))
            self.bn2 = nn.BatchNorm()
            #self.bn2 = nn.InstanceNorm(in_channels=64)
            self.dr2 = nn.Dropout(dr_rate)
            self.pool2 = nn.MaxPool3D((1,2,2),(1,2,2))
            
            self.conv3 = nn.Conv3D(96,kernel_size=(3,3,3),strides=(1,1,1),padding=(1,2,2))
            self.bn3 = nn.BatchNorm()
            #self.bn3 = nn.InstanceNorm(in_channels=96)
            self.dr3 = nn.Dropout(dr_rate)
            self.pool3 = nn.MaxPool3D((1,2,2),(1,2,2))
            
            self.gru1 = rnn.GRU(256,bidirectional=True)
            self.gru2 = rnn.GRU(256,bidirectional=True)
            
            self.dense = nn.Dense(27+1,flatten=False)
            
    def summary(self,desc,out):
        print("=======================================")
        print("{d} shape : {o}".format(d=desc,o=out.shape))
            
            
    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = F.relu(out)
        out = self.dr1(out)
        out = self.pool1(out)
        
        out = self.conv2(out)
        out = self.bn2(out)
        out = F.relu(out)
        out = self.dr2(out)
        out = self.pool2(out)
        
        out = self.conv3(out)
        out = self.bn3(out)
        out = F.relu(out)
        out = self.dr3(out)
        out = self.pool3(out)
        
        out = nd.transpose(out,(2,0,1,3,4))
        #out = out.swapaxes(1,2)
        out = out.reshape((out.shape[0],out.shape[1],-1))
        out = self.gru1(out)
        out = self.gru2(out)
        out = self.dense(out)
        out = F.log_softmax(out,axis=2)
        #out = out.swapaxes(0,1)
        out = nd.transpose(out,(1,0,2))
        
        return out

In [10]:
def char_conv(out):
    out_conv = list()
    for i in range(out.shape[0]):
        tmp_str = ''
        for j in range(out.shape[1]):
            if int(out[i][j]) >=0:
                tmp_char = int2char(int(out[i][j]))
                if int(out[i][j]) == 27:
                    tmp_char = ''
                tmp_str = tmp_str+tmp_char
        out_conv.append(tmp_str)
    return out_conv

net = LipNet(0.5)
net.initialize(ctx=ctx)

In [11]:
input_transform  = transforms.Compose([transforms.ToTensor()
                                    , transforms.Normalize((0.7136,0.4906,0.3283),(0.1138,0.1078,0.0917))
                                 ])

In [12]:
training_dataset = LipsDataset('./datasets/TARGET/','./datasets/align/',transform=input_transform)
train_dataloader = mx.gluon.data.DataLoader(training_dataset, batch_size=16, shuffle=True,num_workers=4)

In [13]:
for input_data, label in tqdm(train_dataloader):
    input_data = nd.transpose(input_data,(0,2,1,3,4))
    input_data = input_data.copyto(ctx)
    label = label.copyto(ctx)
    with autograd.train_mode():
        pred = net(input_data)
    label = label.asnumpy()
    pred = pred.argmax(2).asnumpy()
    break

  0%|          | 0/2005 [00:00<?, ?it/s]

In [15]:
label_conv = char_conv(label)
pred_conv = char_conv(pred)

In [16]:
label_conv

['bin red in p six soon',
 'set blue with z two again',
 'bin red with a zero please',
 'place red in n six now',
 'lay green in q one please',
 'set red with i three soon',
 'bin green in y four again',
 'set green with j five please',
 'place green at d nine again',
 'lay red at x five soon',
 'bin red by x four soon',
 'set white with s seven again',
 'place white by p zero now',
 'bin green at f eight again',
 'set blue by r nine soon',
 'bin white at y seven again']

In [17]:
pred_conv

['ppcpxpccccccccccgaccxcccxlacltqlttttttttttatttctcccccctqcttaactxxgttppcllll',
 'xfgggggg ggxxogggggggxxgggggzzzggggxxxxxxxxxxxxxxxxgxxgxtgxxxzxzzzxzxxxxxz',
 'xxxcccxxxxxxxxxloootoxoxxxxxlllxlllclcggllllxxxccooxlxlllcceqqccccxxxxetczc',
 'xgtttagxxpppppgglppllgggggpglllgllgaggggggggxxxgxxxxxgglgggggxxxxctgtfaaass',
 'rrlaaalllllllllgllgglglggggglgggglgggggggggggggggggggggllgaagggggptgggggggo',
 'ippltllccccccoxglllrlggllcgglclllclclllgccllllllcccccccccgalcrarllll  aagls',
 'ccccccgttttttcctxtggtttccctttcgggtgttxxtttttccggggttgccccxcccgllttccxxtxx  ',
 'xxxxxxllxxxxlllllulxxlllllggazooollllggllllllll    lggggnggglllllllggtzxzz',
 'nnnnyyyoyyylgxnaaaaaaaaaaaaaaaaxxxraggggxaaaaaaaaaaaaayylllllllllllggaaggsb',
 'yeyylllooaxooelllggglllllllllllloloaaaaaaaaalplaalllllllaaaaaaaalllllllaall',
 'pxxxcccoxxxglxxxxxxxxapclloolacccggccaccgxcagggggttttcgcaxxooccoalglllltll',
 'xnxalllplllbbbbblbllxxfynyxxxllllllt    xxxxxxlxlllllllllllxxxxbbxxxxxllxxg',
 'ptglplgggggggggllglllllllllllllllllllllll