In [None]:
import torch
import matplotlib.pyplot as plt
import numpy as np

num_datasets = 10
num_points_in_each_dataset = 100

# Let's fix our constants
a = 0.1
b = 1.0

# Now we go through the above up from the bottom, first we define our samples of w
ws = torch.distributions.Normal(torch.tensor([0.0]), b).sample((num_datasets,)) # a tensor of shape (num_datasets, 1)

# For each constant we generate `num_points_in_each_dataset` many x's and y's
xs = torch.rand(num_datasets, num_points_in_each_dataset, 1)
ys = torch.distributions.Normal(torch.einsum('nmf, nf -> nm', xs, ws), a).sample()


In [None]:
for dataset_index in range(num_datasets):
    plt.scatter(xs[dataset_index,:,0].numpy(), ys[dataset_index].numpy())


In [None]:
def sample_from_prior(num_datasets = 10, num_features=1, num_points_in_each_dataset = 100,
                      hyperparameters={'a': 0.1, 'b': 1.0}):
    ws = torch.distributions.Normal(torch.zeros(num_features+1), hyperparameters['b']).sample((num_datasets,)) # a tensor of shape (num_datasets, num_features+1)

    xs = torch.rand(num_datasets, num_points_in_each_dataset, num_features)
    ys = torch.distributions.Normal(
        torch.einsum('nmf, nf -> nm',
                     torch.cat([xs,torch.ones(num_datasets, num_points_in_each_dataset,1)],2),
                     ws
                    ),
        hyperparameters['a']
    ).sample()
    return xs, ys


In [None]:
xs, ys = sample_from_prior()
for dataset_index in range(num_datasets):
    plt.scatter(xs[dataset_index,:,0].numpy(), ys[dataset_index].numpy())


In [None]:
# in our convention we name the `num_datasets` -> `batch_size`, and the `num_points_in_each_dataset` -> `seq_len`

def get_batch_for_ridge_regression(batch_size=2,seq_len=100,num_features=1,
                                   hyperparameters=None, device='cpu', **kwargs):
    if hyperparameters is None:
        hyperparameters = {'a': 0.1, 'b': 1.0}
    ws = torch.distributions.Normal(torch.zeros(num_features+1), hyperparameters['b']).sample((batch_size,))

    xs = torch.rand(batch_size, seq_len, num_features)
    concatenated_xs = torch.cat([xs,torch.ones(batch_size, seq_len,1)],2)

    ys = torch.distributions.Normal(
        torch.einsum('nmf, nf -> nm',
                     concatenated_xs,
                     ws
                    ),
        hyperparameters['a']
    ).sample()[..., None]

    # Simple return format for TinyPFN
    return {'x': concatenated_xs.to(device), 'y': ys.to(device), 'target_y': ys.to(device)}


In [None]:
from tiny_pfn import TinyPFN
import torch.nn as nn

def train_a_tiny_pfn(get_batch_function, epochs=10, max_dataset_size=20, batch_size=16, steps_per_epoch=100):
    
    # Create TinyPFN model with confidence intervals
    model = TinyPFN(
        num_features=2,  # 1 feature + bias
        d_model=64,
        n_heads=4,
        dropout=0.1,
        max_seq_len=max_dataset_size,
        output_mode='distributional',  # Enable confidence intervals
        n_mixture_components=3
    )
    
    optimizer = torch.optim.Adam(model.parameters(), lr=0.0003)
    
    print(f"TinyPFN with {sum(p.numel() for p in model.parameters()):,} parameters")
    
    # Training loop
    model.train()
    losses = []
    
    for epoch in range(epochs):
        epoch_losses = []
        
        for step in range(steps_per_epoch):
            # Generate new synthetic data each step - this is the key PFN innovation
            batch = get_batch_function(
                batch_size=batch_size, 
                seq_len=max_dataset_size,
                num_features=1,
                hyperparameters={'a': 0.1, 'b': 1.0}
            )
            
            # Split into train/test (like original PFN)
            train_len = torch.randint(2, max_dataset_size-2, (1,)).item()
            x_train = batch['x'][:, :train_len, :]
            y_train = batch['y'][:, :train_len, :]
            x_test = batch['x'][:, train_len:, :]
            y_test = batch['y'][:, train_len:, :]
            
            # Forward pass
            optimizer.zero_grad()
            predictions = model(x_train, y_train, x_test)
            
            # Custom distributional loss
            weights, means, stds = model.get_distribution_params(predictions)
            
            # Negative log-likelihood for mixture of Gaussians
            targets = y_test.unsqueeze(-1)
            log_probs = []
            for i in range(model.n_mixture_components):
                component_dist = torch.distributions.Normal(means[..., i], stds[..., i])
                log_prob = component_dist.log_prob(targets.squeeze(-1))
                log_probs.append(log_prob)
            
            log_probs = torch.stack(log_probs, dim=-1)
            weighted_log_probs = torch.log(weights) + log_probs
            mixture_log_prob = torch.logsumexp(weighted_log_probs, dim=-1)
            loss = -mixture_log_prob.mean()
            
            # Backward pass
            loss.backward()
            optimizer.step()
            
            epoch_losses.append(loss.item())
        
        avg_loss = np.mean(epoch_losses)
        losses.append(avg_loss)
        
        if epoch % 2 == 0:
            print(f"Epoch {epoch:3d}/{epochs} | Loss: {avg_loss:.4f}")
    
    return model, losses

trained_model = train_a_tiny_pfn(get_batch_for_ridge_regression, epochs=10)


In [None]:
# let's sample some datasets to look at
batch = get_batch_for_ridge_regression(seq_len=100, batch_size=10)


In [None]:
# our model wants the seq dimension first, remember that!

batch_index = 0 # change this to see other examples
num_training_points = 4

train_x = batch['x'][batch_index, :num_training_points]
train_y = batch['y'][batch_index, :num_training_points]
test_x = batch['x'][batch_index]

with torch.no_grad():
    # we add our batch dimension, as our transformer always expects that
    predictions = trained_model(train_x[None], train_y[None], test_x[None])

# the model outputs mixture parameters, we need to extract means and confidence intervals
pred_means = trained_model.mean(predictions)[0]
pred_confs = trained_model.quantile(predictions, quantiles=[0.1, 0.9])[0]

plt.scatter(train_x[...,0],train_y.squeeze())
order_test_x = test_x[...,0].argsort()
plt.plot(test_x[order_test_x, 0],pred_means[order_test_x], color='green', label='TinyPFN')
plt.fill_between(test_x[order_test_x,0], pred_confs[order_test_x, 0], pred_confs[order_test_x, 1], alpha=.1, color='green')

import sklearn.linear_model

ridge_model = sklearn.linear_model.Ridge(alpha=(a/b)**2)
ridge_model.fit(train_x,train_y.squeeze())
plt.plot(test_x[order_test_x, 0], ridge_model.predict(test_x[order_test_x]), label='ridge regression')
plt.legend()
plt.plot();
