In [20]:
import torch
import torchvision
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import transforms
from torchtext.data import Field
from torch.nn.utils.rnn import pad_sequence
from torch.utils.tensorboard import SummaryWriter
import torch.optim as optim
from torchsummaryX import summary

import warnings
warnings.filterwarnings("ignore")
import os
import time
import copy

import utils
import models
from dataset import PhoenixDataset, ToTensorVideo, RandomResizedCropVideo

## Dataset statistic

In [21]:
# root = '/mnt/data/public/datasets'
# print('video length: maximum / minimum / average / std')
# print(utils.DatasetStatistic(root, 'train'))
# print(utils.DatasetStatistic(root, 'dev'))
# print(utils.DatasetStatistic(root, 'test'))

## Building vocab

In [22]:
TRG = Field(sequential=True, use_vocab=True,
            init_token='<sos>', eos_token= '<eos>',
            lower=True, tokenize='spacy',
            tokenizer_language='de')

root = '/mnt/data/public/datasets'
csv_file = utils.get_csv(root)
tgt_sents = [csv_file.iloc[i, 0].lower().split('|')[3].split()
             for i in range(len(csv_file))]

# hyper
TRG.build_vocab(tgt_sents, min_freq=1)

## Process batch

In [23]:
def collate_fn(batch):
    '''
    process the batch:
        pad the video to fixed frame length
        convert sentence to index
        video: [C, T, H, W]
    '''
    videos = [item['video'].permute(1,0,2,3) for item in batch]
    videos= pad_sequence([v for v in videos], batch_first=True)
    videos = videos.permute(0, 2, 1, 3 , 4)
    
    annotations = [item['annotation'].split() for item in batch]
    annotations = TRG.process(annotations)

    return {'videos': videos, 'annotations': annotations}

## Loading dataset

In [24]:
BSZ = 1
root = '/mnt/data/public/datasets'
transform = transforms.Compose([ToTensorVideo(),
                                RandomResizedCropVideo(112)])

train_loader = DataLoader(PhoenixDataset(root, 'train', transform=transform),
                          batch_size=BSZ, shuffle=True, num_workers=4, collate_fn=collate_fn)

dev_loader = DataLoader(PhoenixDataset(root, 'dev', transform=transform),
                        batch_size=BSZ, shuffle=True, num_workers=4, collate_fn=collate_fn)

test_loader = DataLoader(PhoenixDataset(root, 'test', transform=transform),
                         batch_size=BSZ, shuffle=True, num_workers=4, collate_fn=collate_fn)


In [None]:
# videos: [N, C, T, H, W]
# annotations: [L, N]

# batch = next(iter(train_loader))
# print(batch['videos'].shape)
# print(batch['annotations'].shape)

# print(utils.itos(batch['annotations'].squeeze(1), TRG))

# print(len(train_loader))
# print(len(dev_loader))
# print(len(test_loader))

## Define model

In [25]:
D_MODEL = 512
DROPOUT = 0.5
NHEAD = 1
NLAYER
NHID = 64
ACTIVATION = 'relu'
NCLIP = 10
NEPOCH = 1

device = torch.device('cuda:2' if torch.cuda.is_available() else 'cpu')

writer = SummaryWriter(os.path.join('./log', time.strftime(
    "%Y-%m-%d %H:%M:%S", time.localtime(time.time()))))

res3d = torchvision.models.video.r3d_18(pretrained=True)

encoder = models.Res3D(res3d)

decoder = models.Transformer(
    device, len(TRG.vocab), D_MODEL, DROPOUT,
    NHEAD, NLAYER, NLAYER, NHID, ACTIVATION)

model = models.Seq2Seq(NCLIP, encoder, decoder, device).to(device)

criterion = nn.CrossEntropyLoss(ignore_index=TRG.vocab.stoi['<pad>'])

optimizer = optim.Adam(model.parameters())

## Train and evaluate

In [26]:
best_val_bleu = 0.0
best_val_model = copy.deepcopy(model.state_dict())
for n_epoch in range(NEPOCH):
    models.train(model, train_loader, device, criterion, optimizer, TRG, writer, n_epoch)
    val_loss, val_bleu, val_wer = models.evaluate(model, dev_loader, device, criterion, TRG)
    print(val_loss, val_bleu, val_wer)
    
    if val_bleu > best_val_bleu:
        best_val_bleu = val_bleu
        best_val_model = copy.deepcopy(model.state_dict())
        
model.load_state_dict(best_val_model)
test_loss, test_bleu, test_wer = models.evaluate(model, test_loader, device, criterion, TRG)
print(test_loss, test_bleu, test_wer)

## Save model

In [27]:
if not os.path.exists('./save'):
    os.mkdir("save")
dir_name = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time()))
torch.save(model.state_dict(), './save/'+dir_name+'.pth')

# # change input shape from [N, C, T, H, W] to [N, T, C, H, W]
# videos = batch['videos'].permute(0, 2, 1, 3, 4)
# texts = batch['annotations'].permute(1, 0)
# texts = [' '.join([TRG.vocab.itos[i] for i in sent]) for sent in texts]
# writer.add_video('input', videos, global_step=0, fps=32)
# writer.add_text('annotations', str(texts), 0)

## Load and test

In [1]:
%%time
model.load_state_dict(torch.load('./save/2020-03-01 18:17:57.pth'))
test_loss, test_bleu, test_wer = models.evaluate(model, test_loader, device, criterion, TRG)