In [39]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from models.nvib_sa_transformer_encoder import (
    NVIBTransformerEncoder,
    NVIBTransformerEncoderLayer,
)
from models.seq2seq_lightning import Seq2SeqLightning
from models.transformer import *
from models.transformer_encoder import (
    CustomTransformerEncoder,
    CustomTransformerEncoderLayer,
)

from utils import *

In [40]:
class NVIBSaTransformerLightning(Seq2SeqLightning):
    def __init__(self, args, **kwargs):
        super().__init__(args, **kwargs)

        # Tokenizer
        self.tokenizer = CharizardTokenizer(model_max_length=args.max_length)
        # Model
        self.model = NVIBSaTransformer(tokenizer=self.tokenizer, **vars(args))

        # Nvib
        self.lambda_klg = args.klg_lambda
        self.lambda_kld = args.kld_lambda

        # Logging metrics
        self.log_bleu = True
        self.log_chrf = True
        self.plot_encoder_attention = True
        self.plot_cross_attention = True
        self.model_type = "NVIBSaTransformer"
        self.is_nvib = True
        self.weighted_kl = args.weighted_kl

        # Initialization
        init_weights(self.model)

        self.save_hyperparameters()

class NVIBSaTransformer(Transformer):
    """
    A vanilla Transformer Encoder-Decoder in Pytorch

    Data format:
    SRC: ... [EOS]
    TGT: ... [EOS]
    Encoder_input(SRC): ... [EOS]
    Decoder_input(TGT): [SOS] ...

    For an autoencoder x -> x (SRC = TGT)
        The loss function requires SRC and logits.
    For different models x -> y (Eg: translation SRC != TGT)
        The loss function requires TGT and logits.

    If we keep this format the attention masks for padding are identical for autoencoder's encoder + decoder .
    """

    def __init__(self, tokenizer, **kwargs):
        super().__init__(tokenizer=tokenizer, **kwargs)
        print(kwargs)
        self.d_model = kwargs["d_model"]
        self.nhead = kwargs["nhead"]
        self.dim_feedforward = kwargs["dim_feedforward"]
        self.dropout = kwargs["dropout"]
        self.num_encoder_layers = kwargs["num_encoder_layers"]
        self.num_nvib_encoder_layers = kwargs["num_nvib_encoder_layers"]
        self.num_decoder_layers = kwargs["num_decoder_layers"]
        self.kappa = kwargs["kappa"]
        self.delta = kwargs["delta"]

        # Transformer encoder layer
        encoder_layer = CustomTransformerEncoderLayer(
            d_model=self.d_model,
            nhead=self.nhead,
            dim_feedforward=self.dim_feedforward,
            dropout=self.dropout,
            activation="relu",
        )

        # NVIB Transformer encoder layer
        nvib_transformer_encoder_layer = NVIBTransformerEncoderLayer(
            d_model=self.d_model,
            nhead=self.nhead,
            dim_feedforward=self.dim_feedforward,
            dropout=self.dropout,
            activation="relu",
            kappa=self.kappa,
            delta=self.delta,
        )
        encoder_norm = nn.LayerNorm(self.d_model, eps=1e-5)
        self.encoder = CustomTransformerEncoder(
            encoder_layer, self.num_encoder_layers, encoder_norm
        )
        self.nvib_transformer_encoder = NVIBTransformerEncoder(
            nvib_transformer_encoder_layer, self.num_nvib_encoder_layers, encoder_norm
        )

        # Transformer decoder
        decoder_layer = CustomTransformerDecoderLayer(
            self.d_model,
            self.nhead,
            self.dim_feedforward,
            self.dropout,
        )
        decoder_norm = nn.LayerNorm(self.d_model, eps=1e-5)
        self.decoder = CustomTransformerDecoder(
            decoder_layer, self.num_decoder_layers, decoder_norm
        )

        self.pad_token_id = tokenizer.pad_token_id
        self.decoder_start_token_id = tokenizer.cls_token_id
        self.args = kwargs
        self.embedding = nn.Embedding(tokenizer.vocab_size, self.d_model, padding_idx=0)
        self.positional_encoding = PositionalEncoding(self.d_model)
        self.output_proj = nn.Linear(self.d_model, tokenizer.vocab_size)
        self.drop = nn.Dropout(self.dropout)

    def encode(self, src, src_key_padding_mask):
        """
        Encode the input ids to embeddings and then pass to the transformer encoder
        :param src: source token ids [Ns, B]
        :param src_key_padding_mask: Trues where to mask [B,Ns]
        :return: memory: [Ns,B,H]
        """
        # Add position encodings + Embeddings
        src = self.positional_encoding(self.drop(self.embedding(src)))  # [Ns,B,H]

        # Transformer encoder
        memory1, attention1 = self.encoder(
            src, src_key_padding_mask=src_key_padding_mask
        )  # [Ns,B,H]

        # NVIB Transformer encoder
        memory2, attention2, klg, kld, latent_dict = self.nvib_transformer_encoder(
            memory1[-1], src_key_padding_mask=src_key_padding_mask
        )  # [Ns,B,H]
        # Concatenate the attention lists
        attention = attention1 + attention2
        return memory2, attention, klg, kld, latent_dict, memory1

    def decode(
        self, tgt, z, memory_key_padding_mask, tgt_key_padding_mask, *args, **kwargs
    ):
        """

        :param tgt: target token ids [Nt,B]
        :param z: output from the latent layer [Nl,B,H]
        :param memory_key_padding_mask: mask for latent layer [B, Nl] (typically Ns = Nl)
        :param tgt_key_padding_mask: target mask [B,Nt]
        :param args:
        :param kwargs:
        :return: logits over the vocabulary [Nt,B,V]
        """

        # Add position encodings + Embeddings
        tgt = self.positional_encoding(self.drop(self.embedding(tgt)))  # [Nt,B,H]
        # Generate target teacher forcing mask
        tgt_mask = nn.Transformer.generate_square_subsequent_mask(tgt.size(0)).to(
            tgt.device
        )  # [Nt, Nt]
        output, attention = self.decoder(
            tgt=tgt,  # [Nt,B,H]
            memory=z,  # [Nt,B,H]
            tgt_mask=tgt_mask,  # [Nt,Nt]
            tgt_key_padding_mask=tgt_key_padding_mask,  # [B,Nt]
            memory_key_padding_mask=memory_key_padding_mask,
            # latent_dict=kwargs["latent_dict"] if "latent_dict" in kwargs else None,
        )  # [B,Nl]
        logits = self.output_proj(output)  # [Nt,B,V]
        return logits, attention

    def generate(self, input_ids, max_new_tokens, *args, **kwargs):
        """
        Generate autoregressively without teacher forcing
        :param z: output from the latent layer [Nl,B,H]
        :param memory_key_padding_mask: mask from the latent layer [B,Nl]
        :param max_len: maximum generation length
        :param tokenizer: tokenizer
        :param args:
        :param kwargs:
        :return: logits [Nt,B,V] and list of predictions
        """
        # Encode
        src_key_padding_mask = ~(input_ids.bool()).transpose(0, 1)  # [B,Ns]
        memory, _, _, _, self_attention_latent, _ = self.encode(
            input_ids, src_key_padding_mask=src_key_padding_mask
        )  # [Ns,B,H]

        # Mask the src_key_padding_mask with the final latent layer's pi for cross attention
        src_key_padding_mask = src_key_padding_mask + self_attention_latent[-1][
            "alpha"
        ].squeeze(-1).transpose(0, 1)[:, 1:].le(0.1)

        # Soft weighting of vectors
        # memory = memory * self_attention_latent[-1]["pi"][1:, :, :]

        # latent layer
        latent_output_dict = self.latent_layer(memory, src_key_padding_mask)
        memory_key_padding_mask = latent_output_dict["memory_key_padding_mask"]
        z = latent_output_dict["z"]

        # Initialise target ids with BOS token
        target_ids = (
            torch.tensor([[self.decoder_start_token_id]])
            .expand(memory_key_padding_mask.shape[0], -1)
            .T.to(memory_key_padding_mask.device)
        )  # [1, B]
        # For each token in length
        for token_idx in range(max_new_tokens):
            # Decode the target ids regressively
            logits, _ = self.decode(
                target_ids,
                z,
                memory_key_padding_mask,
                None,
                # latent_dict=self_attention_latent[-1]
            )  # [token_idx, B, V]
            # Select only the final set of logits
            prediction = logits[-1, :, :].unsqueeze(0)  # [target_ids1,B,V]
            # Get prediction over vocabulary and return index
            prediction = prediction.argmax(-1)  # [1,B]
            # Concatenate the predictions to form next token_ids
            target_ids = torch.cat((target_ids, prediction), dim=0)  # [token_index, B]
        

        # Decode into a sentence
        # predictions = [tokenizer.decode(encoded) for encoded in target_ids[1:, :].T]  # list [B]
        return target_ids[1:, :]

    def forward(
        self,
        input_ids,
        decoder_input_ids,
        labels,
        attention_mask,
        **kwargs,
    ):
        """
        Forward pass for all transformer models

        :param src: the sequence to the encoder (required). [Ns,B]
        :param tgt: the sequence  nce to the decoder (required). [Nt,B]
        :param src_mask: the additive mask for the src sequence (optional). [Ns, Ns]
        :param tgt_mask: the additive mask for the tgt sequence (optional). [Nt, Nt]
        :param memory_mask: the additive mask for the encoder output (optional). [Nt,Ns]
        :param src_key_padding_mask: the ByteTensor mask for src keys per batch (optional). [B,Ns]
        :param tgt_key_padding_mask: the ByteTensor mask for tgt keys per batch (optional). [B,Nt]
        :param memory_key_padding_mask: the ByteTensor mask for memory keys per batch (optional).[B,Nl]
        :return: logits and latent dimension dictionary

        Check out here for more info masks on https://stackoverflow.com/questions/62170439/difference-between-src-mask-and-src-key-padding-mask
        The memory ones are interesting. I use memory_key_padding_mask to mask the tokens in the latent space.

        """
        # Reformat the attention mask
        src_key_padding_mask = ~(attention_mask.bool())
        tgt_key_padding_mask = decoder_input_ids.transpose(0, 1) == self.pad_token_id

        # Encode
        (
            memory,
            encoder_attention,
            klg,
            kld,
            self_attention_latent,
            old_memory,
        ) = self.encode(
            input_ids, src_key_padding_mask=src_key_padding_mask
        )  # [Ns,B,H]

        # Mask the src_key_padding_mask with the final latent layer's pi for cross attention
        alpha = self_attention_latent[-1]["alpha"].squeeze(-1).transpose(0, 1)  # [B, Ns]
        if alpha.size(1) != src_key_padding_mask.size(1):
            # If alpha has an extra token (prior token), remove it
            alpha = alpha[:, 1:]
            # If alpha is still smaller, pad it to match src_key_padding_mask
            if alpha.size(1) < src_key_padding_mask.size(1):
                alpha = F.pad(alpha, (0, src_key_padding_mask.size(1) - alpha.size(1)), value=0)
        src_key_padding_mask = src_key_padding_mask + alpha.le(0.1)

        # Soft weighting of vectors
        # memory = memory * self_attention_latent[-1]["pi"][1:, :, :]

        # latent layer
        latent_output_dict = self.latent_layer(memory, src_key_padding_mask)
        # Decode
        output, decoder_attention = self.decode(
            tgt=decoder_input_ids,  # [Nt,B,H]
            z=latent_output_dict["z"],  # [Nl,B,H]
            tgt_key_padding_mask=tgt_key_padding_mask,  # [B,Nt]
            memory_key_padding_mask=latent_output_dict["memory_key_padding_mask"],
            # latent_dict=self_attention_latent[-1],
        )  # [B,Nl]

        return {
            "logits": output,  # [Nt, B, V]
            "encoder_attentions": encoder_attention,  # Self attention
            "cross_attentions": decoder_attention,  # Cross attention
            "kl_gaussian": klg,
            "kl_dirichlet": kld,
            "latent_dict_list": self_attention_latent,
            "old_memory": old_memory,
            "old_memory_mask": ~(attention_mask.bool()),
        }



In [41]:
class Args:
    # Paths + Naming
    experiment_name: str = "nvib-paper"
    project_name: str = "nvib_selfattention"
    output_dir: str = "outputs"
    entity: str = None

    # Data
    data: str = "wikitext"
    data_subset: str = "wikitext-2-raw-v1"
    num_workers: int = 17
    max_length: int = 128

    # Model
    model: str = "NVIBSaTransformer"

    # Training
    seed: int = 42
    fast_dev_run: bool = False
    fp16: bool = False

    # Transformer
    d_model: int = 512
    d_compress_model: int = 512
    nhead: int = 8
    num_encoder_layers: int = 3
    num_decoder_layers: int = 2
    dim_feedforward: int = 512
    dropout: float = 0.1

    # Token deletion probability
    deletion_prob: float = 0
    deletion_type: str = "token"  # choices: ["token", "word", "token_word", "None"]

    # NVIB
    num_nvib_encoder_layers: int = 1
    kappa: float = 1
    delta: float = 1
    klg_lambda: float = 0.001
    kld_lambda: float = 1
    kl_annealing_type: str = "constant"
    weighted_kl: bool = True

    # Learning rate + Schedulers
    learning_rate: float = 1e-4
    lr_scheduler: bool = False
    perc_warmup: float = 0
    batch_size: int = 1

    # PL Trainer
    max_time: str = None
    max_steps: int = 1000
    max_epochs: int = None
    accumulate_grad_batches: int = 1
    checkpoint_interval: int = 100
    validation_interval: int = None

    def __init__(self):
        # Make sure all attributes are properly initialized
        for attr_name in dir(self):
            if not attr_name.startswith('__'):
                setattr(self, attr_name, getattr(self, attr_name))

args = Args()


In [42]:
args.d_model

512

In [43]:
model = {
    "Transformer": TransformerLightning,
    "NVIBSaTransformer": NVIBSaTransformerLightning,
}[args.model](args)

{'accumulate_grad_batches': 1, 'batch_size': 1, 'checkpoint_interval': 100, 'd_compress_model': 512, 'd_model': 512, 'data': 'wikitext', 'data_subset': 'wikitext-2-raw-v1', 'deletion_prob': 0, 'deletion_type': 'token', 'delta': 1, 'dim_feedforward': 512, 'dropout': 0.1, 'entity': None, 'experiment_name': 'nvib-paper', 'fast_dev_run': False, 'fp16': False, 'kappa': 1, 'kl_annealing_type': 'constant', 'kld_lambda': 1, 'klg_lambda': 0.001, 'learning_rate': 0.0001, 'lr_scheduler': False, 'max_epochs': None, 'max_length': 128, 'max_steps': 1000, 'max_time': None, 'model': 'NVIBSaTransformer', 'nhead': 8, 'num_decoder_layers': 2, 'num_encoder_layers': 3, 'num_nvib_encoder_layers': 1, 'num_workers': 17, 'output_dir': 'outputs', 'perc_warmup': 0, 'project_name': 'nvib_selfattention', 'seed': 42, 'validation_interval': None, 'weighted_kl': True}


In [44]:
print(f"Model type: {args.model}")
print(f"d_model: {args.d_model}")
print(f"num_nvib_encoder_layers: {args.num_nvib_encoder_layers}")
print(f"weighted_kl: {args.weighted_kl}")
print(f"klg_lambda: {args.klg_lambda}")
print(f"kld_lambda: {args.kld_lambda}")

Model type: NVIBSaTransformer
d_model: 512
num_nvib_encoder_layers: 1
weighted_kl: True
klg_lambda: 0.001
kld_lambda: 1


In [45]:
OUTPUT_PATH = "/project/phan/tqn/new_Adapter/nvib_selfattention/outputs/nvib_selfattention/nvib-debug4/"
CHECKPOINT_PATH = get_checkpoint_path(OUTPUT_PATH)
BEST_MODEL_PATH = "/project/phan/tqn/new_Adapter/nvib_selfattention/outputs/nvib_selfattention/nvib-debug4/best_model.ckpt"

In [46]:
print(CHECKPOINT_PATH)

/project/phan/tqn/new_Adapter/nvib_selfattention/outputs/nvib_selfattention/nvib-debug4/epoch=40-step=6500.ckpt


In [47]:
model, wandb_id = create_or_load_model(OUTPUT_PATH, BEST_MODEL_PATH, model, args)

Loading model
Loading model:  /project/phan/tqn/new_Adapter/nvib_selfattention/outputs/nvib_selfattention/nvib-debug4/best_model.ckpt
{'accumulate_grad_batches': 1, 'batch_size': 1, 'checkpoint_interval': 100, 'd_compress_model': 512, 'd_model': 512, 'data': 'wikitext', 'data_subset': 'wikitext-2-raw-v1', 'deletion_prob': 0, 'deletion_type': 'token', 'delta': 1, 'dim_feedforward': 512, 'dropout': 0.1, 'entity': None, 'experiment_name': 'nvib-paper', 'fast_dev_run': False, 'fp16': False, 'kappa': 1, 'kl_annealing_type': 'constant', 'kld_lambda': 1, 'klg_lambda': 0.001, 'learning_rate': 0.0001, 'lr_scheduler': False, 'max_epochs': None, 'max_length': 128, 'max_steps': 1000, 'max_time': None, 'model': 'NVIBSaTransformer', 'nhead': 8, 'num_decoder_layers': 2, 'num_encoder_layers': 3, 'num_nvib_encoder_layers': 1, 'num_workers': 17, 'output_dir': 'outputs', 'perc_warmup': 0, 'project_name': 'nvib_selfattention', 'seed': 42, 'validation_interval': None, 'weighted_kl': True}
Loading W&B ID


In [48]:
print(model)

NVIBSaTransformerLightning(
  (model): NVIBSaTransformer(
    (encoder): CustomTransformerEncoder(
      (layers): ModuleList(
        (0-2): 3 x CustomTransformerEncoderLayer(
          (self_attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
          )
          (linear1): Linear(in_features=512, out_features=512, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
          (linear2): Linear(in_features=512, out_features=512, bias=True)
          (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (dropout1): Dropout(p=0.1, inplace=False)
          (dropout2): Dropout(p=0.1, inplace=False)
        )
      )
      (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
    )
    (decoder): CustomTransformerDecoder(
      (layers): ModuleList(
        (0-1): 2 x CustomTransformerDecoderLayer(
      

In [49]:
import torch

input_text = "Hello, how are you?"
inputs = model.tokenizer(input_text, 
                      padding=True, 
                      truncation=True, 
                      return_tensors="pt")
print(inputs)

{'input_ids': tensor([[  0,  50,  21,  28,  28,  31,  80, 101,  24,  31,  39, 101,  17,  34,
          21, 101,  41,  31,  37,  89,   1]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}


In [50]:
input_ids = inputs["input_ids"].transpose(0, 1).to('cuda')  # [Ns,B]
attention_mask = inputs["attention_mask"].to('cuda')  # [B,Ns]

In [51]:
print(attention_mask)

tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]],
       device='cuda:0')


In [52]:
print(len(inputs["input_ids"][0]))

21


In [53]:
# Initialize decoder with BOS token
input_text = "the game's opening theme was sung by"
inputs = model.tokenizer(input_text, 
                      padding=True, 
                      truncation=True, 
                      return_tensors="pt")

input_ids = inputs["input_ids"].transpose(0, 1).to('cuda')  # [Ns,B]

# Set model to evaluation mode
model.model.eval()
model.model.to('cuda')

# Generate sequence using the model's generate method
with torch.no_grad():
    # Generate up to 50 new tokens
    generated_ids = model.model.generate(
        input_ids=input_ids,
        max_new_tokens=50
    )
    
    # Decode the generated sequence
    input_text = model.tokenizer.decode(input_ids.transpose(0, 1)[0], skip_special_tokens=True)
    decoded_text = model.tokenizer.decode(generated_ids.transpose(0, 1)[0], skip_special_tokens=True)
    print(f"Input text: {input_text}")
    print(f"Generated text: {decoded_text}")

Input text: the game's opening theme was sung by
Generated text: l t .........................................


In [54]:
decoded = model.tokenizer.decode(inputs['input_ids'][0], skip_special_tokens=True)
print(f"Decoded text: {decoded}")

Decoded text: the game's opening theme was sung by
