# Sampling From Scratch

In [1]:
import math
import torch
import torch.nn.functional as F
from torch import Tensor, tensor
from jaxtyping import Float, Int
from transformers import AutoModelForCausalLM, AutoTokenizer

torch.set_grad_enabled(False); # disable backprop

In [2]:
local_files_only = True
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print("device:", device)

# model_name = "gpt2"
model_name = "meta-llama/Llama-3.2-1B"
model = AutoModelForCausalLM.from_pretrained(model_name, local_files_only=local_files_only).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_name, local_files_only=local_files_only)
tokenizer.pad_token = tokenizer.eos_token

device: mps


LLMs do not operate on words -- each word is converted into a high dimensional vector that contains information that gets passed through the model. At each layer, the model reads the vector, performs some computation (attention or MLP) and writes it back to the vector.

We call this vector the **residual stream**. To initially create these vectors from a sentence, we have a large lookup table of each "word" (or sub-word, see [here](TODO) for more info) to a this high dimensional vector.

> We call each "word" a **token**.  
> You can imagine `token ~= word`

This is 768 dimensions on GPT2, and can also be thought of the _width_ of the model

_depth_ being the number of layers.

We look up each word in an _embedding_ table. This is a map of 50,000 words to a high dimensional embedding.

In [3]:
if model_name == "gpt2":
    W_E = model.transformer.wte
else: # llama
    W_E = model.model.embed_tokens
print(W_E)

Embedding(128256, 2048)


Let's see the first 10 dimensions of the token (word) 9246

In [4]:
token = 9246
first_n_dimensions = 10
W_E.weight[token, :first_n_dimensions]

tensor([ 0.0378,  0.0211,  0.0037,  0.0184,  0.0267, -0.0135,  0.0104, -0.0383,
         0.0137, -0.0064], device='mps:0', requires_grad=True)

And to find the corresponding string word associated with token 9246:

In [5]:
def decode(tokens) -> str:
    return tokenizer.decode(tokens, skip_special_tokens=True, clean_up_tokenization_spaces=True)

print(f"decoded token: {repr(decode(token))}")

decoded token: ' strings'


Using the `tokenize` and `decode` functions, we can convert back and forth between a string and the initial model vectors ("embeddings").

Notably the model adds a "batch" dimension to the input, which allows us to process multiple inputs at the same time, imagine this allows us to run "the cat sat on the mat" and "I took my dog for a walk" at the _same time_.

Input to a LLM is a list of tokens, which we call length sequence length (or `seq` / `T` (for time dimesion) for short.


In [6]:
if model_name == "gpt2":
    context_length = model.config.n_ctx
else: # llama
    context_length = model.config.max_position_embeddings
print(f"context length: {context_length}")

context length: 131072


In [7]:
def tokenize(input) -> Int[Tensor, "bs seq"]:
    return tokenizer(
        input,
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=context_length
    )["input_ids"].to(device)

prompt = 'the cat sat on a mat'
tokens = tokenize(prompt)
embeddings = W_E.weight[tokens]

decoded = decode(tokens[0])

print(f"""\
# prompt
{prompt}

# tokens shape: {tuple(tokens.shape)}
{tokens.tolist()}

# decoded
{decoded}

# embeddings shape: {tuple(embeddings.shape)}
""")

# prompt
the cat sat on a mat

# tokens shape: (1, 7)
[[128000, 1820, 8415, 7731, 389, 264, 5634]]

# decoded
the cat sat on a mat

# embeddings shape: (1, 7, 2048)



## Output

Now given the prompt input, lets run the tokens through the model and look at the output. These are called **logits**.

In [8]:
logits = model(tokens).logits

print(f"""\
# Tokens ({tuple(tokens.shape)})

# Logit Output ({tuple(logits.shape)})
""")

We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)


# Tokens ((1, 7))

# Logit Output ((1, 7, 128256))



The input has shape, `(batch size, sequence length)`, with output `(batch size, sequence length, logits)`

For each token in the sequence, the model outputs a score for _every next token_ (50K) representing how likely that token is to come next.

For each token, we can see which token the model predicted as _most likely_.

In [9]:
for i in range(tokens.shape[1]):
    inp = decode(tokens[0, :i+1])
    pred = decode(logits[0, i].argmax())
    print(f"{repr(decode(tokens[0, :i+1]))} => {repr(pred)}")

'' => 'Question'
'the' => ' '
'the cat' => ' is'
'the cat sat' => ' on'
'the cat sat on' => ' the'
'the cat sat on a' => ' hot'
'the cat sat on a mat' => '\n'


So to continue generating tokens, we need to run an **auto regressive** function, that selects a token from the _last_ word in the sequence, and append it to the prompt.

In [10]:
def generate(prompt, num_tokens, verbose=False):
    tokens = tokenize(prompt)
    for i in range(num_tokens):
        logits = model(tokens).logits[0, -1] # get the scores of the final token [shape: (n_vocab)]
        next_token = logits.argmax(keepdim=True) # pick the largest one
        tokens = torch.cat([ tokens, next_token[None] ], dim=1) # concatenate to the current text
        if verbose:
            print("---")
            print(decode(tokens[0]))
    return decode(tokens[0])

generate(prompt, num_tokens=20, verbose=True);

---
the cat sat on a mat

---
the cat sat on a mat
the
---
the cat sat on a mat
the cat
---
the cat sat on a mat
the cat sat
---
the cat sat on a mat
the cat sat on
---
the cat sat on a mat
the cat sat on a
---
the cat sat on a mat
the cat sat on a mat
---
the cat sat on a mat
the cat sat on a mat

---
the cat sat on a mat
the cat sat on a mat
the
---
the cat sat on a mat
the cat sat on a mat
the cat
---
the cat sat on a mat
the cat sat on a mat
the cat sat
---
the cat sat on a mat
the cat sat on a mat
the cat sat on
---
the cat sat on a mat
the cat sat on a mat
the cat sat on a
---
the cat sat on a mat
the cat sat on a mat
the cat sat on a mat
---
the cat sat on a mat
the cat sat on a mat
the cat sat on a mat

---
the cat sat on a mat
the cat sat on a mat
the cat sat on a mat
the
---
the cat sat on a mat
the cat sat on a mat
the cat sat on a mat
the cat
---
the cat sat on a mat
the cat sat on a mat
the cat sat on a mat
the cat sat
---
the cat sat on a mat
the cat sat on a mat
the cat 

## Sampling Probability Distribution

But just picking the most likely can give quite bland output

This takes the model output (which can be any number) and create a _probability distribution_ such that all the scores add up to 1.

To do this we use the **softmax** function.


In [11]:
def generate(prompt, num_tokens, verbose=False, seed=42): # add a seed to keep the output deterministic. Try other seeds!
    torch.manual_seed(seed)
    tokens = tokenize(prompt)
    for i in range(num_tokens):
        logits = model(tokens).logits[0, -1]
        ### New lines
        probs = F.softmax(logits, dim=-1) # create probability distribution of scores
        next_token = torch.multinomial(probs, 1) # pick a single token from distribution
        ###
        tokens = torch.cat([ tokens, next_token[None] ], dim=1)
        if verbose:
            print(decode(tokens[0]))
    return decode(tokens[0])

generate(prompt, num_tokens=25, verbose=True);

the cat sat on a mat from
the cat sat on a mat from pee
the cat sat on a mat from pee

the cat sat on a mat from pee
The
the cat sat on a mat from pee
The owner
the cat sat on a mat from pee
The owner (
the cat sat on a mat from pee
The owner (the
the cat sat on a mat from pee
The owner (the man
the cat sat on a mat from pee
The owner (the man)
the cat sat on a mat from pee
The owner (the man) came
the cat sat on a mat from pee
The owner (the man) came home
the cat sat on a mat from pee
The owner (the man) came home at
the cat sat on a mat from pee
The owner (the man) came home at noon
the cat sat on a mat from pee
The owner (the man) came home at noon (
the cat sat on a mat from pee
The owner (the man) came home at noon (duck
the cat sat on a mat from pee
The owner (the man) came home at noon (duck)
the cat sat on a mat from pee
The owner (the man) came home at noon (duck).
the cat sat on a mat from pee
The owner (the man) came home at noon (duck).When
the cat sat on a mat from pee
Th

This already gives a much more interesting output! But perhaps we want to control 

Now how can we _control_ how much of the distribution we sample.

## Temperature

**Temperature** controls how the distribution is sampled. It's best shown in the context of the examples above

- Temperature 0: Completely flattens the distrubution, all probability is given to the token with the largest score
- Temperature 1: Standard softmax distrubution, same as sampling above

By increasing the temperature, we increase the chance of a token with a lower probability getting picked.

In [12]:
def generate(prompt, num_tokens, temperature=1.0, seed=42):
    torch.manual_seed(seed)
    tokens = tokenize(prompt)
    temperature = max(temperature, 1e-8) # temperature 0 => divide by _very small_ constant
    for i in range(num_tokens):
        logits = model(tokens).logits[0, -1]
        probs = F.softmax(logits / temperature, dim=-1) # divide scores, flattening distribution
        next_token = torch.multinomial(probs, 1)
        tokens = torch.cat([ tokens, next_token[None] ], dim=1)
    return decode(tokens[0])

for temp in torch.arange(0, 2.2, 0.2):
    print(f"\n### {temp.item():.1f} ###")
    print(generate(prompt, num_tokens=20, temperature=temp))



### 0.0 ###
the cat sat on a mat
the cat sat on a mat
the cat sat on a mat
the cat sat on a

### 0.2 ###
the cat sat on a mat
the cat sat on a mat
the cat sat on a mat
the cat sat on a

### 0.4 ###
the cat sat on a mat
the cat sat on a mat
the cat sat on a mat
the cat sat on a

### 0.6 ###
the cat sat on a mat
the cat sat on a mat
the cat sat on a mat
the cat sat on a

### 0.8 ###
the cat sat on a mat from the beginning of time (the pre-colonial period) (1969-1970)


### 1.0 ###
the cat sat on a mat from pee
The owner (the man) came home at noon (duck).When we were

### 1.2 ###
the cat sat on a mat from pee over it, (the blood doesn't fade at all! he wasn't using blood pressure

### 1.4 ###
the cat sat on a mat from pee overмосячна по елеруму_KPapasonêduckdourdanination.The

### 1.6 ###
the cat sat on a mat(нула共同мосячна постілербалиenorasonêduckisticallyisiemptination.The

### 1.8 ###
the cat sat on a mat�sнула共同мосячна постілерб_KPenorasonêduckisticallyisiemptination.The

### 2.0 ##

As the temperature increases, less likely tokens are predicted, which can lead to more interesting output. Setting the temperature hyperparameter correctly can be key to model performance.

## Top K

Another parameter used in sampling is `top_k`. This essentially limits the model predicting too "wild" predictions by limiting the probability distribution to the top k results.

A.k.a currently we are sampling from the entire distribution of 50,000 tokens. But it makes sense that only the top 50 tokens are reasonable continuations

In [13]:
def generate(prompt, num_tokens, temperature=1.0, top_k=50, seed=42):
    torch.manual_seed(seed)
    tokens = tokenize(prompt)
    temperature = max(temperature, 1e-8)
    for i in range(num_tokens):
        logits = model(tokens).logits[0, -1]
        if top_k:
            logits, idxs = logits.topk(top_k) # Sample only topk tokens
        else:
            idxs = torch.arange(len(logits), device=device) # All idxs
    
        probs = F.softmax(logits / temperature, dim=-1)
        next_token = idxs[torch.multinomial(probs, 1)] # we use the idxs of topk only
        tokens = torch.cat([ tokens, next_token[None] ], dim=1)

        tokens = torch.cat([ tokens, next_token[None] ], dim=1)
    return decode(tokens[0])

for temp in torch.arange(0, 2.2, 0.2):
    print(f"\n### Temperature {temp.item():.1f} ###")
    print(generate(prompt, num_tokens=20, temperature=temp))


### Temperature 0.0 ###
the cat sat on a mat

##11

##11

##11

##11

##11

##11

##

### Temperature 0.2 ###
the cat sat on a mat

##11

##11

##11

##11

##11

##11

##

### Temperature 0.4 ###
the cat sat on a mat of of grass grass

thethe cat cat sat sat on on a a mat mat of of grass grass

thethe cat cat sat sat on on on on a a mat mat of of

### Temperature 0.6 ###
the cat sat on a mat of of grass grass

thethe cat cat sat sat on on mat mat of of grass grass

thethe cat cat sat sat on on mat mat of of grass grass

thethe

### Temperature 0.8 ###
the cat sat on a mat of of grass grass

QuestionQuestion:: The The cat cat sat sat on on a a mat mat of of

AnswerAnswer::##

### Temperature 1.0 ###
the cat sat on a mat of of many many stripes stripes of of this this cat cat

thethe cat cat sat sat on on a a mat mat of ofmanymany many many much much many many many many of of

### Temperature 1.2 ###
the cat sat on a mat of of many many stripes stripes of of this this cat cat

thethe ca

You can see at even very high temperatures, the output does not devolve into gibberish.

## Min P

Top K can often be a to naive heuristic for sampling. A more common technique nowdays is to instead dispose of tokens that have too low probability.

We do this by computing the fraction of the of the probability of a token compared to the most probable token.

A.k.a If the most probable token has 60% proability and we have `min_p = 0.1`, we dispose of all tokens with a probability less than 6%.

In [14]:
def generate(
    prompt,
    num_tokens,
    temperature=1.0,
    top_k=None,
    min_p=None,
    seed=42
):
    torch.manual_seed(seed)
    tokens = tokenize(prompt)
    temperature = max(temperature, 1e-8)
    for i in range(num_tokens):
        logits = model(tokens).logits[0, -1]
        if top_k:
            logits, idxs = logits.topk(top_k)
        else:
            idxs = torch.arange(len(logits), device=device)

        # TODO: temperature before or after min_p?
        probs = F.softmax(logits / temperature, dim=-1)

        if min_p is not None:
            mask = probs >= (probs.max() * min_p) 
            idxs, probs = idxs[mask], probs[mask]

        next_token = idxs[torch.multinomial(probs, 1)]
        tokens = torch.cat([ tokens, next_token[None] ], dim=1)
    return decode(tokens[0])

for min_p in reversed(torch.logspace(start=math.log10(0.01), end=math.log10(0.5), steps=10, base=10)):
    print(f"\n### Min P: {min_p.item():.2f} ###")
    print(generate(prompt, num_tokens=20, temperature=1.5, min_p=min_p))


### Min P: 0.50 ###
the cat sat on a mat
I’ve been thinking a lot about the cat sat on a mat. It’s a phrase that

### Min P: 0.32 ###
the cat sat on a mat
A cat sat on a mat. The cat sat on a mat. The cat sat on a

### Min P: 0.21 ###
the cat sat on a mat.
I am so thankful to have the opportunity to write for you! I am looking forward to this

### Min P: 0.14 ###
the cat sat on a mat on a mat.
He is on the mat!
He is sitting on the cat.
The cat is

### Min P: 0.09 ###
the cat sat on a mat (it’s not my house).
I remember sitting in a room, an office perhaps, and the

### Min P: 0.06 ###
the cat sat on a mat the
What is the the cat sat on a mat the formula? In this regard, how can

### Min P: 0.04 ###
the cat sat on a mat

When you interact with a mat that does not respond to you, you assume that they are a

### Min P: 0.02 ###
the cat sat on a mat? - When We Have No Clothes, Will That Mean We Have a LUXURY BOWL?

### Min P: 0.02 ###
the cat sat on a mat? he scared the mat! now, all you

## Frequency Penalty

As we've seen at low temperatures, the model has a tendancy to repeat itself. For this we can apply a frequency penalty to discourage the model from predicting the same token again.

higher frequency -> higher penalty. If token not in sequence, count will be 0 and no penalty applied

In [15]:
def generate(
    prompt,
    num_tokens,
    temperature=1.0,
    top_k=None,
    min_p=None,
    frequency_penalty=None,
    seed=42,
):
    torch.manual_seed(seed)
    tokens = tokenize(prompt)
    temperature = max(temperature, 1e-8)
    for i in range(num_tokens):
        logits = model(tokens).logits[0, -1]

        if frequency_penalty:
            *_, vocab_size = logits.shape
            # get frequency of each of the logits in the current output
            id_freqs = torch.bincount(tokens[0], minlength=vocab_size)
            logits -= frequency_penalty * id_freqs

        if top_k:
            logits, idxs = logits.topk(top_k)
        else:
            idxs = torch.arange(len(logits), device=device)

        # TODO: temperature before or after min_p?
        probs = F.softmax(logits / temperature, dim=-1)

        if min_p is not None:
            mask = probs >= (probs.max() * min_p) 
            idxs, probs = idxs[mask], probs[mask]

        next_token = idxs[torch.multinomial(probs, 1)]
        tokens = torch.cat([ tokens, next_token[None] ], dim=1)
    return decode(tokens[0])

for freq_penalty in torch.linspace(start=0, end=1., steps=6):
    print(f"\n### Frequency Penalty {freq_penalty.item():.1f} ###")
    print(generate(prompt, num_tokens=20, temperature=0., frequency_penalty=freq_penalty))



### Frequency Penalty 0.0 ###
the cat sat on a mat
the cat sat on a mat
the cat sat on a mat
the cat sat on a

### Frequency Penalty 0.2 ###
the cat sat on a mat
the cat sat on a mat
the cat sat on a mat
the cat sat on a

### Frequency Penalty 0.4 ###
the cat sat on a mat
the cat sat on a mat
The cat sat on a mat. The cat sat on a

### Frequency Penalty 0.6 ###
the cat sat on a mat
The cat sat on a mat. The cat sat on a mat. The cat sat on a

### Frequency Penalty 0.8 ###
the cat sat on a mat
The cat sat on a mat. The cat sat on the mat. The cat sat on the

### Frequency Penalty 1.0 ###
the cat sat on a mat
The cat sat on a mat. The cat was very happy. The mouse was very sad.



## Soft Sampling

GPT2 has tied embeddings, so this should be easy!

In [16]:
def generate(
    prompt,
    num_tokens,
    temperature=1.0,
    top_k=None,
    min_p=None,
    frequency_penalty=None,
    soft_sample=False,
    seed=42,
):
    torch.manual_seed(seed)
    tokens = tokenize(prompt)
    inputs_embeds = W_E(tokens)
    temperature = max(temperature, 1e-8)
    for i in range(num_tokens):
        logits = model(inputs_embeds=inputs_embeds).logits[0, -1]

        if frequency_penalty:
            *_, vocab_size = logits.shape
            # get frequency of each of the logits in the current output
            id_freqs = torch.bincount(tokens[0], minlength=vocab_size)
            logits -= frequency_penalty * id_freqs

        if top_k:
            logits, idxs = logits.topk(top_k)
        else:
            idxs = torch.arange(len(logits), device=device)

        # TODO: temperature before or after min_p?
        probs = F.softmax(logits / temperature, dim=-1)
        
        if min_p is not None:
            mask = probs >= (probs.max() * min_p) 
            idxs, probs = idxs[mask], probs[mask]

        next_token = idxs[torch.multinomial(probs, 1)]
        tokens = torch.cat([ tokens, next_token[None] ], dim=1)

        if soft_sample:
            # tied embeddings
            next_embed = probs[None] @ W_E.weight[idxs]
        else:
            next_embed = W_E(next_token)
            
        inputs_embeds = torch.cat( [ inputs_embeds, next_embed[None] ], dim=1)
    
    return decode(tokens[0])

In [17]:
for soft_sample in False, True:
    print(f"\n\n==== SOFT SAMPLE: {soft_sample} ====")
    for min_p in reversed(torch.logspace(start=math.log10(0.01), end=math.log10(0.5), steps=10, base=10)):
        print(f"\n### Min P: {min_p.item():.2f} ###")
        print(generate(prompt, num_tokens=20, temperature=1., min_p=min_p, frequency_penalty=1., soft_sample=soft_sample))



==== SOFT SAMPLE: False ====

### Min P: 0.50 ###
the cat sat on a mat
I have been reading a lot about the importance of being present. It’s not just for meditation

### Min P: 0.32 ###
the cat sat on a mat
A cat sat on a mat. The cat was very pleased with herself. She had got her

### Min P: 0.21 ###
the cat sat on a mat
A cat sat on a mat. The cat was very fat, but it wasn't very happy

### Min P: 0.14 ###
the cat sat on a mat
A cat sat on a mat. The cat sat there for 5 minutes, then jumped up

### Min P: 0.09 ###
the cat sat on a mat.
A little girl said, "Mama, can I have a cookie?" Her mother replied,

### Min P: 0.06 ###
the cat sat on a mat that was not warm enough, so I put him in the freezer, which is why he’s not

### Min P: 0.04 ###
the cat sat on a mat is cute!
I don't know what it is about cats that I love, but I just do

### Min P: 0.02 ###
the cat sat on a mat (an example of what the cat could eat if she chose to) with the child's favourite food

### Min P: 0.02 ###
the

## Mock Training

In [34]:
import torch
import torch.nn.functional as F
from fastcore.meta import delegates

def mk_proba_dist(
    logits, # (batch_size, d_vocab)
    temperature=1.0,
    top_k=None,
    min_p=None,
):
    batch_size, d_vocab = logits.shape
    device = logits.device
    if top_k:
        logits, idxs = logits.topk(top_k, dim=-1)
    else:
        idxs = (
            torch.arange(d_vocab, device=device)
            .repeat(batch_size)
            .reshape(batch_size, d_vocab)
        )

    # TODO: temperature before or after min_p?
    probs = F.softmax(logits / temperature, dim=-1)

    if min_p is not None:
        max_probs = probs.max(dim=-1, keepdim=True).values
        threshold = max_probs * min_p
        mask = probs >= threshold
        probs = probs * mask
        probs = probs / probs.sum(dim=-1, keepdim=True) # renormalize
        idxs = idxs * mask
    return idxs, probs

@delegates(mk_proba_dist)
def soft_sampling_train_step(
    model,
    batch, # tokens of shape (batch_size, seq_len)
    W_E, # model's embedding matrix
    guidance_alpha, # guidance weighting -- 1 equivalent to discrete sampling
    **kwargs, # passed to mk_proba_dist
):
    "Single train step using soft sampling"
    assert 0 <= guidance_alpha <= 1
    batch_size, seq_len = batch.shape
    device = batch.device

    # cache
    past_key_values = None
    position_ids = torch.zeros(batch_size, 1, dtype=torch.long, device=device)


    loss = torch.tensor(0., device=device)
    embeds = W_E[batch[:, :1]]  # BOS shape: (batch_size, 1, d_model)
    tokens = [ batch[:, :1].detach().cpu() ]
    for t in range(1, seq_len):
        outputs = model(
            inputs_embeds=embeds,
            past_key_values=past_key_values,
            position_ids=position_ids,
            use_cache=True
        )

        logits_t = outputs.logits[:, -1]
        past_key_values = outputs.past_key_values

        i_t, p_t = mk_proba_dist(logits_t, **kwargs)

        # loss
        loss_t = F.cross_entropy(p_t, batch[:, t])
        loss += loss_t

        # discrete sample -- for logging
        indices = torch.multinomial(p_t, 1) # (batch_size, 1)
        batch_indices = torch.arange(batch_size)[:, None] # (batch_size, 1)
        next_token = i_t[batch_indices, indices].detach().cpu()
        tokens.append(next_token)

        # soft sample
        next_emb_soft = p_t @ W_E      # soft sampling
        next_emb_gt = W_E[batch[:, t]] # guidance sampling

        next_embed = (
            guidance_alpha * next_emb_gt +
            (1 - guidance_alpha) * next_emb_soft
        )
        embeds = torch.cat([embeds, next_embed[:, None, :]], dim=1)
        position_ids += 1

    if return_tokens:
        tokens = torch.cat(tokens, dim=1)
    # normalize gradient: sum batch, mean sequence length
    loss /= seq_len
    return loss, tokens

prompts = [
    "the cat sat in the hat",
    "where did all the hats go?",
    "I am the rizzler",
]

batch = tokenize(prompts)
loss, tokens = soft_sampling_train_step(
    model,
    batch,
    W_E.weight,
    guidance_alpha=0.5,
    return_tokens=True,
    min_p=0.1
)
print(f"Loss: {loss.item():.4f}")
for token_samp in tokens:
    print("---")
    print(decode(token_samp))


Loss: 10.2525
---
Question
 (y the cat the
---
def
 you
 two
 where
---
Question  is number
inic


## Broken training loop

In [27]:
from dataclasses import dataclass
from functools import partial
from tqdm import tqdm

@dataclass
class TrainConfig:
    # training
    max_lr: float = 1e-3
    max_samp_alpha: 0.7
    lr_warmup_frac: float = 0.1
    samp_alpha_warmup_frac: float = 0.5
    # model
    W_E_k: tuple[str, str]
    # sampling
    temperature: float = 1.0
    top_k: float | None = None
    min_p: float | None = None
    frequency_penalty: float | None = None
    # logging / checkpoint
    log_every: int | None = None
    # validation
    val_every: int | None = None
    val_samples: int | None = None

    def get_W_E(model):
        mod_k, submod_k = W_E_k
        return getattr(getattr(model, mod_k), submod_k)

class DataLoaders:
    def __init__(train_dl, val_dl): self.train, self.val = train_dl, val_dl
        
def linear_sched(t, max_t):
    return t / max_t

@torch.no_grad()
def validate(model, tokenizer, val_dl, cfg: TrainConfig):
    losses = []
    samples = []
    for step, batch in enumerate(tqdm(val_dl)):
        loss, tokens = do_step(model, batch, cfg, return_tokens=True)
        if step < cfg.val_samples: samples.append(tokens)
        losses.append(loss)
    print(f"Val Loss: {losses.mean().item():.4f}")

def train(model, tokenizer, dls: DataLoaders, cfg: TrainConfig):
    W_E = cfg.get_W_E(model)
    total_steps = len(dl)
    for step, batch in enumerate(tqdm(dls.train)):
        lr = cfg.max_lr * linear_sched(step, total_steps)
        opt.param_groups[0]['lr'] = lr
        opt.zero_grad()
        loss, _ = do_step(model, batch, W_E, sched_sample_alpha, return_tokens=False)
        loss.backward()
        opt.step()
        if step % cfg.log_every == 0: print(f"Train Loss: {loss.item():.4f}")
        if step % cfg.val_every == 0: validate(model, tokenizer, dls.valid, cfg)
        

cfg = TrainConfig(
    W_E_k=("model", "embed_tokens")
)
dls = DataLoaders(train_dl, val_dl)
train(model, tokenizer, dls, cfg)

TypeError: non-default argument 'max_samp_alpha' follows default argument