## Setup data and task information

In [1]:
%load_ext autoreload
%autoreload 2
import sys
import logging
base_dir = '../'
sys.path.append(base_dir)
import os
from utils import *
import os
import polars as pl
import numpy as np
import pandas as pd
from functools import lru_cache

In [2]:
train_data_dir = '../data/raw_data/'
test_data_dir = '../data/raw_data/'
task = 'task3'
PREDS_PER_SESSION = 100

In [3]:
# Cache loading of data for multiple calls

@lru_cache(maxsize=1)
def read_product_data():
    return pd.read_csv(os.path.join(train_data_dir, 'products_train.csv'))

@lru_cache(maxsize=1)
def read_train_data():
    return pd.read_csv(os.path.join(train_data_dir, 'sessions_train.csv'))

@lru_cache(maxsize=3)
def read_test_data(task):
    return pd.read_csv(os.path.join(test_data_dir, f'sessions_test_{task}.csv'))

## Data Description

The Multilingual Shopping Session Dataset is a collection of **anonymized customer sessions** containing products from six different locales, namely English, German, Japanese, French, Italian, and Spanish. It consists of two main components: **user sessions** and **product attributes**. User sessions are a list of products that a user has engaged with in chronological order, while product attributes include various details like product title, price in local currency, brand, color, and description.

---

### Each product as its associated information:


**locale**: the locale code of the product (e.g., DE)

**id**: a unique for the product. Also known as Amazon Standard Item Number (ASIN) (e.g., B07WSY3MG8)

**title**: title of the item (e.g., “Japanese Aesthetic Sakura Flowers Vaporwave Soft Grunge Gift T-Shirt”)

**price**: price of the item in local currency (e.g., 24.99)

**brand**: item brand name (e.g., “Japanese Aesthetic Flowers & Vaporwave Clothing”)

**color**: color of the item (e.g., “Black”)

**size**: size of the item (e.g., “xxl”)

**model**: model of the item (e.g., “iphone 13”)

**material**: material of the item (e.g., “cotton”)

**author**: author of the item (e.g., “J. K. Rowling”)

**desc**: description about a item’s key features and benefits called out via bullet points (e.g., “Solid colors: 100% Cotton; Heather Grey: 90% Cotton, 10% Polyester; All Other Heathers …”)


## EDA 💽

In [4]:
def read_locale_data(locale, task):
    products = read_product_data().query(f'locale == "{locale}"')
    sess_train = read_train_data().query(f'locale == "{locale}"')
    sess_test = read_test_data(task).query(f'locale == "{locale}"')
    return products, sess_train, sess_test

def show_locale_info(locale, task):
    products, sess_train, sess_test = read_locale_data(locale, task)

    train_l = sess_train['prev_items'].apply(lambda sess: len(sess))
    test_l = sess_test['prev_items'].apply(lambda sess: len(sess))

    print(f"Locale: {locale} \n"
          f"Number of products: {products['id'].nunique()} \n"
          f"Number of train sessions: {len(sess_train)} \n"
          f"Train session lengths - "
          f"Mean: {train_l.mean():.2f} | Median {train_l.median():.2f} | "
          f"Min: {train_l.min():.2f} | Max {train_l.max():.2f} \n"
          f"Number of test sessions: {len(sess_test)}"
        )
    if len(sess_test) > 0:
        print(
             f"Test session lengths - "
            f"Mean: {test_l.mean():.2f} | Median {test_l.median():.2f} | "
            f"Min: {test_l.min():.2f} | Max {test_l.max():.2f} \n"
        )
    print("======================================================================== \n")

In [5]:
products = read_product_data()
locale_names = products['locale'].unique()
for locale in locale_names:
    show_locale_info(locale, task)

Locale: DE 
Number of products: 518327 
Number of train sessions: 1111416 
Train session lengths - Mean: 57.89 | Median 40.00 | Min: 27.00 | Max 2060.00 
Number of test sessions: 10000
Test session lengths - Mean: 40.07 | Median 27.00 | Min: 27.00 | Max 502.00 


Locale: JP 
Number of products: 395009 
Number of train sessions: 979119 
Train session lengths - Mean: 59.61 | Median 40.00 | Min: 27.00 | Max 6257.00 
Number of test sessions: 10000
Test session lengths - Mean: 40.32 | Median 27.00 | Min: 27.00 | Max 502.00 


Locale: UK 
Number of products: 500180 
Number of train sessions: 1182181 
Train session lengths - Mean: 54.85 | Median 40.00 | Min: 27.00 | Max 2654.00 
Number of test sessions: 10000
Test session lengths - Mean: 48.44 | Median 40.00 | Min: 27.00 | Max 594.00 


Locale: ES 
Number of products: 42503 
Number of train sessions: 89047 
Train session lengths - Mean: 48.82 | Median 40.00 | Min: 27.00 | Max 792.00 
Number of test sessions: 6422
Test session lengths - Mean: 

In [6]:
products.sample(5)

Unnamed: 0,id,locale,title,price,brand,color,size,model,material,author,desc
1083538,B01EZA4YMK,UK,"Yankee Candle Car Jar Scented Air Freshener, M...",6.05,Yankee Candle,Dark Brown,Car Jar Bonus 3 Pack,5038580069723,Plastic,,Also ideal in small spaces like pet areas or g...
25160,B0B5TFZC95,DE,"Make up Pinsel, Grundierung, flache Oberseite,...",9.99,BEASOFEE,01,,,,,【Basispinsel für Selbstbräunungsprozesse】: Der...
1064110,B00ET02XD4,UK,The Hitchhiker's Guide to the Galaxy,5.99,,,,,,Alan Rickman,
1227861,B00J8NA9EC,UK,KitchenCraft Living Nostalgia Mechanical Kitch...,37.82,KitchenCraft,Vintage Blue,,LNSCALEBLU,Alloy Steel,,BEAUTIFUL AND FUNCTIONAL: They feature a gorge...
1016976,B09P58R7TJ,UK,Integral 250GB SSD NVME M.2 2280 PCIe Gen3x4 R...,23.95,Integral,,250GB,INSSD250GM280NM2,,,3 Year Free UK-based help and support team


In [7]:
train_sessions = read_train_data()
train_sessions.sample(5)

Unnamed: 0,prev_items,next_item,locale
449209,['B00LSTQUHO' 'B01JANOAO4'],B075DFD4QF,DE
3591690,['B09Y8THMVH' 'B09PC2G77G' 'B09BZ3RZ5X' 'B09PC...,B09H5PTNV1,IT
259958,['B0B6FYGP15' 'B0B6FVTZ5Z'],B09FG5QWFL,DE
324675,['B00E4ZNGVW' 'B08XQZ7LZJ' 'B00E67346A' 'B00V8...,B008YW7YK0,DE
3450957,['B099XB56K5' 'B099X6Y7TT'],B099X8NS9C,FR


In [8]:
test_sessions = read_test_data(task)
test_sessions.sample(5)

Unnamed: 0,prev_items,locale
10653,['B01KV7YJT4' 'B01KV7YY5S' 'B01KV7YUHK'],DE
50099,['B016X03RM8' 'B016X03U3O' 'B0142XIDFE'],UK
11775,['B07NMF9CPC' 'B07NMF9CPC'],DE
42972,['B000WLMR3A' 'B000WMDGU2'],JP
35650,['B08434KVN2' 'B08434KVN2' 'B086XPR3W5' 'B086X...,IT


In [9]:
test_sessions.shape

(56422, 2)

## Generate Submission 🏋️‍♀️



Submission format:
1. The submission should be a **parquet** file with the sessions from all the locales.  
2. Predictions should be added in new column named **"next_item_prediction"**.
3. Predictions should be a single string, the next product title for the session.

In [10]:
def random_predicitons(locale, sess_test_locale):
    random_state = np.random.RandomState(42)
    products = read_product_data().query(f'locale == "{locale}"')
    predictions = (products['title']
                   .sample(len(sess_test_locale), replace=True, random_state=random_state)
                   .values
    )
    sess_test_locale['next_item_prediction'] = predictions
    sess_test_locale.drop('prev_items', inplace=True, axis=1)
    return sess_test_locale

In [14]:
test_sessions = read_test_data(task)
predictions = []
test_locale_names = test_sessions['locale'].unique()
for locale in test_locale_names:
    sess_test_locale = test_sessions.query(f'locale == "{locale}"').copy()
    predictions.append(
        random_predicitons(locale, sess_test_locale)
    )
predictions = pd.concat(predictions).reset_index(drop=True)
predictions.sample(5)

Unnamed: 0,locale,next_item_prediction
44509,JP,メンズ ランニングウェア セット コンプレッションウェア メンズ スポーツパーカー トレーニ...
5164,ES,"Diyife Candado Combinacion, Impermeable Candad..."
2142,ES,"Nenuco Agua de Colonia Fragancia Original, 600ml"
2587,ES,NAWA Home & Work Pack de 2 estanterías 100% me...
50457,UK,Gummy Styling Wax 150 ml Ultra Hold


In [21]:
predictions.head(1)['next_item_prediction'].values

array(['Amazon Basics - Set de 2 fundas de almohada de 400 hilos, 50 x 80 cm - Azul marino'],
      dtype=object)

# Read data 

In [12]:
task2 = pl.scan_parquet('../data/sub_files/task2_test4task3_task2_rank_lgbm_v2.parquet')
task1 = pl.scan_parquet('../data/sub_files/task1_test4task3_rank_lgbm_v4.parquet')


test4task3_pl = pl.scan_parquet(os.path.join('../data/raw_data_session_id/', 'sessions_test_task3.parquet'), 
                                
                                # n_rows=n_rows
                               )


In [13]:
test4task3_pl.head().collect()

prev_items,locale,session_id
str,str,i64
"""['B0BF9JMVDG' …","""ES""",4365996
"""['B09QQG85HM' …","""ES""",4365997
"""['B09NSKDG4K' …","""ES""",4365998
"""['B09B7NYDJ7' …","""ES""",4365999
"""['B0B6J17LK4' …","""ES""",4366000


In [14]:
task2.head().collect()

prev_items,locale,session_id,next_item_prediction
list[str],str,i64,list[str]
"[""B0BF9JMVDG"", ""B01ET9V90M""]","""ES""",4365996,"[""B09HSK3MR5""]"
"[""B09QQG85HM"", ""B09J4T4JF5""]","""ES""",4365997,"[""B0B1V6Q61B""]"
"[""B09NSKDG4K"", ""B09YY6J1ZM""]","""ES""",4365998,"[""B09XM6Z7VY""]"
"[""B09B7NYDJ7"", ""B09B7NYDJ7""]","""ES""",4365999,"[""B014EWSGX2""]"
"[""B0B6J17LK4"", ""B0B6R7X6GY"", ""B07HXY5SGH""]","""ES""",4366000,"[""B086MSYD32""]"


In [15]:
task1.head().collect()

prev_items,locale,session_id,next_item_prediction
list[str],str,i64,list[str]
"[""B01D37JZDO"", ""B09798DT5N""]","""DE""",4372418,"[""B07QV6GZ6P""]"
"[""B09TPHS4J1"", ""B09TPHGHR8""]","""DE""",4372419,"[""B09TPHBD98""]"
"[""B09CPRS6QK"", ""B09XMGPTZ2""]","""DE""",4372420,"[""B093ZXQQ9Y""]"
"[""3785586620"", ""3809439908""]","""DE""",4372421,"[""3741525065""]"
"[""B09Q3DCGW3"", ""B09Q3C5Z33""]","""DE""",4372422,"[""B09Q3C3CKB""]"


In [16]:
res = pl.concat([task1, task2], how='vertical')
product_pl = (pl.scan_parquet(os.path.join('../data/raw_data_session_id/', 
                                          'products_train.parquet'))
                  # .with_columns(
                  #     pl.when(pl.col('locale')=='DE').then(1).when(pl.col('locale')=='DE')
                  #       .then(2)
                  #       .otherwise(3).alias('locale')
                  # )
             )

In [17]:
# product_pl.select(['id', 'locale', 'title']).head().collect()

In [18]:
final_res = (
    res.explode('next_item_prediction')
        .join(product_pl.select(['id', 'locale', 'title']), how='left', right_on=['id', 'locale']
                 , left_on=['next_item_prediction', 'locale'])
)#.head().collect()

In [19]:
predictions = (
    test4task3_pl.join(
        final_res.select(['session_id', 'locale', 'title'])
        , how='left'
        , on=['session_id', 'locale']
    )
        .with_columns(pl.col('title').alias('next_item_prediction').cast(pl.Utf8))
)#.head().collect()

In [20]:
test4task3_pl.select('locale').collect().to_series().value_counts()

locale,counts
str,u32
"""FR""",10000
"""UK""",10000
"""ES""",6422
"""DE""",10000
"""IT""",10000
"""JP""",10000


In [21]:
predictions.head().collect()

prev_items,locale,session_id,title,next_item_prediction
str,str,i64,str,str
"""['B0BF9JMVDG' …","""ES""",4365996,"""Pata Negra - E…","""Pata Negra - E…"
"""['B09QQG85HM' …","""ES""",4365997,"""Oppo Enco Air2…","""Oppo Enco Air2…"
"""['B09NSKDG4K' …","""ES""",4365998,"""Yisica Correa …","""Yisica Correa …"
"""['B09B7NYDJ7' …","""ES""",4365999,"""Rowenta Compac…","""Rowenta Compac…"
"""['B0B6J17LK4' …","""ES""",4366000,"""HOMEASY Fiambr…","""HOMEASY Fiambr…"


In [22]:

predictions.select('locale').collect().to_series().value_counts()

locale,counts
str,u32
"""JP""",10000
"""IT""",10000
"""UK""",10000
"""ES""",6422
"""DE""",10000
"""FR""",10000


In [23]:
predictions_df = predictions.collect().to_pandas()

In [24]:
predictions_df['next_item_prediction'] = predictions_df['next_item_prediction'].astype('str')

# Validate predictions ✅

In [25]:
def check_predictions(predictions):
    """
    These tests need to pass as they will also be applied on the evaluator
    """
    test_locale_names = test_sessions['locale'].unique()
    for locale in test_locale_names:
        sess_test = test_sessions.query(f'locale == "{locale}"')
        preds_locale =  predictions[predictions['locale'] == sess_test['locale'].iloc[0]]
        assert sorted(preds_locale.index.values) == sorted(sess_test.index.values), f"Session ids of {locale} doesn't match"
        assert predictions['next_item_prediction'].apply(lambda x: isinstance(x, str)).all(), "Predictions should all be strings"

In [26]:
predictions_df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 56422 entries, 0 to 56421
Data columns (total 5 columns):
 #   Column                Non-Null Count  Dtype 
---  ------                --------------  ----- 
 0   prev_items            56422 non-null  object
 1   locale                56422 non-null  object
 2   session_id            56422 non-null  int64 
 3   title                 55306 non-null  object
 4   next_item_prediction  56422 non-null  object
dtypes: int64(1), object(4)
memory usage: 2.2+ MB


In [27]:
check_predictions(predictions_df)

In [28]:
task

'task3'

In [29]:
# Its important that the parquet file you submit is saved with pyarrow backend
predictions_df.to_parquet(f'submission_{task}.parquet', engine='pyarrow')

In [30]:
f'submission_{task}.parquet'

'submission_task3.parquet'

## Submit to AIcrowd 🚀

In [31]:
# You can submit with aicrowd-cli, or upload manually on the challenge page.
!aicrowd submission create -c task-3-next-product-title-generation -f f'submission_{task}.parquet'

[2K[1;34msubmission_task3.parquet[0m [90m━━━━━━━━━━━━[0m [35m100.0%[0m • [32m12.5/12.5 MB[0m • [31m4.5 MB/s[0m • [36m0:00:00[0m00:01[0m00:01[0m
[?25h                                                                                  ╭─────────────────────────╮                                                                                  
                                                                                  │ [1mSuccessfully submitted![0m │                                                                                  
                                                                                  ╰─────────────────────────╯                                                                                  
[3m                                                                                        Important links                                                                                        [0m
┌──────────────────┬───────────────────────────────