<a href="https://colab.research.google.com/github/vishal-burman/PyTorch-Architectures/blob/master/modeling_TabTransformer/test_sample_TabTransformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
! git clone https://github.com/vishal-burman/PyTorch-Architectures.git
%cd PyTorch-Architectures/modeling_TabTransformer/

In [None]:
! wget https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data

In [19]:
import pandas as pd
from sklearn.preprocessing import LabelEncoder

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

In [4]:
dataset = pd.read_csv('adult.data', names=['age', 'workclass', 'fnlwgt', 'education', 'education-num', 'marital-status', 'occupation', 'relationship', 'race', 'sex', 'capital-gain', 'capital-loss', 'hours-per-week', 'native-country', 'target'])

In [5]:
dataset.head()

Unnamed: 0,age,workclass,fnlwgt,education,education-num,marital-status,occupation,relationship,race,sex,capital-gain,capital-loss,hours-per-week,native-country,target
0,39,State-gov,77516,Bachelors,13,Never-married,Adm-clerical,Not-in-family,White,Male,2174,0,40,United-States,<=50K
1,50,Self-emp-not-inc,83311,Bachelors,13,Married-civ-spouse,Exec-managerial,Husband,White,Male,0,0,13,United-States,<=50K
2,38,Private,215646,HS-grad,9,Divorced,Handlers-cleaners,Not-in-family,White,Male,0,0,40,United-States,<=50K
3,53,Private,234721,11th,7,Married-civ-spouse,Handlers-cleaners,Husband,Black,Male,0,0,40,United-States,<=50K
4,28,Private,338409,Bachelors,13,Married-civ-spouse,Prof-specialty,Wife,Black,Female,0,0,40,Cuba,<=50K


In [6]:
cont_classes = ['age', 'fnlwgt', 'education-num', 'capital-gain', 'capital-loss', 'hours-per-week']
cat_classes = ['workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race', 'sex', 'native-country']


def transform_df(list_col):
  for col in list_col:
    le = LabelEncoder()
    sample = dataset[col]
    le.fit(sample)
    dataset[col] = le.transform(dataset[col])

transform_df(cat_classes)

In [7]:
dataset.head()

Unnamed: 0,age,workclass,fnlwgt,education,education-num,marital-status,occupation,relationship,race,sex,capital-gain,capital-loss,hours-per-week,native-country,target
0,39,7,77516,9,13,4,1,1,4,1,2174,0,40,39,<=50K
1,50,6,83311,9,13,2,4,0,4,1,0,0,13,39,<=50K
2,38,4,215646,11,9,0,6,1,4,1,0,0,40,39,<=50K
3,53,4,234721,1,7,2,6,0,2,1,0,0,40,39,<=50K
4,28,4,338409,9,13,2,10,5,2,0,0,0,40,5,<=50K


In [8]:
dataset['target'] = dataset['target'].apply(lambda x: 0 if x == " <=50K" else 1)
dataset.head()

Unnamed: 0,age,workclass,fnlwgt,education,education-num,marital-status,occupation,relationship,race,sex,capital-gain,capital-loss,hours-per-week,native-country,target
0,39,7,77516,9,13,4,1,1,4,1,2174,0,40,39,0
1,50,6,83311,9,13,2,4,0,4,1,0,0,13,39,0
2,38,4,215646,11,9,0,6,1,4,1,0,0,40,39,0
3,53,4,234721,1,7,2,6,0,2,1,0,0,40,39,0
4,28,4,338409,9,13,2,10,5,2,0,0,0,40,5,0


In [25]:
list_columns = dataset.columns.tolist()
cat_classes_index = []
cont_classes_index = []

for col in cat_classes:
  cat_classes_index.append(list_columns.index(col))

for col in cont_classes:
  cont_classes_index.append(list_columns.index(col))

cat_dataset = dataset.iloc[:, cat_classes_index]
cont_dataset = dataset.iloc[:, cont_classes_index]

# Convert to numpy arrays
cat_dataset_numpy = cat_dataset.to_numpy()
cont_dataset_numpy = cont_dataset.to_numpy()
target_numpy = dataset['target'].to_numpy()
assert len(cat_dataset) == len(cont_dataset)
assert len(cont_dataset) == len(target_numpy)

In [28]:
class TabDataset(Dataset):
  def __init__(self, cat_dataset_numpy, cont_dataset_numpy, targets):
    self.cat_dataset = cat_dataset_numpy
    self.cont_dataset = cont_dataset_numpy
    self.targets = targets
    self.length = len(targets)
    self.list_samples = []
    self.build()
  
  def __len__(self):
    return len(self.list_samples)
  
  def __getitem__(self, idx):
    sample = self.list_samples[idx]
    return {
        'cat_tensor': torch.tensor(sample['cat_list'], dtype=torch.long),
        'cont_tensor': torch.tensor(sample['cont_list'], dtype=torch.long),
        'target': torch.tensor(sample['target'], dtype=torch.long),
    }

  def build(self):
    for i in range(self.length):
      cat = self.cat_dataset[i]
      cont = self.cont_dataset[i]
      target = self.targets[i]
      self.list_samples.append({
          'cat_list': cat,
          'cont_list': cont,
          'target': target,
      })