In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import gc
import tqdm

In [2]:
from torch.utils.data import DataLoader as torch_dl
from torch.utils.data import Dataset
from torch import  nn
from torch import optim
from torch.nn.init import *
from torch.nn import functional as F

In [11]:
train = pd.read_csv('../Data/mercari/train.tsv', sep='\t')
test = pd.read_csv('../Data/mercari/test.tsv', sep='\t')

In [12]:
test.shape

(693359, 7)

In [14]:
train.columns = ['id', 'name', 'item_condition_id', 'category_name', 'brand_name',
       'price', 'shipping', 'item_description']

test['price'] = 0
test.columns = ['id', 'name', 'item_condition_id', 'category_name', 'brand_name',
       'shipping', 'item_description', 'price']

In [15]:
train.head()

Unnamed: 0,id,name,item_condition_id,category_name,brand_name,price,shipping,item_description
0,0,MLB Cincinnati Reds T Shirt Size XL,3,Men/Tops/T-shirts,,10.0,1,No description yet
1,1,Razer BlackWidow Chroma Keyboard,3,Electronics/Computers & Tablets/Components & P...,Razer,52.0,0,This keyboard is in great condition and works ...
2,2,AVA-VIV Blouse,1,Women/Tops & Blouses/Blouse,Target,10.0,1,Adorable top with a hint of lace and a key hol...
3,3,Leather Horse Statues,1,Home/Home Décor/Home Décor Accents,,35.0,1,New with tags. Leather horses. Retail for [rm]...
4,4,24K GOLD plated rose,1,Women/Jewelry/Necklaces,,44.0,0,Complete with certificate of authenticity


In [19]:
train_test = pd.concat([train, test], 0, sort=False)

In [20]:
train_test.drop(['id', 'name', 'item_description'], axis=1, inplace=True) 

In [23]:
train_test.isna().sum()

item_condition_id         0
category_name          9385
brand_name           928207
price                     0
shipping                  0
dtype: int64

In [24]:
train_test.item_condition_id.value_counts()

1    940630
3    633834
2    551032
4     46815
5      3583
Name: item_condition_id, dtype: int64

In [25]:
train_test.head()

Unnamed: 0,item_condition_id,category_name,brand_name,price,shipping
0,3,Men/Tops/T-shirts,,10.0,1
1,3,Electronics/Computers & Tablets/Components & P...,Razer,52.0,0
2,1,Women/Tops & Blouses/Blouse,Target,10.0,1
3,1,Home/Home Décor/Home Décor Accents,,35.0,1
4,1,Women/Jewelry/Necklaces,,44.0,0


In [26]:
train_test.fillna('missing', inplace=True)

In [28]:
train_test.category_name = train_test.category_name.astype('category').cat.codes

In [31]:
train_test.brand_name = train_test.brand_name.astype('category').cat.codes

In [32]:
train_test.head()

Unnamed: 0,item_condition_id,category_name,brand_name,price,shipping
0,3,829,5265,10.0,1
1,3,86,3889,52.0,0
2,1,1277,4588,10.0,1
3,1,503,5265,35.0,1
4,1,1204,5265,44.0,0


In [33]:
train_test = train_test.reset_index(drop=True)

In [40]:
cats = ['item_condition_id', 'category_name', 'brand_name']

In [41]:
def EmbeddingDataPreprocess(data, cats, inplace =True):
    ### Each categorical column should have indices as values 
    ### Which will be looked up at embedding matrix and used in modeling
    ### Make changes inplace
    if inplace:
        for c in cats:
            data[c].replace({val:i  for i, val in enumerate(data[c].unique())}, inplace=True)
        return data
    else:
        data_copy = data.copy()
        for c in cats:
            data_copy[c].replace({val:i  for i, val in enumerate(data_copy[c].unique())}, inplace=True)
        return data_copy

In [42]:
train_test = EmbeddingDataPreprocess(train_test, cats, inplace=True)

In [49]:
train_df = train_test.iloc[range(len(train))]
test_df = train_test.iloc[range(len(train),len(train_test))]

In [51]:
del train
test_id = test['id']
del test
gc.collect()

33

In [52]:
train_input, train_y = train_df.drop('price', 1), np.log(train_df.price + 1)
test_input, test_y = test_df.drop('price', 1), np.log(test_df.price + 1)
y_range = (train_y.min(), train_y.max())

In [53]:
def get_embs_dims(data, cats):
    cat_sz = [len(data[c].unique()) for c in cats]
    return [(c, min(50, (c+1)//2)) for c in cat_sz]

In [54]:
y_range = (train_y.min(), train_y.max())
emb_szs = get_embs_dims(train_test, cats)

In [55]:
emb_szs

[(5, 3), (1311, 50), (5290, 50)]