Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reproducing ExcelFormer #46

Merged
merged 13 commits into from
Sep 22, 2023
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added `ExcelFormer` example. ([#46](https://github.com/pyg-team/pytorch-frame/pull/46))
- Support inductive `DataFrame` to `TensorFrame` transformation ([#75](https://github.com/pyg-team/pytorch-frame/pull/75))
- Added `CatBoost` baseline and tuned `CatBoost` example. ([#73](https://github.com/pyg-team/pytorch-frame/pull/73))
- Added `na_strategy` as argument in `StypeEncoder`. ([#69](https://github.com/pyg-team/pytorch-frame/pull/69))
Expand Down
171 changes: 171 additions & 0 deletions examples/excelformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
"""
Reported (reproduced) accuracy(rmse for regression task) of ExcelFormer based
on Table 1 of the paper. https://arxiv.org/pdf/2301.02819.pdf
ExcelFormer uses the same train-validation-test split as the Yandex paper.

california: 0.4587 (0.4733) num_layers=5, num_heads=4, num_layers=5,
channels=32, lr: 0.001,
jannis : 72.51 (72.38) num_heads=32, lr: 0.0001
covtype: 97.17 (95.37)
helena: 38.20 (36.80)
higgs_small: 80.75 (65.17) lr: 0.0001
"""
import argparse
import os.path as osp

import torch
import torch.nn.functional as F
from torch.optim.lr_scheduler import ExponentialLR
from tqdm import tqdm

from torch_frame.data.loader import DataLoader
from torch_frame.datasets.yandex import Yandex
from torch_frame.nn import ExcelFormer
from torch_frame.transforms import (
CategoricalCatBoostEncoder,
MutualInformationSort,
)

parser = argparse.ArgumentParser()
parser.add_argument('--dataset', type=str, default='higgs_small')
parser.add_argument('--channels', type=int, default=256)
parser.add_argument('--batch_size', type=int, default=512)
parser.add_argument('--num_heads', type=int, default=4)
parser.add_argument('--num_layers', type=int, default=5)
parser.add_argument('--lr', type=float, default=0.001)
parser.add_argument('--epochs', type=int, default=100)
parser.add_argument('--mixup', type=bool, default=True)
parser.add_argument('--beta', type=float, default=0.5)
args = parser.parse_args()

if torch.cuda.is_available():
device = torch.device('cuda')
else:
device = torch.device('cpu')

path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data',
args.dataset)
dataset = Yandex(root=path, name=args.dataset)
dataset.materialize()
train_dataset = dataset.get_split_dataset('train')
val_dataset = dataset.get_split_dataset('val')
test_dataset = dataset.get_split_dataset('test')
train_tensor_frame = train_dataset.tensor_frame.to(device)
val_tensor_frame = val_dataset.tensor_frame.to(device)
test_tensor_frame = test_dataset.tensor_frame.to(device)

# CategoricalCatBoostEncoder encodes the categorical features
# into numerical features with CatBoostEncoder.
categorical_transform = CategoricalCatBoostEncoder()
categorical_transform.fit(train_dataset.tensor_frame, train_dataset.col_stats)

train_tensor_frame = categorical_transform(train_tensor_frame)
val_tensor_frame = categorical_transform(val_tensor_frame)
test_tensor_frame = categorical_transform(test_tensor_frame)
col_stats = categorical_transform.transformed_stats

# MutualInformationSort sorts the features based on mutual
# information.
mutual_info_sort = MutualInformationSort(task_type=dataset.task_type)
yiweny marked this conversation as resolved.
Show resolved Hide resolved

mutual_info_sort.fit(train_tensor_frame, col_stats)
train_tensor_frame = mutual_info_sort(train_tensor_frame)
val_tensor_frame = mutual_info_sort(val_tensor_frame)
test_tensor_frame = mutual_info_sort(test_tensor_frame)

train_loader = DataLoader(train_tensor_frame, batch_size=args.batch_size,
shuffle=True)
val_loader = DataLoader(val_tensor_frame, batch_size=args.batch_size)
test_loader = DataLoader(test_tensor_frame, batch_size=args.batch_size)

is_classification = dataset.task_type.is_classification

if is_classification:
out_channels = dataset.num_classes
else:
out_channels = 1

model = ExcelFormer(
in_channels=args.channels,
out_channels=out_channels,
num_layers=args.num_layers,
num_cols=len(dataset.col_to_stype) - 1,
num_heads=args.num_heads,
residual_dropout=0.,
diam_dropout=0.3,
aium_dropout=0.,
col_stats=mutual_info_sort.transformed_stats,
col_names_dict=train_tensor_frame.col_names_dict,
).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
lr_scheduler = ExponentialLR(optimizer, gamma=0.95)


def train(epoch: int) -> float:
model.train()
loss_accum = total_count = 0

for tf in tqdm(train_loader, desc=f'Epoch: {epoch}'):
pred_mixedup, y_mixedup = model.forward_mixup(tf)
if is_classification:
loss = F.cross_entropy(pred_mixedup, y_mixedup)
else:
loss = F.mse_loss(pred_mixedup.view(-1), y_mixedup.view(-1))
optimizer.zero_grad()
loss.backward()
loss_accum += float(loss) * len(y_mixedup)
total_count += len(y_mixedup)
optimizer.step()
return loss_accum / total_count


@torch.no_grad()
def test(loader: DataLoader) -> float:
model.eval()
accum = total_count = 0

for tf in loader:
pred = model(tf)
if is_classification:
pred_class = pred.argmax(dim=-1)
accum += float((tf.y == pred_class).sum())
else:
accum += float(
F.mse_loss(pred.view(-1), tf.y.view(-1), reduction='sum'))
total_count += len(tf.y)

if is_classification:
accuracy = accum / total_count
return accuracy
else:
rmse = (accum / total_count)**0.5
return rmse


if is_classification:
metric = 'Acc'
best_val_metric = 0
best_test_metric = 0
else:
metric = 'RMSE'
best_val_metric = float('inf')
best_test_metric = float('inf')

for epoch in range(1, args.epochs + 1):
train_loss = train(epoch)
train_metric = test(train_loader)
val_metric = test(val_loader)
test_metric = test(test_loader)

if is_classification and val_metric > best_val_metric:
best_val_metric = val_metric
best_test_metric = test_metric
elif not is_classification and val_metric < best_val_metric:
best_val_metric = val_metric
best_test_metric = test_metric

print(f'Train Loss: {train_loss:.4f}, Train {metric}: {train_metric:.4f}, '
f'Val {metric}: {val_metric:.4f}, Test {metric}: {test_metric:.4f}')

print(f'Best Val {metric}: {best_val_metric:.4f}, '
f'Best Test {metric}: {best_test_metric:.4f}')
49 changes: 36 additions & 13 deletions test/nn/models/test_excelformer.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,51 @@
import copy

import pytest
import torch

from torch_frame import TaskType
from torch_frame.data.dataset import Dataset
from torch_frame.datasets.fake import FakeDataset
from torch_frame.nn import ExcelFormer
from torch_frame.stype import stype


def test_excelformer():
@pytest.mark.parametrize('task_type', [
TaskType.REGRESSION, TaskType.BINARY_CLASSIFICATION,
TaskType.MULTICLASS_CLASSIFICATION
])
def test_excelformer(task_type):
batch_size = 10
in_channels = 8
out_channels = 1
num_heads = 2
num_layers = 6
dataset: Dataset = FakeDataset(num_rows=10, with_nan=False,
stypes=[stype.numerical])
stypes=[stype.numerical],
task_type=task_type)
dataset.materialize()
if task_type.is_classification:
out_channels = dataset.num_classes
else:
out_channels = 1
num_cols = len(dataset.col_stats) - 1
tensor_frame = dataset.tensor_frame
model = ExcelFormer(
in_channels=in_channels,
out_channels=out_channels,
num_cols=num_cols,
num_layers=num_layers,
num_heads=num_heads,
col_stats=dataset.col_stats,
col_names_dict=tensor_frame.col_names_dict,
)
model = ExcelFormer(in_channels=in_channels, out_channels=out_channels,
num_cols=num_cols, num_layers=num_layers,
num_heads=num_heads, col_stats=dataset.col_stats,
col_names_dict=tensor_frame.col_names_dict)

# Test the original forward pass
out = model(tensor_frame)
assert out.shape == (batch_size, num_cols, in_channels)
assert out.shape == (batch_size, out_channels)

# Test the mixup forward pass
x_num = copy.copy(tensor_frame.x_dict[stype.numerical])
out_mixedup, y_mixedup = model.forward_mixup(tensor_frame)
assert out_mixedup.shape == (batch_size, out_channels)
# Make sure the numerical feature is not modified.
assert torch.allclose(x_num, tensor_frame.x_dict[stype.numerical])

if task_type.is_classification:
assert y_mixedup.shape == (batch_size, out_channels)
else:
assert y_mixedup.shape == tensor_frame.y.shape
112 changes: 106 additions & 6 deletions torch_frame/nn/models/excelformer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import Any, Dict, List
from typing import Any, Dict, List, Optional, Tuple, Union

import torch
import torch.nn.functional as F
from torch import Tensor
from torch.nn import Module, ModuleList
from torch.nn.modules.module import Module
Expand All @@ -14,6 +16,52 @@
from torch_frame.nn.encoder.stypewise_encoder import StypeWiseFeatureEncoder


def feature_mixup(
x: Tensor,
y: Tensor,
num_classes: Optional[int] = None,
beta: int = 0.5,
) -> TensorFrame:
r"""Mixup :obj: input numerical feature tensor `x` by swaping some feature
elements of two shuffled sample samples. The shuffle rates for each row is
sampled from the Beta distribution. The target `y` is also linearly
mixed up.

Args:
x (Tensor): The input numerical feature.
y (Tensor): The target.
num_classes (int, optional): Number of classes. Needs to be given in
the case of classification tasks. (default: :obj:`None`)
beta (float): The concentration parameter of the Beta distribution.

Returns:
x_mixedup (Tensor): The mixedup numerical feature.
y_mixedup (Tensor): Transformed target.
[batch_size, num_classes] for classification and
[batch_size, 1] for regression.
"""
beta = torch.tensor(beta, device=x.device)
beta_distribution = torch.distributions.beta.Beta(beta, beta)
shuffle_rates = beta_distribution.sample((len(x), 1))
feat_masks = torch.rand(x.shape, device=x.device) < shuffle_rates
shuffled_idx = torch.randperm(len(x), device=x.device)
x_mixedup = feat_masks * x + ~feat_masks * x[shuffled_idx]

y_shuffled = y[shuffled_idx]
if y.is_floating_point():
# Regression task
shuffle_rates = shuffle_rates.view(-1, )
y_mixedup = shuffle_rates * y + (1 - shuffle_rates) * y_shuffled
else:
# Classification task
assert num_classes is not None
one_hot_y = F.one_hot(y, num_classes=num_classes)
one_hot_y_shuffled = F.one_hot(y_shuffled, num_classes=num_classes)
y_mixedup = (shuffle_rates * one_hot_y +
(1 - shuffle_rates) * one_hot_y_shuffled)
return x_mixedup, y_mixedup


class ExcelFormer(Module):
r"""The ExcelFormer model introduced in
https://arxiv.org/pdf/2301.02819.pdf
Expand Down Expand Up @@ -75,16 +123,68 @@ def reset_parameters(self):
self.excelformer_decoder.reset_parameters()

def forward(self, tf: TensorFrame) -> Tensor:
r"""Transforming :obj:`TensorFrame` object into
output predictions.
r"""Transform :obj:`TensorFrame` object into output embeddings.

Args:
tf (TensorFrame): Input :obj:TensorFrame object.
tf (TensorFrame): Input :obj:`TensorFrame` object.

Returns:
x (Tensor): [batch_size, num_cols, out_channels].
out (Tensor): The output embeddings of size
[batch_size, out_channels].
"""
if stype.numerical not in tf.x_dict or len(
tf.x_dict[stype.numerical]) == 0:
raise ValueError(
"Excelformer only takes in numerical features, but the input "
"TensorFrame object does not have numerical features.")
x, _ = self.excelformer_encoder(tf)
for excelformer_conv in self.excelformer_convs:
x = excelformer_conv(x)
return x
out = self.excelformer_decoder(x)
return out

def forward_mixup(
self,
tf: TensorFrame,
beta: Optional[float] = 0.5,
) -> Union[Tensor, Tuple[Tensor, Tensor]]:
r"""Transform :obj:`TensorFrame` object into output embeddings. If
`mixup` is :obj:`True`, it produces the output embeddings together with
the mixed-up targets.

Args:
tf (TensorFrame): Input :obj:`TensorFrame` object.
beta (float, optional): Shape parameter for beta distribution to
calculate shuffle rate in mixup. Only useful when mixup is
true. (default: 0.5)

Returns:
out_mixedup (Tensor): The mixed up output embeddings of size
[batch_size, out_channels].
y_mixedup (Tensor): Output :obj:`Tensor` y_mixedup will be
returned only when mixup is set to true. The size is
[batch_size, num_classes] for classification and
[batch_size, 1] for regression.
"""
# Mixup numerical features
x_mixedup, y_mixedup = feature_mixup(
tf.x_dict[stype.numerical],
tf.y,
num_classes=self.out_channels,
beta=beta,
)

# Create a new `x_dict`, where stype.numerical is swapped with
# mixed up feature.
x_dict: Dict[stype, Tensor] = {}
for stype_name, x in tf.x_dict.items():
if stype_name == stype.numerical:
x_dict[stype_name] = x_mixedup
else:
x_dict[stype_name] = x
tf_mixedup = TensorFrame(x_dict, tf.col_names_dict, tf.y)

# Call Excelformer forward function
out_mixedup = self(tf_mixedup)

return out_mixedup, y_mixedup
1 change: 1 addition & 0 deletions torch_frame/transforms/categorical_catboost_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def _fit(self, tf_train: TensorFrame, col_stats: Dict[str, Dict[StatType,
logging.info(
"The input TensorFrame does not contain any categorical "
"columns. No fitting will be performed.")
self._transformed_stats = col_stats
return
# TODO: Implement the CatBoostEncoder with Pytorch rather than relying
# on external library.
Expand Down
Loading
Loading