### Imports

In [1]:
import os
import numpy as np
import pandas as pd
import json
import warnings
import logging
import gc
import random
import math
import re
import ast
from tqdm import tqdm
from typing import Optional
from datetime import datetime

import nltk
from nltk.translate.bleu_score import sentence_bleu
from nltk.translate.meteor_score import meteor_score
from rouge_score.rouge_scorer import RougeScorer

warnings.filterwarnings("ignore")

import torch
from torch.utils.data import DataLoader, TensorDataset


### Config

In [2]:
TARGET_COLUMN = 'code_mixed_explanation'
TEXT_INPUT_PATH = '../../Data/Text/'
ACOUSTIC_INPUT_PATH = '../../Data/Audio/'
VISUAL_INPUT_PATH = '../../Data/Video/'

MODEL_OUTPUT_DIR = '../../models/MAF-TAV-BART/'
RESULT_OUTPUT_DIR = '../../results/MAF-TAV-BART/'

LOWERCASE_UTTERANCES = False
UNFOLDED_DIALOGUE = True

if UNFOLDED_DIALOGUE:
    SOURCE_COLUMN = 'dialogue'
else:
    SOURCE_COLUMN_1 = 'target'
    SOURCE_COLUMN_2 = 'context'
    


SOURCE_MAX_LEN = 480
TARGET_MAX_LEN = 50
MAX_UTTERANCES = 25

ACOUSTIC_DIM = 154
ACOUSTIC_MAX_LEN = 600

VISUAL_DIM = 2048
VISUAL_MAX_LEN = 96 

BATCH_SIZE = 4
MAX_EPOCHS = 60

BASE_LEARNING_RATE = 5e-6
NEW_LEARNING_RATE = 5e-5
WEIGHT_DECAY = 1e-4

NUM_BEAMS = 5
EARLY_STOPPING = True
NO_REPEAT_NGRAM_SIZE = 3

EARLY_STOPPING_THRESHOLD = 5


In [6]:
def set_random_seed(seed: int):
    """
    Helper function to seed experiment for reproducibility.
    If -1 is provided as seed, experiment uses random seed from 0~9999
    Args:
        seed (int): integer to be used as seed, use -1 to randomly seed experiment
    """
    print("Seed: {}".format(seed))


    # These lines below ensure determinism by turning off some optimizations
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.enabled = False
    torch.backends.cudnn.deterministic = True

    if seed == -1:
        seed = random.randint(0, 9999)

    # We set all seeds in libraries to our seed
    
    random.seed(seed) #python built-in
    os.environ["PYTHONHASHSEED"] = str(seed) #seed for hash based ops in python
    np.random.seed(seed) #numpy random gen
    torch.manual_seed(seed) # torch CPU random generator
    torch.cuda.manual_seed(seed) # current gpu rand gen
    torch.cuda.manual_seed_all(seed) # all gpus rand gen


In [7]:
if torch.cuda.is_available():
    DEVICE = torch.device("cuda")
    print("Using GPU")
else:
    DEVICE = torch.device("cpu")
    print("Using CPU")


Using GPU


In [8]:
set_random_seed(42)

Seed: 42


### Model

In [1]:
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint
from torch.nn import CrossEntropyLoss, MSELoss

from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union

from transformers.modeling_utils import PreTrainedModel, unwrap_model

from transformers import (
    BartTokenizerFast,
    AdamW
)

from transformers.models.bart.configuration_bart import BartConfig

# from transformers.models.bart.modeling_bart import (
#     BartPretrainedModel,
#     BartDecoder,
#     BartLearnedPositionalEmbedding,
#     BartEncoderLayer,
#     shift_tokens_right,
#     _make_causal_mask,
#     _expand_mask 
# )


from transformers.modeling_outputs import (
    BaseModelOutput,
    Seq2SeqLMOutput,
    Seq2SeqModelOutput
)


# from transformer_encoder import TransformerEncoder


### MCA2
Refer to [MCA2 doc](./MCA2_test.ipynb) to understand the workings of this module

In [11]:
import torch
import torch.nn.functional as F
import torch.nn as nn
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union

In [12]:
class ContextAwareAttention(nn.Module):

    def __init__(self,
                 dim_model: int,
                 dim_context: int,
                 dropout_rate: Optional[float]=0.0):
        super(ContextAwareAttention, self).__init__()
        
        self.dim_model = dim_model
        self.dim_context = dim_context
        self.dropout_rate = dropout_rate
        self.attention_layer = nn.MultiheadAttention(embed_dim=self.dim_model, 
                                                     num_heads=1, 
                                                     dropout=self.dropout_rate, 
                                                     bias=True,
                                                     add_zero_attn=False,
                                                     batch_first=True,
                                                     device=DEVICE)


        self.u_k = nn.Linear(self.dim_context, self.dim_model, bias=False)
        self.w1_k = nn.Linear(self.dim_model, 1, bias=False)
        self.w2_k = nn.Linear(self.dim_model, 1, bias=False)
        
        self.u_v = nn.Linear(self.dim_context, self.dim_model, bias=False)
        self.w1_v = nn.Linear(self.dim_model, 1, bias=False)
        self.w2_v = nn.Linear(self.dim_model, 1, bias=False)
        




    def forward(self,
                q: torch.Tensor, 
                k: torch.Tensor,
                v: torch.Tensor,
                context: Optional[torch.Tensor]=None):
        
        key_context = self.u_k(context)
        value_context = self.u_v(context)

        lambda_k = F.sigmoid(self.w1_k(k) + self.w2_k(key_context))
        lambda_v = F.sigmoid(self.w1_v(v) + self.w2_v(value_context))

        k_cap = (1 - lambda_k) * k + lambda_k * key_context
        v_cap = (1 - lambda_v) * v + lambda_v * value_context

        attention_output, _ = self.attention_layer(query=q,
                                                   key=k_cap,
                                                   value=v_cap)
        return attention_output

### Training setup

In [None]:
gen_kwargs = {
        'num_beams': NUM_BEAMS,
        'max_length': TARGET_MAX_LEN,
        'early_stopping': EARLY_STOPPING,
        'no_repeat_ngram_size': NO_REPEAT_NGRAM_SIZE
    }
    
    train(model=MODEL,
          tokenizer=TOKENIZER,
          train_data_loader=train_dataset,
          val_data_loader=val_dataset,
          test_data_loader=test_dataset,
          base_learning_rate=BASE_LEARNING_RATE,
          new_learning_rate=NEW_LEARNING_RATE,
          weight_decay=WEIGHT_DECAY,
          **gen_kwargs)
    
    print("Model Trained!")