In [152]:
import torch
import pandas as pd
import numpy as np
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.feature_extraction.text import TfidfVectorizer

In [2]:
train_df = pd.read_csv('../data/50K_1K_R_train.csv')
test_df = pd.read_csv('../data/50K_1K_R_test.csv')

In [7]:
train_df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 50000 entries, 0 to 49999
Data columns (total 8 columns):
train_id             50000 non-null int64
name                 50000 non-null object
item_condition_id    50000 non-null int64
category_name        49817 non-null object
brand_name           28484 non-null object
price                50000 non-null float64
shipping             50000 non-null int64
item_description     50000 non-null object
dtypes: float64(1), int64(3), object(4)
memory usage: 3.1+ MB


In [8]:
test_df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 50000 entries, 0 to 49999
Data columns (total 8 columns):
train_id             50000 non-null int64
name                 50000 non-null object
item_condition_id    50000 non-null int64
category_name        49817 non-null object
brand_name           28484 non-null object
price                50000 non-null float64
shipping             50000 non-null int64
item_description     50000 non-null object
dtypes: float64(1), int64(3), object(4)
memory usage: 3.1+ MB


In [181]:
tf = TfidfVectorizer(max_features=1000)
X_train = tf.fit_transform(train_df['item_description'].values)
X_test = tf.transform(test_df['item_description'].values)
Y_train = np.log1p(train_df['price'].values)
Y_test = np.log1p(test_df['price'].values)

In [11]:
print(X_train.shape, X_test.shape, Y_train.shape)

(50000, 1000) (50000, 1000) (50000,)


In [148]:
class TextDataset(Dataset):
    
    def __init__(self, X, Y):
        if X.shape[0] != Y.shape[0]:
            raise ValueError('X and Y must be same len')
        self._X = X
        self._Y = Y
        self._sparse_size = torch.Size([1, X.shape[1]]) 
    
    def __len__(self):
        return self._X.shape[0]

    def __getitem__(self, i):
        X = self._X[i]
        Y = self._Y[i]
        indices_size = X.indices.shape[0]
        indices = np.append(np.zeros([indices_size, 1]), X.indices.reshape(indices_size, 1), axis=1)
        if indices.shape[0]:
            i = torch.from_numpy(indices).type(torch.LongTensor)
            v = torch.from_numpy(X.data).type(torch.FloatTensor)
            data = torch.sparse.FloatTensor(i.t(), v, self._sparse_size).to_dense()
        else:
            data = torch.sparse.FloatTensor(self._sparse_size).to_dense()
        return {'data': data, 'target': Y}

In [177]:
batch_size = 10
dataset = TextDataset(X_train, Y_train)
loader = DataLoader(dataset, batch_size=batch_size)

In [179]:
class TwoLayerNet(torch.nn.Module):
    def __init__(self, D_in, H, D_out):
        """
        In the constructor we instantiate two nn.Linear modules and assign them as
        member variables.
        """
        super(TwoLayerNet, self).__init__()
        self.linear1 = torch.nn.Linear(D_in, H)
        self.linear2 = torch.nn.Linear(H, D_out)

    def forward(self, x):
        """
        In the forward function we accept a Variable of input data and we must return
        a Variable of output data. We can use Modules defined in the constructor as
        well as arbitrary operators on Variables.
        """
        h_relu = self.linear1(x).clamp(min=0)
        y_pred = self.linear2(h_relu)
        return y_pred


# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = batch_size, 1000, 100, 1

# Construct our model by instantiating the class defined above
model = TwoLayerNet(D_in, H, D_out)

# Construct our loss function and an Optimizer. The call to model.parameters()
# in the SGD constructor will contain the learnable parameters of the two
# nn.Linear modules which are members of the model.
criterion = torch.nn.MSELoss(size_average=False)
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4)
for t in loader:
    # Forward pass: Compute predicted y by passing x to the model
    y_pred = model(Variable(t['data']))

    # Compute and print loss
    loss = criterion(y_pred, Variable(t['target']).type(torch.FloatTensor))
    print(loss.data[0])

    # Zero gradients, perform a backward pass, and update the weights.
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

69.27415466308594
98.05586242675781
113.06080627441406
110.19693756103516
133.05014038085938
81.9401626586914
83.45996856689453
89.29167938232422
87.31775665283203
104.5847396850586
106.79246520996094
91.37677001953125
88.01837921142578
73.36383056640625
93.26463317871094
97.21378326416016
65.95162200927734
94.34750366210938
72.30450439453125
65.110107421875
111.12226104736328
80.7280502319336
76.71150207519531
84.9918212890625
112.44255828857422
69.92308807373047
72.84857177734375
89.22653198242188
86.03199005126953
88.4156723022461
81.9052505493164
80.10733032226562
103.37017822265625
87.54008483886719
102.99297332763672
91.54987335205078
98.39675903320312
80.54597473144531
84.13878631591797
109.25601196289062
62.630714416503906
103.92832946777344
63.36895751953125
74.78907012939453
71.61268615722656
71.0043716430664
84.58391571044922
60.31776809692383
89.29618072509766
88.10643768310547
93.99876403808594
68.01327514648438
77.85311889648438
52.476810455322266
81.907470703125
68.91115

2.4489855766296387
9.1051025390625
24.85993194580078
10.567963600158691
19.665788650512695
8.697900772094727
10.570738792419434
17.50668716430664
12.872322082519531
3.002581834793091
7.553407669067383
7.15053129196167
7.658703327178955
2.703014612197876
5.876870155334473
7.674907684326172
10.04802417755127
7.536769866943359
3.4137139320373535
13.312444686889648
3.2259154319763184
8.70157241821289
9.725056648254395
8.710277557373047
4.644178867340088
9.274775505065918
3.603076934814453
9.630376815795898
12.174673080444336
13.694945335388184
7.775801658630371
3.271726369857788
9.661983489990234
6.09834623336792
4.665202617645264
7.297691345214844
11.25481128692627
15.67796516418457
8.234142303466797
32.42180633544922
6.6612019538879395
10.129732131958008
5.341192245483398
7.3177056312561035
6.083393096923828
12.432079315185547
9.753915786743164
13.884550094604492
2.99055552482605
6.811795711517334
7.440742015838623
12.932124137878418
8.345868110656738
5.448704242706299
12.898265838623047

16.139060974121094
10.848488807678223
5.422901153564453
4.346826553344727
4.851825714111328
5.213144779205322
2.7874062061309814
0.9361686706542969
3.229170322418213
2.686117172241211
5.686120986938477
12.983726501464844
5.9818644523620605
3.5408518314361572
2.8509654998779297
10.036745071411133
4.975166320800781
1.171015739440918
5.171454906463623
8.789322853088379
6.508940696716309
3.0923876762390137
5.050140857696533
2.63907790184021
3.515531063079834
3.610121250152588
5.319393634796143
10.818153381347656
4.703774452209473
6.894707679748535
4.3811516761779785
4.034586429595947
5.954945087432861
5.612804889678955
8.237969398498535
2.505688428878784
5.651148319244385
2.37705397605896
1.3244482278823853
2.985275983810425
6.189765453338623
6.921513557434082
1.8916356563568115
4.666758060455322
7.783088207244873
7.466863632202148
4.100712776184082
10.368195533752441
6.435361862182617
3.6161351203918457
4.10204553604126
8.539475440979004
8.139025688171387
4.860469818115234
2.2843751907348

2.672988176345825
3.06994366645813
3.691541910171509
6.609806060791016
6.243494987487793
1.9374709129333496
3.0113484859466553
2.401578664779663
2.5564725399017334
4.789926528930664
11.31527328491211
9.336465835571289
5.822862148284912
4.141042709350586
2.624276876449585
6.836803913116455
8.347652435302734
3.368687152862549
15.779791831970215
4.126307010650635
5.7621588706970215
5.635601997375488
2.900318145751953
5.941375732421875
2.5752508640289307
5.003165245056152
4.213350772857666
4.343987464904785
5.864757537841797
6.4808030128479
4.117509365081787
2.677309989929199
2.8853039741516113
5.797059059143066
3.948582887649536
1.8562146425247192
1.5226181745529175
3.6030306816101074
3.779778242111206
2.878740072250366
4.801275253295898
2.6509203910827637
6.62585973739624
6.528929233551025
4.629093170166016
3.041011095046997
6.762441158294678
2.2264719009399414
3.347369432449341
2.770852565765381
5.5651373863220215
5.585241794586182
2.9630916118621826
2.0815608501434326
1.049515366554260

8.517023086547852
5.730642318725586
1.191056728363037
4.369935035705566
7.579133987426758
4.4947357177734375
2.546351194381714
2.060525894165039
4.043412685394287
4.12092924118042
9.172698020935059
14.330012321472168
8.218741416931152
5.359267711639404
3.5683889389038086
8.775279998779297
10.02114486694336
3.286179542541504
3.7894275188446045
5.504963397979736
8.21374225616455
8.048097610473633
10.108133316040039
2.5004801750183105
4.280254364013672
7.757200241088867
2.34321928024292
2.630549669265747
4.785092830657959
3.2807869911193848
9.162548065185547
4.294368267059326
4.469855308532715
7.4547905921936035
2.808335065841675
7.008872032165527
9.946816444396973
3.3946280479431152
7.903842926025391
4.291398525238037
1.7893407344818115
3.538942337036133
5.771637916564941
2.8490774631500244
2.798574924468994
3.8814375400543213
5.625006198883057
7.82311487197876
3.2912960052490234
8.398715019226074
9.653091430664062
1.6184104681015015
4.679900169372559
9.273048400878906
4.005084991455078


4.006716251373291
7.271731853485107
4.042892932891846
1.9424920082092285
2.4586181640625
3.805908441543579
6.672966003417969
7.2566375732421875
5.388357639312744
4.4076247215271
2.3383407592773438
6.060356616973877
5.1587138175964355
3.572387218475342
5.063686370849609
7.375541687011719
1.9971551895141602
5.052272796630859
7.480877876281738
5.361124038696289
5.208392143249512
7.681761264801025
3.547027111053467
3.1484744548797607
5.016903400421143
1.7895491123199463
3.6637930870056152
5.104625701904297
5.963303565979004
10.959890365600586
20.91139793395996
2.2181923389434814
6.8234171867370605
7.171722412109375
10.687872886657715
1.2373896837234497
10.842305183410645
5.224625110626221
2.6901164054870605
5.98533821105957
4.111292839050293
3.5908737182617188
3.0645763874053955
4.1460957527160645
4.964346885681152
0.9707373380661011
12.237727165222168
4.288321495056152
8.395561218261719
6.973427772521973
7.220684051513672
3.708874225616455
6.424964427947998
6.530986785888672
3.02896022796

4.336616516113281
5.167959690093994
9.719587326049805
9.116131782531738
7.452897071838379
6.429579734802246
6.338644981384277
6.372892379760742
4.580075740814209
4.318845272064209
6.36552095413208
5.290266513824463
5.012463569641113
4.742961883544922
3.3293087482452393
6.515216827392578
5.831027984619141
4.514964580535889
1.5280711650848389
3.5078606605529785
7.580153465270996
6.994128227233887
5.013180732727051
2.289463758468628
5.376167297363281
2.6513214111328125
4.798013687133789
2.517237424850464
2.491915702819824
5.546175956726074
10.638849258422852
10.50993537902832
7.138739585876465
5.334479331970215
7.298709869384766
5.731610298156738
7.221772193908691
10.777266502380371
3.123136520385742
7.46103572845459
4.661559581756592
7.8485941886901855
4.682251453399658
4.0651164054870605
1.9886765480041504
3.4656333923339844
7.144335746765137
6.487508296966553
2.2186734676361084
3.236754894256592
3.345299243927002
6.7425689697265625
4.0221757888793945
6.303460597991943
1.671109914779663

2.6096458435058594
8.694451332092285
12.75821304321289
4.6268815994262695
3.367805004119873
5.752384662628174
1.4867347478866577
2.5378708839416504
2.876638174057007
2.4247381687164307
8.49404239654541
11.372133255004883
5.11704683303833
4.85622501373291
3.7398457527160645
5.348134517669678
2.4627907276153564
5.187452793121338
2.0713930130004883
6.891374111175537
2.8536555767059326
1.355959177017212
7.149406433105469
6.41940450668335
12.721183776855469
4.005789756774902
2.4451773166656494
6.47117805480957
7.714834213256836
7.188117027282715
3.022543430328369
5.076902866363525
3.48709774017334
3.9045543670654297
2.620748281478882
3.445159435272217
3.8084311485290527
2.8452467918395996
5.60280704498291
4.967496871948242
9.482276916503906
2.7587013244628906
5.204675197601318
4.612271308898926
8.054248809814453
3.7706005573272705
5.524694442749023
2.701556921005249
2.0972182750701904
5.978957176208496
5.6894073486328125
5.942451477050781
8.778736114501953
6.5851898193359375
6.0635094642639

4.14719295501709
2.170093297958374
4.621443748474121
4.630993843078613
4.317547798156738
3.6142399311065674
15.724212646484375
1.5436526536941528
9.105579376220703
5.832883358001709
7.55577278137207
6.273493766784668
3.8622870445251465
7.750178337097168
4.314785003662109
10.748883247375488
4.139461994171143
6.509088039398193
3.4357972145080566
6.0884575843811035
5.6671857833862305
8.65779972076416
2.056600570678711
2.213655948638916
5.585132598876953
6.955280780792236
2.169205665588379
8.812883377075195
5.6359453201293945
6.923859596252441
3.558279275894165
3.5658960342407227
2.4162535667419434
7.186689853668213
7.6403117179870605
5.881947994232178
5.105988502502441
9.128632545471191
6.412230491638184
8.906667709350586
2.3515987396240234
5.019322872161865
4.267756462097168
15.961119651794434
7.813410758972168
5.314545631408691
3.2801079750061035
2.05753493309021
4.235881328582764
12.713916778564453
3.0251758098602295
5.15285062789917
1.9661575555801392
10.142868041992188
7.663320541381

4.8671183586120605
6.769717216491699
8.658233642578125
3.0976815223693848
4.556129455566406
3.2968034744262695
3.298210620880127
10.579523086547852
1.6487706899642944
5.0347700119018555
4.661306381225586
8.825225830078125
5.816458702087402
2.845184564590454
4.323267936706543
3.6431455612182617
8.229033470153809
11.77115249633789
2.7129907608032227
4.319991588592529
2.064694881439209
4.102275848388672
2.2457189559936523
5.7662811279296875
2.541813373565674
3.047114133834839
6.929861068725586
8.831093788146973
1.6846083402633667
5.190451145172119
6.789389133453369
3.6230058670043945
12.949962615966797
8.385841369628906
12.41191577911377
7.375239372253418
3.762551784515381
2.572680711746216
0.8738710880279541
3.9145278930664062
4.966689586639404
13.183104515075684
4.386415958404541
2.7836015224456787
3.1717257499694824
3.3677501678466797
3.2253472805023193
5.269674301147461
5.828909397125244
9.353590965270996
5.964693069458008
1.4343081712722778
4.668096542358398
7.861514568328857
5.37889

4.683337688446045
4.569417476654053
7.245110511779785
3.187776803970337
7.216116428375244
2.818514823913574
2.6642231941223145
8.323209762573242
7.933629512786865
2.6552882194519043
3.5475871562957764
6.595272064208984
12.180856704711914
4.8653693199157715
6.226613521575928
3.0966148376464844
4.028882026672363
6.337127685546875
3.1201934814453125
4.808385848999023
2.7044591903686523
3.580702781677246
4.2186198234558105
7.734091281890869
4.082413673400879
3.4631330966949463
15.739370346069336
8.340505599975586
10.18293285369873
4.381714344024658
10.817546844482422
5.699502944946289
2.9361307621002197
7.049714088439941
3.275618553161621
4.964550971984863
4.056212902069092
5.788803577423096
6.789497375488281
2.2783758640289307
5.487668514251709
2.8623061180114746
7.732900619506836
1.9851138591766357
5.685057640075684
3.7713916301727295
5.991044998168945
2.037606954574585
6.276732921600342
3.9068307876586914
3.061643362045288
8.41323471069336
2.9393370151519775
7.196328163146973
4.74733448

In [182]:
test_dataset = TextDataset(X_test, Y_test)
# model.forward()

In [208]:
def test(num):
    print(model(Variable(test_dataset[num]['data'])))
    print(test_dataset[num]['target'])
    print(np.expm1(model(Variable(test_dataset[num]['data'])).data.numpy()))
    print(test_df['price'][num])
test(2000)

Variable containing:
 2.8760
[torch.FloatTensor of size 1x1]

2.94443897917
[[ 16.74291611]]
18.0
