In [1]:
## Install pytorch tabular
!pip install --quiet -r /content/Plato_Neural_Net_requirements.txt

[K     |████████████████████████████████| 2.0 MB 4.3 MB/s 
[K     |████████████████████████████████| 80 kB 8.5 MB/s 
[K     |████████████████████████████████| 809 kB 19.4 MB/s 
[K     |████████████████████████████████| 74 kB 3.3 MB/s 
[K     |████████████████████████████████| 329 kB 53.3 MB/s 
[K     |████████████████████████████████| 596 kB 48.5 MB/s 
[K     |████████████████████████████████| 1.8 MB 36.0 MB/s 
[K     |████████████████████████████████| 13.2 MB 154 kB/s 
[K     |████████████████████████████████| 829 kB 48.2 MB/s 
[K     |████████████████████████████████| 132 kB 48.4 MB/s 
[K     |████████████████████████████████| 636 kB 51.7 MB/s 
[K     |████████████████████████████████| 76 kB 4.6 MB/s 
[K     |████████████████████████████████| 180 kB 50.2 MB/s 
[K     |████████████████████████████████| 97 kB 6.5 MB/s 
[K     |████████████████████████████████| 139 kB 47.2 MB/s 
[K     |████████████████████████████████| 1.1 MB 37.9 MB/s 
[K     |████████████████████████

In [2]:
# Importing Libraries
import pandas as pd
import numpy as np
import torch
import random
from pytorch_tabular import TabularModel
from pytorch_tabular.models import CategoryEmbeddingModelConfig
from pytorch_tabular.config import DataConfig, OptimizerConfig, TrainerConfig, ExperimentConfig
from sklearn.model_selection import StratifiedKFold
import warnings
warnings.filterwarnings('ignore')

  import pandas.util.testing as tm


In [None]:
seed = 100100 # for reproductibility
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)

if torch.cuda.is_available(): 
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [None]:
# Loading files
df = pd.read_csv('radiant_pixels.csv')
sample_submission = pd.read_csv('SampleSubmission.csv')
train = df[df.crop_type != 0]
test = df[df.crop_type == 0]
train.head()

Unnamed: 0,Field ID,crop_type,0_B01,0_B02,0_B03,0_B04,0_B05,0_B06,0_B07,0_B08,0_B8A,0_B09,0_B11,0_B12,0_CLM,1_B01,1_B02,1_B03,1_B04,1_B05,1_B06,1_B07,1_B08,1_B8A,1_B09,1_B11,1_B12,1_CLM,2_B01,2_B02,2_B03,2_B04,2_B05,2_B06,2_B07,2_B08,2_B8A,2_B09,2_B11,2_B12,...,21_CLM,22_B01,22_B02,22_B03,22_B04,22_B05,22_B06,22_B07,22_B08,22_B8A,22_B09,22_B11,22_B12,22_CLM,23_B01,23_B02,23_B03,23_B04,23_B05,23_B06,23_B07,23_B08,23_B8A,23_B09,23_B11,23_B12,23_CLM,24_B01,24_B02,24_B03,24_B04,24_B05,24_B06,24_B07,24_B08,24_B8A,24_B09,24_B11,24_B12,24_CLM
0,1,4,17.972797,21.714743,28.888264,35.199116,43.310113,51.914518,56.716495,61.86545,62.589646,63.789743,67.713135,45.579791,251.232395,12.392508,14.695611,22.883006,30.30164,39.662074,49.342428,55.085064,58.148971,61.867042,62.498923,72.701334,48.374711,0.0,10.71492,13.251994,21.380498,28.931238,37.393987,46.389936,51.411559,56.495257,58.435466,59.088039,72.932122,48.82037,...,0.0,13.942765,18.017363,27.907814,35.66865,45.651415,57.813135,63.208794,67.756109,68.415048,69.451206,84.713441,62.890466,0.0,14.528489,18.420048,28.003312,36.34828,46.310354,59.610096,65.381383,70.352074,71.209968,72.308199,88.385354,64.010981,0.0,11.829068,16.377974,27.102444,35.009711,44.464051,58.472074,64.281559,70.347299,69.858666,70.148344,87.061109,64.93254,0.0
1,2,7,57.678261,52.814348,59.959565,73.303043,81.438261,87.894783,94.265217,99.215217,102.787826,116.69087,96.460435,71.968696,252.45,8.737826,20.316522,31.723043,46.443913,54.148696,61.853478,67.621304,73.173913,77.004783,64.651304,93.877826,65.813478,0.0,23.501739,24.061304,34.047391,46.228696,52.685217,59.615217,65.555217,72.614348,74.077826,86.732609,95.90087,67.32,...,0.0,19.498696,20.703913,30.646957,45.367826,51.049565,54.406957,58.797391,65.296957,65.856522,71.753478,98.096087,78.769565,0.0,18.766957,19.843043,29.57087,44.636087,50.834348,54.062609,58.797391,63.919565,65.727391,69.859565,101.410435,80.448261,0.0,21.435652,22.296522,32.110435,45.927391,50.963478,54.493043,57.635217,64.263913,64.737391,68.18087,101.711739,84.580435,0.0
2,3,6,25.681465,29.68473,44.715424,63.017699,72.983869,79.479949,86.882044,93.826041,98.857481,99.869113,133.503663,102.065437,0.0,9.861825,13.203393,22.84126,35.265887,43.02937,48.549447,55.44509,62.144769,68.589949,70.273458,109.919267,77.561028,0.0,13.820553,17.696568,26.936144,39.05919,45.436928,50.46964,56.623419,64.528149,67.995694,69.945154,108.82874,77.136015,...,0.0,23.139023,32.865964,48.134614,69.020051,77.792622,82.560656,87.900039,95.961285,96.872391,97.077262,115.557686,86.019293,0.0,23.416427,33.030116,49.284949,71.622301,81.350514,85.827147,91.365039,98.143612,100.460823,100.034537,123.538766,90.875129,0.0,26.002134,35.575103,51.581799,70.70865,78.400874,81.694087,86.362866,93.043458,93.674614,94.127622,124.671285,95.122712,0.0
3,4,8,30.585099,35.220397,42.79947,50.273642,55.249868,57.452781,60.219536,67.103642,65.602252,84.73351,79.914636,61.629139,252.45,25.694106,32.689669,45.389205,60.449007,67.785497,71.430795,77.010199,82.156887,86.366026,86.733179,127.703444,110.257152,0.0,18.541192,27.962583,39.757351,53.623907,59.288543,63.569801,68.913179,76.170993,78.144437,75.200662,116.918344,99.072119,...,0.0,14.273046,21.24894,31.011258,43.002715,51.07351,57.819934,63.064967,69.824503,72.197881,69.011523,95.67596,74.944967,0.0,18.836225,28.618212,40.858808,57.944503,65.51702,72.21755,78.105099,87.264238,88.136225,82.851854,113.31894,86.103775,0.0,19.236159,28.428079,40.432649,55.171192,61.917616,69.378675,75.069536,84.064768,83.992649,78.727947,112.230596,87.828079,0.0
4,6,4,17.8425,18.97875,29.694375,37.36125,49.449375,65.671875,74.075625,79.936875,83.16,83.32875,90.19125,59.76,0.0,15.19875,18.241875,28.85625,36.264375,47.8125,60.53625,67.68,72.545625,76.179375,75.29625,89.251875,60.2775,0.0,13.545,16.700625,27.039375,34.965,45.208125,57.211875,63.534375,70.554375,72.0225,71.375625,86.445,58.089375,...,0.0,14.551875,15.946875,27.275625,27.41625,44.58375,72.849375,80.87625,86.8275,87.8175,88.475625,76.708125,48.414375,0.0,13.5675,14.259375,25.48125,25.245,43.250625,75.38625,85.275,91.9125,92.2725,92.5425,76.303125,47.1375,0.0,13.03875,15.024375,27.8775,29.37375,46.485,76.629375,86.52375,94.809375,94.815,93.67875,83.221875,52.2675,0.0


In [None]:
# Setting parameters
batch_size = 2048
max_epochs = 300
lr = 1e-3
layers = '2048-1024-512'
target = ['crop_type']
continuous_cols = list(train.drop(['Field ID', 'crop_type'],1).columns)
categorical_cols = []
gpus = 1

def neural_trainer(X_train, X_valid, X_test):
  # Configure pytorch tabular
  data_config = DataConfig(target=target, continuous_cols=continuous_cols, categorical_cols=categorical_cols, normalize_continuous_features = True)
  trainer_config = TrainerConfig(auto_lr_find=True, profiler = 'simple',batch_size= batch_size, max_epochs = max_epochs, gpus = gpus, early_stopping_patience = 25)
  optimizer_config = OptimizerConfig()
  model_config = CategoryEmbeddingModelConfig(task="classification", layers=layers, activation="LeakyReLU",learning_rate = lr)
  model = TabularModel(data_config=data_config, trainer_config=trainer_config, optimizer_config=optimizer_config, model_config=model_config)

  # Fit model
  model.fit(X_train, X_valid)

  #Make predictions
  return model.predict(X_test).drop(continuous_cols+['crop_type', 'prediction'], 1).to_numpy()

In [None]:
# Test data
X_test = test.drop('Field ID',1)
X_test['crop_type'] = 8

# Target 
crop_type = train.crop_type

# Cross validation training
oof_predictions = []
for train_index, test_index in StratifiedKFold(n_splits=10, random_state=seed, shuffle=True).split(train.drop('Field ID', 1), crop_type):
  X_train, X_valid, _, _ = train.drop('Field ID', 1).iloc[train_index], train.drop('Field ID', 1).iloc[test_index], crop_type[train_index], crop_type[test_index]

  oof_predictions.append(neural_trainer(X_train, X_valid, X_test))

# Submission file preparation
submission_file = pd.DataFrame({'Field ID': test['Field ID']})
for i, j in enumerate(sample_submission.columns[1:]):
  submission_file[j] = np.mean(oof_predictions, 0)[:, i]

submission_file.to_csv('pytorch_tabular.csv', index = False)
submission_file.head()

Unnamed: 0,Field ID,Crop_Lucerne/Medics,Crop_Planted pastures (perennial),Crop_Fallow,Crop_Wine grapes,Crop_Weeds,Crop_Small grain grazing,Crop_Wheat,Crop_Canola,Crop_Rooibos
87113,5,4e-06,3.5e-05,1.330641e-08,6e-06,2.74722e-08,0.004926,0.991181,0.003848,1.341721e-09
87114,10,0.225719,0.625817,0.01671481,0.001045,0.007133731,0.105978,0.012205,0.004886,0.0005009401
87115,11,0.145782,0.437603,0.003529492,0.000327,0.002854581,0.381725,0.013288,0.014798,9.256932e-05
87116,17,0.000201,0.008122,0.3868369,7.7e-05,0.3775796,0.21158,0.000638,1.8e-05,0.01494759
87117,18,0.001086,0.24917,0.09511567,0.00962,0.6383551,0.006176,8.9e-05,1.1e-05,0.0003764822
