In [2]:
B = 2
I = 5
# Example sentences list
sentences_list = [
    "This is the first sentence.",
    "This is the second sentence.",
    "This is the third sentence.",
    "This is the fourth sentence.",
    "This is the fifth sentence.", 
    "This is the sixth sentence.",
    "This is the seventh sentence.",
    "This is the eighth sentence.",
    "This is the ninth sentence.",
    "This is the tenth sentence."
]

# Option 1: Use Python's list slicing
sentences_reshaped = [sentences_list[i * I:(i + 1) * I] for i in range(B)]

sentences_reshaped


[['This is the first sentence.',
  'This is the second sentence.',
  'This is the third sentence.',
  'This is the fourth sentence.',
  'This is the fifth sentence.'],
 ['This is the sixth sentence.',
  'This is the seventh sentence.',
  'This is the eighth sentence.',
  'This is the ninth sentence.',
  'This is the tenth sentence.']]

In [1]:
from sentence_transformers import SentenceTransformer
import torch

st_model_name = 'sentence-transformers/all-mpnet-base-v2'
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
cache_folder = '/home/anwoy/phuhoang/ped/saved_models/sentence-transformers/all-mpnet-base-v2'
local_file_only = True
torch_dtype = torch.float32

sbert = SentenceTransformer(st_model_name,
                                         device=device,
                                        cache_folder=cache_folder,
                                        local_files_only=local_file_only,
                                        model_kwargs={
                                            'torch_dtype': torch_dtype
                                        })

In [5]:
sentence = sbert.encode(sentences_list, convert_to_tensor=True)
print(sentence.shape)  # Should print torch.Size([B, I, 768]) if B=2 and I=5
print(sentence)
reshaped_sentence = sentence.view(B, I, -1)  # Reshape to (B, I, 768)
print(reshaped_sentence.shape)  # Should print torch.Size([2, 5, 768])
print(reshaped_sentence)

torch.Size([10, 768])
tensor([[ 5.1659e-02, -5.9793e-02,  2.1608e-03,  ..., -1.0040e-02,
         -4.2099e-02, -2.0275e-02],
        [ 6.1750e-02, -3.6361e-02,  5.6251e-04,  ..., -1.3233e-02,
         -4.6793e-02, -2.9813e-02],
        [ 9.5577e-02, -6.4551e-02, -4.1595e-03,  ...,  1.9570e-04,
         -6.7684e-02, -2.2272e-02],
        ...,
        [ 3.1220e-02,  3.3037e-02, -3.3743e-03,  ...,  1.9906e-02,
         -9.1838e-02, -2.3821e-02],
        [ 2.1601e-02,  1.4657e-02,  8.7469e-05,  ...,  1.4291e-02,
         -6.8486e-02, -1.6401e-02],
        [ 2.1405e-02,  5.6179e-03, -6.9115e-03,  ...,  3.6593e-02,
         -8.4533e-02, -1.7980e-02]], device='cuda:0')
torch.Size([2, 5, 768])
tensor([[[ 5.1659e-02, -5.9793e-02,  2.1608e-03,  ..., -1.0040e-02,
          -4.2099e-02, -2.0275e-02],
         [ 6.1750e-02, -3.6361e-02,  5.6251e-04,  ..., -1.3233e-02,
          -4.6793e-02, -2.9813e-02],
         [ 9.5577e-02, -6.4551e-02, -4.1595e-03,  ...,  1.9570e-04,
          -6.7684e-02, -2.2

# Test new model

In [63]:
import torch
import torch.nn as nn
# import torch.nn.functional as F
# import lightning as L
from dataclasses import dataclass
# from torch.optim.lr_scheduler import CosineAnnealingLR
from sentence_transformers import SentenceTransformer
# from lightning.pytorch.utilities import grad_norm

# from model.transformer_blocks import Block, CausalSelfAttention, MLP

In [64]:
import torch.nn as nn
import torch.nn.functional as F

class CausalSelfAttention(nn.Module):
    def __init__(self, n_embd, n_head):
        super().__init__()
        assert n_embd % n_head == 0
        self.n_embd = n_embd
        self.n_head = n_head

        self.c_attn = nn.Linear(n_embd, 3 * n_embd)
        self.c_proj = nn.Linear(n_embd, n_embd)
        self.c_proj.SCALE_INIT = 1

    def forward(self, x):
        B, T, C = x.size()
        qkv = self.c_attn(x)                   # (B, T, 3C)
        q, k, v = qkv.split(self.n_embd, dim=2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
        y = y.transpose(1, 2).contiguous().view(B, T, C)
        y = self.c_proj(y)
        return y


class MLP(nn.Module):
    def __init__(self, n_embd, mlp_ratio=4.0):
        super().__init__()
        hidden_dim = int(n_embd * mlp_ratio)
        self.c_fc = nn.Linear(n_embd, hidden_dim)
        self.gelu = nn.GELU(approximate='tanh')
        self.c_proj = nn.Linear(hidden_dim, n_embd)
        self.c_proj.SCALE_INIT = 1

    def forward(self, x):
        x = self.c_fc(x)
        x = self.gelu(x)
        x = self.c_proj(x)
        return x


class Block(nn.Module):
    # removed the 'idea_vector' argument / logic
    def __init__(self, n_embd, n_head, mlp_ratio=4.0, norm_layer=nn.LayerNorm):
        super().__init__()
        self.ln_1 = norm_layer(n_embd)
        self.attn = CausalSelfAttention(n_embd, n_head)
        self.ln_2 = norm_layer(n_embd)
        self.mlp = MLP(n_embd, mlp_ratio)

    def forward(self, x):
        x = x + self.attn(self.ln_1(x))
        x = x + self.mlp(self.ln_2(x))
        return x

In [65]:
class PretrainedIdearEncoder(nn.Module):
    def __init__(self, st_model_name='sentence-transformers/all-mpnet-base-v2'):
        super().__init__()
        # Load the Sentence Transformer model
        self.sbert = SentenceTransformer(st_model_name)
        # The output dimension is typically 768
        self.sbert_dim = 768

    def forward(self, sentence_list):
        # Get the idea embeddings from all-mpnet-base-v2
        with torch.no_grad():
            idea_vecs = self.sbert.encode(sentence_list, convert_to_tensor=True)  # shape (B, 768)
        return idea_vecs

In [66]:
class IdeaPredictor(nn.Module):
    def __init__(self, 
                block_size=32,
                input_dim=768,
                predictor_embed_dim=384,
                predictor_depth=12,
                predictor_num_heads=12,
                mlp_ratio=4.0,
                norm_layer=nn.LayerNorm,
                ):
        """
        Initialize the IdeaPredictor model.
        Args:
            block_size (int): The maximum number of sequences for the predictor.
            input_dim (int): The dimension of the input idea embeddings.
            predictor_embed_dim (int): The embedding dimension for the predictor.
            predictor_depth (int): The number of transformer blocks in the predictor.
            predictor_num_heads (int): The number of attention heads in each transformer block.
            mlp_ratio (float): The ratio of the hidden dimension to the embedding dimension in the MLP.
            norm_layer (nn.Module): The normalization layer to use.
        Returns:
            None
        """
        super().__init__()
        self.block_size = block_size
        self.input_dim = input_dim
        self.predictor_embed_dim = predictor_embed_dim
        self.predictor_depth = predictor_depth
        self.predictor_num_heads = predictor_num_heads
        self.mlp_ratio = mlp_ratio
        self.norm_layer = norm_layer

        # Project from idea encoder dimension to predictor dimension
        if predictor_embed_dim != input_dim:
            self.idea_proj = nn.Linear(input_dim, predictor_embed_dim, bias=False)            
            self.predictor_proj = nn.Linear(predictor_embed_dim, input_dim, bias=False)
            self.predictor_proj.SCALE_INIT = 1
            self.idea_proj.SCALE_INIT = 1
        else:
            self.idea_proj = nn.Identity()
            self.predictor_proj = nn.Identity()

        # Use a larger positional embedding to handle variable number of ideas
        max_ideas = 512  # Allow for up to 512 ideas
        self.predictor_wpe = nn.Embedding(max_ideas, predictor_embed_dim)
        self.predictor_blocks = nn.ModuleList([
            Block(
                n_embd=predictor_embed_dim,
                n_head=predictor_num_heads,
                mlp_ratio=mlp_ratio,
                norm_layer=norm_layer,
                # qkv_bias=qkv_bias,
                # qk_scale=qk_scale,
                # drop_rate=drop_rate,
                # attn_drop_rate=attn_drop_rate,
                # drop_path_rate=drop_path_rate,
            )
            for _ in range(predictor_depth)
        ])

        self.predictor_ln = nn.LayerNorm(predictor_embed_dim)

        # might need weight tying
        # what and why is weight tying?
        # what: Weight tying is a technique where the same weight matrix is used for both the input and output projections in a model, reducing the number of parameters and potentially improving generalization.
        # why: This can help prevent overfitting, especially in models with large vocabularies or embedding spaces, by forcing the model to learn a more compact representation.
        # as idea prediction is a regression task, we do not use weight tying

        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            std = 0.02
            if hasattr(module, 'SCALE_INIT'):
                std *= (2 * self.predictor_depth) ** -0.5
            torch.nn.init.normal_(module.weight, mean=0.0, std=std)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, idea_vecs, target=None):
        """
        Forward pass for the IdeaPredictor model.
        Args:
            idea_vecs (torch.Tensor): Input idea embeddings of shape (B, I, input_dim), where B is the batch size and I is the number of ideas.
            target (torch.Tensor, optional): Target values for training. Not used in this implementation.
        """
        B, I = idea_vecs.size(0), idea_vecs.size(1)

        # Project the idea embeddings
        idea_emb = self.idea_proj(idea_vecs)  # shape (B, I, predictor_embed_dim)

        # Add positional embeddings - ensure we don't exceed embedding size
        pos_ids = torch.arange(0, I, device=idea_vecs.device)  # length I
        pos_emb = self.predictor_wpe(pos_ids)  # shape (I, predictor_embed_dim)
        hidden = idea_emb + pos_emb.unsqueeze(0)  # (B, I, predictor_embed_dim)

        # Forward through predictor blocks
        for block in self.predictor_blocks:
            hidden = block(hidden)
        hidden = self.predictor_ln(hidden)

        # Project back to the original dimension
        hidden = self.predictor_proj(hidden)

        return hidden

In [67]:
class IdeaDecoder(nn.Module):
    def __init__(self,
                block_size=32,
                vocab_size=50304,
                input_dim=768,
                decoder_embed_dim=384,
                decoder_depth=6,
                decoder_num_heads=12,
                mlp_ratio=4.0,
                # qkv_bias=False,
                # qk_scale=None,
                # drop_rate=0.0,
                # attn_drop_rate=0.0,
                # drop_path_rate=0.0,
                # init_std=0.02,
                norm_layer=nn.LayerNorm,
                ):
        """ Initialize the IdeaDecoder model.
        Args:
            block_size (int): The maximum number of tokens for the decoder.
            vocab_size (int): The size of the vocabulary for the decoder.
            input_dim (int): The dimension of the input idea embeddings.
            decoder_embed_dim (int): The embedding dimension for the decoder.
            decoder_depth (int): The number of transformer blocks in the decoder.
            decoder_num_heads (int): The number of attention heads in each transformer block.
            mlp_ratio (float): The ratio of the hidden dimension to the embedding dimension in the MLP.
            norm_layer (nn.Module): The normalization layer to use.
        Returns:
            None
        """
        super().__init__()
        self.block_size = block_size
        self.input_dim = input_dim
        self.decoder_embed_dim = decoder_embed_dim
        self.decoder_depth = decoder_depth
        self.decoder_num_heads = decoder_num_heads
        self.mlp_ratio = mlp_ratio
        self.norm_layer = norm_layer
        self.vocab_size = vocab_size

        # Project from idea encoder dimension to predictor dimension
        if decoder_embed_dim != input_dim:
            self.idea_proj = nn.Linear(input_dim, decoder_embed_dim, bias=False)            
            self.decoder_proj = nn.Linear(decoder_embed_dim, input_dim, bias=False)
            self.decoder_proj.SCALE_INIT = 1
            self.idea_proj.SCALE_INIT = 1
        else:
            self.idea_proj = nn.Identity()
            self.decoder_proj = nn.Identity()

        self.decoder_wte = nn.Embedding(vocab_size, decoder_embed_dim)
        self.decoder_wpe = nn.Embedding(block_size+1, decoder_embed_dim)
        self.block = nn.ModuleList([
            Block(
                n_embd=decoder_embed_dim,
                n_head=decoder_num_heads,
                mlp_ratio=mlp_ratio,
                norm_layer=norm_layer,
                # qkv_bias=qkv_bias,
                # qk_scale=qk_scale,
                # drop_rate=drop_rate,
                # attn_drop_rate=attn_drop_rate,
                # drop_path_rate=drop_path_rate,
            )
            for _ in range(decoder_depth)
        ])

        self.decoder_ln = nn.LayerNorm(decoder_embed_dim)
        self.decoder_lm_head = nn.Linear(decoder_embed_dim, vocab_size, bias=False)
        # Weight tying
        self.decoder_wte.weight = self.decoder_lm_head.weight
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            std = 0.02
            if hasattr(module, 'SCALE_INIT'):
                std *= (2 * self.decoder_depth) ** -0.5
            torch.nn.init.normal_(module.weight, mean=0.0, std=std)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, x, idea_vec, target=None):
        """
        Forward pass for the IdeaDecoder model.
        Args:
            x (torch.Tensor): Input token indices of shape (B*I, T), where B is the batch size and I is the number of ideas.
            idea_vec (torch.Tensor): Input idea embeddings of shape (B*I, input_dim).
            target (torch.Tensor, optional): Target values for training. Not used in this implementation.
        Returns:
            logits (torch.Tensor): Logits for next-word prediction of shape (B*I, T, vocab_size).   
        """
        B, T = x.size()   # now B and T are ints

        # Project the idea embeddings
        idea_vec = self.idea_proj(idea_vec)

        # forward the token and positional embeddings
        token_emb = self.decoder_wte(x)  # (B*I, T, decoder_embed_dim)
        pos_ids = torch.arange(0, T, device=x.device)
        pos_emb = self.decoder_wpe(pos_ids)       
        idea_token = idea_vec.unsqueeze(1)                 # (B*I, 1, D)
        token_and_pos = token_emb + pos_emb               # (B*I, T, D)
        hidden = torch.cat([idea_token, token_and_pos], 1) # (B*I, T+1, D)


        # Forward through decoder blocks
        for block in self.block:
            hidden = block(hidden)

        # Normalize the hidden states
        hidden = self.decoder_ln(hidden)

        # skip the idea token when producing logits for next-word prediction
        logits = self.decoder_lm_head(hidden[:, 1:, :])  # shape (B*I, T, vocab_size)

        return logits

In [68]:
class iGPTConfig:
    id_pred_block_size: int = 32
    id_pred_n_layer: int = 6
    id_pred_n_head: int = 12
    id_pred_n_embd: int = 384
    id_dec_block_size: int = 32
    id_dec_n_layer: int = 6
    id_dec_n_head: int = 12
    id_dec_n_embd: int = 384
    mlp_ratio: float = 4.0
    vocab_size: int = 50304

In [69]:
class iGPT(nn.Module):
    """
    GPT model that prepends a single 'idea token' derived from
    all-mpnet-base-v2 embeddings at the start of the sequence.
    """

    def __init__(self, config: iGPTConfig, st_model_name='sentence-transformers/all-mpnet-base-v2'):
        super().__init__()
        self.config = config

        # -- idea encoder
        # Load the Sentence Transformer model
        self.sbert = SentenceTransformer(st_model_name)
        self.sbert.eval()  # Set to evaluation mode
        self.sbert_dim = self.sbert.get_sentence_embedding_dimension()  # typically 768
        
        # -- idea predictor
        self.idea_predictor = IdeaPredictor(
            block_size=config.id_pred_block_size,
            input_dim=self.sbert_dim,
            predictor_embed_dim=config.id_pred_n_embd,
            predictor_depth=config.id_pred_n_layer,
            predictor_num_heads=config.id_pred_n_head,
            mlp_ratio=config.mlp_ratio,
            norm_layer=nn.LayerNorm
        )

        # -- idea decoder
        self.idea_decoder = IdeaDecoder(
            block_size=config.id_dec_block_size,
            vocab_size=config.vocab_size,
            input_dim=self.sbert_dim,
            decoder_embed_dim=config.id_dec_n_embd,
            decoder_depth=config.id_dec_n_layer,
            decoder_num_heads=config.id_dec_n_head,
            mlp_ratio=config.mlp_ratio,
            norm_layer=nn.LayerNorm
        )

    def forward(self, xt, xs):
        """
        Forward pass for the iGPT model.
        Args:
            xt (torch.Tensor): Input token indices of shape (B*I, T), where B is the batch size and I is the number of ideas.
            xs (torch.Tensor): Input idea embeddings of shape (B, I, sbert_dim) - pre-computed embeddings instead of sentences.
        Returns:
            yt (torch.Tensor): Logits for next-word prediction of shape (B*I, T, vocab_size).
            ys (torch.Tensor): Idea embeddings of shape (B*I, sbert_dim).
            idea_vecs (torch.Tensor): Original idea embeddings of shape (B, I, sbert_dim).
        """
        batch_size, num_ideas, _ = xs.shape  # Get actual batch size and number of ideas
        seq_length = xt.shape[1]
        
        print(f"xt shape: {xt.shape}")
        print(f"xs shape: {xs.shape}")
        print(f"B: {batch_size}, I: {num_ideas}, T: {seq_length}")

        # xs is already the idea embeddings tensor, no need for sentence encoding
        idea_vecs = xs  # shape (B, I, sbert_dim)
        
        # Pass to idea predictor
        ys = self.idea_predictor(idea_vecs)  # shape (B, I, sbert_dim)
        
        # Reshape ys to match xt batch dimension: (B, I, sbert_dim) -> (B*I, sbert_dim)
        ys_flat = ys.view(batch_size * num_ideas, -1)

        # Pass to idea decoder
        yt = self.idea_decoder(xt, ys_flat)

        return yt, ys_flat, idea_vecs

In [70]:
class iGPTConfig:
    id_pred_block_size = 2
    id_pred_n_layer = 2
    id_pred_n_head = 4
    id_pred_n_embd = 128

    id_dec_block_size = 2
    id_dec_n_layer = 2
    id_dec_n_head = 4
    id_dec_n_embd = 128
    
    mlp_ratio = 4
    vocab_size = 1000  # Reduced from 50304 for sanity check

In [71]:
def test_igpt():
    # Small test configuration
    test_config = iGPTConfig()
    
    # Initialize model
    model = iGPT(test_config)
    
    # Create sample inputs
    batch_size = 2
    num_ideas = 5
    seq_length = test_config.id_dec_block_size  # FIX: The sequence length must not exceed the configured block size
    sbert_dim = 768
    
    sample_tokens = torch.randint(0, test_config.vocab_size, (batch_size * num_ideas, seq_length))
    sample_idea_embeddings = torch.randn(batch_size, num_ideas, sbert_dim)  # Random embeddings instead of sentences
    
    # Forward pass
    with torch.no_grad():
        logits, embeddings, _ = model(sample_tokens, sample_idea_embeddings)
    
    print(f"Output logits shape: {logits.shape}")
    print(f"Output embeddings shape: {embeddings.shape}")
    
    return model

In [72]:
test_igpt()

xt shape: torch.Size([10, 2])
xs shape: torch.Size([2, 5, 768])
B: 2, I: 5, T: 2
Output logits shape: torch.Size([10, 2, 1000])
Output embeddings shape: torch.Size([10, 768])


iGPT(
  (sbert): SentenceTransformer(
    (0): Transformer({'max_seq_length': 384, 'do_lower_case': False}) with Transformer model: MPNetModel 
    (1): Pooling({'word_embedding_dimension': 768, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False, 'pooling_mode_weightedmean_tokens': False, 'pooling_mode_lasttoken': False, 'include_prompt': True})
    (2): Normalize()
  )
  (idea_predictor): IdeaPredictor(
    (idea_proj): Linear(in_features=768, out_features=128, bias=False)
    (predictor_proj): Linear(in_features=128, out_features=768, bias=False)
    (predictor_wpe): Embedding(512, 128)
    (predictor_blocks): ModuleList(
      (0-1): 2 x Block(
        (ln_1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (attn): CausalSelfAttention(
          (c_attn): Linear(in_features=128, out_features=384, bias=True)
          (c_proj): Linear(in_features=128, out_features=128, bias=True