In [None]:
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms
from torch.nn.utils.rnn import pad_sequence
from torchtext.data import Field
from torch.utils.tensorboard import SummaryWriter
from torchsummaryX import summary

import warnings
warnings.filterwarnings("ignore")
from itertools import groupby
from jiwer import wer
import pandas as pd
import os
import time

from Core.cnn_rnn import CNN_RNN
from Core.dataset import PhoenixDataset

## Load dataset

In [None]:
TRG = Field(sequential=True, use_vocab=True,
            init_token=None, eos_token= None,
            lower=True, tokenize='spacy',
            tokenizer_language='de')


root = '/mnt/data/public/datasets'
csv_dir = os.path.join(root, 'phoenix2014-release/phoenix-2014-multisigner')
csv_dir = os.path.join(csv_dir, 'annotations/manual/train.corpus.csv')
csv_file = pd.read_csv(csv_dir)
tgt_sents = [csv_file.iloc[i, 0].lower().split('|')[3].split()
             for i in range(len(csv_file))]
TRG.build_vocab(tgt_sents, min_freq=1)
VocabSize = len(TRG.vocab)


def collate_fn(batch):
    videos = [item['video'] for item in batch]
    video_lens = torch.tensor([len(v) for v in videos])
    videos = pad_sequence(videos, batch_first=True)
    annotations = [item['annotation'].split() for item in batch]
    annotation_lens = torch.tensor([len(anno) for anno in annotations])
    annotations = TRG.process(annotations)
    return {'videos': videos,
            'annotations': annotations,
            'video_lens': video_lens,
            'annotation_lens': annotation_lens}

FrameSize = 224
BSZ = 2
interval = 4

transform = transforms.Compose([
    transforms.RandomResizedCrop(FrameSize),
    transforms.ToTensor()])

train_loader = DataLoader(
    PhoenixDataset(root, mode='train', interval=interval, transform=transform),
    batch_size=BSZ, shuffle=True, num_workers=BSZ,
    collate_fn=collate_fn, pin_memory=True)

dev_loader = DataLoader(
    PhoenixDataset(root, mode='dev', interval=interval, transform=transform),
    batch_size=BSZ, shuffle=False, num_workers=BSZ,
    collate_fn=collate_fn, pin_memory=True)

test_loader = DataLoader(
    PhoenixDataset(root, mode='test', interval=interval, transform=transform),
    batch_size=BSZ, shuffle=False, num_workers=BSZ,
    collate_fn=collate_fn, pin_memory=True)

## Define val

In [None]:
def val(net, test_loader, criterion):
    net.eval()
    epoch_wer = 0.0
    epoch_loss = 0.0
    with torch.no_grad():
        for batch_idx, batch in enumerate(test_loader):
            inputs = batch['videos'].cuda()
            targets = batch['annotations'].permute(1,0).contiguous().cuda()
            input_lens = batch['video_lens'].cuda()
            target_lens = batch['annotation_lens'].cuda()
            
            outs = net(inputs)
            loss = criterion(outs, targets, input_lens, target_lens)
            
            outs = outs.max(-1)[1].permute(1,0).contiguous().view(-1)
            outs = ' '.join([TRG.vocab.itos[k] for k, _ in groupby(outs) if k != VocabSize])
            targets = targets.view(-1)
            targets = ' '.join([TRG.vocab.itos[k] for k in targets])
            epoch_wer += wer(targets, outs, standardize=True)
            epoch_loss += loss.item()
          
    return epoch_loss/len(test_loader), epoch_wer/len(test_loader)

## test

In [None]:
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
save_root = '/home/xieliang/Data/sign-language-recognition'
save_model = os.path.join(save_root, 'save/CNN_RNN_CTC3.pth')

save_dict = torch.load(save_model)
best_dev_wer = save_dict['best_dev_wer']
net = save_dict['net'].cuda()
criterion = nn.CTCLoss(blank=VocabSize)

test_loss, test_wer = val(net, test_loader, criterion)
print(test_loss, test_wer)

