## Things to do
1. Get Embeddings for segment >= 20 for uuid and brand
2. Check if model inference works
3. 

### User and Brand Embeddings for SegGE20

In [1]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import os, sys, json, joblib

In [2]:
sys.path.append('../../offline/src/')
from constants import *
from network import ProductRecommendationModel
from baseline_feats_utils import feat_type_feats_dct

In [34]:
feat_type_feats_dct['item']

['num_interactions',
 'earliest_interaction_date',
 'min_num_interactions_per_user',
 'max_num_interactions_per_user',
 'mean_num_interactions_per_user']

In [3]:
# GLOBALS
SEGMENT = 'GE20'
N_USERS = 1444170
N_ITEMS = 1175648
N_ONTOLOGIES = 801
N_BRANDS = 1686
MODEL_FN = os.path.join(MODEL_DIR, 'Class_model_SegGE20_E1_ckpt.pt')

In [4]:
def choose_embedding_size(cat_cols, cat_num_values, min_emb_dim=100):
    """
    cat_cols: list of categorical columns
    cat_num_values: list of number of unique values for each categorical column
    """

    embedded_cols = dict(zip(cat_cols, cat_num_values))
    embedding_sizes = [(n_categories, min(min_emb_dim, (n_categories+1)//2))
                       for _, n_categories in embedded_cols.items()]
    return embedding_sizes

In [5]:
# choose embedding size

if SEGMENT != 'GE20':
    cat_cols = [ITEM_COL, ONTOLOGY_COL, BRAND_COL]
    cat_num_values = [N_ITEMS, N_ONTOLOGIES, N_BRANDS]
else:
    cat_cols = [USER_COL, ITEM_COL, ONTOLOGY_COL, BRAND_COL]
    cat_num_values = [N_USERS, N_ITEMS, N_ONTOLOGIES, N_BRANDS]

embedding_sizes = choose_embedding_size(cat_cols, cat_num_values, 150)

In [6]:
embedding_sizes

[(1444170, 150), (1175648, 150), (801, 150), (1686, 150)]

In [7]:
model = ProductRecommendationModel(embedding_sizes, 18, 3)

In [8]:
model

ProductRecommendationModel(
  (embeddings): ModuleList(
    (0): Embedding(1444170, 150)
    (1): Embedding(1175648, 150)
    (2): Embedding(801, 150)
    (3): Embedding(1686, 150)
  )
  (lin1): Linear(in_features=618, out_features=300, bias=True)
  (lin2): Linear(in_features=300, out_features=100, bias=True)
  (lin3): Linear(in_features=100, out_features=3, bias=True)
  (bn1): BatchNorm1d(18, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (bn2): BatchNorm1d(300, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (bn3): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (emb_drop): Dropout(p=0.6, inplace=False)
  (drops): Dropout(p=0.3, inplace=False)
)

In [9]:
ckpt = torch.load(MODEL_FN, map_location=torch.device('cpu'))
model.load_state_dict(ckpt['model_state_dict'])

<All keys matched successfully>

In [10]:
model.embeddings

ModuleList(
  (0): Embedding(1444170, 150)
  (1): Embedding(1175648, 150)
  (2): Embedding(801, 150)
  (3): Embedding(1686, 150)
)

In [33]:
model.embeddings[0].weight.data[0].cpu().detach().numpy().tolist()

[-2.342577267482619e-39,
 1.213080641630508e-30,
 4.4801276403994006e-12,
 -1.0745834560718648e-12,
 -1.2736684536912568e-12,
 6.0091482803018485e-31,
 2.27411683100803e-39,
 -3.6126340930431473e-31,
 8.189247845003542e-31,
 -4.468207216983683e-13,
 -5.60428921796904e-32,
 1.2939744761797378e-12,
 -1.2274464051116813e-30,
 -1.1692410147495006e-32,
 -5.756624822276282e-31,
 -1.294582171497416e-12,
 5.167707353617018e-14,
 8.789154163014901e-40,
 1.9223996042439304e-33,
 -1.3788569423846007e-30,
 -4.247238354069029e-31,
 1.7108385021169514e-12,
 -8.790560732155114e-34,
 2.0065267271361113e-30,
 7.22623729517563e-13,
 5.943382857262827e-14,
 6.27796900294457e-14,
 -1.5818236477616043e-12,
 3.3994831147125815e-13,
 -4.610817478691021e-13,
 7.891153008809937e-13,
 2.3336201676986546e-39,
 -1.1290121597218619e-39,
 -7.991583740249442e-13,
 1.6740942387189768e-32,
 1.8846783211338293e-31,
 -3.4870087766718993e-32,
 -2.192589286353612e-39,
 6.515122642769544e-13,
 1.9089695102447735e-12,
 -2.8

### Get the ordered list of features to be sent as input to NN model

In [12]:
from torch.utils.data import IterableDataset
from itertools import chain, islice
from metadata_utils import _find_files


class InteractionsStream(IterableDataset):

    def __init__(self, sample, model_type, file_name=None,
                 interim_data_dir=INTERIM_DATA_DIR, user_col=USER_COL,
                 item_col=ITEM_COL, ontology_col=ONTOLOGY_COL,
                 brand_col=BRAND_COL, price_col=PRICE_COL, dv_col=DV_COL,
                 date_col=DATE_COL, end_token='.gz', chunksize=10,
                 segment='LE3', user2idx_fn=USER2IDX_SEGGE20_FN):

        data_dir = interim_data_dir
        
        if file_name is None:
            files = _find_files(data_dir, end_token)
            if sample == 'train':
                self.files = [os.path.join(data_dir, x) for x in files
                              if not x.startswith('0005')]
            elif sample == 'test':
                self.files = [os.path.join(data_dir, x) for x in files
                              if x.startswith('0005')]
        else:
            self.files = [os.path.join(data_dir, file_name)]
        print(self.files)
        
        self.model_type = model_type
        self.segment = segment
        self.user_col = user_col
        self.item_col = item_col
        self.ontology_col = ontology_col
        self.brand_col = brand_col
        self.price_col = price_col
        self.date_col = date_col
        self.dv_col = dv_col
        self.feat_type_feats_dct = feat_type_feats_dct
        self.chunksize = chunksize
        user_feats = ['{}_{}'.format(self.user_col, x) for x in
                      self.feat_type_feats_dct['user']
                      if x != 'earliest_interaction_date']
        user_feats.append('{}_days_since_earliest_interaction'.format(
            self.user_col))
        item_feats = ['{}_{}'.format(self.item_col, x) for x in
                      self.feat_type_feats_dct['item']
                      if x != 'earliest_interaction_date']
        item_feats.append('{}_days_since_earliest_interaction'.format(
            self.item_col))
        self.numeric_feats = [self.price_col] + user_feats + item_feats
        if self.segment == 'GE20':
            self.cat_feats = [self.user_col, self.item_col,
                              self.ontology_col, self.brand_col]
        else:
            self.cat_feats = [self.item_col, self.ontology_col,
                              self.brand_col]
        if self.segment == 'GE20':
            self.user2idx = json.load(open(user2idx_fn))
        else:
            self.user2idx = None
        

    def read_file(self, fn):
        
        df = pd.read_csv(fn, compression='gzip', sep='|', iterator=True,
                         chunksize=self.chunksize)
        return df
    
    
    def get_dv_for_classification(self, dv_lst):
        
        if self.model_type == 'classification':
            return [int(x-1) for x in dv_lst]
        else:
            return [int(x) for x in dv_lst]
        
    
    def _segment_filter(self, num_interactions_lst, feat_type, feats_lst):
        
        if self.segment != 'GE20':
            idxs = [i for i, x in enumerate(num_interactions_lst)
                    if x < 20]
        elif self.segment == 'GE20':
            idxs = [i for i, x in enumerate(num_interactions_lst)
                    if x >= 20]
        
        if idxs:
            new_feats_lst = [feats_lst[i] for i in idxs]
            if (self.segment == 'GE20') and (feat_type == 'cat'):
                new_feats_lst = []
                for i in idxs:
                    out = feats_lst[i]
                    out[0] = self.user2idx[str(out[0])]
                    new_feats_lst.append(out)
            return new_feats_lst

    
    def process_data(self, fn):

        print('read data')
        data = self.read_file(fn)

        for row in data:
            num_interactions = row['uuid_num_interactions'].values.tolist()
            
            x1 = row[self.cat_feats].values.tolist()
            x2 = row[self.numeric_feats].values.tolist()
            y = self.get_dv_for_classification(
                    row[self.dv_col].tolist())
            x1 = self._segment_filter(num_interactions, 'cat', x1)
            if x1:
                x2 = self._segment_filter(num_interactions, 'numeric', x2)
                y = self._segment_filter(num_interactions, 'dv', y)
                yield (x1, x2, y)
            else:
                continue

    
    def get_stream(self, files):
        return chain.from_iterable(map(self.process_data, files))

    
    def __iter__(self):
        return self.get_stream(self.files)

In [13]:
from torch.utils.data import DataLoader

test_dataset = InteractionsStream(
    file_name='0005_part_07.gz', model_type='classification',
    sample='test', chunksize=2, segment='GE20')
test_loader = DataLoader(test_dataset, batch_size=2,
                          shuffle=False)

['/Users/varunn/Documents/ExternalTest_Data/MAD/interim/0005_part_07.gz']


In [31]:
from itertools import islice
from torch import tensor


def construct_tensor(a):

    final = []
    for i in a:
        out = []
        for j in i:
            out.append(j.tolist())
        out1 = []
        for item in zip(*out):
            out1.append(list(item))
        final += out1
    return tensor(final)


def construct_tensor_y(a):

    out = []
    for i in a:
        out += i.tolist()
    return tensor(out)

for x1, x2, y in islice(test_loader, 1):
    x1, x2 = construct_tensor(x1), construct_tensor(x2)
    y = construct_tensor_y(y)
    print(x1)
    print('\n')
    print(x2)
    print('\n')
    print(y)
    print(x1.shape)

read data
tensor([[ 659123,  911340,     431,    1480],
        [1003142,  693320,     696,    1507],
        [ 221334,  876085,     217,    1327],
        [ 822311,  630337,     329,    1222]])


tensor([[1.0990e+03, 7.7600e+02, 2.6948e+03, 1.0000e+00, 1.3000e+01, 1.0270e+00,
         1.0000e+00, 2.6000e+01, 1.5190e+00, 1.0000e+00, 2.3000e+01, 1.3889e+00,
         8.8655e+00, 1.8800e+02, 1.0000e+00, 4.0000e+00, 1.0271e+00, 3.8233e+01],
        [1.9990e+03, 9.3000e+01, 2.4667e+03, 1.0000e+00, 5.0000e+00, 2.2821e+00,
         1.0000e+00, 8.0000e+00, 2.3345e+00, 1.0000e+00, 5.0000e+00, 2.3013e+00,
         4.4421e-02, 4.7400e+02, 1.0000e+00, 3.0000e+00, 1.0091e+00, 1.5834e+01],
        [5.3990e+03, 1.0150e+03, 2.6539e+03, 1.0000e+00, 1.1000e+01, 1.0242e+00,
         1.0000e+00, 1.7000e+01, 1.2651e+00, 1.0000e+00, 1.9000e+01, 1.1541e+00,
         3.6472e+01, 1.2100e+03, 1.0000e+00, 4.0000e+00, 1.0462e+00, 4.6897e+01],
        [6.9990e+03, 1.4200e+02, 9.2414e+03, 1.0000e+00, 1.3000e+01, 1.

In [16]:
print(test_dataset.cat_feats)
print(test_dataset.numeric_feats)

['uuid', 'sourceprodid', 'ontology', 'brand']
['price', 'uuid_num_interactions', 'uuid_mean_price_interactions', 'uuid_min_num_interactions_per_pdt', 'uuid_max_num_interactions_per_pdt', 'uuid_mean_num_interactions_per_pdt', 'uuid_min_num_interactions_per_ont', 'uuid_max_num_interactions_per_ont', 'uuid_mean_num_interactions_per_ont', 'uuid_min_num_interactions_per_brand', 'uuid_max_num_interactions_per_brand', 'uuid_mean_num_interactions_per_brand', 'uuid_days_since_earliest_interaction', 'sourceprodid_num_interactions', 'sourceprodid_min_num_interactions_per_user', 'sourceprodid_max_num_interactions_per_user', 'sourceprodid_mean_num_interactions_per_user', 'sourceprodid_days_since_earliest_interaction']


### Model Inference

In [17]:
inp_fn = os.path.join(INTERIM_DATA_DIR, '0005_part_07.gz')
df = pd.read_csv(inp_fn, sep='|', compression='gzip')

In [18]:
print(df.shape)
df.head()

(1371989, 24)


Unnamed: 0,uuid,userevent,sourceprodid,clicked_epoch,ontology,brand,price,uuid_num_interactions,uuid_mean_price_interactions,uuid_days_since_earliest_interaction,...,uuid_max_num_interactions_per_ont,uuid_mean_num_interactions_per_ont,uuid_min_num_interactions_per_brand,uuid_max_num_interactions_per_brand,uuid_mean_num_interactions_per_brand,sourceprodid_num_interactions,sourceprodid_days_since_earliest_interaction,sourceprodid_min_num_interactions_per_user,sourceprodid_max_num_interactions_per_user,sourceprodid_mean_num_interactions_per_user
0,4852310,1,911340,1551714898,431,1480,1099.0,776.0,2694.77895,8.865451,...,26.0,1.519018,1.0,23.0,1.388859,188.0,38.233414,1.0,4.0,1.027066
1,6013644,1,876085,1550334397,217,1327,5399.0,1015.0,2653.918513,36.472095,...,17.0,1.265118,1.0,19.0,1.154053,1210.0,46.896551,1.0,4.0,1.046217
2,551584,1,693320,1550337611,696,1507,1999.0,93.0,2466.662973,0.044421,...,8.0,2.334516,1.0,5.0,2.301251,474.0,15.834213,1.0,3.0,1.009087
3,2954929,1,630337,1550341067,329,1222,6999.0,142.0,9241.447513,7.639884,...,20.0,3.707738,1.0,14.0,2.38984,367.0,46.536771,1.0,4.0,1.048521
4,2936231,1,537273,1550336735,285,708,499.0,163.0,820.488392,46.41912,...,16.0,2.269754,1.0,15.0,2.29338,357.0,9.044988,1.0,3.0,1.016065


In [19]:
inp_fn = os.path.join(RAW_DATA_DIR, '0005_part_07.gz')
raw_df = pd.read_csv(inp_fn, sep='|', compression='gzip')

print(raw_df.shape)
raw_df.head()

(1371989, 7)


Unnamed: 0,uuid,userevent,sourceprodid,clicked_epoch,ontology,brand,price
0,cc1b580857481534abb2204b167915d7,pageView,3a9feb4237f4203b3118d2071b93c96c,1551714898,644e3342d3fb99e0b4d03b610dd4827d,a6b68a1deb25ba3f4b5a4c4f780094e4,1099.0
1,dae70b91a4c3707956e7dd17a5b03e5c,pageView,611f8d943412d7260bedac2f493f2c77,1550334397,f9e13a341127b189f97e6ee05923340c,b081d61f98a982edd345e81a9d70102a,5399.0
2,300665c14ec978de7eeb31466eb27712,pageView,cd2d2d61897732d4aed73db4af897010,1550337611,e3bb7a2fc0e60206b5b12b95c0c25b07,72c035606a07faa83f56eeb7a1be1beb,1999.0
3,94c68896d5a983923c5acfc62c2303a0,pageView,55eb55799a57b2cc969eb5025e655025,1550341067,87400c7f16b66890a0e0e97305291c92,787923f3de426787a37e7c024f96418d,6999.0
4,eb42828d343f5e56bbf969ac7b7a0a36,pageView,5c2c92c8a1a442024b23016f145d3fda,1550336735,ed414ff376ba74be64279ba9b31a94f3,49b42c44eb0bf64a6a33b4df5ce3b7e9,499.0


In [20]:
pd.to_datetime(1551714898, unit='s')

Timestamp('2019-03-04 15:54:58')

#### Inputs
1. uuid
2. sourceprodid

#### Approach
0. map sourceprodid to brand, ontology and price
1. user and item baseline features
2. index mapping for uuid, sourceprodid, ontology, brand
3. prepare input tensors
4. model.forward

In [21]:
user = raw_df.loc[0, 'uuid']
item = raw_df.loc[0, 'sourceprodid']
clicked_epoch = 1551714898

print('User: ', user)
print('Item: ', item)
print('Clicked Epoch: ', clicked_epoch)

User:  cc1b580857481534abb2204b167915d7
Item:  3a9feb4237f4203b3118d2071b93c96c
Clicked Epoch:  1551714898


In [22]:
# Step 0
pdt_mapping = json.load(open(PDT_MAPPING_FN))
ont, brand, price = pdt_mapping[item]
del pdt_mapping

print('Ontology: ', ont)
print('Brand: ', brand)
print('Price: ', price)

Ontology:  644e3342d3fb99e0b4d03b610dd4827d
Brand:  a6b68a1deb25ba3f4b5a4c4f780094e4
Price:  1099.0


In [23]:
# step 1

def get_baseline_feats(user_col, item_col, user, item, clicked_epoch,
                       user_feats, item_feats,
                       feat_type_dct=feat_type_feats_dct):
    
    print('User Features')
    feats = {}
    for feat_pos, feat_name in enumerate(feat_type_dct['user']):
        val = user_feats[user][feat_pos]
        key = user_col+'_'+feat_name
        if feat_name == 'earliest_interaction_date':
            key = user_col+'_days_since_earliest_interaction'
            val = (float(clicked_epoch)-float(val))/(60*60*24)
            if val < 0:
                val = -1
        feats[key] = val

    print('Item Features')
    for feat_pos, feat_name in enumerate(feat_type_dct['item']):
        val = item_feats[item][feat_pos]
        key = item_col+'_'+feat_name
        if feat_name == 'earliest_interaction_date':
            key = item_col+'_days_since_earliest_interaction'
            val = (float(clicked_epoch)-float(val))/(60*60*24)
            if val < 0:
                val = -1
        feats[key] = val
    
    return feats


print('read baseline feats dct')
user_feats = json.load(open(USER_BASELINE_FEATS_FN))
item_feats = json.load(open(ITEM_BASELINE_FEATS_FN))

print('get baseline feats')
baseline_feats = get_baseline_feats(USER_COL, ITEM_COL, user, item,
                                    clicked_epoch, user_feats,
                                    item_feats)

del user_feats, item_feats

print(baseline_feats)

read baseline feats dct
get baseline feats
User Features
Item Features
{'uuid_num_interactions': 776.0, 'uuid_mean_price_interactions': 2694.7789502483192, 'uuid_days_since_earliest_interaction': 8.86545138888889, 'uuid_min_num_interactions_per_pdt': 1.0, 'uuid_max_num_interactions_per_pdt': 13.0, 'uuid_mean_num_interactions_per_pdt': 1.0270275027907854, 'uuid_min_num_interactions_per_ont': 1.0, 'uuid_max_num_interactions_per_ont': 26.0, 'uuid_mean_num_interactions_per_ont': 1.5190180584199862, 'uuid_min_num_interactions_per_brand': 1.0, 'uuid_max_num_interactions_per_brand': 23.0, 'uuid_mean_num_interactions_per_brand': 1.3888592131566546, 'sourceprodid_num_interactions': 188.0, 'sourceprodid_days_since_earliest_interaction': 38.23341435185185, 'sourceprodid_min_num_interactions_per_user': 1.0, 'sourceprodid_max_num_interactions_per_user': 4.0, 'sourceprodid_mean_num_interactions_per_user': 1.0270659587429387}


In [24]:
# step 2

print('User IDX\n')
user2idx_fn = FINAL_USER2IDX_SEGGE20_FN if baseline_feats['uuid_num_interactions'] >= 20 else FINAL_USER2IDX_SEGLT20_FN
print('user2idx FN: ', user2idx_fn)
user2idx = json.load(open(user2idx_fn))
user_idx = user2idx[user]
del user2idx

print('Item IDX\n')
item2idx_fn = ITEM2IDX_FN
item2idx = json.load(open(item2idx_fn))
item_idx = item2idx[item]
del item2idx

print('Ontology IDX\n')
ont2idx_fn = ONT2IDX_FN
ont2idx = json.load(open(ont2idx_fn))
ont_idx = ont2idx[ont]
del ont2idx

print('Brand IDX\n')
brand2idx_fn = BRAND2IDX_FN
brand2idx = json.load(open(brand2idx_fn))
brand_idx = brand2idx[brand]
del brand2idx

print(user_idx, '\t', item_idx, '\t', ont_idx, '\t', brand_idx)

User IDX

user2idx FN:  /Users/varunn/Documents/ExternalTest_Data/MAD/metadata/final_user2idx_segGE20.json
Item IDX

Ontology IDX

Brand IDX

659123 	 911340 	 431 	 1480


In [37]:
# step 3 - prepare input tensors

cat_feat_cols = [USER_COL, ITEM_COL, ONTOLOGY_COL, BRAND_COL]
numeric_feat_cols = [
    PRICE_COL, 'uuid_num_interactions',
    'uuid_mean_price_interactions',
    'uuid_min_num_interactions_per_pdt',
    'uuid_max_num_interactions_per_pdt',
    'uuid_mean_num_interactions_per_pdt',
    'uuid_min_num_interactions_per_ont',
    'uuid_max_num_interactions_per_ont',
    'uuid_mean_num_interactions_per_ont',
    'uuid_min_num_interactions_per_brand',
    'uuid_max_num_interactions_per_brand',
    'uuid_mean_num_interactions_per_brand',
    'uuid_days_since_earliest_interaction',
    'sourceprodid_num_interactions',
    'sourceprodid_min_num_interactions_per_user',
    'sourceprodid_max_num_interactions_per_user',
    'sourceprodid_mean_num_interactions_per_user',
    'sourceprodid_days_since_earliest_interaction']

cat_feat_values = [user_idx, item_idx, ont_idx, brand_idx]
numeric_feat_values = [price] + [baseline_feats[col] for col in
                                 numeric_feat_cols[1:]]
cat_feat_values = [cat_feat_values] + [cat_feat_values]
numeric_feat_values = [numeric_feat_values] + [numeric_feat_values]

print(cat_feat_values)
print('\n')
print(numeric_feat_values)

print('convert feature lists to tensors\n')
cat_feat_tensor = torch.tensor(cat_feat_values)
if cat_feat_tensor.dim() == 1:
    cat_feat_tensor = cat_feat_tensor.view(1, cat_feat_tensor.size()[0])

numeric_feat_tensor = torch.tensor(numeric_feat_values)
if numeric_feat_tensor.dim() == 1:
    numeric_feat_tensor = numeric_feat_tensor.view(
        1, numeric_feat_tensor.size()[0])

print(cat_feat_tensor)
print('\n')
print(numeric_feat_tensor)

[[659123, 911340, 431, 1480], [659123, 911340, 431, 1480]]


[[1099.0, 776.0, 2694.7789502483192, 1.0, 13.0, 1.0270275027907854, 1.0, 26.0, 1.5190180584199862, 1.0, 23.0, 1.3888592131566546, 8.86545138888889, 188.0, 1.0, 4.0, 1.0270659587429387, 38.23341435185185], [1099.0, 776.0, 2694.7789502483192, 1.0, 13.0, 1.0270275027907854, 1.0, 26.0, 1.5190180584199862, 1.0, 23.0, 1.3888592131566546, 8.86545138888889, 188.0, 1.0, 4.0, 1.0270659587429387, 38.23341435185185]]
convert feature lists to tensors

tensor([[659123, 911340,    431,   1480],
        [659123, 911340,    431,   1480]])


tensor([[1.0990e+03, 7.7600e+02, 2.6948e+03, 1.0000e+00, 1.3000e+01, 1.0270e+00,
         1.0000e+00, 2.6000e+01, 1.5190e+00, 1.0000e+00, 2.3000e+01, 1.3889e+00,
         8.8655e+00, 1.8800e+02, 1.0000e+00, 4.0000e+00, 1.0271e+00, 3.8233e+01],
        [1.0990e+03, 7.7600e+02, 2.6948e+03, 1.0000e+00, 1.3000e+01, 1.0270e+00,
         1.0000e+00, 2.6000e+01, 1.5190e+00, 1.0000e+00, 2.3000e+01, 1.3889e+00,
   

In [45]:
# prediction
import torch.nn.functional as F
model.eval()
out = model(cat_feat_tensor, numeric_feat_tensor)
print(out)
pred_prob = F.softmax(out, dim=1)
pred = torch.max(out, 1)[1]
print(pred_prob)
print(pred)
buy_pred_prob = pred_prob.cpu().detach().numpy()[:,2]
buy_pred_prob.argsort()[::-1]

tensor([[ 3.1425, -0.7323, -2.7550],
        [ 3.1425, -0.7323, -2.7550]], grad_fn=<AddmmBackward>)
tensor([[0.9770, 0.0203, 0.0027],
        [0.9770, 0.0203, 0.0027]], grad_fn=<SoftmaxBackward>)
tensor([0, 0])


array([1, 0])