In [None]:
"""
from CoLLM paper

ML-1M
preserve the interactions from the most recent twenty months, using the first 10 months for training, the middle 5 months for validation, and the last 5 months for testing
Train: 33,891
Valid: 10,401
Test: 7,331
User: 839
Item: 3,256


Amazon-Book dataset
preserve interactions from the year 2017 (including about 4 million interactions)
allocating the first 11 months for training, and the remaining two half months for validation and testing, respectively

filtered out users and items with fewer than 20 interactions to ensure data quality for measuring warm-start performance

Train: 727,468
Valid: 25,747
Test: 25,747
User: 22,967
Item: 34,154
"""

In [None]:
from src.data.datasets import AmazonDataset
from src.data.graphDatasets import AmazonGraphDataset

In [None]:
import os

In [None]:
root="data/AmazonReviews"
os.makedirs(root, exist_ok=True)
datasetConfig = "src/data/datasetConfigAmazon.json"
datasetName = "AmazonAllBeautyDataset"

AmazonAllBeautyDataset = AmazonDataset(root, datasetConfig, datasetName)

In [None]:
root="data/AmazonReviews"
os.makedirs(root, exist_ok=True)
datasetConfig = "src/data/datasetConfigAmazon.json"
datasetName = "AmazonAllBeautyDataset"

AmazonAllBeautyGraphDataset = AmazonGraphDataset(root, datasetConfig, datasetName, devCtrl=True)

In [None]:
AmazonAllBeautyGraphDataset.interactionData[ AmazonAllBeautyGraphDataset.interactionData["Split"] == "train" ]

In [None]:
AmazonAllBeautyGraphDataset.trainingData

In [None]:
AmazonAllBeautyGraphDataset.trainingData['user', 'item']

In [None]:
AmazonAllBeautyGraphDataset.trainingData['user']

In [None]:
AmazonAllBeautyGraphDataset.trainingData['item']

In [None]:
AmazonAllBeautyGraphDataset.interactionData[ AmazonAllBeautyGraphDataset.interactionData["Split"] == "valid" ]

In [None]:
AmazonAllBeautyGraphDataset.validationData['user', 'item'].edge_index[:, :2]

In [None]:
AmazonAllBeautyGraphDataset.validationData['user', 'item'].y

In [None]:
AmazonAllBeautyGraphDataset.interactionData[ AmazonAllBeautyGraphDataset.interactionData["Split"] == "test" ]

In [None]:
AmazonAllBeautyGraphDataset.testData['user', 'item']

In [None]:
AmazonAllBeautyGraphDataset.testData['user']

In [None]:
AmazonAllBeautyGraphDataset.testData['user'].node_id.shape[0]

In [None]:
AmazonAllBeautyGraphDataset.testData['user'].x.shape

In [None]:
AmazonAllBeautyGraphDataset.testData['item']

In [None]:
stop

In [None]:
# NEXT TODOS:
# then feature creation

In [None]:
# training

In [None]:
from torch_geometric.datasets import AmazonBook
osp = os.path

path = osp.join('data', 'AmazonPyG')
dataset = AmazonBook(path)
data = dataset[0]
num_users, num_books = data['user'].num_nodes, data['book'].num_nodes
# data = data.to_homogeneous().to(device)

In [None]:
data

In [None]:
import torch

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

trainHomo = AmazonAllBeautyGraphDataset.trainingData.to_homogeneous().to(device)
testHomo = AmazonAllBeautyGraphDataset.testData.to_homogeneous().to(device)

In [None]:
AmazonAllBeautyGraphDataset.trainingData

In [None]:
# ADAPT THIS HERE

# then test vanilla model
import tqdm
from torch_geometric.nn import LightGCN

# Use all message passing edges as training labels:
batch_size = 16
mask = data.edge_index[0] < data.edge_index[1]
train_edge_label_index = data.edge_index[:, mask]
train_loader = torch.utils.data.DataLoader(
    range(train_edge_label_index.size(1)),
    shuffle=True,
    batch_size=batch_size,
)

model = LightGCN(
    num_nodes=data.num_nodes,
    embedding_dim=64,
    num_layers=2,
).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [None]:
def train():
    total_loss = total_examples = 0

    for index in tqdm(train_loader):
        # Sample positive and negative labels.
        pos_edge_label_index = train_edge_label_index[:, index]
        neg_edge_label_index = torch.stack([
            pos_edge_label_index[0],
            torch.randint(num_users, num_users + num_books,
                          (index.numel(), ), device=device)
        ], dim=0)
        edge_label_index = torch.cat([
            pos_edge_label_index,
            neg_edge_label_index,
        ], dim=1)

        optimizer.zero_grad()
        pos_rank, neg_rank = model(data.edge_index, edge_label_index).chunk(2)

        loss = model.recommendation_loss(
            pos_rank,
            neg_rank,
            node_id=edge_label_index.unique(),
        )
        loss.backward()
        optimizer.step()

        total_loss += float(loss) * pos_rank.numel()
        total_examples += pos_rank.numel()

    return total_loss / total_examples


@torch.no_grad()
def test(k: int):
    emb = model.get_embedding(data.edge_index)
    user_emb, book_emb = emb[:num_users], emb[num_users:]

    precision = recall = total_examples = 0
    for start in range(0, num_users, batch_size):
        end = start + batch_size
        logits = user_emb[start:end] @ book_emb.t()

        # Exclude training edges:
        mask = ((train_edge_label_index[0] >= start) &
                (train_edge_label_index[0] < end))
        logits[train_edge_label_index[0, mask] - start,
               train_edge_label_index[1, mask] - num_users] = float('-inf')

        # Computing precision and recall:
        ground_truth = torch.zeros_like(logits, dtype=torch.bool)
        mask = ((data.edge_label_index[0] >= start) &
                (data.edge_label_index[0] < end))
        ground_truth[data.edge_label_index[0, mask] - start,
                     data.edge_label_index[1, mask] - num_users] = True
        node_count = degree(data.edge_label_index[0, mask] - start,
                            num_nodes=logits.size(0))

        topk_index = logits.topk(k, dim=-1).indices
        isin_mat = ground_truth.gather(1, topk_index)

        precision += float((isin_mat.sum(dim=-1) / k).sum())
        recall += float((isin_mat.sum(dim=-1) / node_count.clamp(1e-6)).sum())
        total_examples += int((node_count > 0).sum())

    return precision / total_examples, recall / total_examples


for epoch in range(1, 101):
    loss = train()
    precision, recall = test(k=20)
    print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Precision@20: '
          f'{precision:.4f}, Recall@20: {recall:.4f}')

In [None]:
# data = data.to(device)
# do we have a neg_sampling_ratio command somewhere?
