## Dependencies and imports

In [None]:
import nltk
nltk.download('stopwords')
nltk.download('punkt')

In [None]:
!pip install -q transformers datasets

In [None]:
import torch
import pickle
import json
import os
import numpy as np
import pandas as pd
import h5py
import torch
from transformers import BertTokenizer, BertModel, VisualBertForQuestionAnswering, VisualBertModel
from torch.nn.utils.rnn import pad_sequence
from nltk.corpus import stopwords
import torch.optim as optim
import torch.nn as nn
from tqdm import tqdm

### Download data from TVQA

In [None]:
DOWNLOAD_PATH = "."

In [None]:
# Download and extract QA data
!wget https://tvqa.cs.unc.edu/files/tvqa_qa_release.tar.gz -P ${DOWNLOAD_PATH}
!tar xzf ${BASE_PATH}/tvqa_qa_release.tar.gz -C ${DOWNLOAD_PATH}

In [None]:
# Download and extract subtitles

# TVQA
!wget https://tvqa.cs.unc.edu/files/tvqa_subtitles.tar.gz -P ${DOWNLOAD_PATH}
!tar xzf ${DOWNLOAD_PATH}/tvqa_subtitles.tar.gz  -C ${DOWNLOAD_PATH}

# TVQA+
wget https://tvqa.cs.unc.edu/files/tvqa_plus_subtitles.tar.gz  -P ${DOWNLOAD_PATH}/tvqa_plus/
tar -xf ${DOWNLOAD_PATH}/tvqa_plus/tvqa_plus_subtitles.tar.gz -C ${DOWNLOAD_PATH}/tvqa_plus/

In [None]:
# Download and extract resnet features
!wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=1m8bC4lefQsP2tRhMLAaiy0AVuBXZtegc' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1m8bC4lefQsP2tRhMLAaiy0AVuBXZtegc" -O tvqa_imagenet_resnet101_pool5_hq.tar.gz && rm -rf /tmp/cookies.txt
!tar xzf tvqa_imagenet_resnet101_pool5_hq.tar.gz

In [None]:
# Download and extract visual concept features
!wget http://tvqa.cs.unc.edu/files/det_visual_concepts_hq.pickle.tar.gz
!tar -xvf det_visual_concepts_hq.pickle.tar.gz

In [None]:
# Download and extract glove embeddings 
!wget http://nlp.stanford.edu/data/wordvecs/glove.6B.zip -P ${DOWNLOAD_PATH}
!unzip -qqq ${DOWNLOAD_PATH}/glove.6B -C ${DOWNLOAD_PATH}
!unzip ../glove.6B.zip

## Model and training


In [None]:
question_types = ["what", "who", "whom", "how", "where", "why"]
english_stopwords = stopwords.words('english')
english_stopwords = list(filter(lambda i: i not in question_types, english_stopwords))
english_stopwords.append("?")

In [None]:
vcpt = pickle.load(open('../det_visual_concepts_hq.pickle', 'rb'))
# Create embedding dict
embeddingDict = {}
with open('./glove.6B.50d.txt') as f:
  lines = f.readlines()
  for line in lines:
    arr = line.split()
    embeddingDict[arr[0]] = np.asarray(arr[1:], dtype='float32')

In [None]:
def extract_objects_from_frame(clip_name, frame_number):
  objects = vcpt[clip_name][frame_number].split(",")
  objects = list(map(lambda x: x.strip(), objects))
  return objects

def split_phrases(phrases):
  split_phrase = set()
  for phrase in phrases:
    split_phrase.update(phrase.split())
  return split_phrase

def get_list_embedding(word_list):
  embedding = []
  for obj in word_list:
    if (obj in embeddingDict):
      embedding.append(embeddingDict[obj])
  return np.array(embedding)

def get_question_tokens(question):
  words = nltk.word_tokenize(question)
  words = list(filter(lambda x: x not in english_stopwords and x in embeddingDict, map(lambda x: x.lower(), words))) 
  return words

def get_best_frames_start(match_scores, window_size):
  max_sum = -1
  index  = -1
  for i in range(len(match_scores) - window_size):
    curr_sum = np.sum(match_scores[i:i+ ws])
    if curr_sum > max_sum:
      max_sum = curr_sum
      index  = i
  return index

def get_frame_scores(clip_name, question):
  NUM_OBJECTS = 10
  frame_scores = np.zeros(total_frames)
  question_embedding = get_list_embedding(get_question_tokens(question))
  for frame_number in range(len(vcpt(clip_name])):
    score = 0
    objects_in_frame = split_phrases(extract_objects_from_frame(clip_name, frame_number)[:NUM_OBJECTS])
    frame_embedding = get_list_embedding(objects_in_frame)
    if objects_in_frame:
      score = np.sum(frame_embedding.dot(question_embedding.T))
    frame_scores[frame_number] = score
  return frame_scores

def get_unique_objects_from_frame(clip_name, frame_start_index, window_size):
  total_frames = len(vcpt[clip_name])
  unique_objects = set()
  for frame_number in range(frame_start_index, frame_start_index + window_size):
    unique_objects.update(split_phrases(extract_objects_from_frame(clip_name, i)))
  return list(unique_objects)


In [None]:
class TVQAPlus(torch.utils.data.Dataset):
    def __init__(self, isTraining=True):
        if isTraining:
          QAFilePath = DOWNLOAD_PATH + "/tvqa_plus/tvqa_plus_train.json"
        else:
          QAFilePath = DOWNLOAD_PATH + "/tvqa_plus/tvqa_plus_val.json"

        self.qa = {}
        with open(QAFilePath) as f:
          self.qa = json.load(f)

        self.subtitles = {}
        with open(DOWNLOAD_PATH + "/tvqa_plus/tvqa_plus_subtitles.json") as f:
          self.subtitles = json.load(f)

    def __len__(self):
        return len(self.qa)
        
    def __getitem__(self, i):
      q, a0, a1, a2, a3, a4 = self.qa[i]['q'], self.qa[i]['a0'],  self.qa[i]['a1'], self.qa[i]['a2'], self.qa[i]['a3'], self.qa[i]['a4']
      answer_idx = int(self.qa[i]['answer_idx'])
      vid_name = self.qa[i]['vid_name']
      subt_text = self.subtitles[vid_name]['sub_text']
      return q, subt_text, a0, a1, a2, a3, a4, video_name, answer_idx

In [None]:
device = "cuda"
class VBERT_Wrapper(torch.nn.Module):

    def __init__(self):
        super(VBERT_Wrapper, self).__init__()
        self.model = VisualBertModel.from_pretrained("uclanlp/visualbert-vqa-coco-pre",  output_hidden_states = True)
        self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased",  output_hidden_states = True)
        self.proj = torch.nn.Sequential(torch.nn.Linear(768, 256), torch.nn.GELU(), torch.nn.Linear(256, 64), torch.nn.GELU(), torch.nn.Linear(64, 1)).to(device)
        unfreezed_layers = ['encoder.layer.11','pooler.dense.weight', 'pooler.dense.bias']
        for name, parameter in self.model.named_parameters():
            freeze_layer = False
            for layer in unfreezed_layers:
                if layer in name:
                    freeze_layer = True
                    break
            parameter.requires_grad = freeze_layer

    def get_vid_feats_and_objects(clip_names, questions, window_size=10):
      resnet_feats = []
      objects_from_video = []
      for clip_number, clip_name in enumerate(clip_names):
          frame_scores = get_frame_scores(clip_name, questions[clip_number])
          best_frames_start_index = get_best_frames_start(frame_scores, window_size)
          unique_objects = get_unique_objects_from_frame(clip_name, best_frames_start_index, window_size)
          objects_from_video.append(' '.join(unique_objects))
          resnet_feats.append(torch.tensor(vid_h5[clip_name][best_frames_start_index:best_frames_start_index+window_size, :], device="cuda"))
      resnet_feats =  pad_sequence(resnet_feats, batch_first=True)
      return resnet_feats, objects_from_video

    def bert_forward(self, questions, options, subtitles, clip_names):
        qa_representation = [questions[i]+ ' [SEP] ' + options[i] for i in range(len(questions))]
        resnet_feats, objects_from_video = self.get_vid_feats_and_objects(clip_names, questions)
        subitlte_representation = [subtitles[i]+ ' [SEP] ' + objects_from_video[i] for i in range(len(subtitles))]
        token_ids = torch.ones(resnet_feats.shape[:-1], dtype=torch.long).to(device)
        attention_mask = torch.ones(resnet_feats.shape[:-1], dtype=torch.float).to(device)
        inputs = self.tokenizer(subitlte_representation, qa_representation, padding="max_length", truncation=True, return_token_type_ids=True, return_attention_mask=True, add_special_tokens=True, return_tensors="pt")
        inputs.update({"visual_embeds": resnet_feats, "visual_token_type_ids": token_ids, "visual_attention_mask": attention_mask,})
        inputs = inputs.to(device)
        output = self.model(**inputs)
        hidden_states = output.last_hidden_state
        cls_tokens = hidden_states[:,0,:]
        return cls_tokens

    def forward(self, question, subt_text, options, video_names):
        scores  = []
        for i range(len(options)):
          scores.append(self.proj(self.bert_forward(question=question, ans=options[i], subt_text=subt_text, video_names=video_names)))
        return torch.tensor(scores)


In [None]:
train_dataset = TVQAPlus(isTraining=True)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataset = TVQAPlus(isTraining=False)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=dev_batch_size, shuffle=False)
vbert_mod = VBERT_Wrapper()
batch_size=24
dev_batch_size=24

In [None]:
def val_acc(model):
  model.eval()
  num_correct = 0
  for batch_idx, (questions, subt_text, a0, a1, a2, a3, a4, video_names, answer_idx) in enumerate(val_loader):
    answer_idx = answer_idx.to("cuda")
    with torch.no_grad():
      logits = model.forward(questions, subt_text, [a0, a1, a2, a3, a4], video_names)
    num_correct += int((torch.argmax(logits, axis=1) == answer_idx).sum())
    acc = 100 * num_correct / ((batch_idx + 1) * dev_batch_size)
  dev_acc = 100 * num_correct / (len(val_loader) * dev_batch_size)
  model.train()
  return dev_acc

In [None]:
optimizer = optim.Adam(vbert_mod.parameters(), lr=1e-3)
criterion = torch.nn.CrossEntropyLoss()
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.8)

In [None]:
def get_fileName(epoch, isDev=False):
  if isDev:
    return "best_acc_model.pth"
  else:
    return "train_" + str(epoch) + ".pth"
  
torch.cuda.empty_cache()

epoch = 1
best_dev_acc = 0

while epoch <= 10:
  
    num_correct = 0
    loss_epoch = 0
    batch_bar = tqdm(total=len(train_loader), dynamic_ncols=True, leave=False, position=0, desc='Train')
    vbert_mod.train()

    # Appending an "s" in front of all var names to make it clear that we are dealing with a batch
    for batch_idx, (questions, subt_texts, a0s, a1s, a2s, a3s, a4s, video_names, answer_idx) in enumerate(train_loader):
        logits = vbert_mod.forward(questions, subt_texts, [a0s, a1s, a2s, a3s, a4s], video_names)
        answer_idx = answer_idx.to(device)
        correct_prediction_count += int((torch.argmax(logits, axis=1) == answer_idx).sum())
        loss = criterion(logits, answer_idx)
        loss.backward()
        optimizer.step()
        loss_epoch += float(loss)
        optimizer.zero_grad()
        batch_number = batch_idx + 1
        batch_bar.set_postfix(
            acc="{:.03f}%".format(100 * (correct_prediction_count /(batch_number*batch_size))),
            loss="{:.03f}".format(float(loss_epoch /batch_number)),
            num_correct=correct_prediction_count,
            lr="{:.03f}".format(float(optimizer.param_groups[0]['lr'])))
        batch_bar.update() 
        torch.cuda.empty_cache()
    batch_bar.close()
    train_acc = 100 * (correct_prediction_count / (len(train_loader) * batch_size))
    dev_acc = val_acc(vbert_mod, val_loader, dev_batch_size)

    fileName = get_fileName(epoch, isDev=False)
    torch.save({'epoch': epoch, 'model_state_dict': vbert_mod.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'loss': loss_epoch/len(train_loader),},  f'../{fileName}')
    
    if dev_acc > best_dev_acc:
        fileName = get_fileName(epoch, isDev=True)
        best_dev_acc = dev_acc
        torch.save({'epoch': epoch, 'model_state_dict': vbert_mod.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'loss': loss_epoch/len(train_loader),},  f'../{fileName}')
    epoch += 1