In [1]:
len(['p', ' ', "'", 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', '?'])

30

'?'

In [17]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

import torch.nn.functional as F
class Codebook(torch.nn.Module):
    def __init__(self, vocab, model_name="meta-llama/Llama-3.2-1B-Instruct"):
        super(Codebook, self).__init__()
        self.vocab = vocab
        self.model_name = model_name
        # Initialize the LLM discriminator and tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.llm = AutoModelForCausalLM.from_pretrained(model_name)
        
        # Set EOS token as the padding token
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token
            
        # Freeze LLM parameters
        for param in self.llm.parameters():
            param.requires_grad = False

    def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, embed):
        """
        Args:
            input_ids: [batch, seq_len] token IDs
        Returns:
            perplexity: scalar reward (lower is better)
        """
        with torch.no_grad():  # Ensure LLM stays frozen
            outputs = self.llm(inputs_embeds=embed, labels=input_ids, attention_mask=attention_mask)
            logits = outputs.logits[:, :-1, :]       # [B, T-1, V]
            target = input_ids[:, 1:]             # [B, T-1]
            mask = attention_mask[:, 1:].float()  # [B, T-1]

            # Cross-entropy loss per token
            loss = F.cross_entropy(
                logits.reshape(-1, logits.size(-1)), 
                target.reshape(-1), 
                reduction='none'
            ).reshape(target.shape)  # [B, T-1]

            # Mask pad tokens and compute mean loss
            masked_loss = loss * mask  # [B, T-1]
            token_count = mask.sum(dim=1)  # [B]
            
            loss = masked_loss.sum(dim=1) / token_count  
            perplexity = torch.exp(loss)    

            return loss, perplexity, outputs.loss

In [21]:
sentence = ["FV D HEFYY", "We do not remember how the forest fell, only that fire burned close behind. Children cried, others ran, and silence grew louder in the smoky air. Desperate men shouted into the trees, but no answer came. In that fading light, everything felt final—like memory slipping through a door we couldn't close."]


# function to convert a string to a white space separated string ignoreing the spaces
def convert_to_whitespace_string(string):
    # Remove spaces and join characters with whitespace
    return ' '.join(string)

sentence = [convert_to_whitespace_string(s.upper()) for s in sentence]

print(sentence)
# Initialize Codebook discriminator
codebook = Codebook(vocab=None)
tokenizer = codebook.tokenizer

# Tokenize sentence to get input IDs
tokenized = tokenizer(sentence, padding=True, return_tensors='pt', add_special_tokens=False)
embed = codebook.llm.get_input_embeddings()(tokenized['input_ids'])

codebook(**tokenized, embed=embed)


['F V   D   H E F Y Y', "W E   D O   N O T   R E M E M B E R   H O W   T H E   F O R E S T   F E L L ,   O N L Y   T H A T   F I R E   B U R N E D   C L O S E   B E H I N D .   C H I L D R E N   C R I E D ,   O T H E R S   R A N ,   A N D   S I L E N C E   G R E W   L O U D E R   I N   T H E   S M O K Y   A I R .   D E S P E R A T E   M E N   S H O U T E D   I N T O   T H E   T R E E S ,   B U T   N O   A N S W E R   C A M E .   I N   T H A T   F A D I N G   L I G H T ,   E V E R Y T H I N G   F E L T   F I N A L — L I K E   M E M O R Y   S L I P P I N G   T H R O U G H   A   D O O R   W E   C O U L D N ' T   C L O S E ."]


(tensor([5.3233, 2.2565]), tensor([205.0494,   9.5500]), tensor(4.7987))

In [19]:
sentence

['F V   D   H E   O R N E C M R N H T W D O T A F   T A L   F Y Y B   C I N E   O I L O N   C F   C E   E C I B E A I   D E A V I A E U S   T M S   U N E T M R   R S T I L U N   E T F Y T L E N   O T W N R   O H L G H Y B Y   T E H N E   H T F   I   I A   R H T P G K H P D S K H O   O W H L O S G K H L Y   O E A E Y P H E R   N E O O W   C O T N T F I   D   R O N E S T   O H R S   I   I R Y R U E L   O N T   Y N   A R I N R T B N   H C F   O F   D E P I P T E   T R   A C K E A B G H T M R N D O   K A O S A V',
 "W E   D O   N O T   R E M E M B E R   H O W   T H E   F O R E S T   F E L L ,   O N L Y   T H A T   F I R E   B U R N E D   C L O S E   B E H I N D .   C H I L D R E N   C R I E D ,   O T H E R S   R A N ,   A N D   S I L E N C E   G R E W   L O U D E R   I N   T H E   S M O K Y   A I R .   D E S P E R A T E   M E N   S H O U T E D   I N T O   T H E   T R E E S ,   B U T   N O   A N S W E R   C A M E .   I N   T H A T   F A D I N G   L I G H T ,   E V E R Y T H I N G   F E L T 

In [7]:
import torch
from torch.nn.utils.rnn import pad_sequence

target = torch.tensor([
    [1, 9, 9, 9, 5, 6, 7, 8, 9],
    [1, 2, 3, 4, 5, 6, 7, 9, 9]
], dtype=torch.long)

target

tensor([[1, 9, 9, 9, 5, 6, 7, 8, 9],
        [1, 2, 3, 4, 5, 6, 7, 9, 9]])

In [8]:
# Create a mask based on the valid indices
valid_mask = target != 9

valid_mask

tensor([[ True, False, False, False,  True,  True,  True,  True, False],
        [ True,  True,  True,  True,  True,  True,  True, False, False]])

In [9]:
target = [target[i][valid_mask[i]] for i in range(target.size(0))]
target

[tensor([1, 5, 6, 7, 8]), tensor([1, 2, 3, 4, 5, 6, 7])]

In [10]:

# Pad the sequences for each tensor
target = pad_sequence(target, batch_first=True, padding_value=0)
target

tensor([[1, 5, 6, 7, 8, 0, 0],
        [1, 2, 3, 4, 5, 6, 7]])

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import random
from tqdm.auto import tqdm  
import logging
from collections import Counter
import numpy as np
import matplotlib.pyplot as plt


class Dataset_txt(Dataset):
    def __init__(self, data="/raid/home/rajivratn/hemant_rajivratn/last/data/transcription.txt"):
        super(Dataset_txt, self).__init__()

        
        with open(data, "r") as f:
            out = f.readlines()
        texts = [x.strip() for x in tqdm(out) if len(x) > 10] # filtering out short texts that 2 second.
    
   
        self.save_histogram(texts)
   
    def save_histogram(self, texts):

        
        texts = self.add_question_marks(texts)

        print(f"Saving histogram of the REAL text data.")
        char_counts = Counter("".join(texts))  # Example output: [('a', 2), ('d', 1)]
        char_counts = dict(char_counts)
        print(f"char_counts: {char_counts}")
        c = [char_counts[v] for v in self.vocab if v not in ["p"]]  # Exclude padding and silence tokens
        c = np.array(c, dtype=np.float32)
        c /= c.sum()  # Normalize the counts to get probabilities
        
        self.prior = c # save the counts as prior for kl loss.
        
        # Plotting the histogram
        plt.figure(figsize=(10, 6))
        plt.bar(self.vocab[1:], c, color='blue', alpha=0.7)
        plt.xlabel('Codebook Entry (Char)')
        plt.ylabel('Probability')
        plt.title('Codebook Usage Distribution')
        plt.grid(axis='y')
        plt.savefig('REAL_codebook_usage_distribution.png', bbox_inches='tight')
        
        
    def add_question_marks(self, texts=[]):
        print(f"Preprocessing the text data by adding silence tokens.")
        
        modified_texts = []
        for sentence in tqdm(texts):
            modified_sentence = ['?']# Add question marks at start 
            previous_char = None
            for char in sentence:
                # if  char == previous_char insert a question mark
                if previous_char == char:
                    modified_sentence.append("?")
                
                modified_sentence.append(char)
    
                # Randomly insert question marks with 0.25 probability
                if random.random() < 0.25 and modified_sentence[-1] != '?':
                    modified_sentence.append("?")
                
                previous_char = char
                    
            if modified_sentence[-1] != '?': 
                modified_sentence.append("?")  # Add a question mark at the end
            modified_texts.append("".join(modified_sentence))
        print(f"Preprocessing done.")
        print(f"Modified text sample")
        print(f"{random.choice(modified_texts)}")
        print(f"{random.choice(modified_texts)}")
        
        return  modified_texts

    def build_vocab(self, texts):
        """
        Creates a sorted list of unique characters with special tokens.
        special_tokens = ["p", "?"]  # "p" = PAD, "?" = silence
        """
        unique_chars = sorted(set("".join(texts)))
        return ["p"] + unique_chars + ["?"]
    
    def encode(self, text):
        """Encodes text into a list of indices."""
        return [self.char_to_idx[char] for char in text]

    def decode(self, indices, keep_special_tokens=False):
        """Decodes indices back into text, removing all special tokens."""
        if keep_special_tokens:
            return "".join(self.idx_to_char[idx] for idx in indices)
        return "".join(self.idx_to_char[idx] for idx in indices if self.idx_to_char[idx] not in {"p", "?"})

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
        result = ["?"]
        prev_char = ""

        for char in text:
            if char == prev_char:
                result.append("?")
            result.append(char)

            # Slightly more efficient: only check random if not already '?'
            if result[-1] != "?" and random.random() < 0.25:
                result.append("?")

            prev_char = char

        # Ensure it ends with a question mark
        if result[-1] != "?":
            result.append("?")

        modified_text = "".join(result)
        input_ids = self.encode(modified_text)
        return input_ids
        
    def collate_fn(self, batch):
        inp = [item for item in batch]
        pad_token_id = self.char_to_idx['p']
        max_length = max(len(seq) for seq in inp)

        # Pad sequences
        def pad_sequence(seq, max_length):
            return seq + [pad_token_id] * (max_length - len(seq))

        inp = torch.tensor([pad_sequence(seq, max_length) for seq in inp], dtype=torch.long)
        mask = torch.tensor([[False] * len(seq) + [True] * (max_length - len(seq)) for seq in batch], dtype=torch.bool)
    
        return inp, mask.unsqueeze(-1)
        
dataset = Dataset_txt(data="/raid/home/rajivratn/hemant_rajivratn/last/data/txt/train_norm.txt")