In [1]:
import torch
import torch.nn as nn

In [2]:
from torch.utils.data import Dataset, DataLoader
from datasets import Dataset as HfDataset

class StoryDataset(Dataset):
    def __init__(self, name, sample=None, context_size=32):
        super().__init__()
        self.context_size = context_size
        
        # Load data
        self.train = HfDataset.from_parquet(name)
        if sample:
            self.train = self.train.select(range(sample))
        
        # Split into words
        self._texts = "\n".join(self.train['text']).split()
        
        # Create sliding windows of size context_size + 1
        # We need +1 because we'll split into input and target
        self.data = [self._texts[idx:idx+context_size+1]
                     for idx in range(len(self._texts) - context_size)]
   
    def __len__(self):
        return len(self.data)
        
    def __getitem__(self, index):
        sequence = self.data[index]  # This has context_size + 1 tokens
        
        # Split into input and target
        inputs = sequence[:-1]   # First context_size tokens
        target = sequence[-1]    # Last token (what we want to predict)
        
        return inputs, target

# Custom collate function
def collate_fn(batch):
    inputs = [item[0] for item in batch]  # List of input sequences
    targets = [item[1] for item in batch] # List of target tokens
    return inputs, targets

# Usage
train_dataset = StoryDataset('train.parquet', sample=10, context_size=32)
print(f"Dataset length: {len(train_dataset)}")

# Test individual sample
inputs, target = train_dataset[0]
print(f"Input length: {len(inputs)}")
print(f"Input: {inputs[:10]}...")  # First 10 words
print(f"Target: '{target}'")

# DataLoader
dataloader = DataLoader(
    train_dataset, 
    batch_size=4, 
    shuffle=True, 
    drop_last=True,
    collate_fn=collate_fn
)

print(f"\nNumber of batches: {len(dataloader)}")
for batch_inputs, batch_targets in dataloader:
    print(f"Batch inputs type: {type(batch_inputs)}")
    print(f"Batch inputs length: {len(batch_inputs)} samples")
    print(f"Each input length: {len(batch_inputs[0])} words")
    print(f"Batch targets: {batch_targets}")
    print(f"First input: {batch_inputs[0][:5]}...")  # First 5 words
    print(f"First target: '{batch_targets[0]}'")
    break

  from .autonotebook import tqdm as notebook_tqdm


Dataset length: 1413
Input length: 32
Input: ['One', 'day,', 'a', 'little', 'girl', 'named', 'Lily', 'found', 'a', 'needle']...
Target: 'with'

Number of batches: 353
Batch inputs type: <class 'list'>
Batch inputs length: 4 samples
Each input length: 32 words
Batch targets: ['for', 'a', 'All', 'their']
First input: ['they', 'shared', 'the', 'needle', 'and']...
First target: 'for'


In [None]:
class HAT(nn.Module):
    """
    charencoder
    backbone
    chardecoder
    """

In [6]:
word = 'for'
chars = [x for x in word]
chars = ['[W]'] + chars + [' ']
chars

['[W]', 'f', 'o', 'r', ' ']

In [9]:
char_embed = []
for ch in chars:
    ar = [0]*256
    for _ in ch.encode('utf-8'):
        ar[_]=1
    char_embed.append(ar)

char_emb = torch.tensor(char_embed)
    

In [12]:
char_emb.shape

torch.Size([5, 256])

In [None]:
def self_embed(char_embed):
    nn.cross_atte

In [None]:
word_embed = self_attend(char_embed)[0,:]