In [1]:
import os

In [2]:
os.makedirs(os.path.join('..','data'), exist_ok=True)
data_file = os.path.join('..','data','house_tiny.csv')

In [3]:
with open(data_file, 'w') as f:
    f.write('''NumRooms,RoofType,Price
    NA,NA,127500
    2,NA,106000
    4,Slate,178100
    NA,NA,140000''')

In [4]:
import pandas as pd

In [5]:
data = pd.read_csv(data_file)
print(data)

  NumRooms RoofType   Price
0       NA      NaN  127500
1        2      NaN  106000
2        4    Slate  178100
3       NA      NaN  140000


## Data Preparation

In [18]:
inputs, targets = data.iloc[:, 0:2], data.iloc[:, 2]
inputs = pd.get_dummies(inputs, dummy_na=True)
print(inputs)

   NumRooms_    2  NumRooms_    4  NumRooms_    NA  NumRooms_nan  \
0           False           False             True         False   
1            True           False            False         False   
2           False            True            False         False   
3           False           False             True         False   

   RoofType_Slate  RoofType_nan  
0           False          True  
1           False          True  
2            True         False  
3           False          True  


In [25]:
inputs = inputs.fillna(inputs.mean())
print(inputs)

   NumRooms_    2  NumRooms_    4  NumRooms_    NA  NumRooms_nan  \
0           False           False             True         False   
1            True           False            False         False   
2           False            True            False         False   
3           False           False             True         False   

   RoofType_Slate  RoofType_nan  
0           False          True  
1           False          True  
2            True         False  
3           False          True  


## Conversion to tensor format

In [19]:
from jax import numpy as jnp

In [26]:
X = jnp.array(inputs.to_numpy(dtype=float))
y = jnp.array(targets.to_numpy(dtype=float))
X, y

(Array([[0., 0., 1., 0., 0., 1.],
        [1., 0., 0., 0., 0., 1.],
        [0., 1., 0., 0., 1., 0.],
        [0., 0., 1., 0., 0., 1.]], dtype=float32),
 Array([127500., 106000., 178100., 140000.], dtype=float32))