In [None]:
# Skip NLTK imports entirely - use alternative approach
# We'll import the data loading functions we need without the full data module

import pickle
import os
from torch.utils.data import DataLoader


In [None]:
import sys
import os
os.environ['NLTK_DATA'] = '/home/tqn/.conda/envs/nvib_sa/lib/python3.10/site-packages/nltk_data'

import nltk
# Force NLTK to initialize properly
try:
    nltk.data.find('tokenizers/punkt')
except LookupError:
    nltk.download('punkt')

# Cell 2 - Then run your imports
from data_modules.ReconstructionDataModule import ReconstructionDataModule, load_prepared_data


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from models.nvib_sa_transformer_encoder import (
    NVIBTransformerEncoder,
    NVIBTransformerEncoderLayer,
)
from models.seq2seq_lightning import Seq2SeqLightning
from models.transformer import *
from models.transformer_encoder import (
    CustomTransformerEncoder,
    CustomTransformerEncoderLayer,
)

from utils import *

In [3]:

# TEMPORARY: Simple gradient checking - DELETE AFTER USE
def check_gradients_simple(model, step_name=""):
    """Simple gradient vanishing check - temporary function"""
    print(f"\nüîç GRADIENT CHECK {step_name}:")
    vanishing_count = 0
    exploding_count = 0
    total_params = 0
    
    for name, param in model.named_parameters():
        if param.grad is not None:
            grad_norm = torch.norm(param.grad).item()
            total_params += 1
            
            if grad_norm < 1e-7:
                vanishing_count += 1
                if 'decoder' in name:
                    print(f"   ‚ö†Ô∏è  VANISHING in {name}: {grad_norm:.2e}")
            elif grad_norm > 100:
                exploding_count += 1
                if 'decoder' in name:
                    print(f"   ‚ö†Ô∏è  EXPLODING in {name}: {grad_norm:.2e}")
    
    print(f"   Total params: {total_params}, Vanishing: {vanishing_count}, Exploding: {exploding_count}")
    if vanishing_count > total_params * 0.1:
        print("   üö® GRADIENT VANISHING DETECTED!")
    return vanishing_count, exploding_count

# TEMPORARY: Alpha overflow checking - DELETE AFTER USE
def check_alpha_overflow(alpha_dict, step_name=""):
    """Check for alpha overflow/underflow issues"""
    print(f"\nüìä ALPHA CHECK {step_name}:")
    
    for key, tensor in alpha_dict.items():
        if tensor is not None and torch.is_tensor(tensor):
            # Check for NaN/Inf
            has_nan = torch.isnan(tensor).any().item()
            has_inf = torch.isinf(tensor).any().item()
            
            # Get statistics
            tensor_max = torch.max(tensor).item()
            tensor_min = torch.min(tensor).item()
            tensor_mean = torch.mean(tensor).item()
            
            print(f"   {key}: shape={tensor.shape}")
            print(f"      min={tensor_min:.2e}, max={tensor_max:.2e}, mean={tensor_mean:.2e}")
            
            if has_nan:
                print(f"      üö® NaN detected in {key}!")
                nan_positions = torch.nonzero(torch.isnan(tensor))
                print(f"      NaN positions: {nan_positions[:5].tolist()}...")  # Show first 5
            
            if has_inf:
                print(f"      üö® Inf detected in {key}!")
                inf_positions = torch.nonzero(torch.isinf(tensor))
                print(f"      Inf positions: {inf_positions[:5].tolist()}...")  # Show first 5
            
            # Check for potential overflow (very large values)
            if tensor_max > 700:  # exp(700) is close to overflow
                print(f"      ‚ö†Ô∏è  Very large values in {key} (max={tensor_max:.2e}) - potential exp overflow!")
            
            # Check for potential underflow (very small values after exp)
            if 'alpha' in key.lower() and tensor_max < 1e-10:
                print(f"      ‚ö†Ô∏è  Very small alpha values (max={tensor_max:.2e}) - potential underflow!")
    
    return has_nan, has_inf

# Note:
# B: Batch size
# Ns: Source length
# Nt: Target length
# Nl: Latent length (typically = Ns)
# E: Embedding dimension
# H: Model dimension
# V: Vocab dimension


class NVIBSaTransformer(Transformer):
    """
    A vanilla Transformer Encoder-Decoder in Pytorch

    Data format:
    SRC: ... [EOS]
    TGT: ... [EOS]
    Encoder_input(SRC): ... [EOS]
    Decoder_input(TGT): [SOS] ...

    For an autoencoder x -> x (SRC = TGT)
        The loss function requires SRC and logits.
    For different models x -> y (Eg: translation SRC != TGT)
        The loss function requires TGT and logits.

    If we keep this format the attention masks for padding are identical for autoencoder's encoder + decoder .
    """

    def __init__(self, tokenizer, **kwargs):
        super().__init__(tokenizer=tokenizer, **kwargs)

        self.d_model = kwargs["d_model"]
        self.nhead = kwargs["nhead"]
        self.dim_feedforward = kwargs["dim_feedforward"]
        self.dropout = kwargs["dropout"]
        self.num_encoder_layers = kwargs["num_encoder_layers"]
        self.num_nvib_encoder_layers = kwargs["num_nvib_encoder_layers"]
        self.num_decoder_layers = kwargs["num_decoder_layers"]
        self.kappa = kwargs["kappa"]
        self.delta = kwargs["delta"]

        # Transformer encoder layer
        encoder_layer = CustomTransformerEncoderLayer(
            d_model=self.d_model,
            nhead=self.nhead,
            dim_feedforward=self.dim_feedforward,
            dropout=self.dropout,
            activation="relu",
            norm_first=True,
        )

        # NVIB Transformer encoder layer
        nvib_transformer_encoder_layer = NVIBTransformerEncoderLayer(
            d_model=self.d_model,
            nhead=self.nhead,
            dim_feedforward=self.dim_feedforward,
            dropout=self.dropout,
            activation="relu",
            kappa=self.kappa,
            delta=self.delta,
            norm_first=True,
        )
        encoder_norm = nn.LayerNorm(self.d_model, eps=1e-5)
        self.encoder = CustomTransformerEncoder(
            encoder_layer, self.num_encoder_layers, encoder_norm
        )
        self.nvib_transformer_encoder = NVIBTransformerEncoder(
            nvib_transformer_encoder_layer, self.num_nvib_encoder_layers, encoder_norm
        )

        # Transformer decoder
        decoder_layer = CustomTransformerDecoderLayer(
            self.d_model,
            self.nhead,
            self.dim_feedforward,
            self.dropout,
            norm_first=True,
        )
        decoder_norm = nn.LayerNorm(self.d_model, eps=1e-5)
        self.decoder = CustomTransformerDecoder(
            decoder_layer, self.num_decoder_layers, decoder_norm
        )

        self.pad_token_id = tokenizer.pad_token_id
        self.decoder_start_token_id = tokenizer.cls_token_id
        self.args = kwargs
        self.embedding = nn.Embedding(tokenizer.vocab_size, self.d_model, padding_idx=0)
        self.positional_encoding = PositionalEncoding(self.d_model)
        self.output_proj = nn.Linear(self.d_model, tokenizer.vocab_size)
        self.drop = nn.Dropout(self.dropout)
        
        # TEMPORARY: Add gradient hooks for debugging - DELETE AFTER USE
        self._gradient_debug_hooks = []
        self._gradient_debug_step = 0
        self._gradient_debug_frequency = 100  # Print every 100 steps
        self._add_gradient_hooks()

    # TEMPORARY: Gradient debugging methods - DELETE AFTER USE
    def _add_gradient_hooks(self):
        """Add hooks to monitor decoder gradients"""
        def make_hook(name):
            def hook(grad):
                if grad is not None and torch.isnan(grad).any():
                    print(f"üö® NaN gradient in {name}!")
                    print(f"   Grad shape: {grad.shape}")
                    print(f"   Grad norm: {torch.norm(grad).item():.2e}")
            return hook
        
        for name, param in self.named_parameters():
            if 'decoder' in name and param.requires_grad:
                hook = param.register_hook(make_hook(name))
                self._gradient_debug_hooks.append(hook)

    def enable_gradient_debug(self):
        """Enable gradient debugging - call this before training"""
        self._gradient_debug_enabled = True
        print("üîç Gradient debugging ENABLED")
        
        # TEMPORARY: Test if debugging is working
        print(f"üîç Debug enabled: {hasattr(self, '_gradient_debug_enabled')}")
        print(f"üîç Number of gradient hooks: {len(self._gradient_debug_hooks)}")
        print("üîç Ready to monitor gradients and alpha values!")

    def print_final_gradient_summary(self):
        """Print final gradient summary before finishing"""
        if hasattr(self, '_gradient_debug_enabled'):
            print(f"\n{'='*60}")
            print(f"üèÅ FINAL GRADIENT SUMMARY (Step {self._gradient_debug_step})")
            print(f"{'='*60}")
            
            # Force a gradient check regardless of frequency
            old_step = self._gradient_debug_step
            self._gradient_debug_step = (self._gradient_debug_step // self._gradient_debug_frequency + 1) * self._gradient_debug_frequency
            check_gradients_simple(self, f"FINAL_STEP_{old_step}")
            self._gradient_debug_step = old_step
            
            print(f"Total steps monitored: {self._gradient_debug_step}")
            print(f"Monitoring frequency: every {self._gradient_debug_frequency} steps")
            print(f"{'='*60}")

    def disable_gradient_debug(self):
        """Disable gradient debugging and clean up"""
        self.print_final_gradient_summary()  # Print final summary
        
        if hasattr(self, '_gradient_debug_enabled'):
            delattr(self, '_gradient_debug_enabled')
        for hook in self._gradient_debug_hooks:
            hook.remove()
        self._gradient_debug_hooks = []
        print("üîç Gradient debugging DISABLED")

    def encode(self, src, src_key_padding_mask):
        """
        Encode the input ids to embeddings and then pass to the transformer encoder
        :param src: source token ids [Ns, B]
        :param src_key_padding_mask: Trues where to mask [B,Ns]
        :return: memory: [Ns,B,H]
        """
        # Add position encodings + Embeddings
        src = self.positional_encoding(self.drop(self.embedding(src)))  # [Ns,B,H]

        # Transformer encoder
        memory1, attention1 = self.encoder(
            src, src_key_padding_mask=src_key_padding_mask
        )  # [Ns,B,H]

        # NVIB Transformer encoder
        memory2, attention2, klg, kld, latent_dict = self.nvib_transformer_encoder(
            memory1[-1], src_key_padding_mask=src_key_padding_mask
        )  # [Ns,B,H]
        # Concatenate the attention lists
        attention = attention1 + attention2
        return memory2, attention, klg, kld, latent_dict, memory1


    def decode(
        self, tgt, z, memory_key_padding_mask, tgt_key_padding_mask, *args, **kwargs
    ):
        """

        :param tgt: target token ids [Nt,B]
        :param z: output from the latent layer [Nl,B,H]
        :param memory_key_padding_mask: mask for latent layer [B, Nl] (typically Ns = Nl)
        :param tgt_key_padding_mask: target mask [B,Nt]
        :param args:
        :param kwargs:
        :return: logits over the vocabulary [Nt,B,V]
        """

        # Add position encodings + Embeddings
        tgt = self.positional_encoding(self.drop(self.embedding(tgt)))  # [Nt,B,H]
        
        # Normalize and use tanh for smooth, differentiable value bounding
        tgt = F.layer_norm(tgt, [tgt.size(-1)])
        tgt = 100 * torch.tanh(tgt / 100)  # Smooth bound to [-100, 100]
        
        # Generate target teacher forcing mask
        tgt_mask = nn.Transformer.generate_square_subsequent_mask(tgt.size(0)).to(
            tgt.device
        )  # [Nt, Nt]
        
        # Normalize and bound memory input
        z = F.layer_norm(z, [z.size(-1)])
        z = 100 * torch.tanh(z / 100)  # Smooth bound to [-100, 100]
        
        output, attention = self.decoder(
            tgt=tgt,  # [Nt,B,H]
            memory=z,  # [Nt,B,H]
            tgt_mask=tgt_mask,  # [Nt,Nt]
            tgt_key_padding_mask=tgt_key_padding_mask,  # [B,Nt]
            memory_key_padding_mask=memory_key_padding_mask,
        )  # [B,Nl]
        assert not torch.isnan(output).any(), (
            f"NaN values detected in output after decoder layer at {self.__class__.__name__}. "
            f"NaN indices: {torch.nonzero(torch.isnan(output), as_tuple=True)}. "
            f"Shape: {output.shape}, Max value: {torch.max(output)}"
            f"Output: {output}"
        )
        # Normalize decoder output
        output = F.layer_norm(output, [output.size(-1)])
        
        # Differentiable gradient norm scaling
        norm = torch.norm(output, p=2, dim=-1, keepdim=True)
        scale = torch.min(
            torch.ones_like(norm),
            5 / (norm + 1e-6)  # Target norm of 5
        )
        output = output * scale
        
        # Smooth value bounding
        output = 100 * torch.tanh(output / 100)
        
        # Apply output projection
        logits = self.output_proj(output)  # [Nt,B,V]
        
        # Final smooth bounding on logits
        logits = 100 * torch.tanh(logits / 100)
        
        return logits, attention

    def generate(self, input_ids, max_new_tokens, *args, **kwargs):
        """
        Generate autoregressively without teacher forcing
        :param z: output from the latent layer [Nl,B,H]
        :param memory_key_padding_mask: mask from the latent layer [B,Nl]
        :param max_len: maximum generation length
        :param tokenizer: tokenizer
        :param args:
        :param kwargs:
        :return: logits [Nt,B,V] and list of predictions
        """
        # Encode
        src_key_padding_mask = ~(input_ids.bool()).transpose(0, 1)  # [B,Ns]
        memory, _, _, _, self_attention_latent, _ = self.encode(
            input_ids, src_key_padding_mask=src_key_padding_mask
        )  # [Ns,B,H]

        # Mask the src_key_padding_mask with the final latent layer's pi for cross attention
        src_key_padding_mask = src_key_padding_mask + self_attention_latent[-1][
            "alpha"
        ].squeeze(-1).transpose(0, 1)[:, 1:].le(0.1)

        # Soft weighting of vectors
        # memory = memory * self_attention_latent[-1]["pi"][1:, :, :]

        # latent layer
        latent_output_dict = self.latent_layer(memory, src_key_padding_mask)
        memory_key_padding_mask = latent_output_dict["memory_key_padding_mask"]
        z = latent_output_dict["z"]

        # Initialise target ids with BOS token
        target_ids = (
            torch.tensor([[self.decoder_start_token_id]])
            .expand(memory_key_padding_mask.shape[0], -1)
            .T.to(memory_key_padding_mask.device)
        )  # [1, B]
        # For each token in length
        for token_idx in range(max_new_tokens):
            # Decode the target ids regressively
            logits, _ = self.decode(
                target_ids,
                z,
                memory_key_padding_mask,
                None,
                # latent_dict=self_attention_latent[-1]
            )  # [token_idx, B, V]
            # Select only the final set of logits
            prediction = logits[-1, :, :].unsqueeze(0)  # [target_ids1,B,V]
            # Get prediction over vocabulary and return index
            prediction = prediction.argmax(-1)  # [1,B]
            # Concatenate the predictions to form next token_ids
            target_ids = torch.cat((target_ids, prediction), dim=0)  # [token_index, B]

        # Decode into a sentence
        # predictions = [tokenizer.decode(encoded) for encoded in target_ids[1:, :].T]  # list [B]
        return target_ids[1:, :]

    def forward(
        self,
        input_ids,
        decoder_input_ids,
        labels,
        attention_mask,
        **kwargs,
    ):
        """
        Forward pass for all transformer models

        :param src: the sequence to the encoder (required). [Ns,B]
        :param tgt: the sequence  nce to the decoder (required). [Nt,B]
        :param src_mask: the additive mask for the src sequence (optional). [Ns, Ns]
        :param tgt_mask: the additive mask for the tgt sequence (optional). [Nt, Nt]
        :param memory_mask: the additive mask for the encoder output (optional). [Nt,Ns]
        :param src_key_padding_mask: the ByteTensor mask for src keys per batch (optional). [B,Ns]
        :param tgt_key_padding_mask: the ByteTensor mask for tgt keys per batch (optional). [B,Nt]
        :param memory_key_padding_mask: the ByteTensor mask for memory keys per batch (optional).[B,Nl]
        :return: logits and latent dimension dictionary

        Check out here for more info masks on https://stackoverflow.com/questions/62170439/difference-between-src-mask-and-src-key-padding-mask
        The memory ones are interesting. I use memory_key_padding_mask to mask the tokens in the latent space.

        """
        # TEMPORARY: Check if forward is being called with debugging - DELETE AFTER USE
        if hasattr(self, '_gradient_debug_enabled'):
            self._gradient_debug_step += 1
            if self._gradient_debug_step % self._gradient_debug_frequency == 0:
                print(f"üîç Step {self._gradient_debug_step}: Forward pass with gradient debugging")
        # Reformat the attention mask
        
        # Check if there is all True in attention_mask
        if attention_mask is not None:
            all_zero_sequences = torch.all(attention_mask == 0, dim=-1)
            if torch.any(all_zero_sequences):
                assert False, f"Found {torch.sum(all_zero_sequences)} sequences with all 0s in attention_mask"
        src_key_padding_mask = ~(attention_mask.bool())
        tgt_key_padding_mask = decoder_input_ids.transpose(0, 1) == self.pad_token_id
        assert not torch.isnan(input_ids).any(), (
            f"NaN values detected in input_ids at {self.__class__.__name__}. "
            f"NaN indices: {torch.nonzero(torch.isnan(input_ids), as_tuple=True)}. "
            f"Shape: {input_ids.shape}, Max value: {torch.max(input_ids)}"
            f"Input ids: {input_ids}"
        )
        # Encode
        (
            memory,
            encoder_attention,
            klg,
            kld,
            self_attention_latent,
            old_memory,
        ) = self.encode(
            input_ids, src_key_padding_mask=src_key_padding_mask
        )  # [Ns,B,H]
        assert not torch.isnan(memory).any(), (
            f"NaN values detected in memory after encode at {self.__class__.__name__}. "
            f"Shape: {memory.shape}, Max value: {torch.max(memory)}"
        )
        
                 # TEMPORARY: Quick alpha overflow check after encode - DELETE AFTER USE
        if hasattr(self, '_gradient_debug_enabled') and self._gradient_debug_step % self._gradient_debug_frequency == 0:
            alpha_check = self_attention_latent[-1].get("alpha")
            log_alpha_check = self_attention_latent[-1].get("log_alpha") 
            if alpha_check is not None:
                alpha_max = torch.max(alpha_check).item()
                alpha_has_nan = torch.isnan(alpha_check).any().item()
                alpha_has_inf = torch.isinf(alpha_check).any().item()
                if alpha_max > 1e10 or alpha_has_nan or alpha_has_inf:
                    print(f"üö® Step {self._gradient_debug_step}: ALPHA ISSUE after encode: max={alpha_max:.2e}, NaN={alpha_has_nan}, Inf={alpha_has_inf}")
                else:
                    print(f"‚úÖ Step {self._gradient_debug_step}: ALPHA OK after encode: max={alpha_max:.2e}")
            if log_alpha_check is not None:
                log_alpha_max = torch.max(log_alpha_check).item()
                if log_alpha_max > 700:  # Close to exp overflow
                    print(f"‚ö†Ô∏è  Step {self._gradient_debug_step}: LOG_ALPHA very large: max={log_alpha_max:.2e} - potential exp(log_alpha) overflow!")
        # Mask the src_key_padding_mask with the final latent layer's pi for cross attention
        alpha_tensor = self_attention_latent[-1]["alpha"]
        
        # TEMPORARY: Check alpha for overflow - DELETE AFTER USE
        if hasattr(self, '_gradient_debug_enabled') and self._gradient_debug_step % self._gradient_debug_frequency == 0:
            # Check all latent dict items for potential overflow
            # Focus on log_alpha (before exp) and alpha (after exp) to detect overflow
            alpha_items_to_check = {
                'log_alpha_before_exp': self_attention_latent[-1].get("log_alpha"),
                'alpha_after_exp': self_attention_latent[-1].get("alpha"),
                'pi': self_attention_latent[-1].get("pi"),
                'mu': self_attention_latent[-1].get("mu"),
                'logvar': self_attention_latent[-1].get("logvar")
            }
            check_alpha_overflow(alpha_items_to_check, f"LATENT_DICT_STEP_{self._gradient_debug_step}")
        
        # Check for alpha values that would cause all-masked rows (leading to NaN in attention)
        alpha_for_masking = alpha_tensor.squeeze(-1).transpose(0, 1)[:, 1:]  # [B, Ns-1] (excluding prior token)
        alpha_mask = alpha_for_masking.le(0.1)  # Boolean mask where alpha <= 0.1
        
        # Check if any sequence has ALL alpha values <= 0.1 (would cause all-inf attention mask row)
        all_masked_sequences = alpha_mask.all(dim=1)  # [B] - True where entire sequence would be masked
        
        if all_masked_sequences.any():
            problematic_batches = torch.nonzero(all_masked_sequences, as_tuple=False).squeeze(-1)
            print(f"üö® CRITICAL: Found {all_masked_sequences.sum().item()} sequences with ALL alpha <= 0.1!")
            print(f"Problematic batch indices: {problematic_batches.tolist()}")
            print(f"This will create all-inf attention mask rows leading to NaN after softmax!")
            
            # Show details for first few problematic sequences
            for i, batch_idx in enumerate(problematic_batches[:3]):
                alpha_values = alpha_for_masking[batch_idx]
                print(f"Batch {batch_idx.item()}: alpha values = {alpha_values}")
                print(f"Batch {batch_idx.item()}: min_alpha = {alpha_values.min().item():.6f}, max_alpha = {alpha_values.max().item():.6f}")
                print(f"Batch {batch_idx.item()}: src_key_padding_mask before = {src_key_padding_mask[batch_idx]}")
        
        if src_key_padding_mask is not None:
            all_true_sequences = torch.all(src_key_padding_mask, dim=-1)
            if torch.any(all_true_sequences):
                assert False, f"Found {torch.sum(all_true_sequences)} sequences with all True in src_key_padding_mask"
        
        # src_key_padding_mask = src_key_padding_mask + alpha_mask
        # if alpha_mask is not None:
        #     all_true_sequences = torch.all(alpha_mask, dim=-1)
        #     if torch.any(all_true_sequences):
        #         assert False, f"Found {torch.sum(all_true_sequences)} sequences with all True in alpha_mask"
        
        # if src_key_padding_mask is not None:
        #     all_zero_sequences = torch.all(src_key_padding_mask, dim=-1)
        #     if torch.any(all_zero_sequences):
        #         assert False, f"Found {torch.sum(all_zero_sequences)} sequences with all 0s in src_key_padding_mask"
        

        # Soft weighting of vectors
        # memory = memory * self_attention_latent[-1]["pi"][1:, :, :]

        # latent layer
        latent_output_dict = self.latent_layer(memory, src_key_padding_mask)
        # Decode
        assert not torch.isnan(latent_output_dict["z"]).any(), (
            f"NaN values detected in latent_output_dict['z'] at {self.__class__.__name__}. "
            f"Shape: {latent_output_dict['z'].shape}, Max value: {torch.max(latent_output_dict['z'])}"
        )

        output, decoder_attention = self.decode(
            tgt=decoder_input_ids,  # [Nt,B,H]
            z=latent_output_dict["z"],  # [Nl,B,H]
            tgt_key_padding_mask=tgt_key_padding_mask,  # [B,Nt]
            memory_key_padding_mask=latent_output_dict["memory_key_padding_mask"],
            # latent_dict=self_attention_latent[-1],
        )  # [B,Nl]
        assert not torch.isnan(output).any(), (
            f"NaN values detected in output decode at {self.__class__.__name__}. "
            f"Shape: {output.shape}, Max value: {torch.max(output)}"
        )

        # TEMPORARY: Add backward hook to check gradients - DELETE AFTER USE
        def backward_hook():
            if hasattr(self, '_gradient_debug_enabled') and self._gradient_debug_step % self._gradient_debug_frequency == 0:
                check_gradients_simple(self, f"AFTER_DECODE_STEP_{self._gradient_debug_step}")
        
        if hasattr(self, '_gradient_debug_enabled') and output.requires_grad:
            output.register_hook(lambda grad: backward_hook())
        # Check if there is all 0 in attention_mask
        return {
            "logits": output,  # [Nt, B, V]
            "encoder_attentions": encoder_attention,  # Self attention
            "cross_attentions": decoder_attention,  # Cross attention
            "kl_gaussian": klg,
            "kl_dirichlet": kld,
            "latent_dict_list": self_attention_latent,
            "old_memory": old_memory,
            "old_memory_mask": ~(attention_mask.bool()),
        }

class NVIBSaTransformerLightning(Seq2SeqLightning):
    def __init__(self, args, **kwargs):
        super().__init__(args, **kwargs)

        # Tokenizer
        self.tokenizer = CharizardTokenizer(model_max_length=args.max_length)
        # self.tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

        # Model
        self.model = NVIBSaTransformer(tokenizer=self.tokenizer, **vars(args))
        # self.model.enable_gradient_debug()
        # print("‚úÖ Called enable_gradient_debug() in Lightning model")
        # print(f"‚úÖ Model has debug attribute: {hasattr(self.model, '_gradient_debug_enabled')}")
        
        # TEMPORARY: Register cleanup for crashes - DELETE AFTER USE
        import atexit
        atexit.register(self.model.print_final_gradient_summary)
        # Nvib
        self.lambda_klg = args.klg_lambda
        self.lambda_kld = args.kld_lambda

        # Logging metrics
        self.log_bleu = True
        self.log_chrf = True
        self.plot_encoder_attention = True
        self.plot_cross_attention = True
        self.model_type = "NVIBSaTransformer"
        self.is_nvib = True
        self.weighted_kl = args.weighted_kl

        # Initialization
        init_weights(self.model)

        self.save_hyperparameters()


In [4]:
class Args:
    # Paths + Naming
    experiment_name: str = "nvib-paper"
    project_name: str = "nvib_selfattention"
    output_dir: str = "outputs"
    entity: str = None

    # Data
    data: str = "wikitext"
    data_subset: str = "wikitext-2-raw-v1"
    num_workers: int = 17
    max_length: int = 128

    # Model
    model: str = "NVIBSaTransformer"

    # Training
    seed: int = 42
    fast_dev_run: bool = False
    fp16: bool = False

    # Transformer
    d_model: int = 512
    d_compress_model: int = 512
    nhead: int = 8
    num_encoder_layers: int = 3
    num_decoder_layers: int = 2
    dim_feedforward: int = 512
    dropout: float = 0.1

    # Token deletion probability
    deletion_prob: float = 0
    deletion_type: str = "token"  # choices: ["token", "word", "token_word", "None"]

    # NVIB
    num_nvib_encoder_layers: int = 1
    kappa: float = 1
    delta: float = 1
    klg_lambda: float = 0.001
    kld_lambda: float = 1
    kl_annealing_type: str = "constant"
    weighted_kl: bool = True

    # Learning rate + Schedulers
    learning_rate: float = 1e-4
    lr_scheduler: bool = False
    perc_warmup: float = 0
    batch_size: int = 1

    # PL Trainer
    max_time: str = None
    max_steps: int = 1000
    max_epochs: int = None
    accumulate_grad_batches: int = 1
    checkpoint_interval: int = 100
    validation_interval: int = None

    def __init__(self):
        # Make sure all attributes are properly initialized
        for attr_name in dir(self):
            if not attr_name.startswith('__'):
                setattr(self, attr_name, getattr(self, attr_name))

args = Args()


In [None]:
model = {
    "Transformer": TransformerLightning,
    "NVIBSaTransformer": NVIBSaTransformerLightning,
}[args.model](args)

In [None]:
print(model)

In [None]:
print(f"Model type: {args.model}")
print(f"d_model: {args.d_model}")
print(f"num_nvib_encoder_layers: {args.num_nvib_encoder_layers}")
print(f"weighted_kl: {args.weighted_kl}")
print(f"klg_lambda: {args.klg_lambda}")
print(f"kld_lambda: {args.kld_lambda}")

In [8]:
OUTPUT_PATH = "/project/phan/tqn/new_Adapter/nvib_selfattention/outputs/nvib_selfattention/nvib-debug13/"
CHECKPOINT_PATH = get_checkpoint_path(OUTPUT_PATH)
BEST_MODEL_PATH = "/project/phan/tqn/new_Adapter/nvib_selfattention/outputs/nvib_selfattention/nvib-debug13/best_model.ckpt"

In [None]:
CHECKPOINT_PATH = "/project/phan/tqn/new_Adapter/nvib_selfattention/outputs/nvib_selfattention/nvib-debug13/epoch=50-step=8000.ckpt"
print(CHECKPOINT_PATH)

In [None]:
model, wandb_id = create_or_load_model(OUTPUT_PATH, CHECKPOINT_PATH, model, args)

In [None]:
# Example usage - load data without ReconstructionDataModule
# First, make sure you have run train.py to generate the tokenized data

# If you need to use the data loading functionality:
# 1. Make sure train.py has been run to generate the preprocessed data
# 2. Use the direct loading function instead

# Example:
# model = your_model  # Your transformer model
# tokenizer = model.tokenizer
# data_loader = load_prepared_data_direct(
#     tokenizer=tokenizer,
#     name="train", 
#     data_name="wikitext",
#     model_name="Transformer"
# )

print("Use load_prepared_data_direct() instead of importing ReconstructionDataModule")


In [None]:
print(model)

In [None]:
from data_modules.ReconstructionDataModule import ReconstructionDataModule, load_prepared_data
dict_args = vars(args)
dm = ReconstructionDataModule(model, **dict_args)
data_loader = load_prepared_data(
    tokenizer=dm.tokenizer,
    name="train",
    data_name=dm.data,
    model_name=dm.model_name,
)
for batch in data_loader:
    print(batch)
    break

In [36]:
from utils import show_attention, strip_after_token

In [None]:
# Initialize decoder with BOS token
input_text = "homarus gammarus, known as the european lobster or common lobster, is a species of clawed lobster from the eastern atlantic"
inputs = model.tokenizer(input_text,
                      return_tensors="pt")
print(inputs)

input_ids = inputs["input_ids"].transpose(0, 1).to('cuda')  # [Ns,B]

# Set model to evaluation mode
# model.model.eval()
model.model.to('cuda')
model.model.eval()

# Generate sequence using the model's generate method
# with torch.no_grad():
# Generate up to 50 new tokens
generated_ids = model.model.generate(
    input_ids=input_ids,
    max_new_tokens=512
)

# Decode the generated sequence
# input_text = model.tokenizer.decode(input_ids.transpose(0, 1), skip_special_tokens=True)
decoded_text = model.tokenizer.batch_decode(generated_ids.transpose(0, 1))
decoded_text = strip_after_token(decoded_text, model.tokenizer.sep_token)
print(f"Input text: {input_text}")
print(f"Generated text: {decoded_text}")

In [None]:
decoded = model.tokenizer.decode(inputs['input_ids'][0], skip_special_tokens=True)
print(f"Decoded text: {decoded}")

In [None]:
# Test the hypothesis: Compare training vs eval mode
input_text = "homarus gammarus, known as the european lobster or common lobster, is a species of clawed lobster from the eastern atlantic"
inputs = model.tokenizer(input_text, return_tensors="pt")
input_ids = inputs["input_ids"].transpose(0, 1).to('cuda')  # [Ns,B]
print(input_ids.transpose(0, 1))
print("=== TRAINING MODE (like validation) ===")
model.model.train()  # Set to training mode
with torch.no_grad():
    generated_ids_train = model.model.generate(
        input_ids=input_ids,
        max_new_tokens=5
    )
    decoded_text_train = model.tokenizer.batch_decode(generated_ids_train.transpose(0, 1), skip_special_tokens=True)
    print(f"Training mode output: {decoded_text_train}")

print("\n=== EVAL MODE (deterministic) ===")
model.model.eval()  # Set to eval mode
with torch.no_grad():
    generated_ids_eval = model.model.generate(
        input_ids=input_ids,
        max_new_tokens=5
    )
    decoded_text_eval = model.tokenizer.batch_decode(generated_ids_eval.transpose(0, 1), skip_special_tokens=True)
    print(f"Eval mode output: {decoded_text_eval}")

print("\n=== COMPARISON ===")
print(f"Same output? {decoded_text_train == decoded_text_eval}")
