In [None]:
from torch.utils.data import Dataset

class CharLMDataset(Dataset):
    def __init__(self, data="../data/transcription.txt"):
        """
        Args:
            texts (list of str): path to dataset.
        """
        # reading dataset. 
        with open(data, "r") as f:
            out = f.readlines()
        self.sentences = [x.split("\t")[1].strip() for x in out]
        self.sentences = sorted(self.sentences, key=len)
        self.sentences = [x for x in self.sentences if len(x) > 10] # filtering out short sentences that 2 second.
        
        self.texts = self.sentences
        self.vocab = self.build_vocab(self.sentences)
        self.char_to_idx = {char: idx for idx, char in enumerate(self.vocab)}
        self.idx_to_char = {idx: char for char, idx in self.char_to_idx.items()}
    
    def build_vocab(self, texts):
        """Creates a sorted list of unique characters."""
        return sorted(set("".join(texts)))

    def encode(self, text):
        """Encodes text into a list of indices."""
        return [self.char_to_idx[char] for char in text]

    def decode(self, indices):
        """Decodes indices back into text."""
        return "".join(self.idx_to_char[idx] for idx in indices)

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
        input_ids = self.encode(text)
        return {
            "input_ids": input_ids[:-1],  # Input sequence
            "labels": input_ids[1:]       # Target sequence (shifted by 1)
        }


In [15]:
d = CharLMDataset()

In [19]:
d.vocab

[' ',
 "'",
 'A',
 'B',
 'C',
 'D',
 'E',
 'F',
 'G',
 'H',
 'I',
 'J',
 'K',
 'L',
 'M',
 'N',
 'O',
 'P',
 'Q',
 'R',
 'S',
 'T',
 'U',
 'V',
 'W',
 'X',
 'Y',
 'Z']