In [45]:
from model import Model
import numpy as np
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
import sklearn.datasets
import sklearn.preprocessing
import catboost
import time

## 📊 Dataset

In [30]:
dataset = sklearn.datasets.fetch_california_housing()

dataset['data'] = sklearn.preprocessing.QuantileTransformer(output_distribution='normal').fit_transform(dataset['data'])

In [31]:
all_idx = np.arange(len(dataset['data']))
trainval_idx, test_idx = sklearn.model_selection.train_test_split(
    all_idx, train_size=0.8
)
train_idx, val_idx = sklearn.model_selection.train_test_split(
    trainval_idx, train_size=0.8
)
data_mx = {
    'train': {'x': mx.array(dataset['data'][train_idx]), 'y': mx.array(dataset['target'][train_idx])},
    'val': {'x': mx.array(dataset['data'][val_idx]), 'y': mx.array(dataset['target'][val_idx])},
    'test': {'x': mx.array(dataset['data'][test_idx]), 'y': mx.array(dataset['target'][test_idx])},
}

## 🤖 Model

In [36]:
def compute_bins(data, n_bins):
    bins = [
        np.unique(q)
        for q in np.nanquantile(
            data, np.linspace(0.0, 1.0, n_bins), axis=0
        ).T
    ]
    return bins

bins = compute_bins(data_mx['train']['x'], 30)

In [37]:
config = {
    'n_num_features': data_mx['train']['x'].shape[1],
    'n_classes': None,
    'backbone': {'n_blocks': 3, 'd_block': 576, 'dropout': 0.25},
    'arch_type': 'tabm-mini',
    'cat_cardinalities': [],
    'k': 32,
    'share_training_batches': False,
    'num_embeddings': {'d_embedding': 16, 'bins': bins}
}

model = Model(**config)

In [38]:
from mlx.utils import tree_flatten
params = tree_flatten(model.parameters())
total = sum(x[1].size for x in params)
print(f'Number of parametrs: {total}')

Number of parametrs: 766232


## 📈 Training

In [39]:
from tqdm import tqdm
from IPython.display import clear_output

def loss_fn_mse(model, X, y):
	if config['share_training_batches']:
		return mx.mean(nn.losses.mse_loss(model(X).flatten(0, 1)[:, 0], mx.repeat(y, repeats=config['k'])))
	else:
		return mx.mean(nn.losses.mse_loss(model(X).flatten(0, 1)[:, 0], y))
	
def batch_iterate(batch_size, X, y):
    perm = mx.array(np.random.permutation(y.size))
    for s in range(0, y.size, batch_size):
        ids = perm[s : s + batch_size]
        yield X[ids], y[ids]

batch_size = 256

mx.eval(model.parameters())
loss_and_grad_fn = nn.value_and_grad(model, loss_fn_mse)
optimizer = optim.AdamW(learning_rate=0.0003)
train_size = data_mx['train']['x'].shape[0]
split_indexes = [i for i in range(batch_size, train_size, batch_size)]
best_epoch = {'val_loss': 1e9, 'test_loss': 1e9}
cur_patience = 16
patience = cur_patience
for e in tqdm(range(1000000)):
    batches = (
        mx.random.permutation(train_size).split(split_indexes)
        if config['share_training_batches']
        else
        [x_.transpose(0, 1).flatten() for x_ in mx.split(mx.argsort(mx.random.normal((model.k, train_size)), axis=1), split_indexes, axis=1)]
    )
    model.train()
    for batch_indexes in batches:
        loss, grads = loss_and_grad_fn(model, data_mx['train']['x'][batch_indexes], data_mx['train']['y'][batch_indexes])

        clipped_grads, total_norm = optim.clip_grad_norm(grads, max_norm=1.0)

        optimizer.update(model, clipped_grads)

        mx.eval(model.state)
    model.eval()
    val_loss = ((model(data_mx['val']['x']).squeeze().mean(axis=1) - data_mx['val']['y']) ** 2).mean().sqrt()
    test_loss = ((model(data_mx['test']['x']).squeeze().mean(axis=1) - data_mx['test']['y']) ** 2).mean().sqrt()
    epoch = {'val_loss': val_loss.item(), 'test_loss': test_loss.item(), 'epoch_num': e}
    if epoch['val_loss'] < best_epoch['val_loss']:
        clear_output()
        print('New best epoch:', epoch)
        best_epoch = epoch
        cur_patience = patience
    else:
        cur_patience -= 1
        if not cur_patience:
            break

best_epoch

  0%|          | 191/1000000 [09:09<960:20:29,  3.46s/it]

New best epoch: {'val_loss': 0.44075900316238403, 'test_loss': 0.4432174861431122, 'epoch_num': 190}


  0%|          | 206/1000000 [10:02<811:39:06,  2.92s/it] 


{'val_loss': 0.44075900316238403,
 'test_loss': 0.4432174861431122,
 'epoch_num': 190}

## 🐈 Catboost 

In [67]:
model_cat = catboost.CatBoostRegressor()
model_cat.fit(data_mx['train']['x'].tolist(), data_mx['train']['y'].tolist(), eval_set=(data_mx['val']['x'].tolist(), data_mx['val']['y'].tolist()), verbose=100)

Learning rate set to 0.076361
0:	learn: 1.1046602	test: 1.1226143	best: 1.1226143 (0)	total: 2.48ms	remaining: 2.48s
100:	learn: 0.5037056	test: 0.5284944	best: 0.5284944 (100)	total: 111ms	remaining: 984ms
200:	learn: 0.4462318	test: 0.4904789	best: 0.4904789 (200)	total: 212ms	remaining: 843ms
300:	learn: 0.4143273	test: 0.4763013	best: 0.4763013 (300)	total: 317ms	remaining: 736ms
400:	learn: 0.3908162	test: 0.4657481	best: 0.4657481 (400)	total: 429ms	remaining: 640ms
500:	learn: 0.3731543	test: 0.4602580	best: 0.4602580 (500)	total: 538ms	remaining: 536ms
600:	learn: 0.3573078	test: 0.4562292	best: 0.4562292 (600)	total: 638ms	remaining: 424ms
700:	learn: 0.3434931	test: 0.4524826	best: 0.4524826 (700)	total: 754ms	remaining: 321ms
800:	learn: 0.3323861	test: 0.4506490	best: 0.4505686 (796)	total: 853ms	remaining: 212ms
900:	learn: 0.3217677	test: 0.4485375	best: 0.4485375 (900)	total: 952ms	remaining: 105ms
999:	learn: 0.3116026	test: 0.4456904	best: 0.4456904 (999)	total: 1.04s	

<catboost.core.CatBoostRegressor at 0x13d7c82c0>

In [68]:
print(f'Test error: {np.sqrt(((model_cat.predict(data_mx['test']['x'].tolist()) - data_mx['test']['y']) ** 2).mean())}')

Test error: 0.444754201770917


## ⏱️ Benchmarks

In [73]:
start = time.time()

for i in range(train_size):
    model(data_mx['train']['x'][[i]])

print(f'TabM mini throughput: {train_size / (time.time() - start)}')

TabM mini throughput: 19689.7415275832


In [74]:
data = data_mx['train']['x'].tolist()

start = time.time()

for i in range(train_size):
    model_cat.predict([data[i]])

print(f'Catboost throughput: {train_size / (time.time() - start)}')

Catboost throughput: 7834.57003081785
