# Classification - Customer Repurchase Window Prediction



Predicting customer repurchase behavior and timing using historical transaction data from an online retail business.

**Dataset Source**: [Online Retail II UCI Dataset](https://archive.ics.uci.edu/ml/datasets/Online+Retail+II)
**Problem Type**: Classification
**Target Variable**: Customer repurchase probability within specific time windows
**Use Case**: Customer retention strategies, inventory management, targeted marketing campaigns

## Package Imports

In [None]:
import pandas as pd
import xplainable as xp
from xplainable.core.models import XClassifier
from xplainable.core.optimisation.bayesian import XParamOptimiser
from xplainable.preprocessing.pipeline import XPipeline
from xplainable.preprocessing import transformers as xtf
from sklearn.model_selection import train_test_split
import requests
import json

# Additional imports specific to this example
import numpy as np
import datetime as dt
from datetime import datetime, timedelta
import matplotlib.pyplot as plt
import seaborn as sns
import warnings

# New refactored client import
from xplainable_client.client.client import XplainableClient
from xplainable_client.client.base import XplainableAPIError

In [None]:
!pip install xplainable
!pip install xplainable-client

## Xplainable Cloud Setup

In [None]:
# Initialize Xplainable Cloud client using new refactored client
client = XplainableClient(
    api_key="",  # Add your API key from https://platform.xplainable.io/
    hostname="https://platform.xplainable.io"  # Optional, defaults to production
)

## Data Loading and Exploration

Load the Online Retail II dataset and perform basic data exploration.

In [7]:
import pandas as pd
import requests
from io import BytesIO

def load_online_retail_ii():
    """
    Downloads the Online Retail II dataset directly from the UCI repository
    and returns a single DataFrame combining both sheets.
    """
    url = "https://archive.ics.uci.edu/ml/machine-learning-databases/00502/online_retail_II.xlsx"
    r = requests.get(url)
    r.raise_for_status()  # fail early if we got a bad status

    # read both year‐sheets and concatenate
    xls = pd.ExcelFile(BytesIO(r.content))
    df1 = pd.read_excel(xls, sheet_name="Year 2009-2010", parse_dates=["InvoiceDate"])
    df2 = pd.read_excel(xls, sheet_name="Year 2010-2011", parse_dates=["InvoiceDate"])
    df = pd.concat([df1, df2], ignore_index=True)

    # cleanup exactly like you had before
    df = df.dropna(subset=["Customer ID"])
    df = df[(df.Price > 0) & (df.Quantity > 0)].copy()
    df["Amount"] = df.Price * df.Quantity
    return df

# usage
df = load_online_retail_ii()
df.head()

Unnamed: 0,Invoice,StockCode,Description,Quantity,InvoiceDate,Price,Customer ID,Country,Amount
0,489434,85048,15CM CHRISTMAS GLASS BALL 20 LIGHTS,12,2009-12-01 07:45:00,6.95,13085.0,United Kingdom,83.4
1,489434,79323P,PINK CHERRY LIGHTS,12,2009-12-01 07:45:00,6.75,13085.0,United Kingdom,81.0
2,489434,79323W,WHITE CHERRY LIGHTS,12,2009-12-01 07:45:00,6.75,13085.0,United Kingdom,81.0
3,489434,22041,"RECORD FRAME 7"" SINGLE SIZE",48,2009-12-01 07:45:00,2.1,13085.0,United Kingdom,100.8
4,489434,21232,STRAWBERRY CERAMIC TRINKET BOX,24,2009-12-01 07:45:00,1.25,13085.0,United Kingdom,30.0


The timeline below illustrates the core problem the model is solving: will a customer place another order within 30 days of a given purchase? Each row represents an individual customer (C1 – C4), and every blue dot marks one of their historical purchases. From each purchase, a magenta line extends 30 days—the evaluation window used to create the training label. When a follow-up order actually arrives inside that window, it is highlighted with a pink star. Purchases followed by a star are the positive cases (“repurchased”), while those without a star are negative. Visually stepping through these tracks makes it clear how the dataset converts raw transactions into a binary outcome that the model can learn to predict.

In [5]:
%%html
<!DOCTYPE html>
<html lang="en">
<head>
  <meta charset="UTF-8">
  <title>Repurchase Prediction Timeline (30 Days)</title>
  <style>
    body { font-family: Arial, sans-serif; }
    .axis path, .axis line { fill: none; stroke: #000; shape-rendering: crispEdges; }
    .tick line { stroke: #ccc; }
    .purchase { fill: #2774AE; }
    .window { stroke: #E44D9A; stroke-width: 4px; stroke-opacity: 0.4; }
    .rebuy { fill: #E44D9A; }
    .legend { font-size: 12px; }
    .button { position: absolute; top: 10px; right: 20px; padding: 6px 12px; background: #2774AE; color: #fff; border: none; border-radius: 4px; cursor: pointer; }
  </style>
</head>
<body>
  <button class="button" id="replayBtn">Replay</button>
  <h2>Repurchase Prediction Timeline (30 Days)</h2>
  <svg width="800" height="300"></svg>
  <script src="https://d3js.org/d3.v7.min.js"></script>
  <script>
    const rawData = [
      {id: 'C1', date: '2021-02-01', rebuy: '2021-02-25'},
      {id: 'C1', date: '2021-04-05', rebuy: null},
      {id: 'C2', date: '2021-03-01', rebuy: null},
      {id: 'C2', date: '2021-04-15', rebuy: null},
      {id: 'C3', date: '2021-06-01', rebuy: '2021-07-10'},
      {id: 'C3', date: '2021-09-12', rebuy: null},
      {id: 'C4', date: '2021-10-01', rebuy: null}
    ];
    const parseDate = d3.timeParse('%Y-%m-%d');
    function prepareData() {
      return rawData.map(d => {
        const date = parseDate(d.date);
        const rebuyDate = d.rebuy ? parseDate(d.rebuy) : null;
        return { id: d.id, date, end: d3.timeDay.offset(date, 30), rebuyDate };
      });
    }
    const svg = d3.select('svg');
    const margin = {top: 20, right: 20, bottom: 30, left: 60};
    const width = +svg.attr('width') - margin.left - margin.right;
    const height = +svg.attr('height') - margin.top - margin.bottom;
    const g = svg.append('g').attr('transform', `translate(${margin.left},${margin.top})`);

    function render(data) {
      g.selectAll('*').remove();
      const customers = [...new Set(data.map(d => d.id))];
      const x = d3.scaleTime()
        .domain(d3.extent(data.flatMap(d => [d.date, d.end])))
        .range([0, width]);
      const y = d3.scalePoint()
        .domain(customers)
        .range([0, height])
        .padding(0.5);

      g.append('g')
        .attr('class', 'axis')
        .attr('transform', `translate(0,${height})`)
        .call(d3.axisBottom(x).ticks(6).tickFormat(d3.timeFormat('%b-%d')));
      g.append('g')
        .attr('class', 'axis')
        .call(d3.axisLeft(y));

      // animate per customer
      customers.forEach((cust, i) => {
        const custData = data.filter(d => d.id === cust);
        custData.forEach((d, j) => {
          const delay = i * 1000 + j * 300;
          // window
          g.append('line')
            .datum(d)
            .attr('class', 'window')
            .attr('x1', x(d.date))
            .attr('x2', x(d.date))
            .attr('y1', y(d.id))
            .attr('y2', y(d.id))
            .transition()
            .delay(delay)
            .duration(600)
            .attr('x2', x(d.end));
          // purchase
          g.append('circle')
            .datum(d)
            .attr('class', 'purchase')
            .attr('cx', x(d.date))
            .attr('cy', y(d.id))
            .attr('r', 0)
            .transition()
            .delay(delay + 200)
            .duration(300)
            .attr('r', 6);
          // rebuy
          if (d.rebuyDate) {
            g.append('path')
              .datum(d)
              .attr('class', 'rebuy')
              .attr('d', d3.symbol().type(d3.symbolStar).size(200))
              .attr('transform', `translate(${x(d.rebuyDate)},${y(d.id)}) scale(0)`)  
              .transition()
              .delay(delay + 400)
              .duration(400)
              .attr('transform', `translate(${x(d.rebuyDate)},${y(d.id)}) scale(1)`);
          }
        });
      });
      // legend
      const legend = svg.selectAll('.legend').data([0]);
      const lg = legend.enter().append('g').attr('class','legend').merge(legend)
        .attr('transform', `translate(${margin.left},10)`);
      lg.selectAll('*').remove();
      lg.append('circle').attr('cx',0).attr('cy',0).attr('r',6).attr('fill','#2774AE');
      lg.append('text').attr('x',12).attr('y',4).text('Purchase');
      lg.append('line').attr('x1',100).attr('x2',120).attr('y1',0).attr('y2',0)
        .attr('stroke','#E44D9A').attr('stroke-width',4).attr('stroke-opacity',0.4);
      lg.append('text').attr('x',130).attr('y',4).text('30-Day Window');
      lg.append('path')
        .attr('d', d3.symbol().type(d3.symbolStar).size(200))
        .attr('transform', 'translate(240,0) scale(1)')
        .attr('fill','#E44D9A');
      lg.append('text').attr('x',250).attr('y',4).text('Rebuy');
    }

    // initial render & button handler
    function replay() {
      const data = prepareData();
      render(data);
    }
    d3.select('#replayBtn').on('click', replay);
    replay();
  </script>
</body>
</html>


In [21]:
# --- 1. LOAD & CLEAN -------------------------------------------------
df["InvoiceDate"] = pd.to_datetime(df["InvoiceDate"], dayfirst=True, errors="coerce")

df = df.dropna(subset=["Customer ID", "InvoiceDate"])
df = df[(df["Price"] > 0) & (df["Quantity"] > 0)].copy()

df["Amount"]       = df["Price"] * df["Quantity"]
df["InvoiceMonth"] = df["InvoiceDate"].dt.to_period("M")

## 1. Data Preprocessing

### Data Preview and Initial Exploration

In [9]:
df.head()

Unnamed: 0,Invoice,StockCode,Description,Quantity,InvoiceDate,Price,Customer ID,Country,Amount,InvoiceMonth
0,489434,85048,15CM CHRISTMAS GLASS BALL 20 LIGHTS,12,2009-12-01 07:45:00,6.95,13085.0,United Kingdom,83.4,2009-12
1,489434,79323P,PINK CHERRY LIGHTS,12,2009-12-01 07:45:00,6.75,13085.0,United Kingdom,81.0,2009-12
2,489434,79323W,WHITE CHERRY LIGHTS,12,2009-12-01 07:45:00,6.75,13085.0,United Kingdom,81.0,2009-12
3,489434,22041,"RECORD FRAME 7"" SINGLE SIZE",48,2009-12-01 07:45:00,2.1,13085.0,United Kingdom,100.8,2009-12
4,489434,21232,STRAWBERRY CERAMIC TRINKET BOX,24,2009-12-01 07:45:00,1.25,13085.0,United Kingdom,30.0,2009-12


### RFM Feature Engineering

In [22]:
# Sort by customer and invoice date
df_sorted = df.sort_values(["Customer ID", "InvoiceDate"])

# Track the most recent purchase for each row
df_sorted["LastPurchase"] = (
    df_sorted.groupby("Customer ID")["InvoiceDate"].shift()
)

# Add InvoiceMonth and MonthEnd again (safe even if already set)
df_sorted["InvoiceMonth"] = df_sorted["InvoiceDate"].dt.to_period("M")
df_sorted["MonthEnd"] = df_sorted["InvoiceMonth"].dt.to_timestamp("M")

# Keep only the last purchase as of each month
last_purchase = (
    df_sorted.dropna(subset=["LastPurchase"])
             .groupby(["Customer ID", "InvoiceMonth"])["LastPurchase"]
             .max()
             .reset_index()
)

# Create the monthly feature matrix (grp)
grp = (
    df.groupby(["Customer ID", "InvoiceMonth"])
      .agg({
          "Invoice": "nunique",        # Frequency
          "Quantity": "sum",           # DistinctItems or total quantity
          "Amount": "sum",             # Monetary
          "Country": "first",          # Keep Country
      })
      .rename(columns={
          "Invoice": "Frequency",
          "Quantity": "DistinctItems",
          "Amount": "Monetary"
      })
      .reset_index()
)

# Ensure InvoiceMonth is period type
grp["InvoiceMonth"] = grp["InvoiceMonth"].astype("period[M]")

# Merge last purchase dates and calculate Recency
grp = grp.merge(last_purchase, on=["Customer ID", "InvoiceMonth"], how="left")
grp["MonthEnd"] = grp["InvoiceMonth"].dt.to_timestamp("M")
grp["Recency"] = (grp["MonthEnd"] - grp["LastPurchase"]).dt.days
grp.drop(columns=["LastPurchase"], inplace=True)

# Add Month and Quarter for time-based grouping or encoding
grp["Month"] = grp["InvoiceMonth"].dt.month
grp["Quarter"] = grp["InvoiceMonth"].dt.quarter

### Build 30-day Repurchase Label

In [23]:
from pandas.tseries.offsets import Day

# Set the window size
DAYS = 30  # Change to 30 or 90 if needed

# Step 1: Unique (Customer ID, InvoiceDate) combinations
invoice_dates = df[["Customer ID", "InvoiceDate"]].drop_duplicates().copy()

# Step 2: Function to check if there's a purchase within N days
def has_purchase_within_n_days(row):
    cid, date = row["Customer ID"], row["InvoiceDate"]
    future_txns = invoice_dates[
        (invoice_dates["Customer ID"] == cid) &
        (invoice_dates["InvoiceDate"] > date) &
        (invoice_dates["InvoiceDate"] <= date + Day(DAYS))
    ]
    return 1 if len(future_txns) > 0 else 0

# Step 3: Apply the function row-wise (can take 30s+ on large data)
invoice_dates[f"rebuy_{DAYS}d"] = invoice_dates.apply(has_purchase_within_n_days, axis=1)

# Step 4: Convert to monthly and aggregate to get the label
invoice_dates["InvoiceMonth"] = invoice_dates["InvoiceDate"].dt.to_period("M")

label = (
    invoice_dates.groupby(["Customer ID", "InvoiceMonth"])
                 [f"rebuy_{DAYS}d"].max()
                 .reset_index()
                 .rename(columns={f"rebuy_{DAYS}d": f"will_rebuy_{DAYS}d"})
)

# Step 5: Merge with feature matrix
data = grp.merge(label, on=["Customer ID", "InvoiceMonth"], how="left")
data[f"will_rebuy_{DAYS}d"].fillna(0, inplace=True)
data[f"will_rebuy_{DAYS}d"] = data[f"will_rebuy_{DAYS}d"].astype(int)


In [24]:
data[f"will_rebuy_{DAYS}d"].value_counts()

0    16500
1     9095
Name: will_rebuy_30d, dtype: int64

### Train/Test Time-based Split

In [25]:
# --- 4. TIME-BASED SPLIT & MODEL (DYNAMIC DAYS, NO ONE-HOT) --------

data["Date"] = data["InvoiceMonth"].dt.to_timestamp()

train = data[data["Date"] < "2011-07-01"]
test  = data[data["Date"] >= "2011-07-01"]

label_col = f"will_rebuy_{DAYS}d"

X_train = train.drop(columns=[label_col, "InvoiceMonth", "Date", "MonthEnd", "Customer ID"])
y_train = train[label_col]
X_test  = test.drop(columns=[label_col, "InvoiceMonth", "Date", "MonthEnd", "Customer ID"])
y_test  = test[label_col]

In [26]:
X_train

Unnamed: 0,Frequency,DistinctItems,Monetary,Country,Recency,Month,Quarter
0,5,26,113.50,United Kingdom,12.0,12,4
1,4,20,90.00,United Kingdom,16.0,1,1
2,1,5,27.05,United Kingdom,28.0,3,1
3,1,19,142.31,United Kingdom,1.0,6,2
4,1,74215,77183.60,United Kingdom,216.0,1,1
...,...,...,...,...,...,...,...
25589,1,494,833.48,United Kingdom,10.0,8,3
25590,1,732,1071.61,United Kingdom,13.0,5,2
25591,2,508,892.60,United Kingdom,8.0,9,3
25592,1,187,381.50,United Kingdom,7.0,11,4


## 2. Model Optimization

In [27]:
opt = XParamOptimiser()
params = opt.optimise(X_train, y_train)

100%|██████████| 30/30 [00:08<00:00,  3.60trial/s, best loss: -0.8776764727397712]


## 3. Model Training

In [28]:
model = XClassifier(**params)
model.fit(X_train, y_train)

<xplainable.core.ml.classification.XClassifier at 0x2a566e140>

## 4. Model Interpretability and Explainability

In [29]:
model.explain()

## 7. Model Testing

### Hold-out Evaluation

In [30]:
model.evaluate(X_test, y_test)

{'confusion_matrix': [[4333, 25], [823, 1612]],
 'classification_report': {'0': {'precision': 0.8403801396431342,
   'recall': 0.9942634235888022,
   'f1-score': 0.9108681942400673,
   'support': 4358.0},
  '1': {'precision': 0.984728161270617,
   'recall': 0.6620123203285421,
   'f1-score': 0.7917485265225933,
   'support': 2435.0},
  'accuracy': 0.8751656116590608,
  'macro avg': {'precision': 0.9125541504568756,
   'recall': 0.8281378719586721,
   'f1-score': 0.8513083603813303,
   'support': 6793.0},
  'weighted avg': {'precision': 0.892122732409647,
   'recall': 0.8751656116590608,
   'f1-score': 0.868168887469561,
   'support': 6793.0}},
 'roc_auc': 0.8699849600395034,
 'neg_brier_loss': 0.8800487167761432,
 'log_loss': 0.4091278987464692,
 'cohen_kappa': 0.7074258976095472}

## 5. Model Persistence

In [None]:
# Create model using the new client's models service
try:
    model_id, version_id = client.models.create_model(
        model=model,
        model_name="Customer Repurchase - 30 Day Forecast",
        model_description="Predicts whether a customer will make another purchase within 30 days based on their recent order behaviour and RFM features.",
        x=X_train,
        y=y_train
    )
    print(f"Model created successfully!")
    print(f"Model ID: {model_id}")
    print(f"Version ID: {version_id}")
except XplainableAPIError as e:
    print(f"Error creating model: {e.message}")
    model_id, version_id = None, None

## 6. Model Deployment

In [None]:
# Deploy model using the new client's deployments service
try:
    deployment_response = client.deployments.deploy(model_version_id=version_id)
    deployment_id = deployment_response.deployment_id
    print(f"Model deployed successfully!")
    print(f"Deployment ID: {deployment_id}")
except XplainableAPIError as e:
    print(f"Error deploying model: {e.message}")
    deployment_id = None

In [None]:
# Activate deployment using the new client
try:
    client.deployments.activate_deployment(deployment_id)
    print("Deployment activated successfully!")
except XplainableAPIError as e:
    print(f"Error activating deployment: {e.message}")

In [None]:
# Generate deployment key using the new client
try:
    deploy_key = client.deployments.generate_deploy_key(
        deployment_id=deployment_id,
        description='Deployment API for Purchase Prediction',
        days_until_expiry=7
    )
    print(f"Deployment key generated: {str(deploy_key)[:20]}...")
except XplainableAPIError as e:
    print(f"Error generating deploy key: {e.message}")
    deploy_key = None

### Generate Example Payload

In [None]:
#Set the option to highlight multiple ways of creating data
option = 2

In [None]:
if option == 1:
    # Generate example payload using the new client
    try:
        body = client.deployments.generate_example_deployment_payload(deployment_id)
    except:
        body = json.loads(train.drop(columns=[label_col, "InvoiceMonth", "Date", "MonthEnd", "Customer ID"]).sample(1).to_json(orient="records"))
else:
    body = json.loads(train.drop(columns=[label_col, "InvoiceMonth", "Date", "MonthEnd", "Customer ID"]).sample(1).to_json(orient="records"))

In [None]:
body

### Call Inference Endpoint

In [None]:
# Make prediction request
if deploy_key:
    response = requests.post(
        url="https://inference.xplainable.io/v1/predict",
        headers={'api_key': str(deploy_key)},  # Convert deploy_key to string
        json=body
    )

    value = response.json()
    print("Prediction response:")
    print(value)
else:
    print("Deploy key not available, skipping prediction test")