In [25]:
import pandas as pd
import torch
from torch_geometric.data import Data
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
import random

In [2]:
# Step 1: Load transactions from CSV
with open('groceries.csv', 'r') as f:
    transactions = [line.strip().split(',') for line in f.readlines()]

# Step 2: Index items and baskets
item_set = set(item.strip() for txn in transactions for item in txn)
item2idx = {item: i for i, item in enumerate(sorted(item_set))}
basket2idx = {i: i + len(item2idx) for i in range(len(transactions))}

In [5]:
item_set

{'Instant food products',
 'UHT-milk',
 'abrasive cleaner',
 'artif. sweetener',
 'baby cosmetics',
 'baby food',
 'bags',
 'baking powder',
 'bathroom cleaner',
 'beef',
 'berries',
 'beverages',
 'bottled beer',
 'bottled water',
 'brandy',
 'brown bread',
 'butter',
 'butter milk',
 'cake bar',
 'candles',
 'candy',
 'canned beer',
 'canned fish',
 'canned fruit',
 'canned vegetables',
 'cat food',
 'cereals',
 'chewing gum',
 'chicken',
 'chocolate',
 'chocolate marshmallow',
 'citrus fruit',
 'cleaner',
 'cling film/bags',
 'cocoa drinks',
 'coffee',
 'condensed milk',
 'cooking chocolate',
 'cookware',
 'cream',
 'cream cheese',
 'curd',
 'curd cheese',
 'decalcifier',
 'dental care',
 'dessert',
 'detergent',
 'dish cleaner',
 'dishes',
 'dog food',
 'domestic eggs',
 'female sanitary products',
 'finished products',
 'fish',
 'flour',
 'flower (seeds)',
 'flower soil/fertilizer',
 'frankfurter',
 'frozen chicken',
 'frozen dessert',
 'frozen fish',
 'frozen fruits',
 'frozen me

In [6]:
item2idx

{'Instant food products': 0,
 'UHT-milk': 1,
 'abrasive cleaner': 2,
 'artif. sweetener': 3,
 'baby cosmetics': 4,
 'baby food': 5,
 'bags': 6,
 'baking powder': 7,
 'bathroom cleaner': 8,
 'beef': 9,
 'berries': 10,
 'beverages': 11,
 'bottled beer': 12,
 'bottled water': 13,
 'brandy': 14,
 'brown bread': 15,
 'butter': 16,
 'butter milk': 17,
 'cake bar': 18,
 'candles': 19,
 'candy': 20,
 'canned beer': 21,
 'canned fish': 22,
 'canned fruit': 23,
 'canned vegetables': 24,
 'cat food': 25,
 'cereals': 26,
 'chewing gum': 27,
 'chicken': 28,
 'chocolate': 29,
 'chocolate marshmallow': 30,
 'citrus fruit': 31,
 'cleaner': 32,
 'cling film/bags': 33,
 'cocoa drinks': 34,
 'coffee': 35,
 'condensed milk': 36,
 'cooking chocolate': 37,
 'cookware': 38,
 'cream': 39,
 'cream cheese': 40,
 'curd': 41,
 'curd cheese': 42,
 'decalcifier': 43,
 'dental care': 44,
 'dessert': 45,
 'detergent': 46,
 'dish cleaner': 47,
 'dishes': 48,
 'dog food': 49,
 'domestic eggs': 50,
 'female sanitary pro

In [7]:
basket2idx

{0: 169,
 1: 170,
 2: 171,
 3: 172,
 4: 173,
 5: 174,
 6: 175,
 7: 176,
 8: 177,
 9: 178,
 10: 179,
 11: 180,
 12: 181,
 13: 182,
 14: 183,
 15: 184,
 16: 185,
 17: 186,
 18: 187,
 19: 188,
 20: 189,
 21: 190,
 22: 191,
 23: 192,
 24: 193,
 25: 194,
 26: 195,
 27: 196,
 28: 197,
 29: 198,
 30: 199,
 31: 200,
 32: 201,
 33: 202,
 34: 203,
 35: 204,
 36: 205,
 37: 206,
 38: 207,
 39: 208,
 40: 209,
 41: 210,
 42: 211,
 43: 212,
 44: 213,
 45: 214,
 46: 215,
 47: 216,
 48: 217,
 49: 218,
 50: 219,
 51: 220,
 52: 221,
 53: 222,
 54: 223,
 55: 224,
 56: 225,
 57: 226,
 58: 227,
 59: 228,
 60: 229,
 61: 230,
 62: 231,
 63: 232,
 64: 233,
 65: 234,
 66: 235,
 67: 236,
 68: 237,
 69: 238,
 70: 239,
 71: 240,
 72: 241,
 73: 242,
 74: 243,
 75: 244,
 76: 245,
 77: 246,
 78: 247,
 79: 248,
 80: 249,
 81: 250,
 82: 251,
 83: 252,
 84: 253,
 85: 254,
 86: 255,
 87: 256,
 88: 257,
 89: 258,
 90: 259,
 91: 260,
 92: 261,
 93: 262,
 94: 263,
 95: 264,
 96: 265,
 97: 266,
 98: 267,
 99: 268,
 100: 269,

In [8]:
# Step 3: Build edges (basket ↔ item)
edges = []
for basket_id, items in enumerate(transactions):
    b_idx = basket2idx[basket_id]
    for item in items:
        i_idx = item2idx[item.strip()]
        edges.append((b_idx, i_idx))
        edges.append((i_idx, b_idx))


In [9]:
edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()


In [10]:
# Step 4: Create identity features for each node (or random/learnable later)
num_nodes = len(item2idx) + len(basket2idx)
x = torch.eye(num_nodes)

In [11]:
# Step 5: Build PyG graph data object
data = Data(x=x, edge_index=edge_index)


In [12]:
# Step 6: Save item/basket mappings
idx2item = {v: k for k, v in item2idx.items()}
idx2basket = {v: f'basket_{k}' for k, v in basket2idx.items()}
idx2label = {**idx2item, **idx2basket}

In [13]:
data

Data(x=[10004, 10004], edge_index=[2, 86734])

In [15]:
class BasketGCN(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(BasketGCN, self).__init__()
        self.conv1 = GCNConv(input_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, output_dim)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        return x


In [38]:
model = BasketGCN(input_dim=num_nodes, hidden_dim=16, output_dim=8)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

model.train()
for epoch in range(200):
    optimizer.zero_grad()
    out = model(data)

    edge_index = data.edge_index
    num_edges = edge_index.size(1)

    # Positive edges (sample subset)
    pos_indices = random.sample(range(num_edges), min(300, num_edges))
    pos_edge = edge_index[:, pos_indices]
    pos_src = out[pos_edge[0]]
    pos_dst = out[pos_edge[1]]
    #score_pos = F.cosine_similarity(pos_src, pos_dst)
    
    # Negative edges (random node pairs)
    neg_edge = torch.randint(0, num_nodes, (2, len(pos_indices)))
    neg_src = out[neg_edge[0]]
    neg_dst = out[neg_edge[1]]
    #score_neg = F.cosine_similarity(neg_src, neg_dst)

    score_pos = (F.cosine_similarity(pos_src, pos_dst) + 1) / 2  # scale to [0, 1]
    score_neg = (F.cosine_similarity(neg_src, neg_dst) + 1) / 2

    # Binary cross-entropy loss (no logits)
    loss_pos = F.binary_cross_entropy(score_pos, torch.ones_like(score_pos))
    loss_neg = F.binary_cross_entropy(score_neg, torch.zeros_like(score_neg))
    loss = loss_pos + loss_neg

    loss.backward()
    optimizer.step()

    if epoch % 10 == 0:
        print(f"Epoch {epoch}, Loss: {loss.item():.4f}")

Epoch 0, Loss: 1.6616
Epoch 10, Loss: 2.8186
Epoch 20, Loss: 1.8038
Epoch 30, Loss: 1.3594
Epoch 40, Loss: 1.1900
Epoch 50, Loss: 1.1711
Epoch 60, Loss: 3.2694
Epoch 70, Loss: 3.4537
Epoch 80, Loss: 3.2850
Epoch 90, Loss: 3.1567
Epoch 100, Loss: 2.7952
Epoch 110, Loss: 2.4344
Epoch 120, Loss: 1.9376
Epoch 130, Loss: 1.7028
Epoch 140, Loss: 1.6767
Epoch 150, Loss: 1.5930
Epoch 160, Loss: 1.6077
Epoch 170, Loss: 1.5744
Epoch 180, Loss: 1.4442
Epoch 190, Loss: 1.4045


In [39]:
model.eval()
embeddings = model(data).detach()
item_embeddings = embeddings[:len(item2idx)]


In [40]:
print("Item Embeddings:")
for item, idx in item2idx.items():
    print(f"{item}: {item_embeddings[idx].numpy()}")

Item Embeddings:
Instant food products: [ 0.28691065  0.10472018  0.18477124  0.13224946 -0.20400688 -0.14767487
 -0.1453765   0.01538046]
UHT-milk: [ 0.5793189   0.11398551 -0.1541357   0.14099294 -0.18475586 -0.6190041
 -0.33707806  0.5365653 ]
abrasive cleaner: [ 0.26948285 -0.05958717  0.25391948  0.19738285  0.0120545   0.1305749
 -0.06025033  0.02362671]
artif. sweetener: [ 0.15783814 -0.0044165   0.38101238  0.20899852  0.03467907  0.19012894
 -0.10490061 -0.00060692]
baby cosmetics: [ 0.255792   -0.12740023  0.07701311  0.11928042  0.06476794 -0.01256869
 -0.05181247  0.10493633]
baby food: [ 0.25802308 -0.19688001  0.02580818  0.08136772  0.10204689  0.03065911
  0.06608375  0.08585261]
bags: [ 0.09851272 -0.295339   -0.165068    0.3200921   0.28454578  0.09863161
  0.1321632  -0.1382392 ]
baking powder: [ 0.14443867  0.04674324  0.39336854  0.35996297 -0.09546812 -0.20937124
 -0.1424881  -0.09753466]
bathroom cleaner: [-0.01121519 -0.12268078  0.0069975   0.08853971  0.140868

In [32]:
target_item = "candy"
idx_of_target_item = item2idx[target_item]


In [41]:
from sklearn.metrics.pairwise import cosine_similarity

# Convert to numpy
embedding_matrix = item_embeddings.numpy()
similarities = cosine_similarity([embedding_matrix[idx_of_target_item]], embedding_matrix)[0]


### The score is the cosine similarity between "candy" and the other item's GNN embedding.


In [42]:
import numpy as np

top_k = 10  # number of similar items to retrieve
similar_indices = np.argsort(similarities)[-top_k-1:-1][::-1]  # exclude the item itself

for idx in similar_indices:
    similar_item = list(item2idx.keys())[list(item2idx.values()).index(idx)]
    score = similarities[idx]
    print(f"Similar to {target_item}: {similar_item} (score = {score:.4f})")


Similar to candy: abrasive cleaner (score = 0.9601)
Similar to candy: cling film/bags (score = 0.9593)
Similar to candy: house keeping products (score = 0.9495)
Similar to candy: pastry (score = 0.9397)
Similar to candy: frozen potato products (score = 0.9323)
Similar to candy: dog food (score = 0.9254)
Similar to candy: hygiene articles (score = 0.9179)
Similar to candy: salad dressing (score = 0.9118)
Similar to candy: bottled beer (score = 0.9016)
Similar to candy: misc. beverages (score = 0.9004)
