In [10]:
import torch 

class Attention(torch.nn.Module): # BSD -> BSD
    def __init__(self, D=768, head_dim=64, causal=True, device="cuda"): 
        super().__init__()
        self.D = D 
        self.head_dim = head_dim
        assert D % head_dim == 0
        self.nheads = D//head_dim
        self.Wq = torch.nn.Linear(D, D)
        self.Wk = torch.nn.Linear(D, D)
        self.Wv = torch.nn.Linear(D, D)
        self.causal = causal 
        self.Wo = torch.nn.Linear(D, D)
        self.device = device

    def forward(self, x): # input is [B, S, D] 
        B, S, D = x.shape
        # let's make this multi-head now, ie. make each QKV [B, S, D] --> [B, nh, S, hd]

        Q, K, V = self.Wq(x), self.Wk(x), self.Wv(x) # all [B, S, D]

        Q = Q.view(B, S, self.nheads, self.head_dim).transpose(1,2) # [B, nh, S, hd]
        K = K.view(B, S, self.nheads, self.head_dim).transpose(1,2)
        V = V.view(B, S, self.nheads, self.head_dim).transpose(1,2)

        # [B, nh, S, hd] @ [B, nh, hd, S] -> [B, nh, S, S]
        logits = (Q@K.transpose(-2, -1))/torch.sqrt(torch.tensor(self.head_dim, device=self.device)) # [B, nh, S, S]
        if self.causal:
            mask = torch.triu(torch.ones_like(logits), diagonal=1).bool()
            logits_masked = logits.masked_fill(mask, float('-inf'))
        else:
            logits_masked = logits

        A = torch.nn.functional.softmax(logits_masked, dim=-1) # [B, nh, S, S]
        
        preout = torch.einsum('bnxy,bnyd->bnxd', A, V) # [B, nh, S, S] @ [B, nh, S, hd] -> [B, nh, S, hd]
        preout = preout.transpose(1, 2).reshape(B, S, -1) # [B, nh, S, hd] -> [B, S, nh * hd]
        
        out = self.Wo(preout) # [B, S, D]
        return out # [B, S, D]

B, S, D = 1, 512, 768
x = torch.randn(B, S, D)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
x = x.to(device)
mhsa = Attention(D, device=device)
mhsa(x).shape # expect [B, S, D] ie. [1, 512, 768]

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument mat1 in method wrapper_CUDA_addmm)

In [3]:
ACT2FN = {
    'relu': torch.nn.functional.relu,
    'gelu': torch.nn.functional.gelu,
    'silu': torch.nn.functional.silu,
    'swish': torch.nn.functional.silu,
}

class MLP(torch.nn.Module): 
    def __init__(self, D, hidden_multiplier=4, act='swish', device=None): 
        super().__init__()
        self.D = D
        self.device = device if device is not None else torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.up_proj = torch.nn.Linear(D, D*hidden_multiplier).to(self.device)
        self.down_proj = torch.nn.Linear(D*hidden_multiplier, D).to(self.device)
        self.act = ACT2FN[act]

    def forward(self, x): # BSD -> BSD automatically on last dim 
        x = x.to(self.device)
        return self.down_proj(self.act(self.up_proj(x)))

B, S, D = 1, 512, 768
x = torch.randn(B, S, D)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
x = x.to(device)
mlp = MLP(D, device=device)
mlp(x).shape

torch.Size([1, 512, 768])

In [4]:
class LN(torch.nn.Module): 
    def __init__(self, D, eps=1e-9, device=None): 
        super().__init__()
        self.D = D 
        self.eps = eps
        self.device = device if device is not None else torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.mean_scale = torch.nn.Parameter(torch.zeros(D).to(self.device))
        self.std_scale = torch.nn.Parameter(torch.ones(D).to(self.device))

    def forward(self, x): # x is [B, S, D]
        x = x.to(self.device)
        mean = x.mean(dim=-1, keepdim=True) # [B, S, 1]
        std = (x.var(dim=-1, keepdim=True) + self.eps)**0.5 # [B, S, 1]
        x_norm = (x - mean)/(std) 
        return x_norm * self.std_scale + self.mean_scale

B, S, D = 1, 512, 768
x = torch.randn(B, S, D)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
x = x.to(device)
ln = LN(D, device=device)
ln(x).shape

torch.Size([1, 512, 768])

In [5]:
class TransformerLayer(torch.nn.Module): 
    def __init__(self, D, device=None): 
        super().__init__()
        self.D = D 
        self.device = device if device is not None else torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.attn = Attention(D, device=self.device)
        self.mlp = MLP(D, device=self.device)
        self.ln1 = LN(D, device=self.device)
        self.ln2 = LN(D, device=self.device)  
    
    def forward(self, x): # x is BSD
        x = x.to(self.device)
        x = x + self.attn(self.ln1(x))
        x = x + self.mlp(self.ln2(x))
        return x 


class EmbeddingLayer(torch.nn.Module): 
    # this is just a lookup table 
    def __init__(self, vocab_size, D, device=None): 
        super().__init__()
        self.device = device if device is not None else torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.embedding = torch.nn.Parameter(torch.randn(vocab_size, D).to(self.device))

    def forward(self, x): 
        x = x.to(self.device)
        return self.embedding[x]

class UnembeddingLayer(torch.nn.Module): 
    # this is just a lookup table that maps embeddings back to logits
    def __init__(self, vocab_size, D, device=None): 
        super().__init__()
        self.device = device if device is not None else torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.unembedding = torch.nn.Linear(D, vocab_size).to(self.device)

    def forward(self, x): # x is [B, S, D]
        # Return logits of shape [B, S, vocab_size]
        x = x.to(self.device)
        return self.unembedding(x)


class Transformer(torch.nn.Module): 
    def __init__(self, depth, hidden_dim, vocab_size, device=None): 
        super().__init__()
        self.device = device if device is not None else torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.D = hidden_dim 
        self.depth = depth 
        self.emb = EmbeddingLayer(vocab_size, hidden_dim, device=self.device)
        self.unemb = UnembeddingLayer(vocab_size, hidden_dim, device=self.device)
        self.layers = torch.nn.ModuleList([TransformerLayer(self.D, device=self.device) for _ in range(depth)])
        self.to(self.device)

    def forward(self, x): # x is tokenized inputs, so BS1
        x = x.to(self.device)
        x = self.emb(x) # BSD 
        for layer in self.layers:
            x = layer(x) # BSD 
        x = self.unemb(x) # BSV 
        return x


In [7]:
# Download TinyStories dataset from Hugging Face
from datasets import load_dataset

# Load the dataset
dataset = load_dataset("roneneldan/TinyStories")

# Print basic information about the dataset
print(f"Dataset structure: {dataset}")
print(f"Available splits: {dataset.keys()}")
print(f"Number of examples in train: {len(dataset['train'])}")
print(f"Number of examples in validation: {len(dataset['validation'])}")

# Display a sample story
sample_story = dataset['train'][0]
print("\nSample story:")
print(sample_story['text'])


Dataset structure: DatasetDict({
    train: Dataset({
        features: ['text'],
        num_rows: 2119719
    })
    validation: Dataset({
        features: ['text'],
        num_rows: 21990
    })
})
Available splits: dict_keys(['train', 'validation'])
Number of examples in train: 2119719
Number of examples in validation: 21990

Sample story:
One day, a little girl named Lily found a needle in her room. She knew it was difficult to play with it because it was sharp. Lily wanted to share the needle with her mom, so she could sew a button on her shirt.

Lily went to her mom and said, "Mom, I found this needle. Can you share it with me and sew my shirt?" Her mom smiled and said, "Yes, Lily, we can share the needle and fix your shirt."

Together, they shared the needle and sewed the button on Lily's shirt. It was not difficult for them because they were sharing and helping each other. After they finished, Lily thanked her mom for sharing the needle and fixing her shirt. They both felt h

In [8]:
from tqdm import tqdm 

tokens = set()
for example in tqdm(dataset['train']):
    tokens.update(example['text'])
tokens = list(tokens)
tokenizer = {tok:i for (i,tok) in enumerate(tokens)}


  0%|▋                                                                                                                                                  | 9121/2119719 [00:00<00:23, 91203.78it/s]

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2119719/2119719 [00:23<00:00, 88533.79it/s]


In [46]:
def tokenize(el): 
    return list(map(lambda x: tokenizer[x], el))

sentence = dataset['train'][0]['text']
tokenize(sentence)

class DataLoader:
    def __init__(self, dataset, batch_size, seq_len, device='cuda'):
        self.dataset = dataset
        self.batch_size = batch_size
        self.seq_len = seq_len
        self.tokenizer = tokenizer
        self.current_idx = 0
        self.device = device
        
    def __iter__(self):
        self.current_idx = 0
        return self
    
    def __next__(self):
        if self.current_idx >= len(self.dataset):
            raise StopIteration
        
        batch_texts = []
        batch_tokens = []
        
        # Collect batch_size examples
        for i in range(self.batch_size):
            if self.current_idx >= len(self.dataset):
                break
                
            text = self.dataset[self.current_idx]['text']
            tokens = tokenize(text)
            
            # Truncate or pad to seq_len
            if len(tokens) > self.seq_len:
                tokens = tokens[:self.seq_len]
            elif len(tokens) < self.seq_len:
                tokens = tokens + [0] * (self.seq_len - len(tokens))
                
            batch_texts.append(text)
            batch_tokens.append(tokens)
            self.current_idx += 1
        
        # Convert to tensors and move to CUDA
        batch_tokens = torch.tensor(batch_tokens, device=self.device)
        return batch_tokens

In [47]:
B, S, D, nlayers, vocab_size = 16, 512, 768, 8, len(tokenizer.items())
boi = Transformer(nlayers, D, vocab_size)
dataloader = DataLoader(dataset['train'], B, S)

In [50]:
batch = next(dataloader)
boi(batch).shape

torch.Size([16, 512, 174])

In [21]:
import torch
from tqdm import tqdm

def tokenize(tokenizer, el): 
    return list(map(lambda x: tokenizer[x], el))

# Define model parameters (should match those used during training)
# For example, when training with: python3 train.py --hidden_dim 768 --nlayers 6 --verbose --steps 250 --save
hidden_dim = 512  # Set to 768 as specified during training
nlayers = 6
vocab_size = len(tokenizer.items())

# Initialize a new model with the same architecture
inference_model = Transformer(nlayers, hidden_dim, vocab_size, device='cuda')

# Load the saved weights with a proper device mapping
save_dir = "/n/netscratch/gershman_lab/Lab/tkumar/datasets/dclm/global-shard_01_of_10/newest_data/gravity-models/"
model_name = "model_lr0.0003_bs256_seq512.pt"  # Matches training parameters
save_path = f"{save_dir}/{model_name}"
inference_model.load_state_dict(torch.load(save_path, map_location='cuda'))
inference_model.eval()  # Set to evaluation mode

def build_tokenizer_for_dataset(dataset):
    tokens = set()
    # Process a subset of the dataset to build the tokenizer
    for example in tqdm(dataset['train']):
        tokens.update(example['text'])
    tokens = sorted(list(tokens))  # Sort to ensure consistent ordering
    tokenizer = {tok: i for (i, tok) in enumerate(tokens)}
    return tokenizer

# Rebuild the tokenizer
tokenizer = build_tokenizer_for_dataset(dataset)

# Now fix the generate_text function
def generate_text(model, tokenizer, prompt="Once upon a time", max_length=100, temperature=1.0):
    # Create reverse mapping
    reverse_tokenizer = {v: k for k, v in tokenizer.items()}
    
    # Convert the prompt to tokens character by character
    tokens = []
    for char in prompt:
        if char in tokenizer:
            tokens.append(tokenizer[char])
        else:
            print(f"Warning: Character '{char}' not in tokenizer")
    
    input_tokens = torch.tensor([tokens], device=model.device)
    
    # Generate tokens one at a time
    for _ in tqdm(range(max_length)):
        with torch.no_grad():
            logits = model(input_tokens)
            next_token_logits = logits[0, -1, :] / temperature
            
            # Apply a mask to only select valid token indices
            # This prevents selecting tokens that don't exist in the reverse_tokenizer
            valid_indices = torch.tensor([i for i in range(len(logits[0, -1, :])) 
                                         if i in reverse_tokenizer], 
                                        device=model.device)
            
            if len(valid_indices) == 0:
                break
                
            masked_logits = torch.ones_like(next_token_logits) * float('-inf')
            masked_logits[valid_indices] = next_token_logits[valid_indices]
            
            probabilities = torch.nn.functional.softmax(masked_logits, dim=0)
            next_token = torch.multinomial(probabilities, 1)
            
            input_tokens = torch.cat([input_tokens, next_token.unsqueeze(0)], dim=1)
    
    # Convert the generated tokens back to text
    generated_tokens = input_tokens[0].tolist()
    generated_text = ""
    
    for token in generated_tokens:
        if token in reverse_tokenizer:
            generated_text += reverse_tokenizer[token]
        else:
            generated_text += "<?>"  # Placeholder for unknown tokens
    
    return generated_text

# Try generating text again
sample_text = generate_text(inference_model, tokenizer, prompt="Once upon a time")
print(sample_text)

  inference_model.load_state_dict(torch.load(save_path, map_location='cuda'))
  0%|                                                                                                                                                                 | 0/2119719 [00:00<?, ?it/s]

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2119719/2119719 [00:23<00:00, 91110.57it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 317.74it/s]

Once upon a timeE41 E41 E4d®E41d* ˜‰d®déTE{I{IE*1 Ey¹D–„­yEÊ® EyEéEdZ eEQIé EGÊyEKdy1EyIEy1 EÊTEyEÊ *E{kIE{Ê d®E*EyE



