In [2]:
#main_train.py
import wandb
import pandas as pd
from loguru import logger
import torch
import torch.nn.functional as F
from datetime import datetime as dt
import os
from dateutil.relativedelta import relativedelta  # type: ignore
import functools

from process_data import *
from evaluate import *
from constants import *
from model import *

wandb.login()

start_date = dt.strptime("2021-10-24", "%Y-%m-%d").date()
end_date, nxt_start_date = split_date_by_period_months(start_date, TOTAL_MONTHS_PER_ITERATION)
print(start_date, end_date)
directory = "/Users/yhchan/Downloads/FYP/data/processed"
# reviews = pd.read_parquet(f"{directory}/reviews_with_interactions.parquet")
reviews = pd.read_parquet(f"eval_result/i.parquet")
listings = pd.read_parquet(f"{directory}/listings_with_interactions.parquet")

config = {
        "architecture": "Rating-Weighted GraphSAGE",
        "start_date": start_date,
        "end_date": end_date,
        "learning_rate": 0.01,
        "hidden_channels": 64,
        "train_batch_size": 128,
        "test_batch_size": 128,
        "epochs": 300,
        "train_num_neighbours": [10, 10],
        "test_num_neighbours": [-1],
        "train_split_period_months": 10,
        "total_months_of_data": TOTAL_MONTHS_PER_ITERATION,
    }

wandb.init(
    project=PROJECT_NAME,
    config=config,
)
wandb.define_metric("train_loss", step_metric="epoch", summary="min")
wandb.define_metric("test_loss", step_metric="epoch", summary="min")

(
    train_reviews,
    train_listings,
    train_reviewers,
    test_reviews,
    test_listings,
    test_reviewers,
) = main_train_test(
    reviews,
    listings,
    start_date,
    end_date,
    config["train_split_period_months"],
)

# Build Graph
involved_reviews = pd.concat([train_reviews, test_reviews])
involved_listings, involved_reviewers = build_partitioned_data(involved_reviews, listings)
involved_data = build_heterograph(involved_reviews, involved_listings, involved_reviewers, True)
train_data = build_heterograph(train_reviews, train_listings, train_reviewers, True)
test_data = build_heterograph(test_reviews, test_listings, test_reviewers, True)

print("Whole Graph", involved_data)
print("Training Heterogenous Graph", train_data)
print("Test Heterogenous Graph", test_data)

involved_listings2dict = get_entity2dict(involved_listings, "listing_id")
reverse_involved_listings2dict = {k: v for v, k in involved_listings2dict.items()}

metadata_dict = {
    "num_reviews": len(involved_reviews),
    "num_train_reviews": len(train_reviews),
    "num_test_reviews": len(test_reviews),
    
    "num_unique_listings": len(involved_listings),
    "num_unique_train_listings": len(train_listings),
    "num_unique_test_listings": len(test_listings),
    
    "num_unique_reviewers": len(involved_reviewers),
    "num_unique_train_reviewers": len(train_reviewers),
    "num_unique_test_reviewers": len(test_reviewers),

}

wandb.log(metadata_dict)
train_reviews.to_parquet("train/train_reviews.parquet", index=False)
train_listings.to_parquet("train/train_listings.parquet", index=False)
train_reviewers.to_parquet("train/train_reviewers.parquet", index=False)
test_reviews.to_parquet("test/test_reviews.parquet", index=False)
test_listings.to_parquet("test/test_listings.parquet", index=False)
test_reviewers.to_parquet("test/test_reviewers.parquet", index=False)

dataset_art = wandb.Artifact(f"{start_date}_{end_date}_data", type="dataset")
for dir in ["train", "test"]:
    dataset_art.add_dir(dir)
wandb.log_artifact(dataset_art)

2021-10-24 2022-10-23


2023-05-06 17:11:41.075 | INFO     | process_data:main_train_test:140 - Split df into train and test portion
  temp = torch.from_numpy(val).view(-1, 1).to(torch.float32)


Whole Graph HeteroData(
  [1mlisting[0m={ x=[18523, 158] },
  [1muser[0m={ x=[394551, 384] },
  [1m(user, rates, listing)[0m={
    edge_index=[2, 408596],
    edge_label=[408596],
    edge_label_index=[2, 408596]
  },
  [1m(listing, rev_rates, user)[0m={ edge_index=[2, 408596] }
)
Training Heterogenous Graph HeteroData(
  [1mlisting[0m={ x=[17229, 158] },
  [1muser[0m={ x=[324135, 384] },
  [1m(user, rates, listing)[0m={
    edge_index=[2, 334678],
    edge_label=[334678],
    edge_label_index=[2, 334678]
  },
  [1m(listing, rev_rates, user)[0m={ edge_index=[2, 334678] }
)
Test Heterogenous Graph HeteroData(
  [1mlisting[0m={ x=[14380, 158] },
  [1muser[0m={ x=[72447, 384] },
  [1m(user, rates, listing)[0m={
    edge_index=[2, 73918],
    edge_label=[73918],
    edge_label_index=[2, 73918]
  },
  [1m(listing, rev_rates, user)[0m={ edge_index=[2, 73918] }
)


[34m[1mwandb[0m: Adding directory to artifact (./train)... Done. 15.7s
[34m[1mwandb[0m: Adding directory to artifact (./test)... Done. 3.0s


<wandb.sdk.wandb_artifacts.Artifact at 0x2d7c6dd00>

Unnamed: 0,listing_id,id,rating,comments,localized_comments,response,localized_response,language,created_at,localized_date,...,comment_embedding_374,comment_embedding_375,comment_embedding_376,comment_embedding_377,comment_embedding_378,comment_embedding_379,comment_embedding_380,comment_embedding_381,comment_embedding_382,comment_embedding_383
327020,565363545634411496,700148910573345619,5,Cute interior and it’s clear they are working ...,,,,en,2022-08-24 00:30:54+00:00,August 2022,...,0.038409,-0.066551,-0.014902,0.010823,0.041533,0.011691,-0.009620,0.029076,-0.088270,0.102734
79084,1160513,700174478220621084,5,We had a wonderful four nights at the stables....,,,,en,2022-08-24 01:21:42+00:00,August 2022,...,0.032808,0.037328,-0.007237,0.013807,0.049898,-0.020697,-0.041794,0.070709,-0.065858,-0.021460
113978,18257779,700197391577040975,5,A beautiful spot that's very well appointed a...,,,,en,2022-08-24 02:07:13+00:00,August 2022,...,0.094430,-0.070190,-0.000054,-0.010329,0.025568,0.069839,-0.032196,0.121626,-0.076257,0.003784
169334,53943018,700212335835357732,5,We could not have been more pleased with our h...,,,,en,2022-08-24 02:36:55+00:00,August 2022,...,0.032651,0.074519,-0.092788,-0.018733,0.008588,0.100716,-0.013107,0.004675,-0.086565,0.011032
168345,45354179,700214006782856333,5,This place is incredible! Beautiful location a...,,We’re glad you had a great stay! We look forwa...,,en,2022-08-24 02:40:14+00:00,August 2022,...,0.044387,0.031949,-0.034708,-0.060949,0.036342,0.062206,-0.082926,-0.028832,-0.103534,0.014343
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
384193,14064796,743754797882360083,5,房子很乾淨，空間也很舒服，準備的東西也很齊全，唯一美中不足的是，離地鐵站有一段距離，需要走1...,"The house is very clean, the space is also ver...",,,zh-TW,2022-10-23 04:28:01+00:00,October 2022,...,0.063871,-0.013477,-0.029722,0.015655,0.053216,-0.017010,-0.039557,0.056494,-0.085848,-0.010447
294531,46286939,743757361545674933,5,The pictures don’t do this place justice. Perf...,,,,en,2022-10-23 04:33:06+00:00,October 2022,...,0.019919,-0.059019,0.029331,0.022194,0.075792,0.094105,-0.005719,-0.037350,-0.136996,0.003359
52977,51197659,743805310969690631,5,1) 프라이빗하고 귀여운 숙소 찾는다면 필수<br/>2) 한적한 산 향을 코 끝까지...,1) If you are looking for a private and cute p...,,,ko,2022-10-23 06:08:22+00:00,October 2022,...,0.076334,0.019734,-0.027781,-0.032406,0.052717,0.057844,0.030669,0.063396,-0.058295,-0.025689
294815,13660583,743878492821039786,5,Lovely and quiet place away from the main stri...,,,,en,2022-10-23 08:33:46+00:00,October 2022,...,0.033631,0.029964,-0.045013,0.038327,0.025526,0.100794,0.028082,0.017888,-0.078899,0.041347


In [None]:
# Modelling
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_data = train_data.to(device)
train_loader = prepare_data_loader(
    data=train_data,
    batch_size=config["train_batch_size"],
    num_neighbours=config["train_num_neighbours"],
)
test_loader = prepare_data_loader(
    data=test_data,
    batch_size=config["test_batch_size"],
    num_neighbours=config["test_num_neighbours"],
)
model = Model(hidden_channels=config["hidden_channels"], data=involved_data).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=config["learning_rate"])


# Train and Evaluate Loss
best_train_loss = float("inf")
best_test_loss = float("inf")
model_prefix = "./rating_weighted_models"
for epoch in range(1, config["epochs"] + 1):
    model_is_best = False
    train_loss = train(model, optimizer, train_loader, device)
    test_loss = test(test_loader, device, model)

    if train_loss < best_train_loss:
        best_train_loss = train_loss

    if test_loss < best_test_loss:
        best_test_loss = test_loss
        model_is_best = True
        
    metrics_dict = {
        "train_loss": train_loss,
        "test_loss": test_loss,
        "epoch": epoch,
    }
    wandb.log(metrics_dict)
    logger.info(
        f"Epoch: {epoch:03d}, Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f} "
    )
    
    model_path = f"{model_prefix}/{epoch}_model_state_dict.pt"
    torch.save(model.state_dict(), model_path)
    model_art = wandb.Artifact(f"{MODEL_NAME}_epoch_epoch", type="model")
    model_art.add_file(model_path)
    wandb.log_artifact(
        model_art,
        aliases=[
            "BEST",
        ]
        if model_is_best
        else None,
    )
        
logger.info("End of Training")

## Check ground truth data

In [9]:
test_reviews['reviewer_id'].value_counts().describe()

count    72447.000000
mean         1.020304
std          0.154582
min          1.000000
25%          1.000000
50%          1.000000
75%          1.000000
max          6.000000
Name: reviewer_id, dtype: float64

## Model Summary

In [17]:
from collections import defaultdict
from typing import Any, List, Optional, Union

import torch
from torch.jit import ScriptModule
from torch.nn import Module

from torch_geometric.nn.conv import MessagePassing
from torch_geometric.typing import SparseTensor


def summary(
    model: torch.nn.Module,
    *args,
    max_depth: int = 3,
    leaf_module: Optional[Union[Module, List[Module]]] = 'MessagePassing',
    **kwargs,
) -> str:
    r"""Summarizes a given :class:`torch.nn.Module`.
    The summarized information includes (1) layer names, (2) input and output
    shapes, and (3) the number of parameters.

    .. code-block:: python

        import torch
        from torch_geometric.nn import GCN, summary

        model = GCN(128, 64, num_layers=2, out_channels=32)
        x = torch.randn(100, 128)
        edge_index = torch.randint(100, size=(2, 20))

        print(summary(model, x, edge_index))

    .. code-block::

        +---------------------+---------------------+--------------+--------+
        | Layer               | Input Shape         | Output Shape | #Param |
        |---------------------+---------------------+--------------+--------|
        | GCN                 | [100, 128], [2, 20] | [100, 32]    | 10,336 |
        | ├─(act)ReLU         | [100, 64]           | [100, 64]    | --     |
        | ├─(convs)ModuleList | --                  | --           | 10,336 |
        | │    └─(0)GCNConv   | [100, 128], [2, 20] | [100, 64]    | 8,256  |
        | │    └─(1)GCNConv   | [100, 64], [2, 20]  | [100, 32]    | 2,080  |
        +---------------------+---------------------+--------------+--------+

    Args:
        model (torch.nn.Module): The model to summarize.
        *args: The arguments of the :obj:`model`.
        max_depth (int, optional): The depth of nested layers to display.
            Any layers deeper than this depth will not be displayed in the
            summary. (default: :obj:`3`)
        leaf_module (torch.nn.Module or [torch.nn.Module], optional): The
            modules to be treated as leaf modules, whose submodules are
            excluded from the summary.
            (default: :class:`~torch_geometric.nn.conv.MessagePassing`)
        **kwargs: Additional arguments of the :obj:`model`.
    """
    # NOTE This is just for the doc-string to render nicely:
    if leaf_module == 'MessagePassing':
        leaf_module = MessagePassing

    def register_hook(info):
        def hook(module, inputs, output):
            info['input_shape'].append(get_shape(inputs))
            info['output_shape'].append(get_shape(output))

        return hook

    hooks = {}
    depth = 0
    stack = [(model.__class__.__name__, model, depth)]

    info_list = []
    input_shape = defaultdict(list)
    output_shape = defaultdict(list)
    while stack:
        name, module, depth = stack.pop()
        module_id = id(module)

        if name.startswith('(_'):  # Do not summarize private modules.
            continue

        if module_id in hooks:  # Avoid duplicated hooks.
            hooks[module_id].remove()

        info = {}
        info['name'] = name
        info['input_shape'] = input_shape[module_id]
        info['output_shape'] = output_shape[module_id]
        info['depth'] = depth
        num_params = sum(p.numel() for p in module.parameters())
        info['#param'] = f'{num_params:,}' if num_params > 0 else '--'
        info_list.append(info)

        if not isinstance(module, ScriptModule):
            hooks[module_id] = module.register_forward_hook(
                register_hook(info))

        if depth >= max_depth:
            continue

        if (leaf_module is not None and isinstance(module, leaf_module)):
            continue

        module_items = reversed(module._modules.items())
        stack += [(f"({name}){mod.__class__.__name__}", mod, depth + 1)
                  for name, mod in module_items if mod is not None]

    training = model.training
    model.eval()

    with torch.no_grad():
        model(*args, **kwargs)

    model.train(training)

    for h in hooks.values():  # Remove hooks.
        h.remove()

    info_list = postprocess(info_list)
    return make_table(info_list, max_depth=max_depth)


def get_shape(inputs: Any) -> str:
    if not isinstance(inputs, (tuple, list)):
        inputs = (inputs, )

    out = []
    for x in inputs:
        if isinstance(x, SparseTensor):
            out.append(str(list(x.sizes())))
        elif hasattr(x, 'size'):
            out.append(str(list(x.size())))
    return ', '.join(out)


def postprocess(info_list: List[dict]) -> List[dict]:
    for idx, info in enumerate(info_list):
        depth = info['depth']
        if idx > 0:  # root module (0) is exclued
            if depth == 1:
                prefix = '├─'
            else:
                prefix = f"{'│    '*(depth-1)}└─"
            info['name'] = prefix + info['name']

        if info['input_shape']:
            info['input_shape'] = info['input_shape'].pop(0)
            info['output_shape'] = info['output_shape'].pop(0)
        else:
            info['input_shape'] = '--'
            info['output_shape'] = '--'
    return info_list


def make_table(info_list: List[dict], max_depth: int) -> str:
    from tabulate import tabulate
    content = [['Layer', 'Input Shape', 'Output Shape', '#Param']]
    for info in info_list:
        content.append([
            info['name'],
            info['input_shape'],
            info['output_shape'],
            info['#param'],
        ])
    return tabulate(content, headers='firstrow', tablefmt='psql')

In [25]:
len(batch)

6

In [26]:
train_data

HeteroData(
  [1mlisting[0m={ x=[17229, 158] },
  [1muser[0m={ x=[324135, 384] },
  [1m(user, rates, listing)[0m={
    edge_index=[2, 334678],
    edge_label=[334678],
    edge_label_index=[2, 334678]
  },
  [1m(listing, rev_rates, user)[0m={ edge_index=[2, 334678] }
)

In [24]:
batch = next(iter(train_loader))
with torch.no_grad():  # Initialize lazy modules.
    out = model(batch.x_dict, batch.edge_index_dict)
print(model)

print(summary(model, batch.x_dict, batch.edge_index_dict))

Model(
  (encoder): GraphModule(
    (conv1): ModuleDict(
      (user__rates__listing): SAGEConv((-1, -1), 64, aggr=mean)
      (listing__rev_rates__user): SAGEConv((-1, -1), 64, aggr=mean)
    )
    (conv2): ModuleDict(
      (user__rates__listing): SAGEConv((-1, -1), 64, aggr=mean)
      (listing__rev_rates__user): SAGEConv((-1, -1), 64, aggr=mean)
    )
  )
)
+------------------------------------------------+---------------+----------------+----------+
| Layer                                          | Input Shape   | Output Shape   | #Param   |
|------------------------------------------------+---------------+----------------+----------|
| Model                                          |               |                | 86,016   |
| ├─(encoder)GraphModule                         |               |                | 86,016   |
| │    └─(conv1)ModuleDict                       | --            | --             | 69,504   |
| │    │    └─(user__rates__listing)SAGEConv     | [2, 919]      