In [1]:
import torch
import torchvision
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import transforms
from torchtext.data import Field
from torch.nn.utils.rnn import pad_sequence
from torch.utils.tensorboard import SummaryWriter
import torch.optim as optim
from torchsummaryX import summary

import warnings
warnings.filterwarnings("ignore")
import os
import time
import datetime
import copy

import utils
from dataset import PhoenixDataset, ToTensorVideo, RandomResizedCropVideo
from models import cnn, transformer, seq2seq

## Dataset statistic

In [2]:
# root = '/mnt/data/public/datasets'
# print('video length: maximum / minimum / average / std')
# print(utils.DatasetStatistic(root, 'train'))
# print(utils.DatasetStatistic(root, 'dev'))
# print(utils.DatasetStatistic(root, 'test'))

video length: maximum / minimum / average / std
[299, 16, 140.8684767277856, 42.786819204979956]
[251, 32, 139.23333333333332, 44.30025081443922]
[251, 39, 142.24483306836248, 43.06641256741527]


## Building vocab

In [3]:
TRG = Field(sequential=True, use_vocab=True,
            init_token='<sos>', eos_token= '<eos>',
            lower=True, tokenize='spacy',
            tokenizer_language='de')

root = '/mnt/data/public/datasets'
csv_file = utils.get_csv(root)
tgt_sents = [csv_file.iloc[i, 0].lower().split('|')[3].split()
             for i in range(len(csv_file))]


TRG.build_vocab(tgt_sents, min_freq=1)
VOCAB_SIZE = len(TRG.vocab)
print(VOCAB_SIZE)

1235


## Process batch

In [4]:
def collate_fn(batch):
    '''
    process the batch:
        pad the video to fixed frame length
        convert sentence to index
        video: [C, T, H, W]
    '''
    videos = [item['video'].permute(1,0,2,3) for item in batch]
    videos= pad_sequence([v for v in videos], batch_first=True)
    videos = videos.permute(0, 2, 1, 3 , 4)
    
    annotations = [item['annotation'].split() for item in batch]
    annotations = TRG.process(annotations)

    return {'videos': videos, 'annotations': annotations}

## Loading dataset

In [5]:
BSZ = 4
SIZE = 224
root = '/mnt/data/public/datasets'
transform = transforms.Compose([ToTensorVideo(),
                                RandomResizedCropVideo(SIZE)])

# ? shuffle false to use smaller dataset
train_loader = DataLoader(PhoenixDataset(root, 'train', transform=transform),
                          batch_size=BSZ, shuffle=False, num_workers=10, collate_fn=collate_fn)

dev_loader = DataLoader(PhoenixDataset(root, 'dev', transform=transform),
                        batch_size=BSZ, shuffle=False, num_workers=10, collate_fn=collate_fn)

test_loader = DataLoader(PhoenixDataset(root, 'test', transform=transform),
                         batch_size=BSZ, shuffle=False, num_workers=10, collate_fn=collate_fn)


In [6]:
# videos: [N, C, T, H, W]
# annotations: [L, N]

# batch = next(iter(train_loader))
# print(batch['videos'].shape)
# print(batch['annotations'].shape)

# print(utils.itos(batch['annotations'].squeeze(1), TRG))

print(len(train_loader))
print(len(dev_loader))
print(len(test_loader))

1418
135
158


## Define model

In [7]:
D_MODEL = 512
DROPOUT = 0.1
NHEAD = 8
NLAYER = 6
NHID = 1024
ACTIVATION = 'relu'
CLIP_SIZE = 10
NEPOCH = 34
LR=1e-4
SEGMENT = 'OVERLAP'

device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')

path = f'bsz:{BSZ}-lr:{LR}-epoch:{NEPOCH}-size:{SIZE}-dmodel:{D_MODEL}-dropout:{DROPOUT}\
-nhead:{NHEAD}-nlayer:{NLAYER}-nhid:{NHID}-activation:{ACTIVATION}-clip_size:{CLIP_SIZE}-segment:{SEGMENT}'
writer = SummaryWriter(os.path.join('./log', path))

# res3d = torchvision.models.video.r3d_18(pretrained=True)
res3d = torch.load('./save/res3d18.pth', map_location = device)
CNN = cnn.Res3D(res3d)

Transformer = transformer.Transformer(
    len(TRG.vocab), D_MODEL, DROPOUT,
    NHEAD, NLAYER, NHID, ACTIVATION)

# ?
# model = seq2seq.Res3d_transformer(CLIP_SIZE, CNN, Transformer).to(device)
model = torch.load('./save/res3d18_transformer.pth', map_location=device)

criterion = nn.CrossEntropyLoss(ignore_index=TRG.vocab.stoi['<pad>'])

optimizer = optim.Adam(model.parameters(), lr = LR)

In [8]:
# for batch_idx, batch in enumerate(test_loader):
#     print(batch['videos'].shape)
#     print(batch['annotations'].shape)
#     video = batch['videos'].permute(0,2,1,3,4)
#     annotation = batch['annotations'].permute(1,0)[0]
#     annotation = ' '.join([TRG.vocab.itos[i] for i in annotation])
    
#     writer.add_video('test', video, global_step=0, fps=4)
    
#     if batch_idx == 0:
#         break

In [None]:
# summary(model, torch.zeros(1,3,100,112,112).to(device), 
#         torch.zeros(10,1, dtype=torch.long).to(device),
#         torch.zeros(1,100,dtype=torch.bool).to(device))

## Define train

In [None]:
def train(model, train_loader, device, criterion, optimizer, TRG, writer, n_epoch):
    model.train()
    running_loss = 0.0
    running_bleu = 0.0
    running_wer = 0.0
    for batch_idx, batch in enumerate(train_loader):
        inputs = batch['videos'].to(device)
        targets = batch['annotations'].to(device)
        # ? 改进 通过collate_fn返回
        src_padding_mask = transformer.get_padding_mask(inputs).to(device)
        optimizer.zero_grad()
        outputs = model(inputs, targets[:-1, :], src_padding_mask, device)
        loss = criterion(outputs.view(-1, outputs.size(-1)),
                         targets[1:, :].view(-1))
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        running_bleu += utils.bleu_count(outputs, targets[1:, :], TRG)
        running_wer += utils.wer_count(outputs, targets[1:, :], TRG)
        
        # ?
        if batch_idx == 90:
            break
        
            
        # 34 -> 9
        if batch_idx % 9 == 8:
            writer.add_scalar('train loss',
                              running_loss / 9,
                              # ？
                              n_epoch * len(train_loader) + batch_idx)
            writer.add_scalar('train bleu',
                              running_bleu / 9,
                              n_epoch * len(train_loader) + batch_idx)
            writer.add_scalar('train wer',
                              running_wer / 9,
                              n_epoch * len(train_loader) + batch_idx)

            running_loss = 0.0
            running_bleu = 0.0
            running_wer = 0.0

## Define evaluate

In [None]:
def evaluate(model, dev_loader, device, criterion, TRG, writer, n_epoch):
    model.eval()
    epoch_loss = 0.0
    epoch_bleu = 0.0
    epoch_wer = 0.0
    with torch.no_grad():
        for batch_idx, batch in enumerate(dev_loader):
            inputs = batch['videos'].to(device)
            targets = batch['annotations'].to(device)
            # validation and test share the same src_padding_mask
            src_padding_mask = transformer.get_padding_mask(inputs).to(device)
            outputs = model(inputs, targets[:-1,:], src_padding_mask, device)
            loss = criterion(outputs.view(-1, outputs.size(-1)),
                             targets[1:, :].view(-1))

            epoch_loss += loss.item()
            epoch_bleu += utils.bleu_count(outputs, targets[1:, :], TRG)
            epoch_wer += utils.wer_count(outputs, targets[1:, :], TRG)
            
            # ?
            if batch_idx == 9:
                break
    
    # len(dev_loader) -> 10
    epoch_loss /= 10
    epoch_bleu /= 10
    epoch_wer /= 10
    
    # if ?
    if writer:
        writer.add_scalar('val loss', epoch_loss, n_epoch)
        writer.add_scalar('val bleu', epoch_bleu, n_epoch)
        writer.add_scalar('val wer', epoch_wer, n_epoch)
    return epoch_loss, epoch_bleu, epoch_wer

## Define test

In [None]:
def test(model, test_loader, device, criterion, TRG):
    model.eval()
    epoch_loss = 0.0
    epoch_bleu = 0.0
    epoch_wer = 0.0
    with torch.no_grad():
        for batch_idx, batch in enumerate(test_loader):
            inputs = batch['videos'].to(device)
            targets = batch['annotations'].to(device)
            src_padding_mask = transformer.get_padding_mask(inputs).to(device)
            dec_inputs = transformer.greedy_decoder(model, inputs, targets[:-1, :], src_padding_mask, device).to(device)
            outputs = model(inputs, dec_inputs, src_padding_mask, device)
            loss = criterion(outputs.view(-1, outputs.size(-1)), targets[1:,:].view(-1))
            
            epoch_loss += loss.item()
            epoch_bleu += utils.bleu_count(outputs, targets[1:,:], TRG)
            epoch_wer += utils.wer_count(outputs, targets[1:,:], TRG)
            
            if batch_idx == 9:
                break
          
    # ? len(test_loader) -> 10
    epoch_loss /= 10
    epoch_bleu /= 10
    epoch_wer /= 10
                
    return epoch_loss, epoch_bleu, epoch_wer

## Train and evaluate

In [None]:
best_val_wer = 0.9
for n_epoch in range(NEPOCH):
    %time train(model, train_loader, device, criterion, optimizer, TRG, writer, n_epoch)
    %time val_loss, val_bleu, val_wer = evaluate(model, dev_loader, device, criterion, TRG, writer, n_epoch)
    print(f'epoch:{n_epoch} | val loss:{val_loss} | val bleu:{val_bleu} | val wer:{val_wer}')
    
    if val_wer < best_val_wer:
        best_val_wer = val_wer
        torch.save(model, './save/res3d18_transformer.pth')
        print(f'best model params saved in epoch {n_epoch} with best val wer: {best_val_wer}')
        

CPU times: user 4min 24s, sys: 3min 52s, total: 8min 17s
Wall time: 5min 12s
CPU times: user 27.2 s, sys: 26 s, total: 53.1 s
Wall time: 42.3 s
epoch:0 | val loss:4.6117489576339725 | val bleu:0.009405815601348877 | val wer:0.8757083799419846
best model params saved in epoch 0 with best val wer: 0.8757083799419846
CPU times: user 5min 4s, sys: 4min 10s, total: 9min 14s
Wall time: 5min 42s
CPU times: user 36.1 s, sys: 28.9 s, total: 1min 4s
Wall time: 45.5 s
epoch:1 | val loss:4.616577982902527 | val bleu:0.010252672433853149 | val wer:0.8670510711452775
best model params saved in epoch 1 with best val wer: 0.8670510711452775
CPU times: user 5min 31s, sys: 4min 24s, total: 9min 56s
Wall time: 6min 20s
CPU times: user 37.1 s, sys: 31.3 s, total: 1min 8s
Wall time: 46.8 s
epoch:2 | val loss:4.672018146514892 | val bleu:0.020611317455768587 | val wer:0.8794950222961253
CPU times: user 5min 32s, sys: 4min 11s, total: 9min 44s
Wall time: 6min 20s
CPU times: user 30.2 s, sys: 28.7 s, total: 5

KeyboardInterrupt: 

## Load and test

In [None]:
%%time
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
criterion = nn.CrossEntropyLoss(ignore_index=TRG.vocab.stoi['<pad>'])
model_ft = torch.load('./save/res3d18_transformer.pth', map_location=device)

test_loss, test_bleu, test_wer = test(model_ft, test_loader, device, criterion, TRG)
print(f'test loss {test_loss} | test bleu {test_bleu} | test wer {test_wer}')

# val_loss, val_bleu, val_wer = evaluate(model_ft, dev_loader, device, criterion, TRG, writer=None, n_epoch=0)
# print(val_loss, val_bleu, val_wer)