In [1]:
import pandas as pd
from fastai.tabular.all import *

In [2]:
train_df = pd.read_csv('train.csv', index_col=0)
test_df = pd.read_csv('test.csv', index_col=0)

drop_columns = ['koi_disposition', 'koi_pdisposition', 'kepid', 'kepoi_name', 'kepler_name', 'koi_tce_delivname']

train_df = train_df.drop(columns=drop_columns)
test_df_Y = torch.tensor(test_df['koi_score'].to_numpy().reshape(-1, 1), dtype=torch.float32)
test_df_X = test_df.drop(columns=drop_columns+['koi_score'])

In [3]:
train_df

Unnamed: 0_level_0,koi_score,koi_fpflag_nt,koi_fpflag_ss,koi_fpflag_co,koi_fpflag_ec,koi_period,koi_period_err1,koi_period_err2,koi_time0bk,koi_time0bk_err1,...,koi_steff_err2,koi_slogg,koi_slogg_err1,koi_slogg_err2,koi_srad,koi_srad_err1,koi_srad_err2,ra,dec,koi_kepmag
rowid,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
7730,0.000,0,1,0,0,31.005393,0.000002,-0.000002,150.861652,0.000058,...,-151.0,3.900,0.700,-0.300,1.784,1.027,-1.129,284.04434,39.661469,14.991
4272,0.000,0,1,0,1,2.161327,0.000016,-0.000016,131.902360,0.005950,...,-67.0,4.818,0.060,-0.040,0.430,0.041,-0.050,291.54214,48.991192,14.751
5313,0.998,0,0,0,0,34.211502,0.000300,-0.000300,154.575080,0.007850,...,-100.0,4.610,0.018,-0.046,0.721,0.047,-0.029,286.66296,47.413700,14.017
9068,0.000,0,0,1,1,0.933747,0.000015,-0.000015,131.526700,0.015100,...,-235.0,4.528,0.050,-0.200,0.887,0.272,-0.091,292.14102,47.626240,15.024
674,1.000,0,0,0,0,4.807103,0.000006,-0.000006,134.640371,0.000956,...,-74.0,4.338,0.138,-0.103,1.080,0.165,-0.150,291.39050,42.180592,14.466
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
5285,0.000,0,1,1,0,3.204223,0.000006,-0.000006,133.822230,0.001680,...,-138.0,4.609,0.033,-0.077,0.741,0.096,-0.059,291.04214,51.026321,15.341
9295,0.323,0,0,0,0,433.220461,0.005076,-0.005076,271.042770,0.009770,...,-156.0,4.519,0.096,-0.168,0.777,0.126,-0.094,294.62714,41.933182,14.546
3855,0.000,1,0,0,0,319.306330,0.015710,-0.015710,221.310600,0.034400,...,-214.0,4.459,0.056,-0.224,1.014,0.341,-0.114,292.33884,38.844910,13.327
7903,0.000,0,0,1,1,2.404401,0.000031,-0.000031,133.197500,0.011800,...,-86.0,4.550,0.055,-0.055,0.769,0.052,-0.052,297.11740,39.859112,12.632


In [4]:
splits = RandomSplitter(valid_pct=0.2)(range_of(train_df))

In [5]:
to = TabularPandas(train_df, procs=[Categorify],
                   cat_names = ['koi_fpflag_nt', 'koi_fpflag_ss', 'koi_fpflag_co', 'koi_fpflag_ec'],
                   cont_names = [
                      'koi_period',
                      'koi_period_err1',
                      'koi_period_err2',
                      'koi_time0bk',
                      'koi_time0bk_err1',
                      'koi_time0bk_err2',
                      'koi_impact',
                      'koi_impact_err1',
                      'koi_impact_err2',
                      'koi_duration',
                      'koi_duration_err1',
                      'koi_duration_err2',
                      'koi_depth',
                      'koi_depth_err1',
                      'koi_depth_err2',
                      'koi_prad',
                      'koi_prad_err1',
                      'koi_prad_err2',
                      'koi_teq',
                      'koi_insol',
                      'koi_insol_err1',
                      'koi_insol_err2',
                      'koi_model_snr',
                      'koi_tce_plnt_num',
                      'koi_steff',
                      'koi_steff_err1',
                      'koi_steff_err2',
                      'koi_slogg',
                      'koi_slogg_err1',
                      'koi_slogg_err2',
                      'koi_srad',
                      'koi_srad_err1',
                      'koi_srad_err2',
                      'ra',
                      'dec',
                      'koi_kepmag',
                   ],
                   y_names='koi_score',
                   splits=splits)

In [6]:
dls = to.dataloaders(bs=64)

In [7]:
dls.show_batch()

Unnamed: 0,koi_fpflag_nt,koi_fpflag_ss,koi_fpflag_co,koi_fpflag_ec,koi_period,koi_period_err1,koi_period_err2,koi_time0bk,koi_time0bk_err1,koi_time0bk_err2,koi_impact,koi_impact_err1,koi_impact_err2,koi_duration,koi_duration_err1,koi_duration_err2,koi_depth,koi_depth_err1,koi_depth_err2,koi_prad,koi_prad_err1,koi_prad_err2,koi_teq,koi_insol,koi_insol_err1,koi_insol_err2,koi_model_snr,koi_tce_plnt_num,koi_steff,koi_steff_err1,koi_steff_err2,koi_slogg,koi_slogg_err1,koi_slogg_err2,koi_srad,koi_srad_err1,koi_srad_err2,ra,dec,koi_kepmag,koi_score
0,0,1,0,0,11.521446,2e-06,-2e-06,170.839691,0.000131,-0.000131,2.483,2.851,-0.673,3.6399,0.0114,-0.0114,17984.0,31.9,-31.9,150.509995,39.759998,-13.31,753.0,75.879997,58.889999,-19.99,622.099976,1.0,5795.0,155.0,-172.0,4.554,0.033,-0.176,0.848,0.224,-0.075,297.079926,47.597401,15.472,0.0
1,0,0,0,0,30.183702,0.000115,-0.000115,146.646744,0.00311,-0.00311,0.011,0.392,-0.011,5.879,0.103,-0.103,974.900024,28.9,-28.9,3.65,0.57,-0.62,634.0,38.200001,17.02,-13.79,37.900002,2.0,5733.0,115.0,-104.0,4.281,0.143,-0.104,1.183,0.184,-0.202,294.316986,43.629341,15.505,1.0
2,0,0,0,0,4.61225,3.4e-05,-3.4e-05,133.861343,0.00535,-0.00535,0.012,0.44,-0.012,2.695,0.167,-0.167,29.1,2.4,-2.4,0.51,0.06,-0.05,1032.0,268.269989,86.529999,-58.849998,13.7,1.0,5495.0,74.0,-82.0,4.442,0.072,-0.099,0.961,0.115,-0.086,297.148834,46.296951,11.83,0.856
3,0,0,1,0,9.402233,0.00011,-0.00011,134.887604,0.0102,-0.0102,0.019,0.45,-0.019,2.79,0.397,-0.397,316.0,40.200001,-40.200001,23.610001,5.22,-4.69,2309.0,6719.569824,4121.830078,-2772.790039,9.8,1.0,4597.0,92.0,-101.0,2.42,0.033,-0.03,13.874,3.064,-2.758,293.279785,39.393318,10.59,0.0
4,0,0,0,0,3.26176,1e-06,-1e-06,131.778854,0.000322,-0.000322,0.949,0.035,-0.026,1.7593,0.0301,-0.0301,3876.100098,32.200001,-32.200001,8.01,2.66,-0.89,1262.0,599.539978,605.460022,-194.110001,190.399994,1.0,6131.0,191.0,-234.0,4.497,0.054,-0.229,0.945,0.315,-0.105,298.316742,40.996101,15.28,1.0
5,0,1,0,0,38.476746,1e-06,-1e-06,150.574219,3.1e-05,-3.1e-05,0.458,0.005,0.0,5.10744,0.00112,-0.00112,173170.0,57.799999,-57.799999,33.799999,4.36,-3.94,463.0,10.85,4.72,-3.26,3487.800049,1.0,5294.0,159.0,-143.0,4.507,0.095,-0.085,0.79,0.102,-0.092,283.201172,43.888981,14.579,0.0
6,0,0,0,1,1.804551,1.7e-05,-1.7e-05,132.426193,0.01,-0.01,0.348,0.101,-0.348,2.166,0.193,-0.193,39.700001,3.7,-3.7,0.65,0.2,-0.09,1577.0,1457.849976,1304.660034,-482.019989,12.5,1.0,6006.0,162.0,-180.0,4.41,0.09,-0.195,1.019,0.311,-0.133,292.314758,41.582321,13.73,0.0
7,0,0,1,1,3.556428,0.0001,-0.0001,133.537704,0.0263,-0.0263,0.359,0.103,-0.359,6.998,0.917,-0.917,131.399994,12.8,-12.8,0.8,0.12,-0.06,972.0,211.410004,110.790001,-46.740002,12.3,1.0,5404.0,178.0,-146.0,4.638,0.032,-0.097,0.694,0.109,-0.047,297.197388,46.22826,15.927,0.0
8,0,0,0,0,8.703121,8.5e-05,-8.5e-05,140.155304,0.00824,-0.00824,0.886,0.077,-0.615,4.807,0.237,-0.237,32.0,2.7,-2.7,0.72,0.11,-0.05,945.0,188.699997,79.68,-34.52,13.0,3.0,6006.0,72.0,-84.0,4.403,0.054,-0.117,1.082,0.178,-0.076,281.112335,43.227791,11.973,1.0
9,0,0,0,0,33.136288,0.000239,-0.000239,169.50676,0.00583,-0.00583,0.791,0.147,-0.538,4.163,0.18,-0.18,825.200012,54.0,-54.0,2.05,0.08,-0.1,357.0,3.86,0.67,-0.65,16.299999,3.0,4245.0,85.0,-85.0,4.641,0.027,-0.02,0.637,0.026,-0.032,300.242218,44.141399,15.673,0.993


In [8]:
learn = tabular_learner(dls, metrics=accuracy)

In [9]:
learn.fit(1)

epoch,train_loss,valid_loss,accuracy,time
0,0.219653,0.182683,0.42906,00:02


In [10]:
learn.show_results()

Unnamed: 0,koi_fpflag_nt,koi_fpflag_ss,koi_fpflag_co,koi_fpflag_ec,koi_period,koi_period_err1,koi_period_err2,koi_time0bk,koi_time0bk_err1,koi_time0bk_err2,koi_impact,koi_impact_err1,koi_impact_err2,koi_duration,koi_duration_err1,koi_duration_err2,koi_depth,koi_depth_err1,koi_depth_err2,koi_prad,koi_prad_err1,koi_prad_err2,koi_teq,koi_insol,koi_insol_err1,koi_insol_err2,koi_model_snr,koi_tce_plnt_num,koi_steff,koi_steff_err1,koi_steff_err2,koi_slogg,koi_slogg_err1,koi_slogg_err2,koi_srad,koi_srad_err1,koi_srad_err2,ra,dec,koi_kepmag,koi_score,koi_score_pred
0,1.0,1.0,1.0,1.0,92.729721,0.000398,-0.000398,177.43808,0.00347,-0.00347,0.857,0.056,-0.597,6.83,0.153,-0.153,1257.5,31.9,-31.9,2.81,0.2,-0.08,276.0,1.38,0.32,-0.19,43.299999,1.0,4580.0,92.0,-92.0,4.641,0.012,-0.045,0.675,0.046,-0.021,285.253662,46.582191,14.897,0.998,0.67637
1,1.0,1.0,1.0,1.0,28.505722,0.0002837,-0.0002837,132.920746,0.00803,-0.00803,0.695,0.03,-0.506,6.324,0.227,-0.227,755.799988,40.5,-40.5,2.55,0.75,-0.26,555.0,22.43,20.02,-6.28,20.799999,1.0,5712.0,169.0,-169.0,4.534,0.048,-0.192,0.87,0.259,-0.086,292.232422,38.89222,15.778,1.0,0.462324
2,1.0,1.0,1.0,1.0,13.639665,0.0002304,-0.0002304,141.730103,0.0117,-0.0117,0.105,0.344,-0.105,5.425,0.3,-0.3,162.699997,13.2,-13.2,1.1,0.11,-0.11,717.0,62.599998,19.52,-15.52,13.3,3.0,5600.0,112.0,-100.0,4.455,0.102,-0.077,0.872,0.088,-0.088,284.677277,44.79768,15.204,0.992,0.799688
3,2.0,1.0,2.0,2.0,0.566796,5.428e-06,-5.428e-06,131.814804,0.0105,-0.0105,0.653,0.035,-0.518,2.79,0.691,-0.691,18.9,2.2,-2.2,0.64,0.13,-0.16,2755.0,13656.589844,8619.969727,-6906.060059,12.4,1.0,6130.0,167.0,-167.0,4.151,0.21,-0.123,1.393,0.293,-0.358,291.657593,42.729649,13.379,0.0,-0.32866
4,1.0,2.0,1.0,1.0,24.061274,4.604e-05,-4.604e-05,138.387466,0.00176,-0.00176,0.461,0.007,-0.006,19.9438,0.0236,-0.0236,440150.0,467.0,-467.0,83.769997,24.290001,-13.1,749.0,74.370003,63.84,-29.18,1161.099976,1.0,6461.0,181.0,-250.0,4.338,0.108,-0.186,1.138,0.33,-0.178,295.366028,46.377232,14.158,0.0,-0.444665
5,2.0,1.0,1.0,1.0,0.638438,5.25e-07,-5.25e-07,131.592545,0.000708,-0.000708,0.904,0.039,-0.641,0.3028,0.0501,-0.0501,286.799988,13.8,-13.8,2.48,0.24,-0.14,1709.0,2024.339966,711.27002,-442.0,25.799999,1.0,5073.0,149.0,-164.0,4.571,0.02,-0.08,0.824,0.08,-0.046,294.639893,50.33419,15.23,0.0,0.327605
6,2.0,1.0,2.0,1.0,1.126798,1.27e-05,-1.27e-05,132.275696,0.0111,-0.0111,1.215,0.183,-0.118,12.079,0.923,-0.923,227.199997,13.3,-13.3,21.07,5.52,-1.83,1590.0,1514.630005,1198.380005,-409.76001,21.1,1.0,5662.0,169.0,-186.0,4.56,0.033,-0.176,0.84,0.22,-0.073,290.502411,44.96817,14.618,0.0,-0.222654
7,1.0,2.0,2.0,2.0,0.745087,2.3e-06,-2.3e-06,131.826553,0.00294,-0.00294,0.626,0.335,-0.625,4.29,1.7,-1.7,270.0,31.799999,-31.799999,1.29,0.2,-0.1,1654.0,1775.22998,892.859985,-432.660004,61.400002,1.0,5325.0,143.0,-159.0,4.607,0.032,-0.104,0.75,0.122,-0.057,289.169861,48.06678,15.489,0.0,-0.260683
8,1.0,2.0,1.0,1.0,1.085965,5.94e-07,-5.94e-07,132.006607,0.000415,-0.000415,0.916,0.003,-0.002,10.716,0.267,-0.267,10579.0,53.400002,-53.400002,29.66,0.0,0.0,6297.0,371873.875,0.0,0.0,675.900024,1.0,15896.0,0.0,0.0,4.152,0.0,0.0,2.404,0.0,0.0,291.422485,38.535229,15.799,0.0,-0.90611


In [14]:
dl = learn.dls.test_dl(test_df)
preds, targs = learn.get_preds(dl=dl)
test_close(targs, test_df_Y,  eps=1e-3)