In [1]:
import torch
import torchvision
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import transforms
from torchtext.data import Field
from torch.nn.utils.rnn import pad_sequence
from torch.utils.tensorboard import SummaryWriter
import torch.optim as optim
from torchsummaryX import summary

import warnings
warnings.filterwarnings("ignore")
import os
import time
import datetime
import copy

import utils
import models
from dataset import PhoenixDataset, ToTensorVideo, RandomResizedCropVideo

## Dataset statistic

In [2]:
# root = '/mnt/data/public/datasets'
# print('video length: maximum / minimum / average / std')
# print(utils.DatasetStatistic(root, 'train'))
# print(utils.DatasetStatistic(root, 'dev'))
# print(utils.DatasetStatistic(root, 'test'))

## Building vocab

In [3]:
TRG = Field(sequential=True, use_vocab=True,
            init_token='<sos>', eos_token= '<eos>',
            lower=True, tokenize='spacy',
            tokenizer_language='de')

root = '/mnt/data/public/datasets'
csv_file = utils.get_csv(root)
tgt_sents = [csv_file.iloc[i, 0].lower().split('|')[3].split()
             for i in range(len(csv_file))]

# hyper
TRG.build_vocab(tgt_sents, min_freq=1)

## Process batch

In [4]:
def collate_fn(batch):
    '''
    process the batch:
        pad the video to fixed frame length
        convert sentence to index
        video: [C, T, H, W]
    '''
    videos = [item['video'].permute(1,0,2,3) for item in batch]
    videos= pad_sequence([v for v in videos], batch_first=True)
    videos = videos.permute(0, 2, 1, 3 , 4)
    
    annotations = [item['annotation'].split() for item in batch]
    annotations = TRG.process(annotations)

    return {'videos': videos, 'annotations': annotations}

## Loading dataset

In [5]:
BSZ = 16
root = '/mnt/data/public/datasets'
transform = transforms.Compose([ToTensorVideo(),
                                RandomResizedCropVideo(112)])

train_loader = DataLoader(PhoenixDataset(root, 'train', transform=transform),
                          batch_size=BSZ, shuffle=True, num_workers=10, collate_fn=collate_fn)

dev_loader = DataLoader(PhoenixDataset(root, 'dev', transform=transform),
                        batch_size=BSZ, shuffle=True, num_workers=10, collate_fn=collate_fn)

test_loader = DataLoader(PhoenixDataset(root, 'test', transform=transform),
                         batch_size=BSZ, shuffle=True, num_workers=10, collate_fn=collate_fn)


In [6]:
# videos: [N, C, T, H, W]
# annotations: [L, N]

# batch = next(iter(train_loader))
# print(batch['videos'].shape)
# print(batch['annotations'].shape)

# print(utils.itos(batch['annotations'].squeeze(1), TRG))

print(len(train_loader))
print(len(dev_loader))
print(len(test_loader))

355
34
40


## Define model

In [7]:
D_MODEL = 512
DROPOUT = 0.1
NHEAD = 8
NLAYER = 6
NHID = 1024
ACTIVATION = 'relu'
CLIP_SIZE = 10
NEPOCH = 50
LR=1e-4
SEGMENT = 'OVERLAP'

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

path = f'bsz:{BSZ}-lr:{LR}-epoch:{NEPOCH}-dmodel:{D_MODEL}-dropout:{DROPOUT}\
-nhead:{NHEAD}-nlayer:{NLAYER}-nhid:{NHID}-activation:{ACTIVATION}-clip_size:{CLIP_SIZE}-segment:{SEGMENT}'
writer = SummaryWriter(os.path.join('./log', path))

res3d = torchvision.models.video.r3d_18(pretrained=True)

encoder = models.Res3D(res3d)

decoder = models.Transformer(
    device, len(TRG.vocab), D_MODEL, DROPOUT,
    NHEAD, NLAYER, NHID, ACTIVATION)

model = models.Seq2Seq(CLIP_SIZE, encoder, decoder, device).to(device)

criterion = nn.CrossEntropyLoss(ignore_index=TRG.vocab.stoi['<pad>'])

optimizer = optim.Adam(model.parameters(), lr = LR)

In [None]:
# summary(model, torch.zeros(1,3,100,112,112).to(device), 
#         torch.zeros(10,1, dtype=torch.long).to(device),
#         torch.zeros(1,100,dtype=torch.bool).to(device))

## Train and evaluate

In [None]:
best_val_bleu = 0.0
best_val_model = copy.deepcopy(model.state_dict())
for n_epoch in range(NEPOCH):
    %time models.train(model, train_loader, device, criterion, optimizer, TRG, writer, n_epoch)
    %time val_loss, val_bleu, val_wer = models.evaluate(model, dev_loader, device, criterion, TRG)
    print(n_epoch, val_loss, val_bleu, val_wer)
    
    if val_bleu > best_val_bleu:
        best_val_bleu = val_bleu
        best_val_model = copy.deepcopy(model.state_dict())
        
model.load_state_dict(best_val_model)
test_loss, test_bleu, test_wer = models.evaluate(model, test_loader, device, criterion, TRG, test_mode = True)
print(test_loss, test_bleu, test_wer)

CPU times: user 19min 28s, sys: 16min 51s, total: 36min 19s
Wall time: 1h 23min 20s
CPU times: user 1min 50s, sys: 1min 29s, total: 3min 20s
Wall time: 3min 43s
0 4.601124006159165 0.04133609704235021 0.8911424780576237
CPU times: user 20min 15s, sys: 16min 38s, total: 36min 54s
Wall time: 48min 3s
CPU times: user 1min 51s, sys: 1min 35s, total: 3min 27s
Wall time: 2min 39s
1 4.245331091039321 0.0513331392670379 0.8023410824316906
CPU times: user 20min 33s, sys: 17min 4s, total: 37min 37s
Wall time: 45min 53s
CPU times: user 1min 50s, sys: 1min 32s, total: 3min 22s
Wall time: 3min 46s
2 4.0913653513964485 0.047211908132714385 0.7929021234537266
CPU times: user 20min 35s, sys: 17min 5s, total: 37min 40s
Wall time: 48min 39s
CPU times: user 1min 51s, sys: 1min 33s, total: 3min 25s
Wall time: 3min 40s
3 4.081003504640916 0.0458398761556429 0.7848909117338055
CPU times: user 20min 28s, sys: 17min 7s, total: 37min 36s
Wall time: 42min 24s
CPU times: user 1min 51s, sys: 1min 34s, total: 3min

## Save model

In [None]:
# if not os.path.exists('./save'):
#     os.mkdir("save")
# dir_name = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time()))
# torch.save(model.state_dict(), './save/'+dir_name+'.pth')


# # change input shape from [N, C, T, H, W] to [N, T, C, H, W]
# videos = batch['videos'].permute(0, 2, 1, 3, 4)
# texts = batch['annotations'].permute(1, 0)
# texts = [' '.join([TRG.vocab.itos[i] for i in sent]) for sent in texts]
# writer.add_video('input', videos, global_step=0, fps=32)
# writer.add_text('annotations', str(texts), 0)

## Load and test

In [None]:
# %%time
# model.load_state_dict(torch.load('./save/2020-03-01 18:17:57.pth'))
# test_loss, test_bleu, test_wer = models.evaluate(model, test_loader, device, criterion, TRG)