# Easy Tutorial on Distributed Training in JAX with PyTorch Frame Data Loading

Author: Yiwen Yuan
Suggested Device: TPU VM v3-8

[PyTorch Frame](https://github.com/pyg-team/pytorch-frame) is a deep learning extension for PyTorch, designed for heterogeneous tabular data with different column type. It is super easy to load tables into PyTorch Frame and perform necessary stats computation for different semantic types. 

This tutorial is a simple example of distributed training in JAX using PyTorch Frame as data loader.

In [1]:
%%capture
!pip install pytorch_frame

In [2]:
# Import necessary libraries

import time
from functools import partial

import jax
import jax.numpy as jnp
import pandas as pd
import torch
from jax import jit, random, value_and_grad, vmap
from jax.nn import swish
from tqdm import tqdm

from torch_frame import categorical, numerical, timestamp
from torch_frame.data import DataLoader, Dataset
from torch_frame.data.stats import StatType

In [3]:
# Constants <3

NUM_FEAT = 24
NUM_DEVICES = jax.device_count()
BATCH_SIZE = 512
# Two layer simple MLP
LAYER_SIZES = [24, 1024, 1024, 1]
PARAM_SCALE = 0.01
INIT_LR = 0.0001
DECAY_RATE = 0.95
DECAY_STEPS = 5
NUM_EPOCHS = 5
TARGET_COL = 'Premium Amount'

E0000 00:00:1733984901.725559      77 common_lib.cc:798] Could not set metric server port: INVALID_ARGUMENT: Could not find SliceBuilder port 8471 in any of the 0 ports provided in `tpu_process_addresses`="local"
=== Source Location Trace: === 
learning/45eac/tfrc/runtime/common_lib.cc:479
E1212 06:28:21.760614440      77 oauth2_credentials.cc:238]            oauth_fetch: UNKNOWN:C-ares status is not ARES_SUCCESS qtype=A name=metadata.google.internal. is_balancer=0: Domain name not found {grpc_status:2, created_time:"2024-12-12T06:28:21.760595515+00:00"}


In [4]:
print(f"Number of devices: {NUM_DEVICES}")

Number of devices: 8


**Data Loading with PyTorch Frame**

To load data with PyTorch Frame, you need to specify the semantic column types.

In [5]:
col_to_stype = {
    'Age': numerical,
    'Annual Income': numerical,
    'Marital Status': categorical,
    'Number of Dependents': numerical,
    'Education Level': categorical,
    'Occupation': categorical,
    'Health Score': numerical,
    'Location': categorical,
    'Policy Type': categorical,
    'Previous Claims': numerical,
    'Vehicle Age': numerical,
    'Credit Score': numerical,
    'Insurance Duration': numerical,
    'Policy Start Date': timestamp,
    'Customer Feedback': categorical,
    'Smoking Status': categorical,
    'Exercise Frequency': categorical,
    'Property Type': categorical
}
test_dataset = Dataset(df=pd.read_csv('/kaggle/input/playground-series-s4e12/test.csv'),
                       col_to_stype=col_to_stype)
col_to_stype = col_to_stype.copy()
col_to_stype['Premium Amount'] = numerical

dataset = Dataset(df=pd.read_csv('/kaggle/input/playground-series-s4e12/train.csv'),
                  col_to_stype=col_to_stype, target_col=TARGET_COL)

# Saves the materialized tensors for easy reuse
dataset.materialize(path='/kaggle/working/data.pt')
test_dataset.materialize(path='/kaggle/working/test_data.pt')

Dataset()

**Nan Imputation**

The original data contains nans. We need to fill the nan values. We use MEAN to impute nan for numerical columns, MODE for categorical columns and MEDIAN for timestamp columns. The stats are already calculated as part of the PyTorch Frame data materialization process.

In [6]:
# Use mean value to fill nans in numerical columns
# Use mode value to fill nans in categorical columns
# Use newest time to fill nans in timestamp columns
fill_vals = []
means = []
stds = []
for stype in [numerical, categorical, timestamp]:
    col_names = dataset.tensor_frame.col_names_dict[stype]
    for col_name in col_names:
        if stype == numerical:
            fill_vals.append(dataset.col_stats[col_name][StatType.MEAN])
            means.append(dataset.col_stats[col_name][StatType.MEAN])
            stds.append(dataset.col_stats[col_name][StatType.STD])
        elif stype == categorical:
            fill_vals.append(0)
            means.append(0.)
            stds.append(1.)
        elif stype == timestamp:
            fill_vals += dataset.col_stats[col_name][StatType.NEWEST_TIME]
            means += dataset.col_stats[col_name][StatType.MEDIAN_TIME]
            stds += [1.] * 7
        else:
            raise ValueError("Unsupported stype")
fill_vals = jnp.array(fill_vals)
means = jnp.array(means)
stds = jnp.array(stds)

y_mean = torch.mean(dataset.tensor_frame.y).cpu().numpy()
y_std = torch.std(dataset.tensor_frame.y).cpu().numpy()

**2-layer MLP in JAX with distributed training capabilities**

In [7]:
def init_network_params(sizes, key=random.PRNGKey(0), scale=1e-2):
    def random_layer_params(m, n, key, scale=1e-2):
        w_key, b_key = random.split(key)
        return (scale * random.normal(w_key, (n, m)),
                scale * random.normal(b_key, (n, )))

    keys = random.split(key, len(sizes))
    return [
        random_layer_params(m, n, k, scale)
        for m, n, k in zip(sizes[:-1], sizes[1:], keys)
    ]

def predict(params, image):
    activations = image
    for w, b in params[:-1]:
        outputs = jnp.dot(w, activations) + b
        activations = swish(outputs)

    final_w, final_b = params[-1]
    logits = jnp.dot(final_w, activations) + final_b
    return logits

batched_predict = vmap(predict, in_axes=(None, 0))

def loss(params, images, targets):
    logits = batched_predict(params, images)
    return jnp.mean(jnp.abs(logits - targets))

@partial(jax.pmap, axis_name='devices', in_axes=(None, 0, 0, None),
         out_axes=(None, 0))
def update(params, x, y, epoch_number):
    loss_value, grads = value_and_grad(loss)(params, x, y)
    grads = [(jax.lax.psum(dw, 'devices') / NUM_DEVICES,
              jax.lax.psum(db, 'devices') / NUM_DEVICES)
             for dw, db in grads]
    lr = INIT_LR * DECAY_RATE**(epoch_number / DECAY_STEPS)
    return [(w - lr * dw, b - lr * db)
            for (w, b), (dw, db) in zip(params, grads)], loss_value

@jit
def batched_mae(params, images, targets):
    images = jnp.reshape(images, (len(images), NUM_FEAT))
    predicted_targets = batched_predict(params, images)
    return jnp.mean(jnp.abs(predicted_targets - targets))


def mae(params, data_loader):
    maes = []
    for tf in data_loader:
        x = torch.cat([
            tf.feat_dict[categorical], tf.feat_dict[numerical],
            tf.feat_dict[timestamp].squeeze(1)
        ], dim=1).numpy()
        y = (tf.y.numpy() - y_mean) / y_std
        nan_mask = jnp.isnan(x)
        if x.any():
            x = jnp.where(nan_mask, fill_vals, x)
        x = (x - means) / stds
        maes.append(batched_mae(params, x, y))
    return jnp.mean(jnp.array(maes))

In [8]:
init_params = init_network_params(LAYER_SIZES, random.PRNGKey(0),
                                  scale=PARAM_SCALE)

**Actual Training**

In [9]:
train_loader = DataLoader(dataset.tensor_frame, batch_size=BATCH_SIZE,
                          shuffle=True, drop_last=True)
params = init_params
for epoch in range(1, NUM_EPOCHS + 1):
    start_time = time.time()
    losses = []
    for tf in tqdm(train_loader):
        x = torch.cat([
            tf.feat_dict[categorical], tf.feat_dict[numerical],
            tf.feat_dict[timestamp].squeeze(1)
        ], dim=1).numpy()
        y = (tf.y.numpy() - y_mean) / y_std
        nan_mask = jnp.isnan(x)
        if nan_mask.any():
            x = jnp.where(nan_mask, fill_vals, x)
        x = (x - means) / stds
        x = jnp.reshape(x, (NUM_DEVICES, BATCH_SIZE // NUM_DEVICES, NUM_FEAT))
        y = jnp.reshape(y, (NUM_DEVICES, BATCH_SIZE // NUM_DEVICES, 1))
        params, loss_value = update(params, x, y, epoch)
        losses.append(jnp.sum(loss_value))
    epoch_time = time.time() - start_time
    train_mae = mae(params, train_loader)
    print(f"Epoch {epoch} in {epoch_time:0.2f} sec")
    print(f"Training set loss {jnp.mean(jnp.array(losses))}")
    print(f"Training set mae {train_mae}")


  0%|          | 0/2343 [00:00<?, ?it/s]

  0%|          | 1/2343 [00:00<28:35,  1.37it/s]

  1%|          | 18/2343 [00:00<01:20, 28.81it/s]

  1%|▏         | 35/2343 [00:00<00:41, 55.29it/s]

  2%|▏         | 52/2343 [00:01<00:28, 79.52it/s]

  3%|▎         | 69/2343 [00:01<00:22, 100.12it/s]

  4%|▎         | 86/2343 [00:01<00:19, 116.98it/s]

  4%|▍         | 103/2343 [00:01<00:17, 130.60it/s]

  5%|▌         | 120/2343 [00:01<00:15, 140.88it/s]

  6%|▌         | 137/2343 [00:01<00:14, 148.60it/s]

  7%|▋         | 154/2343 [00:01<00:14, 153.17it/s]

  7%|▋         | 172/2343 [00:01<00:13, 158.62it/s]

  8%|▊         | 190/2343 [00:01<00:13, 162.57it/s]

  9%|▉         | 207/2343 [00:01<00:13, 163.41it/s]

 10%|▉         | 224/2343 [00:02<00:12, 163.01it/s]

 10%|█         | 241/2343 [00:02<00:12, 164.16it/s]

 11%|█         | 258/2343 [00:02<00:12, 162.34it/s]

 12%|█▏        | 275/2343 [00:02<00:12, 164.22it/s]

 13%|█▎        | 293/2343 [00:02<00:12, 166.32it/s]

 13%|█▎        | 311/2343 [00:02<00:12, 167.76it/s]

 14%|█▍        | 329/2343 [00:02<00:11, 169.44it/s]

 15%|█▍        | 347/2343 [00:02<00:11, 170.57it/s]

 16%|█▌        | 365/2343 [00:02<00:11, 170.80it/s]

 16%|█▋        | 383/2343 [00:03<00:11, 171.21it/s]

 17%|█▋        | 401/2343 [00:03<00:11, 171.65it/s]

 18%|█▊        | 419/2343 [00:03<00:12, 156.73it/s]

 19%|█▊        | 436/2343 [00:03<00:11, 159.24it/s]

 19%|█▉        | 454/2343 [00:03<00:11, 162.53it/s]

 20%|██        | 471/2343 [00:03<00:11, 163.98it/s]

 21%|██        | 488/2343 [00:03<00:11, 165.34it/s]

 22%|██▏       | 506/2343 [00:03<00:11, 166.90it/s]

 22%|██▏       | 524/2343 [00:03<00:10, 168.24it/s]

 23%|██▎       | 541/2343 [00:03<00:10, 167.93it/s]

 24%|██▍       | 558/2343 [00:04<00:10, 167.80it/s]

 25%|██▍       | 575/2343 [00:04<00:10, 167.27it/s]

 25%|██▌       | 592/2343 [00:04<00:10, 165.77it/s]

 26%|██▌       | 610/2343 [00:04<00:10, 167.90it/s]

 27%|██▋       | 628/2343 [00:04<00:10, 169.23it/s]

 28%|██▊       | 646/2343 [00:04<00:09, 170.23it/s]

 28%|██▊       | 664/2343 [00:04<00:09, 171.16it/s]

 29%|██▉       | 682/2343 [00:04<00:09, 169.30it/s]

 30%|██▉       | 699/2343 [00:04<00:09, 168.70it/s]

 31%|███       | 716/2343 [00:05<00:09, 168.25it/s]

 31%|███▏      | 733/2343 [00:05<00:09, 168.00it/s]

 32%|███▏      | 751/2343 [00:05<00:09, 168.83it/s]

 33%|███▎      | 768/2343 [00:05<00:09, 165.66it/s]

 34%|███▎      | 785/2343 [00:05<00:09, 164.69it/s]

 34%|███▍      | 802/2343 [00:05<00:09, 164.64it/s]

 35%|███▍      | 819/2343 [00:05<00:09, 165.07it/s]

 36%|███▌      | 837/2343 [00:05<00:09, 166.73it/s]

 36%|███▋      | 855/2343 [00:05<00:08, 168.51it/s]

 37%|███▋      | 872/2343 [00:05<00:08, 168.32it/s]

 38%|███▊      | 889/2343 [00:06<00:08, 168.61it/s]

 39%|███▊      | 906/2343 [00:06<00:08, 168.57it/s]

 39%|███▉      | 923/2343 [00:06<00:08, 164.10it/s]

 40%|████      | 940/2343 [00:06<00:08, 164.92it/s]

 41%|████      | 958/2343 [00:06<00:08, 166.85it/s]

 42%|████▏     | 976/2343 [00:06<00:08, 168.20it/s]

 42%|████▏     | 994/2343 [00:06<00:07, 170.10it/s]

 43%|████▎     | 1012/2343 [00:06<00:07, 169.06it/s]

 44%|████▍     | 1029/2343 [00:06<00:07, 167.88it/s]

 45%|████▍     | 1046/2343 [00:06<00:07, 168.00it/s]

 45%|████▌     | 1063/2343 [00:07<00:07, 168.24it/s]

 46%|████▌     | 1080/2343 [00:07<00:07, 168.36it/s]

 47%|████▋     | 1097/2343 [00:07<00:07, 165.45it/s]

 48%|████▊     | 1114/2343 [00:07<00:07, 166.36it/s]

 48%|████▊     | 1131/2343 [00:07<00:07, 166.79it/s]

 49%|████▉     | 1148/2343 [00:07<00:07, 167.37it/s]

 50%|████▉     | 1165/2343 [00:07<00:07, 167.67it/s]

 50%|█████     | 1182/2343 [00:07<00:06, 167.80it/s]

 51%|█████     | 1199/2343 [00:07<00:06, 168.38it/s]

 52%|█████▏    | 1216/2343 [00:07<00:06, 167.55it/s]

 53%|█████▎    | 1234/2343 [00:08<00:06, 168.56it/s]

 53%|█████▎    | 1251/2343 [00:08<00:06, 157.73it/s]

 54%|█████▍    | 1267/2343 [00:08<00:07, 142.70it/s]

 55%|█████▍    | 1284/2343 [00:08<00:07, 148.91it/s]

 56%|█████▌    | 1301/2343 [00:08<00:06, 154.43it/s]

 56%|█████▋    | 1318/2343 [00:08<00:06, 158.54it/s]

 57%|█████▋    | 1335/2343 [00:08<00:06, 161.01it/s]

 58%|█████▊    | 1352/2343 [00:08<00:06, 162.63it/s]

 58%|█████▊    | 1369/2343 [00:08<00:05, 162.67it/s]

 59%|█████▉    | 1386/2343 [00:09<00:05, 163.59it/s]

 60%|█████▉    | 1403/2343 [00:09<00:05, 164.33it/s]

 61%|██████    | 1420/2343 [00:09<00:05, 161.18it/s]

 61%|██████▏   | 1437/2343 [00:09<00:05, 162.67it/s]

 62%|██████▏   | 1454/2343 [00:09<00:05, 164.02it/s]

 63%|██████▎   | 1471/2343 [00:09<00:05, 165.11it/s]

 64%|██████▎   | 1488/2343 [00:09<00:05, 166.54it/s]

 64%|██████▍   | 1506/2343 [00:09<00:04, 168.69it/s]

 65%|██████▌   | 1524/2343 [00:09<00:04, 169.69it/s]

 66%|██████▌   | 1541/2343 [00:10<00:04, 168.31it/s]

 66%|██████▋   | 1558/2343 [00:10<00:04, 167.86it/s]

 67%|██████▋   | 1575/2343 [00:10<00:04, 167.84it/s]

 68%|██████▊   | 1592/2343 [00:10<00:04, 166.45it/s]

 69%|██████▊   | 1609/2343 [00:10<00:04, 166.12it/s]

 69%|██████▉   | 1626/2343 [00:10<00:04, 165.38it/s]

 70%|███████   | 1643/2343 [00:10<00:04, 164.90it/s]

 71%|███████   | 1660/2343 [00:10<00:04, 164.59it/s]

 72%|███████▏  | 1678/2343 [00:10<00:03, 166.56it/s]

 72%|███████▏  | 1695/2343 [00:10<00:03, 165.87it/s]

 73%|███████▎  | 1712/2343 [00:11<00:03, 166.26it/s]

 74%|███████▍  | 1729/2343 [00:11<00:03, 166.08it/s]

 75%|███████▍  | 1747/2343 [00:11<00:03, 167.39it/s]

 75%|███████▌  | 1765/2343 [00:11<00:03, 168.32it/s]

 76%|███████▌  | 1783/2343 [00:11<00:03, 169.30it/s]

 77%|███████▋  | 1800/2343 [00:11<00:03, 169.24it/s]

 78%|███████▊  | 1817/2343 [00:11<00:03, 169.11it/s]

 78%|███████▊  | 1834/2343 [00:11<00:03, 166.38it/s]

 79%|███████▉  | 1851/2343 [00:11<00:02, 165.90it/s]

 80%|███████▉  | 1868/2343 [00:11<00:02, 165.67it/s]

 80%|████████  | 1885/2343 [00:12<00:02, 166.52it/s]

 81%|████████  | 1902/2343 [00:12<00:02, 164.96it/s]

 82%|████████▏ | 1919/2343 [00:12<00:02, 163.76it/s]

 83%|████████▎ | 1936/2343 [00:12<00:02, 163.18it/s]

 83%|████████▎ | 1953/2343 [00:12<00:02, 164.14it/s]

 84%|████████▍ | 1970/2343 [00:12<00:02, 165.01it/s]

 85%|████████▍ | 1987/2343 [00:12<00:02, 164.56it/s]

 86%|████████▌ | 2004/2343 [00:12<00:02, 165.30it/s]

 86%|████████▋ | 2021/2343 [00:12<00:01, 165.14it/s]

 87%|████████▋ | 2038/2343 [00:13<00:01, 164.88it/s]

 88%|████████▊ | 2055/2343 [00:13<00:01, 165.16it/s]

 88%|████████▊ | 2072/2343 [00:13<00:01, 159.29it/s]

 89%|████████▉ | 2090/2343 [00:13<00:01, 163.20it/s]

 90%|████████▉ | 2108/2343 [00:13<00:01, 166.31it/s]

 91%|█████████ | 2126/2343 [00:13<00:01, 168.90it/s]

 92%|█████████▏| 2145/2343 [00:13<00:01, 173.42it/s]

 92%|█████████▏| 2164/2343 [00:13<00:01, 175.83it/s]

 93%|█████████▎| 2183/2343 [00:13<00:00, 177.45it/s]

 94%|█████████▍| 2202/2343 [00:13<00:00, 178.48it/s]

 95%|█████████▍| 2220/2343 [00:14<00:00, 178.70it/s]

 96%|█████████▌| 2238/2343 [00:14<00:00, 178.34it/s]

 96%|█████████▋| 2256/2343 [00:14<00:00, 175.82it/s]

 97%|█████████▋| 2274/2343 [00:14<00:00, 173.64it/s]

 98%|█████████▊| 2292/2343 [00:14<00:00, 174.32it/s]

 99%|█████████▊| 2311/2343 [00:14<00:00, 176.73it/s]

 99%|█████████▉| 2330/2343 [00:14<00:00, 178.02it/s]

100%|██████████| 2343/2343 [00:14<00:00, 157.66it/s]




Epoch 1 in 14.87 sec


Training set loss 8.360461235046387
Training set mae 0.749265193939209


  0%|          | 0/2343 [00:00<?, ?it/s]

  0%|          | 2/2343 [00:00<02:01, 19.29it/s]

  1%|          | 19/2343 [00:00<00:22, 104.42it/s]

  2%|▏         | 36/2343 [00:00<00:17, 133.03it/s]

  2%|▏         | 53/2343 [00:00<00:15, 146.99it/s]

  3%|▎         | 70/2343 [00:00<00:14, 154.74it/s]

  4%|▎         | 87/2343 [00:00<00:14, 159.63it/s]

  4%|▍         | 104/2343 [00:00<00:13, 162.23it/s]

  5%|▌         | 122/2343 [00:00<00:13, 164.69it/s]

  6%|▌         | 140/2343 [00:00<00:13, 166.84it/s]

  7%|▋         | 158/2343 [00:01<00:13, 167.94it/s]

  7%|▋         | 175/2343 [00:01<00:12, 167.30it/s]

  8%|▊         | 192/2343 [00:01<00:12, 168.01it/s]

  9%|▉         | 209/2343 [00:01<00:12, 167.37it/s]

 10%|▉         | 226/2343 [00:01<00:13, 160.90it/s]

 10%|█         | 243/2343 [00:01<00:13, 158.73it/s]

 11%|█         | 261/2343 [00:01<00:12, 161.92it/s]

 12%|█▏        | 278/2343 [00:01<00:12, 162.62it/s]

 13%|█▎        | 295/2343 [00:01<00:12, 163.29it/s]

 13%|█▎        | 312/2343 [00:01<00:12, 164.39it/s]

 14%|█▍        | 329/2343 [00:02<00:12, 165.31it/s]

 15%|█▍        | 346/2343 [00:02<00:12, 165.49it/s]

 15%|█▌        | 363/2343 [00:02<00:11, 166.09it/s]

 16%|█▌        | 380/2343 [00:02<00:11, 166.26it/s]

 17%|█▋        | 397/2343 [00:02<00:11, 166.44it/s]

 18%|█▊        | 414/2343 [00:02<00:11, 166.36it/s]

 18%|█▊        | 431/2343 [00:02<00:11, 166.39it/s]

 19%|█▉        | 448/2343 [00:02<00:11, 166.33it/s]

 20%|█▉        | 465/2343 [00:02<00:11, 167.27it/s]

 21%|██        | 482/2343 [00:02<00:11, 167.73it/s]

 21%|██▏       | 499/2343 [00:03<00:11, 167.42it/s]

 22%|██▏       | 516/2343 [00:03<00:10, 167.74it/s]

 23%|██▎       | 533/2343 [00:03<00:10, 168.34it/s]

 23%|██▎       | 550/2343 [00:03<00:10, 168.30it/s]

 24%|██▍       | 567/2343 [00:03<00:10, 167.92it/s]

 25%|██▍       | 584/2343 [00:03<00:10, 167.79it/s]

 26%|██▌       | 601/2343 [00:03<00:11, 152.04it/s]

 26%|██▋       | 619/2343 [00:03<00:10, 157.33it/s]

 27%|██▋       | 637/2343 [00:03<00:10, 161.19it/s]

 28%|██▊       | 654/2343 [00:04<00:10, 163.34it/s]

 29%|██▊       | 671/2343 [00:04<00:10, 165.24it/s]

 29%|██▉       | 688/2343 [00:04<00:09, 166.52it/s]

 30%|███       | 705/2343 [00:04<00:09, 167.49it/s]

 31%|███       | 722/2343 [00:04<00:09, 167.32it/s]

 32%|███▏      | 740/2343 [00:04<00:09, 168.45it/s]

 32%|███▏      | 758/2343 [00:04<00:09, 170.04it/s]

 33%|███▎      | 776/2343 [00:04<00:09, 170.69it/s]

 34%|███▍      | 795/2343 [00:04<00:08, 174.29it/s]

 35%|███▍      | 813/2343 [00:04<00:08, 174.09it/s]

 35%|███▌      | 831/2343 [00:05<00:08, 173.22it/s]

 36%|███▌      | 849/2343 [00:05<00:08, 172.92it/s]

 37%|███▋      | 867/2343 [00:05<00:08, 173.55it/s]

 38%|███▊      | 885/2343 [00:05<00:08, 173.09it/s]

 39%|███▊      | 903/2343 [00:05<00:08, 171.74it/s]

 39%|███▉      | 921/2343 [00:05<00:08, 170.88it/s]

 40%|████      | 939/2343 [00:05<00:08, 170.99it/s]

 41%|████      | 957/2343 [00:05<00:08, 170.95it/s]

 42%|████▏     | 975/2343 [00:05<00:07, 171.62it/s]

 42%|████▏     | 993/2343 [00:06<00:07, 172.21it/s]

 43%|████▎     | 1011/2343 [00:06<00:07, 171.97it/s]

 44%|████▍     | 1029/2343 [00:06<00:07, 170.24it/s]

 45%|████▍     | 1047/2343 [00:06<00:07, 168.89it/s]

 45%|████▌     | 1064/2343 [00:06<00:07, 168.32it/s]

 46%|████▌     | 1081/2343 [00:06<00:07, 167.98it/s]

 47%|████▋     | 1098/2343 [00:06<00:07, 167.73it/s]

 48%|████▊     | 1116/2343 [00:06<00:07, 168.94it/s]

 48%|████▊     | 1133/2343 [00:06<00:07, 168.83it/s]

 49%|████▉     | 1151/2343 [00:06<00:07, 169.27it/s]

 50%|████▉     | 1168/2343 [00:07<00:06, 169.12it/s]

 51%|█████     | 1186/2343 [00:07<00:06, 170.53it/s]

 51%|█████▏    | 1204/2343 [00:07<00:06, 170.87it/s]

 52%|█████▏    | 1222/2343 [00:07<00:06, 170.67it/s]

 53%|█████▎    | 1240/2343 [00:07<00:06, 170.57it/s]

 54%|█████▎    | 1258/2343 [00:07<00:06, 170.82it/s]

 54%|█████▍    | 1276/2343 [00:07<00:06, 168.92it/s]

 55%|█████▌    | 1293/2343 [00:07<00:06, 168.42it/s]

 56%|█████▌    | 1310/2343 [00:07<00:06, 168.38it/s]

 57%|█████▋    | 1328/2343 [00:07<00:05, 169.97it/s]

 57%|█████▋    | 1346/2343 [00:08<00:05, 171.76it/s]

 58%|█████▊    | 1364/2343 [00:08<00:05, 172.96it/s]

 59%|█████▉    | 1382/2343 [00:08<00:05, 173.58it/s]

 60%|█████▉    | 1400/2343 [00:08<00:05, 169.70it/s]

 61%|██████    | 1418/2343 [00:08<00:05, 170.23it/s]

 61%|██████▏   | 1436/2343 [00:08<00:05, 171.36it/s]

 62%|██████▏   | 1454/2343 [00:08<00:05, 166.57it/s]

 63%|██████▎   | 1472/2343 [00:08<00:05, 168.95it/s]

 64%|██████▎   | 1490/2343 [00:08<00:04, 171.53it/s]

 64%|██████▍   | 1508/2343 [00:09<00:04, 173.63it/s]

 65%|██████▌   | 1526/2343 [00:09<00:04, 175.02it/s]

 66%|██████▌   | 1545/2343 [00:09<00:04, 177.26it/s]

 67%|██████▋   | 1563/2343 [00:09<00:04, 174.62it/s]

 67%|██████▋   | 1581/2343 [00:09<00:04, 172.76it/s]

 68%|██████▊   | 1599/2343 [00:09<00:04, 172.47it/s]

 69%|██████▉   | 1617/2343 [00:09<00:04, 174.44it/s]

 70%|██████▉   | 1635/2343 [00:09<00:04, 173.75it/s]

 71%|███████   | 1653/2343 [00:09<00:04, 169.58it/s]

 71%|███████▏  | 1671/2343 [00:09<00:03, 170.84it/s]

 72%|███████▏  | 1689/2343 [00:10<00:03, 171.52it/s]

 73%|███████▎  | 1707/2343 [00:10<00:03, 171.99it/s]

 74%|███████▎  | 1725/2343 [00:10<00:03, 171.91it/s]

 74%|███████▍  | 1743/2343 [00:10<00:03, 172.40it/s]

 75%|███████▌  | 1761/2343 [00:10<00:03, 172.37it/s]

 76%|███████▌  | 1779/2343 [00:10<00:03, 172.51it/s]

 77%|███████▋  | 1797/2343 [00:10<00:03, 173.02it/s]

 77%|███████▋  | 1815/2343 [00:10<00:03, 172.42it/s]

 78%|███████▊  | 1833/2343 [00:10<00:02, 172.18it/s]

 79%|███████▉  | 1851/2343 [00:11<00:02, 172.13it/s]

 80%|███████▉  | 1869/2343 [00:11<00:02, 172.57it/s]

 81%|████████  | 1887/2343 [00:11<00:02, 172.85it/s]

 81%|████████▏ | 1905/2343 [00:11<00:02, 172.49it/s]

 82%|████████▏ | 1923/2343 [00:11<00:02, 172.70it/s]

 83%|████████▎ | 1941/2343 [00:11<00:02, 172.14it/s]

 84%|████████▎ | 1959/2343 [00:11<00:02, 171.73it/s]

 84%|████████▍ | 1977/2343 [00:11<00:02, 171.31it/s]

 85%|████████▌ | 1995/2343 [00:11<00:02, 171.14it/s]

 86%|████████▌ | 2013/2343 [00:11<00:01, 172.45it/s]

 87%|████████▋ | 2031/2343 [00:12<00:01, 173.99it/s]

 87%|████████▋ | 2049/2343 [00:12<00:01, 174.26it/s]

 88%|████████▊ | 2067/2343 [00:12<00:01, 174.64it/s]

 89%|████████▉ | 2085/2343 [00:12<00:01, 175.16it/s]

 90%|████████▉ | 2103/2343 [00:12<00:01, 175.00it/s]

 91%|█████████ | 2121/2343 [00:12<00:01, 174.36it/s]

 91%|█████████▏| 2139/2343 [00:12<00:01, 173.23it/s]

 92%|█████████▏| 2157/2343 [00:12<00:01, 173.31it/s]

 93%|█████████▎| 2175/2343 [00:12<00:00, 172.43it/s]

 94%|█████████▎| 2193/2343 [00:13<00:00, 172.85it/s]

 94%|█████████▍| 2211/2343 [00:13<00:00, 172.84it/s]

 95%|█████████▌| 2229/2343 [00:13<00:00, 170.76it/s]

 96%|█████████▌| 2247/2343 [00:13<00:00, 168.33it/s]

 97%|█████████▋| 2264/2343 [00:13<00:00, 166.80it/s]

 97%|█████████▋| 2281/2343 [00:13<00:00, 162.25it/s]

 98%|█████████▊| 2298/2343 [00:13<00:00, 155.52it/s]

 99%|█████████▉| 2316/2343 [00:13<00:00, 161.30it/s]

100%|█████████▉| 2334/2343 [00:13<00:00, 165.69it/s]

100%|██████████| 2343/2343 [00:14<00:00, 166.91it/s]




Epoch 2 in 14.05 sec


Training set loss 5.904703617095947
Training set mae 0.7481929063796997


  0%|          | 0/2343 [00:00<?, ?it/s]

  0%|          | 2/2343 [00:00<01:57, 19.88it/s]

  1%|          | 19/2343 [00:00<00:21, 105.72it/s]

  2%|▏         | 36/2343 [00:00<00:17, 133.48it/s]

  2%|▏         | 50/2343 [00:00<00:17, 133.86it/s]

  3%|▎         | 67/2343 [00:00<00:15, 146.32it/s]

  4%|▎         | 85/2343 [00:00<00:14, 154.97it/s]

  4%|▍         | 103/2343 [00:00<00:13, 160.25it/s]

  5%|▌         | 121/2343 [00:00<00:13, 164.21it/s]

  6%|▌         | 139/2343 [00:00<00:13, 168.40it/s]

  7%|▋         | 157/2343 [00:01<00:12, 170.48it/s]

  7%|▋         | 175/2343 [00:01<00:12, 170.78it/s]

  8%|▊         | 193/2343 [00:01<00:12, 170.82it/s]

  9%|▉         | 212/2343 [00:01<00:12, 175.27it/s]

 10%|▉         | 231/2343 [00:01<00:11, 179.22it/s]

 11%|█         | 250/2343 [00:01<00:11, 180.60it/s]

 11%|█▏        | 269/2343 [00:01<00:11, 177.86it/s]

 12%|█▏        | 287/2343 [00:01<00:11, 175.63it/s]

 13%|█▎        | 305/2343 [00:01<00:11, 174.68it/s]

 14%|█▍        | 323/2343 [00:01<00:11, 174.97it/s]

 15%|█▍        | 341/2343 [00:02<00:11, 173.71it/s]

 15%|█▌        | 359/2343 [00:02<00:11, 165.66it/s]

 16%|█▌        | 376/2343 [00:02<00:11, 164.71it/s]

 17%|█▋        | 393/2343 [00:02<00:11, 164.39it/s]

 17%|█▋        | 410/2343 [00:02<00:11, 164.44it/s]

 18%|█▊        | 427/2343 [00:02<00:11, 165.24it/s]

 19%|█▉        | 444/2343 [00:02<00:11, 166.48it/s]

 20%|█▉        | 461/2343 [00:02<00:11, 165.61it/s]

 20%|██        | 478/2343 [00:02<00:11, 165.78it/s]

 21%|██        | 495/2343 [00:03<00:11, 166.15it/s]

 22%|██▏       | 512/2343 [00:03<00:11, 166.31it/s]

 23%|██▎       | 529/2343 [00:03<00:10, 166.61it/s]

 23%|██▎       | 546/2343 [00:03<00:10, 166.75it/s]

 24%|██▍       | 563/2343 [00:03<00:10, 164.61it/s]

 25%|██▍       | 580/2343 [00:03<00:10, 164.85it/s]

 25%|██▌       | 597/2343 [00:03<00:10, 165.07it/s]

 26%|██▌       | 614/2343 [00:03<00:10, 165.06it/s]

 27%|██▋       | 631/2343 [00:03<00:10, 166.03it/s]

 28%|██▊       | 648/2343 [00:03<00:10, 166.73it/s]

 28%|██▊       | 665/2343 [00:04<00:10, 167.55it/s]

 29%|██▉       | 682/2343 [00:04<00:09, 166.98it/s]

 30%|██▉       | 700/2343 [00:04<00:09, 169.22it/s]

 31%|███       | 718/2343 [00:04<00:09, 171.03it/s]

 31%|███▏      | 736/2343 [00:04<00:09, 170.76it/s]

 32%|███▏      | 754/2343 [00:04<00:09, 170.05it/s]

 33%|███▎      | 772/2343 [00:04<00:09, 167.66it/s]

 34%|███▎      | 789/2343 [00:04<00:09, 166.95it/s]

 34%|███▍      | 806/2343 [00:04<00:09, 167.60it/s]

 35%|███▌      | 825/2343 [00:04<00:08, 172.13it/s]

 36%|███▌      | 843/2343 [00:05<00:08, 173.11it/s]

 37%|███▋      | 861/2343 [00:05<00:08, 171.43it/s]

 38%|███▊      | 879/2343 [00:05<00:08, 168.88it/s]

 38%|███▊      | 896/2343 [00:05<00:09, 158.95it/s]

 39%|███▉      | 914/2343 [00:05<00:08, 162.35it/s]

 40%|███▉      | 932/2343 [00:05<00:08, 166.18it/s]

 41%|████      | 950/2343 [00:05<00:08, 167.78it/s]

 41%|████▏     | 968/2343 [00:05<00:08, 168.96it/s]

 42%|████▏     | 986/2343 [00:05<00:07, 169.68it/s]

 43%|████▎     | 1003/2343 [00:06<00:08, 167.13it/s]

 44%|████▎     | 1020/2343 [00:06<00:07, 167.20it/s]

 44%|████▍     | 1037/2343 [00:06<00:07, 167.15it/s]

 45%|████▍     | 1054/2343 [00:06<00:07, 167.01it/s]

 46%|████▌     | 1071/2343 [00:06<00:07, 166.09it/s]

 46%|████▋     | 1088/2343 [00:06<00:07, 164.92it/s]

 47%|████▋     | 1106/2343 [00:06<00:07, 166.74it/s]

 48%|████▊     | 1123/2343 [00:06<00:07, 167.41it/s]

 49%|████▊     | 1141/2343 [00:06<00:07, 170.28it/s]

 49%|████▉     | 1159/2343 [00:06<00:06, 171.15it/s]

 50%|█████     | 1177/2343 [00:07<00:06, 170.82it/s]

 51%|█████     | 1195/2343 [00:07<00:06, 170.44it/s]

 52%|█████▏    | 1213/2343 [00:07<00:06, 171.43it/s]

 53%|█████▎    | 1231/2343 [00:07<00:06, 169.71it/s]

 53%|█████▎    | 1249/2343 [00:07<00:06, 171.08it/s]

 54%|█████▍    | 1268/2343 [00:07<00:06, 173.88it/s]

 55%|█████▍    | 1287/2343 [00:07<00:05, 177.19it/s]

 56%|█████▌    | 1306/2343 [00:07<00:05, 180.33it/s]

 57%|█████▋    | 1325/2343 [00:07<00:05, 182.32it/s]

 57%|█████▋    | 1344/2343 [00:08<00:05, 183.36it/s]

 58%|█████▊    | 1363/2343 [00:08<00:05, 183.40it/s]

 59%|█████▉    | 1382/2343 [00:08<00:05, 178.16it/s]

 60%|█████▉    | 1400/2343 [00:08<00:05, 174.61it/s]

 61%|██████    | 1418/2343 [00:08<00:05, 171.73it/s]

 61%|██████▏   | 1436/2343 [00:08<00:05, 171.53it/s]

 62%|██████▏   | 1454/2343 [00:08<00:05, 171.45it/s]

 63%|██████▎   | 1472/2343 [00:08<00:05, 170.72it/s]

 64%|██████▎   | 1490/2343 [00:08<00:05, 170.01it/s]

 64%|██████▍   | 1508/2343 [00:08<00:04, 169.44it/s]

 65%|██████▌   | 1526/2343 [00:09<00:04, 170.31it/s]

 66%|██████▌   | 1544/2343 [00:09<00:04, 170.85it/s]

 67%|██████▋   | 1562/2343 [00:09<00:04, 170.75it/s]

 67%|██████▋   | 1580/2343 [00:09<00:04, 170.94it/s]

 68%|██████▊   | 1598/2343 [00:09<00:04, 171.25it/s]

 69%|██████▉   | 1616/2343 [00:09<00:04, 169.58it/s]

 70%|██████▉   | 1634/2343 [00:09<00:04, 169.78it/s]

 70%|███████   | 1651/2343 [00:09<00:04, 169.67it/s]

 71%|███████   | 1668/2343 [00:09<00:04, 167.71it/s]

 72%|███████▏  | 1685/2343 [00:10<00:03, 168.01it/s]

 73%|███████▎  | 1702/2343 [00:10<00:03, 168.19it/s]

 73%|███████▎  | 1719/2343 [00:10<00:03, 167.43it/s]

 74%|███████▍  | 1736/2343 [00:10<00:03, 167.70it/s]

 75%|███████▍  | 1753/2343 [00:10<00:03, 159.37it/s]

 76%|███████▌  | 1770/2343 [00:10<00:03, 161.54it/s]

 76%|███████▋  | 1787/2343 [00:10<00:03, 163.14it/s]

 77%|███████▋  | 1804/2343 [00:10<00:03, 164.00it/s]

 78%|███████▊  | 1821/2343 [00:10<00:03, 165.36it/s]

 78%|███████▊  | 1838/2343 [00:10<00:03, 166.69it/s]

 79%|███████▉  | 1855/2343 [00:11<00:02, 167.62it/s]

 80%|███████▉  | 1873/2343 [00:11<00:02, 168.51it/s]

 81%|████████  | 1890/2343 [00:11<00:02, 168.03it/s]

 81%|████████▏ | 1907/2343 [00:11<00:02, 168.28it/s]

 82%|████████▏ | 1924/2343 [00:11<00:02, 168.53it/s]

 83%|████████▎ | 1942/2343 [00:11<00:02, 169.04it/s]

 84%|████████▎ | 1959/2343 [00:11<00:02, 168.66it/s]

 84%|████████▍ | 1977/2343 [00:11<00:02, 169.98it/s]

 85%|████████▌ | 1994/2343 [00:11<00:02, 169.88it/s]

 86%|████████▌ | 2011/2343 [00:11<00:01, 169.10it/s]

 87%|████████▋ | 2029/2343 [00:12<00:01, 169.90it/s]

 87%|████████▋ | 2047/2343 [00:12<00:01, 170.81it/s]

 88%|████████▊ | 2065/2343 [00:12<00:01, 170.74it/s]

 89%|████████▉ | 2083/2343 [00:12<00:01, 170.17it/s]

 90%|████████▉ | 2101/2343 [00:12<00:01, 169.97it/s]

 90%|█████████ | 2118/2343 [00:12<00:01, 169.49it/s]

 91%|█████████ | 2135/2343 [00:12<00:01, 166.74it/s]

 92%|█████████▏| 2152/2343 [00:12<00:01, 166.44it/s]

 93%|█████████▎| 2169/2343 [00:12<00:01, 166.35it/s]

 93%|█████████▎| 2186/2343 [00:13<00:00, 165.58it/s]

 94%|█████████▍| 2203/2343 [00:13<00:00, 165.32it/s]

 95%|█████████▍| 2220/2343 [00:13<00:00, 165.08it/s]

 95%|█████████▌| 2237/2343 [00:13<00:00, 164.40it/s]

 96%|█████████▌| 2254/2343 [00:13<00:00, 165.02it/s]

 97%|█████████▋| 2271/2343 [00:13<00:00, 165.67it/s]

 98%|█████████▊| 2289/2343 [00:13<00:00, 168.51it/s]

 98%|█████████▊| 2307/2343 [00:13<00:00, 169.33it/s]

 99%|█████████▉| 2325/2343 [00:13<00:00, 169.59it/s]

100%|█████████▉| 2342/2343 [00:13<00:00, 169.56it/s]

100%|██████████| 2343/2343 [00:14<00:00, 166.49it/s]




Epoch 3 in 14.08 sec


Training set loss 5.9025983810424805
Training set mae 0.7476887106895447


  0%|          | 0/2343 [00:00<?, ?it/s]

  0%|          | 4/2343 [00:00<01:00, 38.51it/s]

  1%|          | 23/2343 [00:00<00:18, 122.61it/s]

  2%|▏         | 42/2343 [00:00<00:15, 149.75it/s]

  3%|▎         | 60/2343 [00:00<00:14, 161.31it/s]

  3%|▎         | 77/2343 [00:00<00:13, 163.98it/s]

  4%|▍         | 94/2343 [00:00<00:13, 165.97it/s]

  5%|▍         | 111/2343 [00:00<00:13, 167.04it/s]

  6%|▌         | 129/2343 [00:00<00:13, 168.09it/s]

  6%|▋         | 147/2343 [00:00<00:12, 170.12it/s]

  7%|▋         | 165/2343 [00:01<00:12, 170.11it/s]

  8%|▊         | 183/2343 [00:01<00:12, 168.22it/s]

  9%|▊         | 201/2343 [00:01<00:12, 168.99it/s]

  9%|▉         | 219/2343 [00:01<00:12, 169.37it/s]

 10%|█         | 236/2343 [00:01<00:12, 169.13it/s]

 11%|█         | 253/2343 [00:01<00:12, 169.28it/s]

 12%|█▏        | 270/2343 [00:01<00:12, 167.65it/s]

 12%|█▏        | 287/2343 [00:01<00:12, 166.78it/s]

 13%|█▎        | 304/2343 [00:01<00:12, 166.55it/s]

 14%|█▎        | 321/2343 [00:01<00:12, 166.54it/s]

 14%|█▍        | 338/2343 [00:02<00:12, 166.24it/s]

 15%|█▌        | 355/2343 [00:02<00:12, 157.58it/s]

 16%|█▌        | 372/2343 [00:02<00:12, 160.85it/s]

 17%|█▋        | 389/2343 [00:02<00:11, 163.11it/s]

 17%|█▋        | 406/2343 [00:02<00:11, 164.31it/s]

 18%|█▊        | 423/2343 [00:02<00:11, 165.59it/s]

 19%|█▉        | 440/2343 [00:02<00:11, 166.83it/s]

 20%|█▉        | 457/2343 [00:02<00:11, 166.98it/s]

 20%|██        | 474/2343 [00:02<00:11, 167.15it/s]

 21%|██        | 491/2343 [00:02<00:11, 166.17it/s]

 22%|██▏       | 508/2343 [00:03<00:11, 165.92it/s]

 22%|██▏       | 525/2343 [00:03<00:10, 166.06it/s]

 23%|██▎       | 542/2343 [00:03<00:10, 165.55it/s]

 24%|██▍       | 559/2343 [00:03<00:10, 164.93it/s]

 25%|██▍       | 576/2343 [00:03<00:10, 166.12it/s]

 25%|██▌       | 594/2343 [00:03<00:10, 167.36it/s]

 26%|██▌       | 611/2343 [00:03<00:10, 167.69it/s]

 27%|██▋       | 628/2343 [00:03<00:10, 162.18it/s]

 28%|██▊       | 645/2343 [00:03<00:10, 163.67it/s]

 28%|██▊       | 662/2343 [00:04<00:10, 165.00it/s]

 29%|██▉       | 679/2343 [00:04<00:10, 166.10it/s]

 30%|██▉       | 696/2343 [00:04<00:09, 166.51it/s]

 30%|███       | 714/2343 [00:04<00:09, 169.13it/s]

 31%|███       | 731/2343 [00:04<00:09, 168.93it/s]

 32%|███▏      | 748/2343 [00:04<00:09, 168.94it/s]

 33%|███▎      | 765/2343 [00:04<00:09, 168.48it/s]

 33%|███▎      | 782/2343 [00:04<00:09, 165.17it/s]

 34%|███▍      | 799/2343 [00:04<00:09, 164.66it/s]

 35%|███▍      | 816/2343 [00:04<00:09, 164.75it/s]

 36%|███▌      | 833/2343 [00:05<00:09, 165.32it/s]

 36%|███▋      | 851/2343 [00:05<00:08, 166.80it/s]

 37%|███▋      | 868/2343 [00:05<00:08, 166.09it/s]

 38%|███▊      | 885/2343 [00:05<00:08, 166.19it/s]

 38%|███▊      | 902/2343 [00:05<00:08, 166.75it/s]

 39%|███▉      | 919/2343 [00:05<00:08, 167.23it/s]

 40%|███▉      | 936/2343 [00:05<00:08, 167.40it/s]

 41%|████      | 953/2343 [00:05<00:08, 167.85it/s]

 41%|████▏     | 970/2343 [00:05<00:08, 167.77it/s]

 42%|████▏     | 987/2343 [00:05<00:08, 167.41it/s]

 43%|████▎     | 1005/2343 [00:06<00:07, 169.22it/s]

 44%|████▎     | 1022/2343 [00:06<00:07, 168.34it/s]

 44%|████▍     | 1039/2343 [00:06<00:07, 168.26it/s]

 45%|████▌     | 1057/2343 [00:06<00:07, 168.94it/s]

 46%|████▌     | 1074/2343 [00:06<00:07, 168.02it/s]

 47%|████▋     | 1091/2343 [00:06<00:07, 167.91it/s]

 47%|████▋     | 1108/2343 [00:06<00:07, 167.41it/s]

 48%|████▊     | 1126/2343 [00:06<00:07, 168.33it/s]

 49%|████▉     | 1143/2343 [00:06<00:07, 167.93it/s]

 50%|████▉     | 1161/2343 [00:07<00:06, 169.28it/s]

 50%|█████     | 1178/2343 [00:07<00:07, 163.23it/s]

 51%|█████     | 1196/2343 [00:07<00:06, 165.30it/s]

 52%|█████▏    | 1213/2343 [00:07<00:06, 166.47it/s]

 52%|█████▏    | 1230/2343 [00:07<00:06, 166.93it/s]

 53%|█████▎    | 1247/2343 [00:07<00:06, 167.55it/s]

 54%|█████▍    | 1264/2343 [00:07<00:06, 168.06it/s]

 55%|█████▍    | 1282/2343 [00:07<00:06, 168.73it/s]

 55%|█████▌    | 1300/2343 [00:07<00:06, 169.58it/s]

 56%|█████▌    | 1317/2343 [00:07<00:06, 168.33it/s]

 57%|█████▋    | 1334/2343 [00:08<00:05, 168.72it/s]

 58%|█████▊    | 1352/2343 [00:08<00:05, 169.64it/s]

 58%|█████▊    | 1370/2343 [00:08<00:05, 170.06it/s]

 59%|█████▉    | 1388/2343 [00:08<00:05, 167.57it/s]

 60%|█████▉    | 1405/2343 [00:08<00:05, 166.85it/s]

 61%|██████    | 1422/2343 [00:08<00:05, 167.41it/s]

 61%|██████▏   | 1439/2343 [00:08<00:05, 167.86it/s]

 62%|██████▏   | 1456/2343 [00:08<00:05, 167.44it/s]

 63%|██████▎   | 1473/2343 [00:08<00:05, 166.32it/s]

 64%|██████▎   | 1491/2343 [00:08<00:05, 167.84it/s]

 64%|██████▍   | 1508/2343 [00:09<00:04, 168.04it/s]

 65%|██████▌   | 1525/2343 [00:09<00:04, 168.16it/s]

 66%|██████▌   | 1543/2343 [00:09<00:04, 168.85it/s]

 67%|██████▋   | 1561/2343 [00:09<00:04, 169.48it/s]

 67%|██████▋   | 1579/2343 [00:09<00:04, 170.71it/s]

 68%|██████▊   | 1597/2343 [00:09<00:04, 170.47it/s]

 69%|██████▉   | 1615/2343 [00:09<00:04, 171.30it/s]

 70%|██████▉   | 1633/2343 [00:09<00:04, 171.46it/s]

 70%|███████   | 1651/2343 [00:09<00:04, 171.25it/s]

 71%|███████   | 1669/2343 [00:10<00:03, 171.28it/s]

 72%|███████▏  | 1687/2343 [00:10<00:03, 171.60it/s]

 73%|███████▎  | 1705/2343 [00:10<00:03, 171.21it/s]

 74%|███████▎  | 1723/2343 [00:10<00:03, 170.74it/s]

 74%|███████▍  | 1741/2343 [00:10<00:03, 170.55it/s]

 75%|███████▌  | 1759/2343 [00:10<00:03, 170.89it/s]

 76%|███████▌  | 1777/2343 [00:10<00:03, 170.93it/s]

 77%|███████▋  | 1795/2343 [00:10<00:03, 170.78it/s]

 77%|███████▋  | 1813/2343 [00:10<00:03, 169.87it/s]

 78%|███████▊  | 1830/2343 [00:10<00:03, 168.81it/s]

 79%|███████▉  | 1847/2343 [00:11<00:02, 168.27it/s]

 80%|███████▉  | 1864/2343 [00:11<00:02, 168.15it/s]

 80%|████████  | 1881/2343 [00:11<00:02, 168.65it/s]

 81%|████████  | 1898/2343 [00:11<00:02, 168.68it/s]

 82%|████████▏ | 1915/2343 [00:11<00:02, 166.25it/s]

 82%|████████▏ | 1932/2343 [00:11<00:02, 167.29it/s]

 83%|████████▎ | 1949/2343 [00:11<00:02, 167.29it/s]

 84%|████████▍ | 1966/2343 [00:11<00:02, 165.69it/s]

 85%|████████▍ | 1983/2343 [00:11<00:02, 166.25it/s]

 85%|████████▌ | 2000/2343 [00:11<00:02, 166.53it/s]

 86%|████████▌ | 2017/2343 [00:12<00:02, 158.81it/s]

 87%|████████▋ | 2034/2343 [00:12<00:01, 160.64it/s]

 88%|████████▊ | 2051/2343 [00:12<00:01, 161.33it/s]

 88%|████████▊ | 2068/2343 [00:12<00:01, 162.46it/s]

 89%|████████▉ | 2085/2343 [00:12<00:01, 164.32it/s]

 90%|████████▉ | 2102/2343 [00:12<00:01, 165.89it/s]

 90%|█████████ | 2119/2343 [00:12<00:01, 165.48it/s]

 91%|█████████ | 2136/2343 [00:12<00:01, 166.42it/s]

 92%|█████████▏| 2153/2343 [00:12<00:01, 167.05it/s]

 93%|█████████▎| 2170/2343 [00:13<00:01, 165.84it/s]

 93%|█████████▎| 2187/2343 [00:13<00:00, 166.19it/s]

 94%|█████████▍| 2204/2343 [00:13<00:00, 166.61it/s]

 95%|█████████▍| 2221/2343 [00:13<00:00, 166.98it/s]

 96%|█████████▌| 2238/2343 [00:13<00:00, 167.29it/s]

 96%|█████████▌| 2255/2343 [00:13<00:00, 168.02it/s]

 97%|█████████▋| 2272/2343 [00:13<00:00, 164.72it/s]

 98%|█████████▊| 2289/2343 [00:13<00:00, 162.29it/s]

 98%|█████████▊| 2306/2343 [00:13<00:00, 160.24it/s]

 99%|█████████▉| 2323/2343 [00:13<00:00, 161.75it/s]

100%|█████████▉| 2340/2343 [00:14<00:00, 162.82it/s]

100%|██████████| 2343/2343 [00:14<00:00, 165.07it/s]




Epoch 4 in 14.20 sec


Training set loss 5.901686191558838
Training set mae 0.7473083734512329


  0%|          | 0/2343 [00:00<?, ?it/s]

  0%|          | 4/2343 [00:00<01:00, 38.49it/s]

  1%|          | 21/2343 [00:00<00:20, 113.05it/s]

  2%|▏         | 38/2343 [00:00<00:16, 137.97it/s]

  2%|▏         | 55/2343 [00:00<00:15, 150.12it/s]

  3%|▎         | 73/2343 [00:00<00:14, 158.10it/s]

  4%|▍         | 91/2343 [00:00<00:13, 163.34it/s]

  5%|▍         | 108/2343 [00:00<00:13, 164.20it/s]

  5%|▌         | 126/2343 [00:00<00:13, 168.34it/s]

  6%|▌         | 144/2343 [00:00<00:12, 171.14it/s]

  7%|▋         | 162/2343 [00:01<00:12, 172.79it/s]

  8%|▊         | 180/2343 [00:01<00:12, 170.00it/s]

  8%|▊         | 198/2343 [00:01<00:12, 168.14it/s]

  9%|▉         | 215/2343 [00:01<00:12, 167.50it/s]

 10%|▉         | 233/2343 [00:01<00:12, 168.90it/s]

 11%|█         | 251/2343 [00:01<00:12, 169.46it/s]

 11%|█▏        | 268/2343 [00:01<00:12, 169.48it/s]

 12%|█▏        | 285/2343 [00:01<00:12, 168.42it/s]

 13%|█▎        | 302/2343 [00:01<00:12, 168.46it/s]

 14%|█▎        | 319/2343 [00:01<00:11, 168.88it/s]

 14%|█▍        | 337/2343 [00:02<00:11, 170.30it/s]

 15%|█▌        | 355/2343 [00:02<00:11, 168.37it/s]

 16%|█▌        | 372/2343 [00:02<00:11, 167.75it/s]

 17%|█▋        | 389/2343 [00:02<00:11, 167.45it/s]

 17%|█▋        | 406/2343 [00:02<00:11, 168.14it/s]

 18%|█▊        | 423/2343 [00:02<00:11, 168.50it/s]

 19%|█▉        | 440/2343 [00:02<00:11, 168.83it/s]

 20%|█▉        | 457/2343 [00:02<00:11, 168.83it/s]

 20%|██        | 474/2343 [00:02<00:11, 168.37it/s]

 21%|██        | 491/2343 [00:02<00:11, 168.11it/s]

 22%|██▏       | 508/2343 [00:03<00:10, 168.21it/s]

 22%|██▏       | 525/2343 [00:03<00:10, 167.79it/s]

 23%|██▎       | 542/2343 [00:03<00:10, 167.73it/s]

 24%|██▍       | 560/2343 [00:03<00:10, 168.65it/s]

 25%|██▍       | 578/2343 [00:03<00:10, 169.70it/s]

 25%|██▌       | 596/2343 [00:03<00:10, 170.70it/s]

 26%|██▌       | 614/2343 [00:03<00:10, 168.73it/s]

 27%|██▋       | 631/2343 [00:03<00:10, 159.51it/s]

 28%|██▊       | 648/2343 [00:03<00:10, 161.53it/s]

 28%|██▊       | 665/2343 [00:04<00:10, 156.00it/s]

 29%|██▉       | 681/2343 [00:04<00:10, 155.98it/s]

 30%|██▉       | 699/2343 [00:04<00:10, 160.26it/s]

 31%|███       | 717/2343 [00:04<00:09, 163.70it/s]

 31%|███▏      | 735/2343 [00:04<00:09, 166.60it/s]

 32%|███▏      | 753/2343 [00:04<00:09, 169.07it/s]

 33%|███▎      | 771/2343 [00:04<00:09, 170.55it/s]

 34%|███▎      | 789/2343 [00:04<00:09, 171.03it/s]

 34%|███▍      | 807/2343 [00:04<00:08, 171.56it/s]

 35%|███▌      | 825/2343 [00:04<00:08, 171.01it/s]

 36%|███▌      | 843/2343 [00:05<00:08, 171.16it/s]

 37%|███▋      | 861/2343 [00:05<00:08, 170.63it/s]

 38%|███▊      | 879/2343 [00:05<00:08, 171.04it/s]

 38%|███▊      | 897/2343 [00:05<00:08, 171.71it/s]

 39%|███▉      | 915/2343 [00:05<00:08, 172.31it/s]

 40%|███▉      | 933/2343 [00:05<00:08, 170.70it/s]

 41%|████      | 951/2343 [00:05<00:08, 169.20it/s]

 41%|████▏     | 968/2343 [00:05<00:08, 168.44it/s]

 42%|████▏     | 985/2343 [00:05<00:08, 167.71it/s]

 43%|████▎     | 1003/2343 [00:06<00:07, 168.86it/s]

 44%|████▎     | 1021/2343 [00:06<00:07, 170.01it/s]

 44%|████▍     | 1039/2343 [00:06<00:07, 171.54it/s]

 45%|████▌     | 1057/2343 [00:06<00:07, 172.41it/s]

 46%|████▌     | 1075/2343 [00:06<00:07, 172.84it/s]

 47%|████▋     | 1093/2343 [00:06<00:07, 171.25it/s]

 47%|████▋     | 1111/2343 [00:06<00:07, 169.51it/s]

 48%|████▊     | 1128/2343 [00:06<00:07, 167.78it/s]

 49%|████▉     | 1145/2343 [00:06<00:07, 168.10it/s]

 50%|████▉     | 1163/2343 [00:06<00:06, 169.75it/s]

 50%|█████     | 1181/2343 [00:07<00:06, 170.64it/s]

 51%|█████     | 1199/2343 [00:07<00:06, 171.63it/s]

 52%|█████▏    | 1217/2343 [00:07<00:06, 173.10it/s]

 53%|█████▎    | 1235/2343 [00:07<00:06, 173.62it/s]

 53%|█████▎    | 1253/2343 [00:07<00:06, 173.20it/s]

 54%|█████▍    | 1271/2343 [00:07<00:06, 173.46it/s]

 55%|█████▌    | 1289/2343 [00:07<00:06, 172.80it/s]

 56%|█████▌    | 1307/2343 [00:07<00:08, 122.25it/s]

 57%|█████▋    | 1325/2343 [00:08<00:07, 133.88it/s]

 57%|█████▋    | 1342/2343 [00:08<00:07, 141.55it/s]

 58%|█████▊    | 1359/2343 [00:08<00:06, 146.94it/s]

 59%|█████▊    | 1376/2343 [00:08<00:06, 151.22it/s]

 59%|█████▉    | 1394/2343 [00:08<00:06, 156.85it/s]

 60%|██████    | 1412/2343 [00:08<00:05, 161.16it/s]

 61%|██████    | 1430/2343 [00:08<00:05, 163.90it/s]

 62%|██████▏   | 1447/2343 [00:08<00:05, 163.48it/s]

 62%|██████▏   | 1464/2343 [00:08<00:05, 164.30it/s]

 63%|██████▎   | 1482/2343 [00:08<00:05, 167.12it/s]

 64%|██████▍   | 1499/2343 [00:09<00:05, 162.03it/s]

 65%|██████▍   | 1518/2343 [00:09<00:04, 169.54it/s]

 66%|██████▌   | 1537/2343 [00:09<00:04, 173.89it/s]

 66%|██████▋   | 1555/2343 [00:09<00:04, 174.88it/s]

 67%|██████▋   | 1573/2343 [00:09<00:04, 174.12it/s]

 68%|██████▊   | 1592/2343 [00:09<00:04, 176.40it/s]

 69%|██████▉   | 1611/2343 [00:09<00:04, 178.52it/s]

 70%|██████▉   | 1629/2343 [00:09<00:04, 173.99it/s]

 70%|███████   | 1647/2343 [00:09<00:04, 171.09it/s]

 71%|███████   | 1665/2343 [00:10<00:04, 169.37it/s]

 72%|███████▏  | 1682/2343 [00:10<00:03, 167.99it/s]

 73%|███████▎  | 1699/2343 [00:10<00:03, 166.86it/s]

 73%|███████▎  | 1716/2343 [00:10<00:03, 166.96it/s]

 74%|███████▍  | 1733/2343 [00:10<00:03, 164.20it/s]

 75%|███████▍  | 1750/2343 [00:10<00:03, 165.08it/s]

 75%|███████▌  | 1767/2343 [00:10<00:03, 166.08it/s]

 76%|███████▌  | 1785/2343 [00:10<00:03, 167.39it/s]

 77%|███████▋  | 1802/2343 [00:10<00:03, 167.71it/s]

 78%|███████▊  | 1819/2343 [00:10<00:03, 167.20it/s]

 78%|███████▊  | 1836/2343 [00:11<00:03, 167.46it/s]

 79%|███████▉  | 1853/2343 [00:11<00:02, 166.62it/s]

 80%|███████▉  | 1870/2343 [00:11<00:02, 166.89it/s]

 81%|████████  | 1887/2343 [00:11<00:02, 164.35it/s]

 81%|████████▏ | 1904/2343 [00:11<00:02, 164.80it/s]

 82%|████████▏ | 1921/2343 [00:11<00:02, 165.64it/s]

 83%|████████▎ | 1938/2343 [00:11<00:02, 165.77it/s]

 83%|████████▎ | 1955/2343 [00:11<00:02, 166.50it/s]

 84%|████████▍ | 1972/2343 [00:11<00:02, 167.18it/s]

 85%|████████▍ | 1989/2343 [00:11<00:02, 167.70it/s]

 86%|████████▌ | 2006/2343 [00:12<00:02, 167.29it/s]

 86%|████████▋ | 2023/2343 [00:12<00:01, 165.12it/s]

 87%|████████▋ | 2040/2343 [00:12<00:01, 166.47it/s]

 88%|████████▊ | 2058/2343 [00:12<00:01, 167.61it/s]

 89%|████████▊ | 2076/2343 [00:12<00:01, 168.44it/s]

 89%|████████▉ | 2093/2343 [00:12<00:01, 168.33it/s]

 90%|█████████ | 2110/2343 [00:12<00:01, 167.81it/s]

 91%|█████████ | 2127/2343 [00:12<00:01, 167.69it/s]

 92%|█████████▏| 2144/2343 [00:12<00:01, 167.85it/s]

 92%|█████████▏| 2162/2343 [00:13<00:01, 168.57it/s]

 93%|█████████▎| 2179/2343 [00:13<00:00, 168.80it/s]

 94%|█████████▍| 2197/2343 [00:13<00:00, 169.21it/s]

 94%|█████████▍| 2214/2343 [00:13<00:00, 168.89it/s]

 95%|█████████▌| 2232/2343 [00:13<00:00, 169.85it/s]

 96%|█████████▌| 2250/2343 [00:13<00:00, 170.65it/s]

 97%|█████████▋| 2268/2343 [00:13<00:00, 169.18it/s]

 98%|█████████▊| 2285/2343 [00:13<00:00, 168.79it/s]

 98%|█████████▊| 2302/2343 [00:13<00:00, 168.43it/s]

 99%|█████████▉| 2319/2343 [00:13<00:00, 168.36it/s]

100%|█████████▉| 2336/2343 [00:14<00:00, 160.65it/s]

100%|██████████| 2343/2343 [00:14<00:00, 164.62it/s]




Epoch 5 in 14.24 sec


Training set loss 5.901015281677246
Training set mae 0.7477869391441345


**Generating Predictions on Test Data**

In [10]:
test_loader = DataLoader(test_dataset.tensor_frame, batch_size=BATCH_SIZE,
                         shuffle=False, drop_last=False)
results = []

parallel_predict = jax.pmap(vmap(predict, in_axes=(None, 0)), in_axes=(None, 0))

for tf in tqdm(test_loader):
    x = torch.cat([
        tf.feat_dict[categorical], tf.feat_dict[numerical],
        tf.feat_dict[timestamp].squeeze(1)
    ], dim=1).numpy()
    nan_mask = jnp.isnan(x)
    if nan_mask.any():
        x = jnp.where(nan_mask, fill_vals, x)
    x = (x - means) / stds
    x = jnp.reshape(x, (NUM_DEVICES, len(x) // NUM_DEVICES, NUM_FEAT))
    result = parallel_predict(params, x).reshape(-1) * y_std + y_mean
    results.append(result)

outputs = jnp.concatenate(results)
submission = pd.read_csv("/kaggle/input/playground-series-s4e12/sample_submission.csv")
submission[TARGET_COL] = outputs
submission = submission[['id', TARGET_COL]]
submission.to_csv('/kaggle/working/final_submission.csv', index=False)

  0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 1/1563 [00:00<07:08,  3.64it/s]

  1%|▏         | 21/1563 [00:00<00:22, 70.04it/s]

  3%|▎         | 41/1563 [00:00<00:13, 111.97it/s]

  4%|▍         | 61/1563 [00:00<00:10, 139.29it/s]

  5%|▌         | 81/1563 [00:00<00:09, 157.16it/s]

  6%|▋         | 101/1563 [00:00<00:08, 169.20it/s]

  8%|▊         | 121/1563 [00:00<00:08, 177.89it/s]

  9%|▉         | 141/1563 [00:00<00:07, 183.46it/s]

 10%|█         | 161/1563 [00:01<00:07, 187.81it/s]

 12%|█▏        | 181/1563 [00:01<00:07, 189.71it/s]

 13%|█▎        | 201/1563 [00:01<00:07, 191.74it/s]

 14%|█▍        | 221/1563 [00:01<00:06, 193.78it/s]

 15%|█▌        | 242/1563 [00:01<00:06, 196.92it/s]

 17%|█▋        | 263/1563 [00:01<00:06, 198.28it/s]

 18%|█▊        | 284/1563 [00:01<00:06, 199.99it/s]

 20%|█▉        | 305/1563 [00:01<00:06, 200.31it/s]

 21%|██        | 326/1563 [00:01<00:06, 197.91it/s]

 22%|██▏       | 346/1563 [00:02<00:06, 195.58it/s]

 23%|██▎       | 366/1563 [00:02<00:06, 195.56it/s]

 25%|██▍       | 386/1563 [00:02<00:06, 195.18it/s]

 26%|██▌       | 406/1563 [00:02<00:05, 195.59it/s]

 27%|██▋       | 426/1563 [00:02<00:05, 196.07it/s]

 29%|██▊       | 446/1563 [00:02<00:05, 197.11it/s]

 30%|██▉       | 466/1563 [00:02<00:05, 197.85it/s]

 31%|███       | 486/1563 [00:02<00:05, 198.34it/s]

 32%|███▏      | 506/1563 [00:02<00:05, 198.35it/s]

 34%|███▎      | 526/1563 [00:02<00:05, 197.75it/s]

 35%|███▍      | 547/1563 [00:03<00:05, 198.87it/s]

 36%|███▋      | 567/1563 [00:03<00:05, 198.95it/s]

 38%|███▊      | 588/1563 [00:03<00:04, 200.35it/s]

 39%|███▉      | 609/1563 [00:03<00:04, 200.27it/s]

 40%|████      | 630/1563 [00:03<00:04, 200.68it/s]

 42%|████▏     | 651/1563 [00:03<00:04, 201.19it/s]

 43%|████▎     | 672/1563 [00:03<00:04, 201.69it/s]

 44%|████▍     | 693/1563 [00:03<00:04, 201.59it/s]

 46%|████▌     | 714/1563 [00:03<00:04, 201.50it/s]

 47%|████▋     | 735/1563 [00:03<00:04, 200.72it/s]

 48%|████▊     | 756/1563 [00:04<00:04, 200.92it/s]

 50%|████▉     | 777/1563 [00:04<00:03, 202.00it/s]

 51%|█████     | 798/1563 [00:04<00:03, 200.86it/s]

 52%|█████▏    | 819/1563 [00:04<00:03, 199.63it/s]

 54%|█████▎    | 839/1563 [00:04<00:03, 197.97it/s]

 55%|█████▍    | 859/1563 [00:04<00:03, 196.06it/s]

 56%|█████▌    | 879/1563 [00:04<00:03, 194.65it/s]

 58%|█████▊    | 899/1563 [00:04<00:03, 193.57it/s]

 59%|█████▉    | 919/1563 [00:04<00:03, 192.17it/s]

 60%|██████    | 939/1563 [00:05<00:03, 191.73it/s]

 61%|██████▏   | 959/1563 [00:05<00:03, 191.31it/s]

 63%|██████▎   | 979/1563 [00:05<00:03, 180.73it/s]

 64%|██████▍   | 999/1563 [00:05<00:03, 184.19it/s]

 65%|██████▌   | 1019/1563 [00:05<00:02, 186.75it/s]

 66%|██████▋   | 1039/1563 [00:05<00:02, 189.11it/s]

 68%|██████▊   | 1059/1563 [00:05<00:02, 190.92it/s]

 69%|██████▉   | 1079/1563 [00:05<00:02, 192.37it/s]

 70%|███████   | 1099/1563 [00:05<00:02, 194.34it/s]

 72%|███████▏  | 1120/1563 [00:05<00:02, 198.45it/s]

 73%|███████▎  | 1140/1563 [00:06<00:02, 197.81it/s]

 74%|███████▍  | 1160/1563 [00:06<00:02, 197.27it/s]

 75%|███████▌  | 1180/1563 [00:06<00:01, 197.91it/s]

 77%|███████▋  | 1200/1563 [00:06<00:01, 198.39it/s]

 78%|███████▊  | 1221/1563 [00:06<00:01, 199.53it/s]

 79%|███████▉  | 1241/1563 [00:06<00:01, 199.63it/s]

 81%|████████  | 1262/1563 [00:06<00:01, 200.38it/s]

 82%|████████▏ | 1283/1563 [00:06<00:01, 200.51it/s]

 83%|████████▎ | 1304/1563 [00:06<00:01, 199.97it/s]

 85%|████████▍ | 1324/1563 [00:06<00:01, 197.59it/s]

 86%|████████▌ | 1344/1563 [00:07<00:01, 197.72it/s]

 87%|████████▋ | 1364/1563 [00:07<00:01, 197.98it/s]

 89%|████████▊ | 1384/1563 [00:07<00:00, 198.29it/s]

 90%|████████▉ | 1404/1563 [00:07<00:00, 198.26it/s]

 91%|█████████ | 1424/1563 [00:07<00:00, 198.16it/s]

 92%|█████████▏| 1444/1563 [00:07<00:00, 197.50it/s]

 94%|█████████▎| 1464/1563 [00:07<00:00, 196.14it/s]

 95%|█████████▍| 1484/1563 [00:07<00:00, 196.40it/s]

 96%|█████████▌| 1504/1563 [00:07<00:00, 195.74it/s]

 98%|█████████▊| 1524/1563 [00:08<00:00, 195.52it/s]

 99%|█████████▉| 1544/1563 [00:08<00:00, 195.80it/s]

100%|██████████| 1563/1563 [00:08<00:00, 180.67it/s]


