## 출처: https://github.com/jrfiedler/xynn/blob/main/xynn/autoint/estimators.py

## load library

In [86]:
import textwrap
from typing import Type, Union, Callable, Tuple, List, Optional, Dict, Any, Iterable
from abc import ABCMeta, abstractmethod
import random
import zlib
import requests
from pathlib import Path
from collections import namedtuple
from tqdm.auto import tqdm
from scipy.special import expit

import sys

import time
import datetime

from sklearn.model_selection import train_test_split

import torch
from torch import nn
from torch import Tensor
from torch.optim.optimizer import Optimizer
from torch.optim.lr_scheduler import _LRScheduler, ReduceLROnPlateau
from torch.utils.data import DataLoader
from torch.nn.functional import softmax
from torch.nn import functional as F

try:
    import pytorch_lightning as pl
except ImportError:
    HAS_PL = False
else:
    HAS_PL = True

import numpy as np
import pandas as pd

## create seed

In [4]:
SEED=14

In [5]:
def _set_seed(seed):
    # https://discuss.pytorch.org/t/reproducibility-with-all-the-bells-and-whistles
    random.seed(seed)

    np.random.seed(seed)

    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

## create base module

In [6]:
class EmbeddingBase(nn.Module, metaclass=ABCMeta):
    """
    Base class for embeddings
    """

    def __init__(self):
        super().__init__()
        self._isfit = False

    @abstractmethod
    def _fit_array(self, X):
        return

    @abstractmethod
    def _fit_iterable(self, X):
        return

    def fit(self, X) -> "EmbeddingBase":
        """
        Create the embedding from training data
        Parameters
        ----------
        X : array-like or iterable of array-like
            should be a PyTorch Tensor, NumPy array, Pandas DataFrame
            or iterable of arrays/tensors (i.e., batches)
        Return
        ------
        self
        """
        if isinstance(X, (np.ndarray, Tensor, pd.DataFrame)):
            self._fit_array(X)
        elif isinstance(X, DataLoader):
            self._fit_iterable(X)
        else:
            raise TypeError(
                "input X must be a PyTorch Tensor, PyTorch DataLoader, "
                "NumPy array, or Pandas DataFrame"
            )

        self._isfit = True

        return self

In [7]:
class UniformBase(EmbeddingBase):
    """Base class for embeddings that have a single vector size for all fields"""

    def weight_sum(self) -> Tuple[Tensor, Tensor]:
        """
        Sum of absolute value and square of embedding weights
        Return
        ------
        e1_sum : sum of absolute value of embedding values
        e2_sum : sum of squared embedding values
        """
        if not self._isfit:
            return 0.0, 0.0
        e1_sum = self.embedding.weight.abs().sum()
        e2_sum = (self.embedding.weight ** 2).sum()
        return e1_sum, e2_sum

In [8]:
class LinearEmbedding(UniformBase):
    """
    An embedding for numeric fields. There is one embedded vector for each field.
    The embedded vector for a value is that value times its field's vector.
    """

    def __init__(self, embedding_size: int = 10, device: Union[str, torch.device] = "cpu"):
        """
        Parameters
        ----------
        embedding_size : int, optional
            size of each value's embedding vector; default is 10
        device : string or torch.device
        """
        super().__init__()
        self.num_fields = 0
        self.output_size = 0
        self.embedding: Optional[nn.Embedding] = None
        self.embedding_size = embedding_size
        self._device = device
        self.to(device)
        self._isfit = False

    def __repr__(self):
        return f"LinearEmbedding({self.embedding_size}, {repr(self._device)})"

    def from_summary(self, num_fields: int):
        """
        Create the embedding for the given number of fields
        Parameters
        ----------
        num_fields : int
        Return
        ------
        self
        """
        self.num_fields = num_fields
        self.output_size = num_fields * self.embedding_size
        self.embedding = nn.Embedding(num_fields, self.embedding_size).to(device=self._device)
        nn.init.xavier_uniform_(self.embedding.weight)

        self._isfit = True

        return self

    def _fit_array(self, X):
        self.from_summary(X.shape[1])

    def _fit_iterable(self, X):
        for batch in X:
            self._fit_array(batch)
            break

    def forward(self, X: Tensor) -> Tensor:
        """
        Produce embedding for each value in input
        Parameters
        ----------
        X : torch.Tensor
        Return
        ------
        torch.Tensor
        """
        if not self._isfit:
            raise RuntimeError("need to call `fit` or `from_summary` first")
        return self.embedding.weight * X.unsqueeze(dim=-1)

In [9]:
class BaseEstimator(metaclass=ABCMeta):
    """
    Base class for Scikit-learn style classes in this package
    """

    def __init__(
        self,
        embedding_num: Optional[Union[str, EmbeddingBase]] = "auto",
        embedding_cat: Optional[Union[str, EmbeddingBase]] = "auto",
        loss_fn: Union[str, Callable] = "auto",
        seed: Union[int, None] = None,
        device: Union[str, torch.device] = "cpu",
        model_kwargs: Optional[Dict[str, Any]] = None,
    ):
        self.task = ""
        self.embedding_num = embedding_num
        self.embedding_cat = embedding_cat
        self.model_kwargs = model_kwargs if model_kwargs else {}
        self.seed = seed
        self.train_info = []
        self._device = torch.device(device)
        self._model = None
        self._model_class: Optional[Type[BaseNN]] = None
        self._num_numeric_fields = 0
        self._num_categorical_fields = 0
        if seed is not None:
            _set_seed(seed)

        # record init parameters, mostly for logging
        init_params_bef = {
            "embedding_num": embedding_num,
            "embedding_cat": embedding_cat,
        }
        init_params_aft = {
            "loss_fn": loss_fn,
            "seed": seed,
            "device": device
        }
        self.init_parameters = {
            key: val
            for params in (init_params_bef, model_kwargs, init_params_aft)
            for key, val in params.items()
        }

    def __repr__(self):
        init_params = ",\n    ".join(
            f"{key}={_param_repr(val)}" for key, val in self.init_parameters.items()
        )
        repr_str = f"{self.__class__.__name__}(\n    {init_params},\n)"
        return repr_str

    def mlp_weight_sum(self) -> Tuple[Tensor, Tensor]:
        """
        Sum of absolute value and square of weights in MLP layers
        Return
        ------
        w1 : sum of absolute value of MLP weights
        w2 : sum of squared MLP weights
        """
        if self._model:
            return self._model.mlp_weight_sum()
        return torch.tensor([0.0]), torch.tensor([0.0])

    def embedding_sum(self) -> Tuple[Tensor, Tensor]:
        """
        Sum of absolute value and square of embedding values
        Return
        ------
        e1_sum : sum of absolute value of embedding values
        e2_sum : sum of squared embedding values
        """
        if self._model:
            return self._model.embedding_sum()
        return torch.tensor([0.0]), torch.tensor([0.0])

    def num_parameters(self) -> int:
        """
        Number of trainable parameters in the model
        Return
        ------
        int number of trainable parameters
        """
        if self._model:
            return self._model.num_parameters()
        return 0

    def _optimizer_init(self, optimizer, opt_kwargs, scheduler, sch_kwargs):
        self._model.set_optimizer(
            optimizer=optimizer,
            opt_kwargs=opt_kwargs,
            scheduler=scheduler,
            sch_kwargs=sch_kwargs,
        )
        self._model.configure_optimizers()

    def _create_embeddings(self, X_num, X_cat):
        # numeric embedding
        if X_num.shape[1]:
            if self.embedding_num is None:
                if self._require_numeric_embedding:
                    raise ValueError(
                        "embedding_num was set to None; "
                        f"expected zero numeric columns, got {X_num.shape[1]}"
                    )
            elif isinstance(self.embedding_num, EmbeddingBase):
                if not self.embedding_num._isfit:
                    self.embedding_num.fit(X_num)
            else: # initialized with embedding_num = "auto"
                self.embedding_num = LinearEmbedding(device=self._device)
                self.embedding_num.fit(X_num)
        else:
            self.embedding_num = None

        # categorical embedding
        if X_cat.shape[1]:
            if self.embedding_cat is None:
                raise ValueError(
                    "embedding_cat was set to None; "
                    f"expected zero categorical columns, got {X_cat.shape[1]}"
                )
            elif isinstance(self.embedding_cat, EmbeddingBase):
                if not self.embedding_cat._isfit:
                    self.embedding_cat.fit(X_cat)
            else:  # initialized with embedding_cat = "auto"
                self.embedding_cat = DefaultEmbedding(device=self._device)
                self.embedding_cat.fit(X_cat)
        else:
            self.embedding_cat = None

    @abstractmethod
    def _create_model(self, embedding_num, embedding_cat):
        return

    @abstractmethod
    def _fit_init(self, X_num, X_cat, y, warm_start=False):
        return X_num, X_cat, y

    def _convert_x(self, X_num, X_cat, y=None) -> Tuple[Tensor, Union[Tensor, np.ndarray]]:
        if X_num is None and X_cat is None:
            raise TypeError("X_num and X_cat cannot both be None")

        if X_num is None:
            X_num = torch.empty((X_cat.shape[0], 0))
            self._num_numeric_fields = 0
        else:
            self._num_numeric_fields = X_num.shape[1]
            if isinstance(X_num, np.ndarray):
                X_num = torch.from_numpy(X_num).to(dtype=torch.float32)

        if X_cat is None:
            X_cat = torch.empty((X_num.shape[0], 0))
            self._num_categorical_fields = 0
        else:
            self._num_categorical_fields = X_cat.shape[1]

        if X_num.shape[0] != X_cat.shape[0]:
            raise ValueError(
                f"mismatch in shapes for X_num {X_num.shape}, X_cat {X_cat.shape}"
            )
        if y is not None and X_num.shape[0] != y.shape[0]:
            raise ValueError(
                f"mismatch in shapes for X_num {X_num.shape}, "
                f"X_cat {X_cat.shape}, y {y.shape}"
            )

        return X_num, X_cat

    @abstractmethod
    def _convert_y(self, y):
        return y

    def _convert_xy(self, X_num, X_cat, y):
        X_num, X_cat = self._convert_x(X_num, X_cat, y)
        y = self._convert_y(y)
        return X_num, X_cat, y

    def fit(
        self,
        X_num: Optional[Union[Tensor, np.ndarray]],
        X_cat: Optional[Union[Tensor, np.ndarray]],
        y: Union[Tensor, np.ndarray],
        optimizer: Callable,
        opt_kwargs: Optional[Dict[str, Any]] = None,
        scheduler: Optional[Callable] = None,
        sch_kwargs: Optional[Dict[str, Any]] = None,
        val_sets: Optional[List[Tuple[Tensor, Tensor, Tensor]]] = None,
        num_epochs: int = 5,
        batch_size: int = 128,
        warm_start: bool = False,
        extra_metrics: Optional[List[Tuple[str, Callable]]] = None,
        early_stopping_metric: str = "val_loss",
        early_stopping_patience: Union[int, float] = float("inf"),
        early_stopping_mode: str = "min",
        early_stopping_window: int = 1,
        shuffle: bool = True,
        log_path: Optional[str] = None,
        param_path: Optional[str] = None,
        callback: Optional[Callable] = None,
        verbose: bool = False,
    ):
        """
        Fit the model to the training data
        Parameters
        ----------
        X_num : torch.Tensor, numpy.ndarray, or None
        X_cat : torch.Tensor, numpy.ndarray, or None
        y : torch.Tensor or numpy.ndarray
        optimizer : PyTorch Optimizer class, optional
            uninitialized subclass of Optimizer; default is `torch.optim.Adam`
        opt_kwargs : dict or None, optional
            dict of keyword arguments to initialize optimizer with;
            default is None
        scheduler : PyTorch scheduler class, optional
            example: `torch.optim.lr_scheduler.ReduceLROnPlateau`
            default is None
        sch_kwargs : dict or None, optional
            dict of keyword arguments to initialize scheduler with;
            default is None
        val_sets : list of tuples, or None; optional
            each tuple should be (X_num, X_cat, y) validation data;
            default is None
        num_epochs : int, optional
            default is 5
        batch_size : int, optional
            default is 128
        warm_start : boolean, optional
            whether to re-create the model before fitting (warm_start == False),
            or refine the training (warm_start == True); default is False
        extra_metrics : list of (str, callable) tuples or None, optional
            default is None
        early_stopping_metric : str, optional
            should be "val_loss" or one of the passed `extra_metrics`;
            default is "val_loss"
        early_stopping_patience : int, float; optional
            default is float("inf") (no early stopping)
        early_stopping_mode : {"min", "max"}, optional
            use "min" if smaller values are better; default is "min"
        early_stopping_window : int, optional
            number of consecutive epochs to average to determine best;
            default is 1
        shuffle : boolean, optional
            default is True
        log_path : str or None, optional
            filename to save output to; default is None
        param_path : str or None, optional
            specify this to have the best parameters reloaded at end of training;
            default is None
        callback : callable or None, optional
            function to call after each epoch; the function will be passed a list
            of dictionaries, one dictionary for each epoch; default is None
        verbose : boolean, optional
            default is False
        """
        time_start = now()

        X_num, X_cat, y = self._fit_init(X_num, X_cat, y, warm_start)
        self._optimizer_init(optimizer, opt_kwargs, scheduler, sch_kwargs)

        train_dl = TabularDataLoader(
            task=self.task,
            X_num=X_num,
            X_cat=X_cat,
            y=y,
            batch_size=batch_size,
            shuffle=shuffle,
            device=self._device,
        )

        if val_sets is not None:
            valid_dl = [
                TabularDataLoader(
                    self.task,
                    *self._convert_x(*val_set),
                    y=self._convert_y(val_set[-1]),
                    batch_size=batch_size,
                    shuffle=False,
                    device=self._device,
                )
                for val_set in val_sets
            ]
        else:
            valid_dl = None

        train_info = train(
            self._model,
            train_data=train_dl,
            val_data=valid_dl,
            num_epochs=num_epochs,
            max_grad_norm=float("inf"),
            extra_metrics=extra_metrics,
            early_stopping_metric=early_stopping_metric,
            early_stopping_patience=early_stopping_patience,
            early_stopping_mode=early_stopping_mode,
            early_stopping_window=early_stopping_window,
            param_path=param_path,
            callback=callback,
            verbose=verbose,
        )

        if warm_start:
            self.train_info.extend(train_info)
        else:
            self.train_info = train_info

        if log_path:
            info = {
                "init_parameters": {
                    key: _param_json(val) for key, val in self.init_parameters.items()
                },
                "fit_parameters": {
                    "optimizer": str(optimizer.__name__),
                    "opt_kwargs": opt_kwargs,
                    "scheduler": str(scheduler.__name__) if scheduler is not None else None,
                    "sch_kwargs": sch_kwargs,
                    "num_epochs": num_epochs,
                    "batch_size": batch_size,
                    "extra_metrics": [x[0] for x in extra_metrics] if extra_metrics else None,
                    "early_stopping_metric": early_stopping_metric,
                    "early_stopping_patience": early_stopping_patience,
                    "early_stopping_mode": early_stopping_mode,
                    "shuffle": shuffle,
                },
                "num_parameters": self.num_parameters(),
                "time_start": time_start,
                "train_info": self.train_info,
                "time_end": now(),
            }
            _log(info, log_path)

In [10]:
class BaseClassifier(BaseEstimator):
    """
    Base class for Scikit-learn style classification classes in this package
    """

    def __init__(
        self,
        embedding_num: Optional[Union[str, EmbeddingBase]] = "auto",
        embedding_cat: Optional[Union[str, EmbeddingBase]] = "auto",
        loss_fn: Union[str, Callable] = "auto",
        seed: Union[int, None] = None,
        device: Union[str, torch.device] = "cpu",
        **model_kwargs,
    ):
        super().__init__(
            embedding_num=embedding_num,
            embedding_cat=embedding_cat,
            seed=seed,
            device=device,
            model_kwargs=model_kwargs,
        )
        self.task = "classification"
        self.loss_fn = nn.CrossEntropyLoss() if loss_fn == "auto" else loss_fn
        self.classes = {}

    def _create_model(self):
        self._model = self._model_class(
            task="classification",
            output_size=len(self.classes),
            embedding_num=self.embedding_num,
            embedding_cat=self.embedding_cat,
            loss_fn=self.loss_fn,
            device=self._device,
            **self.model_kwargs
        )

    def _convert_y(self, y) -> Tensor:
        if len(y.shape) == 1:
            y = y.reshape((-1, 1))
        y = torch.tensor([self.classes[yval[0].item()] for yval in y])
        return y

    def _fit_init(self, X_num, X_cat, y, warm_start=False):
        if self._model is None or not warm_start:
            self.classes = {old : new for new, old in enumerate(np.unique(y))}
        X_num, X_cat, y = self._convert_xy(X_num, X_cat, y)
        if self._model is None or not warm_start:
            self._create_embeddings(X_num, X_cat)
            self._create_model()
        return X_num, X_cat, y

    def predict_logits(self, X_num, X_cat):
        """
        Calculate class logits
        Parameters
        ----------
        X_num : torch.Tensor, numpy.ndarray, or None
        X_cat : torch.Tensor, numpy.ndarray, or None
        Return
        ------
        torch.Tensor
        """
        if not self._model:
            raise RuntimeError("you need to fit the model first")
        X_num, X_cat = self._convert_x(X_num, X_cat)
        X_num = X_num.to(device=self._device)
        X_cat = X_cat.to(device=self._device)
        self._model.eval()
        with torch.no_grad():
            raw = self._model(X_num, X_cat)
        return raw

    def predict(self, X_num, X_cat):
        """
        Calculate class predictions
        Parameters
        ----------
        X_num : torch.Tensor, numpy.ndarray, or None
        X_cat : torch.Tensor, numpy.ndarray, or None
        Return
        ------
        torch.Tensor
        """
        if not self._model:
            raise RuntimeError("you need to fit the model first")
        class_inverse = {v: k for k, v in self.classes.items()}
        raw = self.predict_logits(X_num, X_cat)
        preds = torch.argmax(raw, dim=1)
        preds = torch.tensor(
            [class_inverse[pred.item()] for pred in preds]
        ).to(device=self._device)
        return preds

    def predict_proba(self, X_num, X_cat):
        """
        Calculate class "probabilities"
        Parameters
        ----------
        X_num : torch.Tensor, numpy.ndarray, or None
        X_cat : torch.Tensor, numpy.ndarray, or None
        Return
        ------
        torch.Tensor
        """
        if not self._model:
            raise RuntimeError("you need to fit the model first")
        raw = self.predict_logits(X_num, X_cat)
        proba = softmax(raw, dim=1)
        return proba

In [78]:
class RaggedBase(EmbeddingBase):
    """Base class for embeddings that allow a different vector size for each field"""

    def __init__(self):
        super().__init__()
        self.embedding: Optional[nn.ModuleList] = None

    def weight_sum(self) -> Tuple[Tensor, Tensor]:
        """
        Sum of absolute value and square of embedding weights
        Return
        ------
        e1_sum : sum of absolute value of embedding values
        e2_sum : sum of squared embedding values
        """
        if not self._isfit:
            return 0.0, 0.0
        e1_sum = 0.0
        e2_sum = 0.0
        for embedding in self.embedding:
            e1_sum += embedding.weight.abs().sum()
            e2_sum += (embedding.weight ** 2).sum()
        return e1_sum, e2_sum

In [79]:
BaseClass = pl.LightningModule if HAS_PL else nn.Module

In [80]:
class BaseNN(BaseClass, metaclass=ABCMeta):
    """
    Base class for neural network models
    """

    @abstractmethod
    def __init__(
        self,
        task: str,
        embedding_num: Optional[EmbeddingBase],
        embedding_cat: Optional[EmbeddingBase],
        embedding_l1_reg: float,
        embedding_l2_reg: float,
        mlp_l1_reg: float,
        mlp_l2_reg: float,
        loss_fn: Union[str, Callable],
        device: Union[str, torch.device] = "cpu",
    ):
        """
        Parameters
        ----------
        task : {"regression", "classification"}
        embedding_num : EmbeddingBase or None
            initialized and fit embedding for numeric fields
        embedding_cat : EmbeddingBase or None
            initialized and fit embedding for categorical fields
        embedding_l1_reg : float
            value for l1 regularization of embedding vectors
        embedding_l2_reg : float
            value for l2 regularization of embedding vectors
        mlp_l1_reg : float
            value for l1 regularization of MLP weights
        mlp_l2_reg : float
            value for l2 regularization of MLP weights
        loss_fn : "auto" or PyTorch loss function, optional
            default is "auto"
        device : string or torch.device, optional
            default is "cpu"
        """
        super().__init__()
        if task not in {"regression", "classification"}:
            raise ValueError(
                f"task {task} not recognized; should be 'regression' or 'classification'"
            )

        self.task = task
        self.num_epochs = 0

        if loss_fn != "auto":
            self.loss_fn = loss_fn
        elif task == "regression":
            self.loss_fn = nn.MSELoss()
        else:
            self.loss_fn = nn.CrossEntropyLoss()

        self.embedding_num = embedding_num
        self.embedding_cat = embedding_cat
        self.embedding_l1_reg = embedding_l1_reg
        self.embedding_l2_reg = embedding_l2_reg
        self.mlp_l1_reg = mlp_l1_reg
        self.mlp_l2_reg = mlp_l2_reg
        self.optimizer: Optional[Callable] = None
        self.optimizer_info: Dict[str, Any] = {}
        self.scheduler: Dict[str, Any] = {}
        self._device = device

    @abstractmethod
    def mlp_weight_sum(self) -> Tuple[Tensor, Tensor]:
        return torch.tensor([0.0]), torch.tensor([0.0])

    def embedding_sum(self) -> Tuple[Tensor, Tensor]:
        """
        Sum of absolute value and square of embedding values
        Return
        ------
        e1_sum : sum of absolute value of embedding values
        e2_sum : sum of squared embedding values
        """
        e1_sum = 0.0
        e2_sum = 0.0

        if hasattr(self, "embedding_num") and self.embedding_num is not None:
            e1_sum_num, e2_sum_num = self.embedding_num.weight_sum()
            e1_sum += e1_sum_num
            e2_sum += e2_sum_num

        if hasattr(self, "embedding_cat") and self.embedding_cat is not None:
            e1_sum_cat, e2_sum_cat = self.embedding_cat.weight_sum()
            e1_sum += e1_sum_cat
            e2_sum += e2_sum_cat

        return e1_sum, e2_sum

    def num_parameters(self) -> int:
        """
        Number of trainable parameters in the model
        Return
        ------
        int number of trainable parameters
        """
        return sum(param.numel() for param in self.parameters() if param.requires_grad)

    def embed(
        self,
        X_num: Tensor,
        X_cat: Tensor,
        num_dim: int = 3,
        concat: bool = True,
    ) -> Union[Tensor, Tuple[Tensor, Tensor]]:
        """
        Embed the numeric and categorical input fields.
        Parameters
        ----------
        X_num : torch.Tensor or numpy.ndarray or None
        X_cat : torch.Tensor or numpy.ndarray or None
        num_dim : 2 or 3, optional
            default is 3
        concat : bool, optional
            whether to concatenate outputs into a single Tensor;
            if True, concatenation is on dim 1; default is True
        Return
        ------
        torch.Tensor if concat else (torch.Tensor, torch.Tensor)
        """
        if X_num is None and X_cat is None:
            raise ValueError("X_num and X_cat cannot both be None")

        if num_dim not in (2, 3):
            raise ValueError(f"num_dim should be 2 or 3, got {num_dim}")

        if num_dim == 3 and (
            isinstance(self.embedding_num, RaggedBase)
            or isinstance(self.embedding_cat, RaggedBase)
        ):
            raise ValueError("cannot use num_dim=3 with ragged embeddings")

        # handle X_num
        if X_num is not None and X_num.shape[1] and self.embedding_num:
            X_num_emb = self.embedding_num(X_num)
        elif (X_num is not None and X_num.shape[1]) or not self.embedding_cat:
            if num_dim == 3:
                X_num_emb = X_num.reshape((X_num.shape[0], X_num.shape[1], 1))
            else:
                X_num_emb = X_num
        else:  # (X_num is None or not X_num.shape[1]) and self.embedding_cat
            X_num_emb = torch.empty(
                (X_cat.shape[0], 0, self.embedding_cat.embedding_size),
                device=self._device,
            )

        # handle X_cat
        if X_cat is not None and X_cat.shape[1] and self.embedding_cat:
            X_cat_emb = self.embedding_cat(X_cat)
        else:
            embed_dim = self.embedding_num.embedding_size if self.embedding_num else 1
            X_cat_emb = torch.empty((X_num.shape[0], 0, embed_dim), device=self._device)

        # reshape, if necessary
        if num_dim == 2:
            X_num_emb = X_num_emb.reshape((X_num_emb.shape[0], -1))
            X_cat_emb = X_cat_emb.reshape((X_cat_emb.shape[0], -1))

        if concat:
            return torch.cat([X_num_emb, X_cat_emb], dim=1)

        return X_num_emb, X_cat_emb

    def training_step(self, train_batch: List[Tensor], batch_idx: int) -> Dict:
        """
        Create predictions on batch and compute loss
        Used by PyTorch Lightning and the Scikit-learn-style classes
        Parameters
        ----------
        train_batch : torch.Tensor
        batch_idx : int
        Returns
        -------
        dict mapping "train_step_loss" to torch.Tensor loss value
        """
        X_num, X_cat, y = train_batch
        y_hat = self.forward(X_num, X_cat)
        loss = self.loss_fn(y_hat, y)
        if self.mlp_l1_reg > 0 or self.mlp_l2_reg > 0:
            w1, w2 = self.mlp_weight_sum()
            loss += self.mlp_l1_reg * w1 + self.mlp_l2_reg * w2
        if self.embedding_l1_reg > 0 or self.embedding_l2_reg > 0:
            w1, w2 = self.embedding_sum()
            loss += self.embedding_l1_reg * w1 + self.embedding_l2_reg * w2
        return {"loss": loss}

    def training_epoch_end(self, outputs: List[Dict]):
        """
        Computes and logs average train loss
        Used by PyTorch Lightning and the Scikit-learn-style classes
        Parameters
        ----------
        outputs : list of dicts
            outputs after all of the training steps
        Side effect
        -----------
        logs average loss as "train_loss"
        """
        avg_loss = torch.stack([x["loss"] for x in outputs]).mean()
        self.log("train_loss", avg_loss)

    def validation_step(self, val_batch: List[Tensor], batch_idx: int) -> Tuple[Tensor, Tensor]:
        """
        Calculate validation loss
        Used by PyTorch Lightning and the Scikit-learn-style classes
        Parameters
        ----------
        val_batch : torch.Tensor
        batch_idx : int
        Returns
        -------
        (y_pred, y_true) pair of tensors
        """
        X_num, X_cat, y = val_batch
        y_hat = self.forward(X_num, X_cat)
        return (y_hat, y)

    def validation_epoch_end(self, validation_step_outputs: List[Tuple[Tensor, Tensor]]):
        """
        Computes average validation loss
        Used by PyTorch Lightning and the Scikit-learn-style classes
        Parameters
        ----------
        validation_step_outputs : list of (y_pred, y_true) tensors
            outputs after all of the validation steps
        Side effect
        -----------
        logs average loss as "val_loss"
        """
        preds = torch.stack([y_hat for y_hat, _ in validation_step_outputs])
        ytrue = torch.stack([y for _, y in validation_step_outputs])
        val_loss = self.loss_fn(preds, ytrue)
        self.log("val_loss", val_loss)

    def custom_val_epoch_end(
        self,
        validation_step_outputs: List[Tuple[Tensor, Tensor]],
        extra_metrics: Iterable[Tuple[str, Callable]],
    ) -> Dict:
        """
        Calculate validation loss and other metrics if provided
        Parameters
        ----------
        validation_step_outputs : list of (y_pred, y_true) tensors
            outputs after all of the validation steps
        extra_metrics: list of (str, callable)
            tuples of str name and callable metric
        Returns
        -------
        dict
        - maps "val_step_loss" to torch.Tensor loss value
        - maps each name in `extra_metrics` to the metric value
        """
        preds = torch.cat([y_hat for y_hat, _ in validation_step_outputs], dim=0)
        ytrue = torch.cat([y for _, y in validation_step_outputs], dim=0)
        loss = self.loss_fn(preds, ytrue)
        info = {"val_loss": loss.item()}
        for name, fn in extra_metrics:
            loss = fn(preds, ytrue)
            info[name] = loss.item() if isinstance(loss, (np.ndarray, Tensor)) else loss
        return info

    def test_step(self, test_batch: List[Tensor], batch_idx: int) -> Dict:
        """
        Calculate test loss
        Used by PyTorch Lightning
        Parameters
        ----------
        test_batch : list of torch.Tensor
        batch_idx : int
        Returns
        -------
        dict
        - maps "test_step_loss" to torch.Tensor loss value
        """
        X_num, X_cat, y = test_batch
        y_hat = self.forward(X_num, X_cat)
        loss = self.loss_fn(y_hat, y)
        info = {"test_step_loss": loss}
        return info

    def test_epoch_end(self, outputs: List[Dict]):
        """
        Computes average test loss
        Used by PyTorch Lightning
        Parameters
        ----------
        outputs : list of dicts
            outputs after all of the test steps
        Side effect
        -----------
        logs average loss as "test_loss"
        """
        avg_loss = torch.stack([x["test_step_loss"] for x in outputs]).mean()
        self.log("test_loss", avg_loss)

    def set_optimizer(
        self,
        optimizer: Type[Optimizer] = torch.optim.Adam,
        opt_kwargs: Optional[Dict] = None,
        scheduler: Optional[Type[_LRScheduler]] = None,
        sch_kwargs: Optional[Dict] = None,
        sch_options: Optional[Dict] = None,
    ):
        """
        Set the models optimizer and, optionally, the learning rate schedule
        Parameters
        ----------
        optimizer : PyTorch Optimizer class, optional
            uninitialized subclass of Optimizer; default is torch.optim.Adam
        opt_kwargs : dict or None, optional
            dict of keyword arguments to initialize optimizer with;
            default is None
        scheduler : PyTorch scheduler class, optional
            default is None
        sch_kwargs : dict or None, optional
            dict of keyword arguments to initialize scheduler with;
            default is None
        sch_options : dict or None, optional
            options for PyTorch Lightning's call to `configure_optimizers`;
            ignore if not using PyTorch Lightning or no options are needed;
            with PyTorch Lightning, `ReduceLROnPlateau` requires "monitor";
            default is None
        """
        if sch_options is None:
            sch_options = {}
        if scheduler is ReduceLROnPlateau and "monitor" not in sch_options:
            sch_options["monitor"] = "val_loss"

        self.optimizer_info = {
            "optimizer": optimizer,
            "opt_kwargs": opt_kwargs if opt_kwargs is not None else {},
            "scheduler": scheduler,
            "sch_kwargs": sch_kwargs if sch_kwargs is not None else {},
            "sch_options": sch_options,
        }

    def configure_optimizers(
        self
    ) -> Union[Optimizer, Tuple[List[Optimizer], List[_LRScheduler]]]:
        """
        Initializes the optimizer and learning rate scheduler
        The optimizer and learning rate info needs to first be set with
        the `set_optimizer` method
        Used by PyTorch Lightning and the Scikit-learn-style classes
        Returns
        -------
        if no scheduler is being used
            initialized optimizer
        else
            tuple with
                list containing just the initialized optimizer
                dict containing scheduler information
        """
        if not self.optimizer_info:
            raise RuntimeError(
                "The optimizer and learning rate info needs to first be set "
                "with the `set_optimizer` method"
            )

        optimizer = self.optimizer_info["optimizer"]
        opt_kwargs = self.optimizer_info["opt_kwargs"]
        self.optimizer = optimizer(self.parameters(), **opt_kwargs)

        if self.optimizer_info["scheduler"] is None:
            return self.optimizer

        scheduler = self.optimizer_info["scheduler"]
        sch_kwargs = self.optimizer_info["sch_kwargs"]
        sch_options = self.optimizer_info["sch_options"]
        self.scheduler = {"scheduler": scheduler(self.optimizer, **sch_kwargs)}
        self.scheduler.update(sch_options)

        return [self.optimizer], [self.scheduler]

## create AutoInt

In [15]:
MODULE_INIT_DOC = """
Parameters
----------
task : {{"regression", "classification"}}
output_size : int
    number of final output values; i.e., number of targets for
    regression or number of classes for classification
embedding_num : EmbeddingBase or None
    initialized and fit embedding for numeric fields
embedding_cat : EmbeddingBase or None
    initialized and fit embedding for categorical fields
embedding_l1_reg : float, optional
    value for l1 regularization of embedding vectors; default is 0.0
embedding_l2_reg : float, optional
    value for l2 regularization of embedding vectors; default is 0.0
{}
mlp_hidden_sizes : int or iterable of int, optional
    sizes for the linear transformations between the MLP input and
    the output size needed based on the target; default is (512, 256, 128, 64)
mlp_activation : subclass of torch.nn.Module (uninitialized), optional
    default is nn.LeakyReLU
mlp_use_bn : boolean, optional
    whether to use batch normalization between MLP linear layers;
    default is True
mlp_bn_momentum : float, optional
    only used if `mlp_use_bn` is True; default is 0.01
mlp_ghost_batch : int or None, optional
    only used if `mlp_use_bn` is True; size of batch in "ghost batch norm";
    if None, normal batch norm is used; defualt is None
mlp_dropout : float, optional
    whether and how much dropout to use between MLP linear layers;
    `0.0 <= mlp_dropout < 1.0`; default is 0.0
mlp_use_skip : boolean, optional
    use a side path in the MLP containing just the optional leaky gate
    plus single linear layer; default is True
mlp_l1_reg : float, optional
    value for l1 regularization of MLP weights; default is 0.0
mlp_l2_reg : float, optional
    value for l2 regularization of MLP weights; default is 0.0
use_leaky_gate : boolean, optional
    whether to include "leaky gate" layers; default is True
loss_fn : "auto" or PyTorch loss function, optional
    default is "auto"
device : string or torch.device, optional
    default is "cpu"
"""

In [16]:
INIT_DOC = MODULE_INIT_DOC.format(
    textwrap.dedent(
        """\
        attn_embedding_size : int, optional
            default is 8
        attn_num_layers : int, optional
            default is 3
        attn_num_heads : int, optional
            default is 2
        attn_activation : subclass of torch.nn.Module or None, optional
            applied to the transformation tensors; default is None
        attn_use_residual : bool, optional
            default is True
        attn_dropout : float, optional
            amount of dropout to use on the product of queries and keys;
            default is 0.1
        attn_normalize : bool, optional
            whether to normalize each attn layer output; default is True"""
    )
)

In [42]:
class LeakyGate(nn.Module):
    """
    This performs an element-wise linear transformation followed by a chosen
    activation; the default activation is nn.LeakyReLU. Fields may be
    represented by individual values or vectors of values (i.e., embedded).
    Input needs to be shaped like (num_rows, num_fields) or
    (num_rows, num_fields, embedding_size)
    """

    def __init__(
        self,
        input_size: int,
        bias: bool = True,
        activation: Type[nn.Module] = nn.LeakyReLU,
        device: Union[str, torch.device] = "cpu",
    ):
        """
        Parameters
        ----------
        input_size : int
        bias : boolean, optional
            whether to include an additive bias; default is True
        activation : torch.nn.Module, optional
            default is nn.LeakyReLU
        device : string or torch.device, optional
            default is "cpu"
        """
        super().__init__()
        self.weight = nn.Parameter(torch.normal(mean=0, std=1.0, size=(1, input_size)))
        self.bias = nn.Parameter(torch.zeros(size=(1, input_size)), requires_grad=bias)
        self.activation = activation()
        self.to(device)

    def forward(self, X: Tensor) -> Tensor:
        """
        Transform the input tensor
        Parameters
        ----------
        X : torch.Tensor
        Return
        ------
        torch.Tensor
        """
        out = X
        if len(X.shape) > 2:
            out = out.reshape((X.shape[0], -1))
        out = out * self.weight + self.bias
        if len(X.shape) > 2:
            out = out.reshape(X.shape)
        out = self.activation(out)
        return out

In [50]:
class AttnInteractionLayer(nn.Module):
    """
    The attention interaction layer for the AutoInt model.
    Paper for the original AutoInt model: https://arxiv.org/pdf/1810.11921v2.pdf
    """

    def __init__(
        self,
        field_input_size: int,
        field_output_size: int = 8,
        num_heads: int = 2,
        activation: Optional[Type[nn.Module]] = None,
        use_residual: bool = True,
        dropout: float = 0.1,
        normalize: bool = True,
        ghost_batch_size: Optional[int] = None,
        device: Union[str, torch.device] = "cpu",
    ):
        """
        Parameters
        ----------
        field_input_size : int
            original embedding size for each field
        field_output_size : int, optional
            embedding size after transformation; default is 8
        num_heads : int, optional
            number of attention heads; default is 2
        activation : subclass of torch.nn.Module or None, optional
            applied to the W tensors; default is None
        use_residual : bool, optional
            default is True
        dropout : float, optional
            default is 0.1
        normalize : bool, optional
            default is True
        ghost_batch_size : int or None, optional
            only used if `use_bn` is True; size of batch in "ghost batch norm";
            if None, normal batch norm is used; defualt is None
        device : string or torch.device, optional
            default is "cpu"
        """
        super().__init__()

        self.use_residual = use_residual

        self.W_q = _initialized_tensor(field_input_size, field_output_size, num_heads)
        self.W_k = _initialized_tensor(field_input_size, field_output_size, num_heads)
        self.W_v = _initialized_tensor(field_input_size, field_output_size, num_heads)

        if use_residual:
            self.W_r = _initialized_tensor(field_input_size, field_output_size * num_heads)
        else:
            self.W_r = None

        if activation:
            self.w_act = activation()
        else:
            self.w_act = nn.Identity()

        if dropout > 0.0:
            self.dropout = nn.Dropout(dropout)
        else:
            self.dropout = nn.Identity()

        if normalize:
            self.layer_norm = nn.LayerNorm(field_output_size * num_heads)
        else:
            self.layer_norm = nn.Identity()

        self.to(device)

    def forward(self, x: Tensor) -> Tensor:
        """
        Transform the input tensor with attention interaction
        Parameters
        ----------
        x : torch.Tensor
            3-d tensor; for example, embedded numeric and/or categorical values,
            or the output of a previous attention interaction layer
        Return
        ------
        torch.Tensor
        """
        # R : # rows
        # F, D : # fields
        # I : field embedding size in
        # O : field embedding size out
        # H : # heads
        num_rows, num_fields, _ = x.shape  # R, F, I

        # (R, F, I) * (I, O, H) -> (R, F, O, H)
        qrys = torch.tensordot(x, self.w_act(self.W_q), dims=([-1], [0]))
        keys = torch.tensordot(x, self.w_act(self.W_k), dims=([-1], [0]))
        vals = torch.tensordot(x, self.w_act(self.W_v), dims=([-1], [0]))
        if self.use_residual:
            rsdl = torch.tensordot(x, self.w_act(self.W_r), dims=([-1], [0]))

        product = torch.einsum("rdoh,rfoh->rdfh", qrys, keys)  # (R, F, F, H)

        alpha = F.softmax(product, dim=2)  # (R, F, F, H)
        alpha = self.dropout(alpha)

        # (R, F, F, H) * (R, F, O, H) -> (R, F, O, H)
        out = torch.einsum("rfdh,rfoh->rfoh", alpha, vals)
        out = out.reshape((num_rows, num_fields, -1))  # (R, F, O * H)
        if self.use_residual:
            out = out + rsdl  # (R, F, O * H)
        out = F.leaky_relu(out)
        out = self.layer_norm(out)

        return out

In [51]:
class AttnInteractionBlock(nn.Module):
    """
    A collection of AttnInteractionLayers, followed by an optional "leaky gate"
    and then a linear layer. This block is originally for the AutoInt model.
    Paper for the original AutoInt model: https://arxiv.org/pdf/1810.11921v2.pdf
    """

    def __init__(
        self,
        field_input_size: int,
        field_output_size: int = 8,
        num_layers: int = 3,
        num_heads: int = 2,
        activation: Optional[Type[nn.Module]] = None,
        use_residual: bool = True,
        dropout: float = 0.1,
        normalize: bool = True,
        ghost_batch_size: Optional[int] = None,
        device: Union[str, torch.device] = "cpu",
    ):
        """
        Parameters
        ----------
        field_input_size : int
            original embedding size for each field
        field_output_size : int, optional
            embedding size after transformation; default is 8
        num_layers : int, optional
            number of attention layers; default is 3
        num_heads : int, optional
            number of attention heads per layer; default is 2
        activation : subclass of torch.nn.Module or None, optional
            applied to the W tensors; default is None
        use_residual : bool, optional
            default is True
        dropout : float, optional
            default is 0.0
        normalize : bool, optional
            default is True
        ghost_batch_size : int or None, optional
            only used if `use_bn` is True; size of batch in "ghost batch norm";
            if None, normal batch norm is used; defualt is None
        device : string or torch.device, optional
            default is "cpu"
        """
        super().__init__()

        layers = []
        for _ in range(num_layers):
            layers.append(
                AttnInteractionLayer(
                    field_input_size,
                    field_output_size,
                    num_heads,
                    activation,
                    use_residual,
                    dropout,
                    normalize,
                    ghost_batch_size,
                    device,
                )
            )
            field_input_size = field_output_size * num_heads

        self.layers = nn.Sequential(*layers)
        self.to(device)

    def forward(self, x: Tensor) -> Tensor:
        """
        Transform the input tensor
        Parameters
        ----------
        x : torch.Tensor
            3-d tensor, usually embedded numeric and/or categorical values
        Return
        ------
        torch.Tensor
        """
        out = self.layers(x)
        return out

In [57]:
class MLP(nn.Module):
    """
    A "multi-layer perceptron". This forms layes of fully-connected linear
    transformations, with opional batch norm, dropout, and an initial
    "leaky gate".
    Input should be shaped like (num_rows, num_fields)
    """

    def __init__(
        self,
        task: str,
        input_size: int,
        hidden_sizes: Union[int, Tuple[int, ...], List[int]],
        output_size: int,
        activation: Type[nn.Module] = nn.LeakyReLU,
        dropout: Union[float, Tuple[float], List[float]] = 0.0,
        dropout_first: bool = False,
        use_bn: bool = True,
        bn_momentum: float = 0.1,
        ghost_batch: Optional[int] = None,
        leaky_gate: bool = True,
        use_skip: bool = True,
        weighted_sum: bool = True,
        device: Union[str, torch.device] = "cpu",
    ):
        """
        Parameters
        ----------
        task : {"regression", "classification"}
        input_size : int
            the number of inputs into the first layer
        hidden_sizes : iterable of int
            intermediate sizes between `input_size` and `output_size`
        output_size : int
            the number of outputs from the last layer
        activation : subclass of torch.nn.Module (uninitialized), optional
            default is nn.LeakyReLU
        dropout : float or iterable of float
            should be between 0.0 and 1.0; if iterable of float, there
            should be one value for each hidden size, plus an additional
            value if `use_bn` is True
        dropout_first : boolean, optional
            whether to include dropout before the first fully-connected
            linear layer (and after "leaky_gate", if using);
            default is False
        use_bn : boolean, optional
            whether to use batch normalization; default is True
        bn_momentum : float, optional
            default is 0.1
        ghost_batch : int or None, optional
            only used if `use_bn` is True; size of batch in "ghost batch norm";
            if None, normal batch norm is used; defualt is None
        leaky_gate : boolean, optional
            whether to include a LeakyGate layer before the linear layers;
            default is True
        use_skip : boolean, optional
            use a side path containing just the optional leaky gate plust
            a single linear layer; default is True
        weighted_sum : boolean, optional
            only used with use_skip; when adding main MLP output with side
            "skip" output, use a weighted sum with learnable weight; default is True
        device : string or torch.device, optional
            default is "cpu"
        """
        super().__init__()

        if isinstance(hidden_sizes, int):
            hidden_sizes = [hidden_sizes]

        dropout_len = len(hidden_sizes) + (1 if dropout_first else 0)

        if isinstance(dropout, float):
            dropout = [dropout] * dropout_len
        elif not len(dropout) == dropout_len:
            raise ValueError(
                f"expected a single dropout value or {dropout_len} values "
                f"({'one more than' if dropout_first else 'same as'} hidden_sizes)"
            )

        main_layers: List[nn.Module] = []

        if leaky_gate:
            main_layers.append(LeakyGate(input_size))

        if dropout_first and dropout[0] > 0:
            main_layers.append(nn.Dropout(dropout[0]))
            dropout = dropout[1:]

        input_size_i = input_size
        for hidden_size_i, dropout_i in zip(hidden_sizes, dropout):
            main_layers.append(nn.Linear(input_size_i, hidden_size_i, bias=(not use_bn)))
            if use_bn:
                if ghost_batch is None:
                    bnlayer = nn.BatchNorm1d(hidden_size_i, momentum=bn_momentum)
                else:
                    bnlayer = GhostBatchNorm(
                        hidden_size_i, ghost_batch, momentum=bn_momentum
                    )
                main_layers.append(bnlayer)
            main_layers.append(activation())
            if dropout_i > 0:
                main_layers.append(nn.Dropout(dropout_i))
            input_size_i = hidden_size_i

        main_layers.append(
            nn.Linear(input_size_i, output_size, bias=(task != "classification"))
        )

        self.main_layers = nn.Sequential(*main_layers)

        self.use_skip = use_skip
        if use_skip:
            skip_linear = nn.Linear(input_size, output_size, bias=(task != "classification"))
            if leaky_gate:
                self.skip_layers = nn.Sequential(LeakyGate(input_size), skip_linear)
            else:
                self.skip_layers = skip_linear
            if weighted_sum:
                self.mix = nn.Parameter(torch.tensor([0.0]))
            else:
                self.mix = torch.tensor([0.0], device=device)
        else:
            self.skip_layers = None
            self.mix = None

        self.to(device)

    def weight_sum(self) -> Tuple[Tensor, Tensor]:
        """
        Sum of absolute value and squared weights, for regularization
        Return
        ------
        w1 : float
            sum of absolute value of weights
        w2 : float
            sum of squared weights
        """
        w1_sum = 0.0
        w2_sum = 0.0
        for layer_group in (self.main_layers, self.skip_layers):
            if layer_group is None:
                continue
            for layer in layer_group:
                if not isinstance(layer, nn.Linear):
                    continue
                w1_sum += layer.weight.abs().sum()
                w2_sum += (layer.weight ** 2).sum()
        return w1_sum, w2_sum

    def forward(self, X: Tensor) -> Tuple[float, float]:
        """
        Transform the input tensor
        Parameters
        ----------
        X : torch.Tensor
        Return
        ------
        torch.Tensor
        """
        out = self.main_layers(X)
        if self.use_skip:
            mix = torch.sigmoid(self.mix)
            skip_out = self.skip_layers(X)
            out = mix * skip_out + (1 - mix) * out
        return out

In [58]:
class AutoInt(BaseNN):
    """
    The AutoInt model, with a side MLP component, aka "AutoInt+", with modifications.
    See AutoInt.diagram() for the general structure of the model.
    Paper for the original AutoInt model: https://arxiv.org/pdf/1810.11921v2.pdf
    """

    def __init__(
        self,
        task: str,
        output_size: int,
        embedding_num: Optional[EmbeddingBase],
        embedding_cat: Optional[EmbeddingBase],
        embedding_l1_reg: float = 0.0,
        embedding_l2_reg: float = 0.0,
        attn_embedding_size: int = 8,
        attn_num_layers: int = 3,
        attn_num_heads: int = 2,
        attn_activation: Optional[Type[nn.Module]] = None,
        attn_use_residual: bool = True,
        attn_dropout: float = 0.1,
        attn_normalize: bool = True,
        attn_use_mlp: bool = True,
        mlp_hidden_sizes: Union[int, Tuple[int, ...], List[int]] = (512, 256, 128, 64),
        mlp_activation: Type[nn.Module] = nn.LeakyReLU,
        mlp_use_bn: bool = True,
        mlp_bn_momentum: float = 0.1,
        mlp_ghost_batch: Optional[int] = None,
        mlp_dropout: float = 0.0,
        mlp_use_skip: bool = True,
        mlp_l1_reg: float = 0.0,
        mlp_l2_reg: float = 0.0,
        use_leaky_gate: bool = True,
        weighted_sum: bool = True,
        loss_fn: Union[str, Callable] = "auto",
        device: Union[str, torch.device] = "cpu",
    ):
        super().__init__(
            task,
            embedding_num,
            embedding_cat,
            embedding_l1_reg,
            embedding_l2_reg,
            mlp_l1_reg,
            mlp_l2_reg,
            loss_fn,
            device,
        )

        device = torch.device(device)
        embed_info = check_uniform_embeddings(embedding_num, embedding_cat)

        if use_leaky_gate:
            self.attn_gate = LeakyGate(embed_info.output_size, device=device)
        else:
            self.attn_gate = nn.Identity()

        self.attn_interact = AttnInteractionBlock(
            field_input_size=embed_info.embedding_size,
            field_output_size=attn_embedding_size,
            num_layers=attn_num_layers,
            num_heads=attn_num_heads,
            activation=attn_activation,
            use_residual=attn_use_residual,
            dropout=attn_dropout,
            normalize=attn_normalize,
            ghost_batch_size=mlp_ghost_batch,
            device=device,
        )

        self.attn_final = MLP(
            task=task,
            input_size=embed_info.num_fields * attn_embedding_size * attn_num_heads,
            hidden_sizes=(mlp_hidden_sizes if mlp_hidden_sizes and attn_use_mlp else []),
            output_size=output_size,
            activation=mlp_activation,
            dropout=mlp_dropout,
            use_bn=mlp_use_bn,
            bn_momentum=mlp_bn_momentum,
            ghost_batch=mlp_ghost_batch,
            leaky_gate=use_leaky_gate,
            use_skip=mlp_use_skip,
            device=device,
        )

        if mlp_hidden_sizes:
            self.mlp = MLP(
                task=task,
                input_size=embed_info.output_size,
                hidden_sizes=mlp_hidden_sizes,
                output_size=output_size,
                activation=mlp_activation,
                dropout=mlp_dropout,
                use_bn=mlp_use_bn,
                bn_momentum=mlp_bn_momentum,
                ghost_batch=mlp_ghost_batch,
                leaky_gate=use_leaky_gate,
                use_skip=mlp_use_skip,
                device=device,
            )
            if weighted_sum:
                self.mix = nn.Parameter(torch.tensor([0.0], device=device))
            else:
                self.mix = torch.tensor([0.0], device=device)
        else:
            self.mlp = None
            self.mix = None

        #self.to(device)

    __init__.__doc__ = INIT_DOC

    @staticmethod
    def diagram():
        """ Print a text diagram of this model """
        gram = """\
        if mlp_hidden_sizes (default)
        -----------------------------
        X_num ─ Num. embedding ┐ ┌─ Attn ─ ... ─ Attn ─ MLP ─┐
                               ├─┤                           w+ ── output
        X_cat ─ Cat. embedding ┘ └────────── MLP ────────────┘
        if no mlp_hidden_sizes
        ----------------------
        X_num ─ Num. embedding ┬─ Attn ─ ... ─ Attn ─ Linear ─ output
        X_cat ─ Cat. embedding ┘ 
        splits are copies and joins are concatenations;
        'w+' is weighted element-wise addition;
        "Attn" is AutoInt's AttentionInteractionLayer
        """
        print("\n" + textwrap.dedent(gram))

    def mlp_weight_sum(self) -> Tuple[Tensor, Tensor]:
        """
        Sum of absolute value and square of weights in MLP layers
        Return
        ------
        w1 : sum of absolute value of MLP weights
        w2 : sum of squared MLP weights
        """
        w1, w2 = self.attn_final.weight_sum()
        if self.mlp is not None:
            side_w1, side_w2 = self.mlp.weight_sum()
            w1 += side_w1
            w2 += side_w2
        return w1, w2

    def forward(self, X_num: Tensor, X_cat: Tensor) -> Tensor:
        """
        Transform the input tensor
        Parameters
        ----------
        X_num : torch.Tensor
            numeric fields
        X_cat : torch.Tensor
            categorical fields
        Return
        ------
        torch.Tensor
        """
        embedded = self.embed(X_num, X_cat)
        out = self.attn_gate(embedded)
        out = self.attn_interact(out)
        out = self.attn_final(out.reshape((out.shape[0], -1)))
        if self.mlp is not None:
            embedded_2d = embedded.reshape((embedded.shape[0], -1))
            mix = torch.sigmoid(self.mix)
            out = mix * out + (1 - mix) * self.mlp(embedded_2d)
        return out

## create AutoIntClassifier

In [18]:
ESTIMATOR_INIT_DOC = """
Parameters
----------
embedding_num : "auto", embedding.EmbeddingBase, or None, optional
    embedding for numeric fields; default is auto
embedding_cat : "auto", embedding.EmbeddingBase, or None, optional
    embedding for categorical fields; default is auto
embedding_l1_reg : float, optional
    value for l1 regularization of embedding vectors; default is 0.0
embedding_l2_reg : float, optional
    value for l2 regularization of embedding vectors; default is 0.0
{}
mlp_hidden_sizes : int or iterable of int, optional
    sizes for the linear transformations between the MLP input and
    the output size needed based on the target; default is (512, 256, 128, 64)
mlp_activation : subclass of torch.nn.Module, optional
    default is nn.LeakyReLU
mlp_use_bn : boolean, optional
    whether to use batch normalization between MLP linear layers;
    default is True
mlp_bn_momentum : float, optional
    only used if `mlp_use_bn` is True; default is 0.01
mlp_ghost_batch : int or None, optional
    only used if `mlp_use_bn` is True; size of batch in "ghost batch norm";
    if None, normal batch norm is used; defualt is None
mlp_dropout : float, optional
    whether and how much dropout to use between MLP linear layers;
    `0.0 <= mlp_dropout < 1.0`; default is 0.0
mlp_l1_reg : float, optional
    value for l1 regularization of MLP weights; default is 0.0
mlp_l2_reg : float, optional
    value for l2 regularization of MLP weights; default is 0.0
mlp_use_skip : boolean, optional
    use a side path in the MLP containing just the optional leaky gate
    plus single linear layer; default is True
use_leaky_gate : boolean, optional
    whether to include "leaky gate" layers; default is True
loss_fn : "auto" or PyTorch loss function, optional
    if "auto", nn.CrossEntropyLoss is used; default is "auto"
seed : int or None, optional
    if int, seed for `torch.manual_seed` and `numpy.random.seed`;
    if None, no seeding is done; default is None
device : string or torch.device, optional
    default is "cpu"
"""

In [19]:
INIT_DOC = ESTIMATOR_INIT_DOC.format(
    textwrap.dedent(
        """\
        attn_embedding_size : int, optional
            default is 8
        attn_num_layers : int, optional
            default is 3
        attn_num_head : int, optional
            default is 2
        attn_activation : subclass of torch.nn.Module or None, optional
            applied to the transformation tensors; default is None
        attn_use_residual : bool, optional
            default is True
        attn_dropout : float, optional
            amount of dropout to use on the product of queries and keys;
            default is 0.1
        attn_normalize : bool, optional
            whether to normalize each attn layer output; default is True"""
    )
)

In [20]:
class AutoIntClassifier(BaseClassifier):
    """
    Scikit-learn style classification model for the AutoInt model
    """

    diagram = AutoInt.diagram

    def __init__(
        self,
        embedding_num: Optional[Union[str, EmbeddingBase]] = "auto",
        embedding_cat: Optional[Union[str, EmbeddingBase]] = "auto",
        embedding_l1_reg: float=0.0,
        embedding_l2_reg: float=0.0,
        attn_embedding_size: int = 8,
        attn_num_layers: int = 3,
        attn_num_heads: int = 2,
        attn_activation: Optional[Type[nn.Module]] = None,
        attn_use_residual: bool = True,
        attn_dropout: float = 0.1,
        attn_normalize: bool = True,
        attn_use_mlp: bool = True,
        mlp_hidden_sizes: Union[int, Tuple[int, ...], List[int]] = (512, 256, 128, 64),
        mlp_activation: Type[nn.Module] = nn.LeakyReLU,
        mlp_use_bn: bool = True,
        mlp_bn_momentum: float = 0.1,
        mlp_ghost_batch: Optional[int] = None,
        mlp_dropout: float = 0.0,
        mlp_l1_reg: float = 0.0,
        mlp_l2_reg: float = 0.0,
        mlp_use_skip: bool = True,
        use_leaky_gate: bool = True,
        weighted_sum: bool = True,
        loss_fn: Union[str, Callable] = "auto",
        seed: Union[int, None] = None,
        device: Union[str, torch.device] = "cpu",
    ):
        super().__init__(
            embedding_num=embedding_num,
            embedding_cat=embedding_cat,
            embedding_l1_reg=embedding_l1_reg,
            embedding_l2_reg=embedding_l2_reg,
            attn_embedding_size=attn_embedding_size,
            attn_num_layers=attn_num_layers,
            attn_num_heads=attn_num_heads,
            attn_activation=attn_activation,
            attn_use_residual=attn_use_residual,
            attn_dropout=attn_dropout,
            attn_normalize=attn_normalize,
            attn_use_mlp=attn_use_mlp,
            mlp_hidden_sizes=mlp_hidden_sizes,
            mlp_activation=mlp_activation,
            mlp_use_bn=mlp_use_bn,
            mlp_bn_momentum=mlp_bn_momentum,
            mlp_ghost_batch=mlp_ghost_batch,
            mlp_dropout=mlp_dropout,
            mlp_l1_reg=mlp_l1_reg,
            mlp_l2_reg=mlp_l2_reg,
            mlp_use_skip=mlp_use_skip,
            use_leaky_gate=use_leaky_gate,
            weighted_sum=weighted_sum,
            loss_fn=loss_fn,
            seed=seed,
            device=device,
        )
        self._model_class = AutoInt
        self._require_numeric_embedding = True

    __init__.__doc__ = INIT_DOC

## download dataset

In [22]:
url = "https://archive.ics.uci.edu/ml/machine-learning-databases/covtype/covtype.data.gz"
datapath = Path('../data/forest_cover/forest-cover-type.csv')

In [23]:
if datapath.exists():
    print("File already exists.")
else:
    print("Downloading file...")
    datapath.parent.mkdir(parents=True, exist_ok=True)
    response = requests.get(url, stream=True)
    data = zlib.decompress(response.content, zlib.MAX_WBITS|32)
    with open(datapath, 'wb') as outfile:
        outfile.write(data)

Downloading file...


In [24]:
target = "Covertype"

soil_types = [f"Soil_Type{i}" for i in range(1, 41)]

bool_columns = [
    "Wilderness_Area1",
    "Wilderness_Area2",
    "Wilderness_Area3",
    "Wilderness_Area4",
] + soil_types

int_columns = [
    "Elevation",
    "Aspect",
    "Slope",
    "Horizontal_Distance_To_Hydrology",
    "Vertical_Distance_To_Hydrology",
    "Horizontal_Distance_To_Roadways",
    "Hillshade_9am",
    "Hillshade_Noon",
    "Hillshade_3pm",
    "Horizontal_Distance_To_Fire_Points",
]

feature_columns = int_columns + bool_columns + [target]

In [25]:
train = pd.read_csv(datapath, header=None, names=feature_columns)
train.head()

Unnamed: 0,Elevation,Aspect,Slope,Horizontal_Distance_To_Hydrology,Vertical_Distance_To_Hydrology,Horizontal_Distance_To_Roadways,Hillshade_9am,Hillshade_Noon,Hillshade_3pm,Horizontal_Distance_To_Fire_Points,Wilderness_Area1,Wilderness_Area2,Wilderness_Area3,Wilderness_Area4,Soil_Type1,Soil_Type2,Soil_Type3,Soil_Type4,Soil_Type5,Soil_Type6,Soil_Type7,Soil_Type8,Soil_Type9,Soil_Type10,Soil_Type11,Soil_Type12,Soil_Type13,Soil_Type14,Soil_Type15,Soil_Type16,Soil_Type17,Soil_Type18,Soil_Type19,Soil_Type20,Soil_Type21,Soil_Type22,Soil_Type23,Soil_Type24,Soil_Type25,Soil_Type26,Soil_Type27,Soil_Type28,Soil_Type29,Soil_Type30,Soil_Type31,Soil_Type32,Soil_Type33,Soil_Type34,Soil_Type35,Soil_Type36,Soil_Type37,Soil_Type38,Soil_Type39,Soil_Type40,Covertype
0,2596,51,3,258,0,510,221,232,148,6279,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,5
1,2590,56,2,212,-6,390,220,235,151,6225,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,5
2,2804,139,9,268,65,3180,234,238,135,6121,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,2
3,2785,155,18,242,118,3090,238,238,122,6211,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,2
4,2595,45,2,153,-1,391,220,234,150,6172,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,5


In [26]:
X_num = train[int_columns + bool_columns]
y = train[[target]]

In [27]:
X_num_train, X_num_valid, y_train, y_valid = train_test_split(
    X_num.values, y.values, test_size=0.2, random_state=0
)

X_cat_train, X_cat_valid = None, None

In [28]:
mean = X_num_train.mean(axis=0, keepdims=True)
stdv = np.sqrt(X_num_train.var(ddof=1, axis=0, keepdims=True))

X_num_train = (X_num_train - mean) / stdv
X_num_valid = (X_num_valid - mean) / stdv

In [29]:
X_num_train.shape, X_num_valid.shape

((464809, 54), (116203, 54))

## create dataloader

In [63]:
def _validate_x(X, y, X_name, device):
    if isinstance(X, (Tensor, np.ndarray)):
        if not X.shape[0] == y.shape[0]:
            raise ValueError(
                f"shape mismatch; got y.shape[0] == {y.shape[0]}, "
                f"{X_name}.shape[0] == {X.shape[0]}"
            )
        if len(X.shape) != 2:
            raise ValueError(
                f"{X_name} should be 2-d; got shape {X.shape}"
            )
        if isinstance(X, np.ndarray):
            X = torch.from_numpy(X).to(dtype=torch.float32)
    elif X is None:
        X = torch.empty((y.shape[0], 0))
    else:
        raise TypeError(f"input {X_name} should be Tensor, NumPy array, or None")
    return X


def _validate_y(y, task, device):
    if isinstance(y, (Tensor, np.ndarray)):
        if any(size == 0 for size in y.shape):
            raise ValueError(f"y has a zero-sized dimension; got shape {y.shape}")

        if task == "regression" and len(y.shape) == 1:
            y = y.reshape((-1, 1))
        elif task == "classification" and len(y.shape) == 2:
            if y.shape[1] != 1:
                raise ValueError("for classification y must be 1-d or 2-d with one column")
            y = y.reshape((-1,))
        elif len(y.shape) > 2:
            raise ValueError(f"y has too many dimensions; got shape {y.shape}")

        if isinstance(y, np.ndarray):
            y = torch.from_numpy(y).to(dtype=torch.float32)
    else:
        raise TypeError("y should be Tensor or NumPy array")
    return y

In [64]:
class TabularDataLoader:
    """
    A DataLoader-like class that aims to be faster for tabular data.
    Based on `FastTensorDataLoader` by Jesse Mu
    https://discuss.pytorch.org/t/dataloader-much-slower-than-manual-batching/27014/6
    """
    def __init__(
        self,
        task: str,
        X_num: Optional[Union[np.ndarray, Tensor]],
        X_cat: Optional[Union[np.ndarray, Tensor]],
        y: Union[np.ndarray, Tensor],
        batch_size: int = 32,
        shuffle: bool = False,
        device: Union[str, torch.device] = "cpu",
    ):
        """
        Parameters
        ----------
        task : {"regression", "classification"}
        X_num : PyTorch Tensor, NumPy array, or None
            numeric input fields
        X_cat : PyTorch Tensor, NumPy array, or None
            categorical input fields (represented as numeric values)
        y : PyTorch Tensor, NumPy array, or None
            target field
        batch_size : int, optional
            default is 32
        shuffle : bool, optional
            default is False
        device : string or torch.device, optional
            default is "cpu"
        """
        if X_num is None and X_cat is None:
            raise TypeError("X_num and X_cat cannot both be None")

        self.y = _validate_y(y, task, device)
        self.X_num = _validate_x(X_num, self.y, "X_num", device)
        self.X_cat = _validate_x(X_cat, self.y, "X_cat", device)
        self.dataset_len = y.shape[0]
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.device = device

        # Calculate # batches
        n_batches, remainder = divmod(self.dataset_len, self.batch_size)
        if remainder > 0:
            n_batches += 1
        self.n_batches = n_batches

    def __iter__(self):
        if self.shuffle:
            self.indices = torch.randperm(self.dataset_len)
        else:
            self.indices = None
        self.i = 0
        return self

    def __next__(self):
        if self.i >= self.dataset_len:
            raise StopIteration
        if self.indices is not None:
            indices = self.indices[self.i:self.i+self.batch_size]
            batch = (
                torch.index_select(self.X_num, 0, indices).to(device=self.device),
                torch.index_select(self.X_cat, 0, indices).to(device=self.device),
                torch.index_select(self.y, 0, indices).to(device=self.device),
            )
        else:
            batch = (
                self.X_num[self.i:self.i+self.batch_size].to(device=self.device),
                self.X_cat[self.i:self.i+self.batch_size].to(device=self.device),
                self.y[self.i:self.i+self.batch_size].to(device=self.device),
            )
        self.i += self.batch_size
        return batch

    def __len__(self):
        return self.n_batches

## define train function

In [68]:
def train(
    model: BaseNN,
    train_data: DataLoader,
    val_data: Optional[Union[DataLoader, Iterable[DataLoader]]] = None,
    num_epochs: int = 5,
    max_grad_norm: float = float("inf"),
    extra_metrics: Optional[List[Tuple[str, Callable]]] = None,
    scheduler_step: str = "epoch",
    early_stopping_metric: str = "val_loss",
    early_stopping_patience: Union[int, float] = float("inf"),
    early_stopping_mode: str = "min",
    early_stopping_window: int = 1,
    param_path: Optional[str] = None,
    callback: Optional[Callable] = None,
    verbose: bool = False,
):
    """
    Train the given model.
    Optimizer and optional scheduler should be already set with
    `model.set_optimizer()` and initialized with `model.configure_optimizer`.
    Parameters
    ----------
    model : BaseNN
        any PyTorch model from this package
    train_data : PyTorch DataLoader
    val_data : PyTorch DataLoader, iterable of DataLoader, or None; optional
        default is None
    num_epochs : int, optional
        default is 5
    max_grad_norm : float, optional
        value to clip gradient norms to; default is float("inf") (no clipping)
    extra_metrics : list of (str, callable) tuples or None, optional
        default is None
    scheduler_step : {"epoch", "batch"}, optional
        whether the scheduler step should be called each epoch or each batch;
        if "batch", the scheduler won't have access to validation metrics;
        default is "epoch"
    early_stopping_metric : str, optional
        should be "val_loss" or one of the passed `extra_metrics`;
        default is "val_loss"
    early_stopping_patience : int, float; optional
        default is float("inf") (no early stopping)
    early_stopping_mode : {"min", "max"}, optional
        use "min" if smaller values are better; default is "min"
    early_stopping_window : int, optional
        number of consecutive epochs to average to determine best;
        default is 1
    param_path : str or None, optional
        specify this to have the best parameters reloaded at end of training;
        default is None
    callback : callable or None, optional
        function to call after each epoch; the function will be passed a list
        of dictionaries, one dictionary for each epoch; default is None
    verbose : boolean, optional
        default is False
    Return
    ------
    list of dictionaries, one dictionary for each epoch
    """

    if isinstance(val_data, DataLoader):
        val_data = [val_data]

    if extra_metrics is None:
        extra_metrics = []

    val_metric_names = ["val_loss"] + [name for name, _ in extra_metrics]

    # check early stopping values
    if early_stopping_patience < float("inf"):
        if not val_data:
            raise ValueError("early_stopping_patience given without validation sets")
        if early_stopping_metric not in val_metric_names:
            raise ValueError(
                f"early_stopping_metric {repr(early_stopping_metric)} "
                "is not 'val_loss' and is not one of the extra_metrics"
            )
        if early_stopping_mode not in ("min", "max"):
            raise ValueError(
                "early_stopping_mode needs to be 'min' or 'max'; "
                f"got {repr(early_stopping_mode)}"
            )
        if not isinstance(early_stopping_window, int) or early_stopping_window <= 0:
            raise ValueError(
                "early_stopping_window needs to be a positive integer; "
                f"got {repr(early_stopping_window)}"
            )

    # check if model's sheduler needs to monitor a validation metric,
    # and check if the metric is in the validation metrics
    if model.scheduler is not None and "monitor" in model.scheduler:
        if not val_data:
            raise ValueError(
                "the model's scheduler expected to monitor "
                f"\'{model.scheduler['monitor']}\', but there is no validation data"
            )
        if model.scheduler["monitor"] not in val_metric_names:
            raise ValueError(
                f"scheduler monitor \'{model.scheduler['monitor']}\' "
                "not found in validation metrics"
            )

    if verbose:
        tmplt_main, tmplt_xtra, _ = _print_header(
            model=model,
            has_validation=bool(val_data),
            extra_metrics=extra_metrics
        )
    else:
        tmplt_main, tmplt_xtra = "", ""

    log_info = []
    es_count = 0
    es_best = float("inf") if early_stopping_mode == "min" else float("-inf")
    for _ in range(num_epochs):
        epoch_log_info = _epoch(
            model=model,
            train_data=train_data,
            val_data=val_data,
            max_grad_norm=max_grad_norm,
            extra_metrics=extra_metrics,
            scheduler_step=scheduler_step,
            verbose=verbose,
            tmplt_main=tmplt_main,
            tmplt_xtra=tmplt_xtra,
        )
        log_info.extend(epoch_log_info)
        es_best, es_count = _evaluate(
            model=model,
            metric=early_stopping_metric,
            patience=early_stopping_patience,
            mode=early_stopping_mode,
            window=early_stopping_window,
            best=es_best,
            count=es_count,
            log_info=log_info,
            param_path=param_path,
        )
        if callback is not None:
            callback(log_info)
        if es_count >= early_stopping_patience + 1:
            if verbose:
                best_epoch = log_info[-1]['epoch'] - es_count - early_stopping_window // 2
                print(
                    "Stopping early. "
                    f"Best epoch: {best_epoch}. "
                    f"Best {early_stopping_metric}: {es_best:11.6g}"
                )
            break

    if param_path:
        model.load_state_dict(torch.load(param_path))

    return log_info

In [71]:
LogInfo = Dict[str, Union[str, int, float, bool]]

In [None]:
def _print_header(
    model: nn.Module,
    has_validation: bool,
    extra_metrics: Iterable[Tuple[str, Callable]],
) -> Tuple[str, str, str]:
    top = "epoch  lrn rate"
    bar = "───────────────"
    tmplt_main = "{epoch:>5}  {lr:>#8.3g}"
    tmplt_xtra = "               "

    if hasattr(model, "mix") and model.mix is not None:
        top += "  non-mlp"
        bar += "─────────"
        tmplt_main += "  {mix:>#7.2g}"
        tmplt_xtra += " " * 9

    top += "  train loss"
    bar += "────────────"
    tmplt_main += "  {train_loss:>#10.4g}"
    tmplt_xtra += " " * 12

    if has_validation:
        top += "   val loss"
        bar += "───────────"
        tmplt_main += "  {val_loss:>#9.4g}"
        tmplt_xtra += "  {val_loss:>#9.4g}"
        for name, _ in extra_metrics:
            width = max(len(name), 9)
            precision = width - 5
            fmt = f"  {{{name}:>#{width}.{precision}g}}"
            top += " " * (2 + width - len(name)) + name
            bar += "─" * (width + 2)
            tmplt_main += fmt
            tmplt_xtra += fmt

    print(f"{top}\n{bar}", flush=True)

    return tmplt_main, tmplt_xtra, bar

In [None]:
def _scheduler_step(model: BaseNN, log_info: List[LogInfo]):
    if not model.scheduler:
        return
    if "monitor" in model.scheduler:
        metric = log_info[0][model.scheduler["monitor"]]
        model.scheduler["scheduler"].step(metric)
    else:
        model.scheduler["scheduler"].step()

In [None]:
def _train_batch(
    model: BaseNN,
    batch: List[Tensor],
    batch_idx: int,
    max_grad_norm: float,
    scheduler_step: str,
) -> float:
    model.optimizer.zero_grad(set_to_none=True)
    info = model.training_step(batch, batch_idx)
    info["loss"].backward()
    if max_grad_norm != float("inf"):
        clip_grad_norm_(model.parameters(), max_norm=max_grad_norm)
    model.optimizer.step()
    if scheduler_step == "batch":
        _scheduler_step(model, [])
    return info["loss"].item()

In [None]:
def _val_epoch(
    model: BaseNN,
    loader: DataLoader,
    extra_metrics: Iterable[Tuple[str, Callable]],
) -> LogInfo:
    pbar = tqdm(
        enumerate(loader),
        leave=False,
        file=sys.stdout,
        total=len(loader),
    )
    pbar.set_description(f"Eval {model.num_epochs}")
    ypairs = []
    for batch_idx, batch in pbar:
        ypair = model.validation_step(batch, batch_idx)
        val_loss = model.loss_fn(ypair[0], ypair[1]).item()
        pbar.set_postfix({"Loss": f"{val_loss:#.2g}"})
        ypairs.append(ypair)

    val_info = {}
    metric_info = model.custom_val_epoch_end(ypairs, extra_metrics)
    for name, value in metric_info.items():
        if "_step" in name:
            name = name.replace("_step", "")
        val_info[name] = value

    return val_info

In [None]:
def _epoch_info(
    model: BaseNN,
    log_info: List[LogInfo],
    val_data: Optional[Iterable[DataLoader]],
    extra_metrics: Iterable[Tuple[str, Callable]],
    verbose: bool,
    tmplt_main: str,
    tmplt_xtra: str,
) -> List[LogInfo]:

    for param_group in model.optimizer.param_groups:
        log_info[0]["lr"] = param_group['lr']
        break

    if hasattr(model, "mix") and model.mix is not None:
        log_info[0]["mix"] = expit(model.mix.item())

    if val_data:
        model.eval()
        with torch.no_grad():
            for i, loader in enumerate(val_data):
                if i == 0:
                    tmplt = tmplt_main
                else:
                    tmplt = tmplt_xtra
                    log_info.append({})
                val_info = _val_epoch(model, loader, extra_metrics)
                log_info[-1].update(val_info)
                if verbose:
                    print(tmplt.format(**log_info[-1]), flush=True)
    elif verbose:
        print(tmplt_main.format(**log_info[-1]), flush=True)

    return log_info

In [None]:
def _epoch(
    model: BaseNN,
    train_data: DataLoader,
    val_data: Optional[Iterable[DataLoader]],
    max_grad_norm: float,
    extra_metrics: Iterable[Tuple[str, Callable]],
    scheduler_step: str,
    verbose: bool,
    tmplt_main: str,
    tmplt_xtra: str,
):
    model.train()

    log_info: List[LogInfo] = [{"epoch": model.num_epochs, "time": now()}]

    pbar = tqdm(
        enumerate(train_data),
        leave=False,
        file=sys.stdout,
        total=len(train_data),
    )
    pbar.set_description(f"Train {model.num_epochs}")
    for batch_idx, batch in pbar:
        loss = _train_batch(model, batch, batch_idx, max_grad_norm, scheduler_step)
        pbar.set_postfix({"Loss": f"{loss:#.2g}"})

    log_info[0]["train_loss"] = loss

    log_info = _epoch_info(
        model, log_info, val_data, extra_metrics, verbose, tmplt_main, tmplt_xtra
    )

    if scheduler_step == "epoch":
        _scheduler_step(model, log_info)

    model.num_epochs += 1

    return log_info

In [72]:
def _evaluate(model, metric, patience, mode, window, best, count, log_info, param_path):
    if (patience == float("inf") and not param_path) or not log_info:
        # either nothing requested or don't have the necessary information
        return best, count
    if len(log_info) < window:
        # not enough values to calculate best yet
        return best, count
    if metric not in log_info[0]:
        raise IndexError(f"cannot find early_stopping_metric '{metric}' in validation info")
    value = np.mean([info[metric] for info in log_info[-window:]])
    if (mode == "min" and value < best) or (mode == "max" and value > best):
        best = value
        count = 0
        if param_path:
            torch.save(model.state_dict(), param_path)
    else:
        count += 1
    return best, count

In [30]:
def accuracy(y_pred, y_true):
    y_pred = torch.argmax(y_pred, dim=1)
    acc = torch.eq(y_pred, y_true).to(dtype=torch.int).sum()
    return 100 * acc / y_pred.shape[0]

In [31]:
def now() -> str:
    """
    Return string representing current time
    Returns
    -------
    string with format '%Y-%m-%d %H:%M:%S'
    """
    timestamp = time.time()
    value = datetime.datetime.fromtimestamp(timestamp)
    return value.strftime('%Y-%m-%d %H:%M:%S')

In [36]:
EmbeddingInfo = namedtuple("EmbeddingInfo", ["num_fields", "output_size"])
UniformEmbeddingInfo = namedtuple(
    "EmbeddingInfo", ["num_fields", "embedding_size", "output_size"]
)

In [39]:
def _check_is_uniform(embedding, name):
    if embedding is None:
        return
    if not isinstance(embedding, UniformBase):
        raise TypeError(
            "only 'uniform' embeddings are allowed for this model; "
            f"{name} is not a uniform embedding"
        )

In [40]:
def check_uniform_embeddings(
    embedding_num: Optional[EmbeddingBase],
    embedding_cat: Optional[EmbeddingBase],
) -> EmbeddingInfo:
    """
    Check that embeddings are uniform, are not both None, and have same
    embedding_size
    Parameters
    ----------
    embedding_num : XyNN embedding or None
    embedding_cat : XyNN embedding or None
    Return
    ------
    UniformEmbeddingInfo NamedTuple containing
    - num_fields
    - embedding_size
    - output_size = num_fields * embedding_size
    """
    # check embedding sizes and get derived values
    if embedding_num is None and embedding_cat is None:
        raise ValueError("embedding_num and embedding_cat cannot both be None")

    _check_is_uniform(embedding_num, "embedding_num")
    _check_is_uniform(embedding_cat, "embedding_cat")

    if (
        embedding_num is not None
        and embedding_cat is not None
        and not embedding_num.embedding_size == embedding_cat.embedding_size
    ):
        raise ValueError(
            "embedding sizes must be the same for numeric and catgorical; got "
            f"{embedding_num.embedding_size} and {embedding_cat.embedding_size}"
        )

    num_fields = 0
    if embedding_num is not None:
        num_fields += embedding_num.num_fields
        embedding_size = embedding_num.embedding_size

    if embedding_cat is not None:
        num_fields += embedding_cat.num_fields
        embedding_size = embedding_cat.embedding_size

    return UniformEmbeddingInfo(num_fields, embedding_size, num_fields * embedding_size)

In [55]:
def _initialized_tensor(*sizes):
    weight = nn.Parameter(torch.Tensor(*sizes))
    nn.init.kaiming_uniform_(weight)
    return weight

## train model

In [None]:
model = AutoIntClassifier(
    attn_activation=None,
    attn_dropout=0.0,
    attn_normalize=False,
    mlp_hidden_sizes=(256, 192, 128, 64),
    mlp_activation=nn.LeakyReLU,
    mlp_use_bn=True,
    mlp_dropout=0.0,
    mlp_use_skip=True,
    use_leaky_gate=True,
    seed=SEED,
    device="cuda",
)

In [87]:
model.fit(
    X_num=X_num_train,
    X_cat=X_cat_train,
    y=y_train,
    optimizer=torch.optim.Adam,
    opt_kwargs={"lr": 1e-2},
    scheduler=torch.optim.lr_scheduler.StepLR,
    sch_kwargs={"step_size": 5, "gamma": 0.1 ** 0.125},
    val_sets=[[X_num_valid, X_cat_valid, y_valid]],
    extra_metrics=[("accuracy", accuracy)],
    num_epochs=100,
    batch_size=2048,
    early_stopping_patience=10,
    early_stopping_metric="accuracy",
    early_stopping_mode="max",
    #log_path=f"autoint_forest_log_seed{SEED}.txt",  # save epoch info to file
    #param_path=f"autoint_forest_seed{SEED}.pkl",  # auto-restore best model
    verbose=True,
)

epoch  lrn rate  non-mlp  train loss   val loss   accuracy
──────────────────────────────────────────────────────────


HBox(children=(FloatProgress(value=0.0, max=227.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=57.0), HTML(value='')))

    0    0.0100     0.44      0.3439     0.3647      84.87


HBox(children=(FloatProgress(value=0.0, max=227.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=57.0), HTML(value='')))

    1    0.0100     0.44      0.2680     0.2671      89.24


HBox(children=(FloatProgress(value=0.0, max=227.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=57.0), HTML(value='')))

    2    0.0100     0.43      0.2091     0.2580      89.43


HBox(children=(FloatProgress(value=0.0, max=227.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=57.0), HTML(value='')))

    3    0.0100     0.43      0.1674     0.2120      91.28


HBox(children=(FloatProgress(value=0.0, max=227.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=57.0), HTML(value='')))

    4    0.0100     0.43      0.1806     0.1966      92.03


HBox(children=(FloatProgress(value=0.0, max=227.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=57.0), HTML(value='')))

    5   0.00750     0.43      0.1302     0.1552      93.69


HBox(children=(FloatProgress(value=0.0, max=227.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=57.0), HTML(value='')))

    6   0.00750     0.43      0.1551     0.1492      94.04


HBox(children=(FloatProgress(value=0.0, max=227.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=57.0), HTML(value='')))

    7   0.00750     0.43      0.1405     0.1521      93.94


HBox(children=(FloatProgress(value=0.0, max=227.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=57.0), HTML(value='')))

    8   0.00750     0.43      0.1383     0.1435      94.27


HBox(children=(FloatProgress(value=0.0, max=227.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=57.0), HTML(value='')))

    9   0.00750     0.43      0.1227     0.1445      94.22


HBox(children=(FloatProgress(value=0.0, max=227.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=57.0), HTML(value='')))

   10   0.00562     0.43      0.1174     0.1269      94.91


HBox(children=(FloatProgress(value=0.0, max=227.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=57.0), HTML(value='')))

   11   0.00562     0.43      0.1067     0.1184      95.24


HBox(children=(FloatProgress(value=0.0, max=227.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=57.0), HTML(value='')))

   12   0.00562     0.44      0.1165     0.1153      95.37


HBox(children=(FloatProgress(value=0.0, max=227.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=57.0), HTML(value='')))

   13   0.00562     0.44     0.08442     0.1118      95.59


HBox(children=(FloatProgress(value=0.0, max=227.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=57.0), HTML(value='')))

   14   0.00562     0.44      0.1012     0.1158      95.43


HBox(children=(FloatProgress(value=0.0, max=227.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=57.0), HTML(value='')))

   15   0.00422     0.44      0.1009     0.1018      95.92


HBox(children=(FloatProgress(value=0.0, max=227.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=57.0), HTML(value='')))

   16   0.00422     0.44     0.08654    0.09768      96.10


HBox(children=(FloatProgress(value=0.0, max=227.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=57.0), HTML(value='')))

   17   0.00422     0.44     0.08180     0.1025      96.05


HBox(children=(FloatProgress(value=0.0, max=227.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=57.0), HTML(value='')))

   18   0.00422     0.44     0.07438    0.09992      96.11


HBox(children=(FloatProgress(value=0.0, max=227.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=57.0), HTML(value='')))

   19   0.00422     0.44     0.08750    0.09885      96.12


HBox(children=(FloatProgress(value=0.0, max=227.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=57.0), HTML(value='')))

   20   0.00316     0.45     0.06623    0.09260      96.37


HBox(children=(FloatProgress(value=0.0, max=227.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=57.0), HTML(value='')))

   21   0.00316     0.45     0.06047    0.08991      96.53


HBox(children=(FloatProgress(value=0.0, max=227.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=57.0), HTML(value='')))

   22   0.00316     0.45     0.06317    0.09402      96.41


HBox(children=(FloatProgress(value=0.0, max=227.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=57.0), HTML(value='')))

   23   0.00316     0.45     0.06363    0.09163      96.44


HBox(children=(FloatProgress(value=0.0, max=227.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=57.0), HTML(value='')))

   24   0.00316     0.45     0.06928    0.09245      96.43


HBox(children=(FloatProgress(value=0.0, max=227.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=57.0), HTML(value='')))

   25   0.00237     0.45     0.04949    0.08544      96.78


HBox(children=(FloatProgress(value=0.0, max=227.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=57.0), HTML(value='')))

   26   0.00237     0.45     0.05488    0.08691      96.77


HBox(children=(FloatProgress(value=0.0, max=227.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=57.0), HTML(value='')))

   27   0.00237     0.46     0.06272    0.08703      96.75


HBox(children=(FloatProgress(value=0.0, max=227.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=57.0), HTML(value='')))

   28   0.00237     0.46     0.05094    0.08535      96.80


HBox(children=(FloatProgress(value=0.0, max=227.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=57.0), HTML(value='')))

   29   0.00237     0.46     0.06651    0.08683      96.82


HBox(children=(FloatProgress(value=0.0, max=227.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=57.0), HTML(value='')))

   30   0.00178     0.46     0.04953    0.08370      96.97


HBox(children=(FloatProgress(value=0.0, max=227.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=57.0), HTML(value='')))

   31   0.00178     0.46     0.03509    0.08243      97.01


HBox(children=(FloatProgress(value=0.0, max=227.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=57.0), HTML(value='')))

   32   0.00178     0.46     0.03549    0.08247      97.01


HBox(children=(FloatProgress(value=0.0, max=227.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=57.0), HTML(value='')))

   33   0.00178     0.46     0.04647    0.08590      96.94


HBox(children=(FloatProgress(value=0.0, max=227.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=57.0), HTML(value='')))

   34   0.00178     0.46     0.04866    0.08914      96.80


HBox(children=(FloatProgress(value=0.0, max=227.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=57.0), HTML(value='')))

   35   0.00133     0.47     0.03072    0.08198      97.08


HBox(children=(FloatProgress(value=0.0, max=227.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=57.0), HTML(value='')))

   36   0.00133     0.47     0.02793    0.08028      97.21


HBox(children=(FloatProgress(value=0.0, max=227.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=57.0), HTML(value='')))

   37   0.00133     0.47     0.03127    0.08120      97.20


HBox(children=(FloatProgress(value=0.0, max=227.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=57.0), HTML(value='')))

   38   0.00133     0.47     0.02825    0.08296      97.12


HBox(children=(FloatProgress(value=0.0, max=227.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=57.0), HTML(value='')))

   39   0.00133     0.47     0.03109    0.08360      97.13


HBox(children=(FloatProgress(value=0.0, max=227.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=57.0), HTML(value='')))

   40   0.00100     0.47     0.02281    0.08080      97.26


HBox(children=(FloatProgress(value=0.0, max=227.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=57.0), HTML(value='')))

   41   0.00100     0.47     0.03187    0.08281      97.20


HBox(children=(FloatProgress(value=0.0, max=227.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=57.0), HTML(value='')))

   42   0.00100     0.48     0.02911    0.08388      97.25


HBox(children=(FloatProgress(value=0.0, max=227.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=57.0), HTML(value='')))

   43   0.00100     0.48     0.02204    0.08367      97.27


HBox(children=(FloatProgress(value=0.0, max=227.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=57.0), HTML(value='')))

   44   0.00100     0.48     0.03115    0.08470      97.22


HBox(children=(FloatProgress(value=0.0, max=227.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=57.0), HTML(value='')))

   45  0.000750     0.48     0.02412    0.08143      97.40


HBox(children=(FloatProgress(value=0.0, max=227.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=57.0), HTML(value='')))

   46  0.000750     0.48     0.02111    0.08278      97.32


HBox(children=(FloatProgress(value=0.0, max=227.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=57.0), HTML(value='')))

   47  0.000750     0.48     0.02403    0.08407      97.31


HBox(children=(FloatProgress(value=0.0, max=227.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=57.0), HTML(value='')))

   48  0.000750     0.48     0.02247    0.08476      97.29


HBox(children=(FloatProgress(value=0.0, max=227.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=57.0), HTML(value='')))

   49  0.000750     0.48     0.03090    0.08550      97.31


HBox(children=(FloatProgress(value=0.0, max=227.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=57.0), HTML(value='')))

   50  0.000562     0.48     0.01749    0.08276      97.37


HBox(children=(FloatProgress(value=0.0, max=227.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=57.0), HTML(value='')))

   51  0.000562     0.48     0.02336    0.08529      97.36


HBox(children=(FloatProgress(value=0.0, max=227.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=57.0), HTML(value='')))

   52  0.000562     0.48     0.02138    0.08595      97.35


HBox(children=(FloatProgress(value=0.0, max=227.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=57.0), HTML(value='')))

   53  0.000562     0.49     0.01625    0.08587      97.35


HBox(children=(FloatProgress(value=0.0, max=227.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=57.0), HTML(value='')))

   54  0.000562     0.49     0.01718    0.08588      97.40


HBox(children=(FloatProgress(value=0.0, max=227.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=57.0), HTML(value='')))

   55  0.000422     0.49     0.02079    0.08570      97.39


HBox(children=(FloatProgress(value=0.0, max=227.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=57.0), HTML(value='')))

   56  0.000422     0.49     0.02219    0.08564      97.40


HBox(children=(FloatProgress(value=0.0, max=227.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=57.0), HTML(value='')))

   57  0.000422     0.49     0.02007    0.08774      97.39


HBox(children=(FloatProgress(value=0.0, max=227.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=57.0), HTML(value='')))

   58  0.000422     0.49     0.02703    0.08679      97.43


HBox(children=(FloatProgress(value=0.0, max=227.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=57.0), HTML(value='')))

   59  0.000422     0.49     0.01889    0.08846      97.40


HBox(children=(FloatProgress(value=0.0, max=227.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=57.0), HTML(value='')))

   60  0.000316     0.49     0.02194    0.08724      97.42


HBox(children=(FloatProgress(value=0.0, max=227.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=57.0), HTML(value='')))

   61  0.000316     0.49     0.01587    0.08738      97.44


HBox(children=(FloatProgress(value=0.0, max=227.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=57.0), HTML(value='')))

   62  0.000316     0.49     0.01131    0.08936      97.43


HBox(children=(FloatProgress(value=0.0, max=227.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=57.0), HTML(value='')))

   63  0.000316     0.49     0.01155    0.08899      97.42


HBox(children=(FloatProgress(value=0.0, max=227.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=57.0), HTML(value='')))

   64  0.000316     0.49     0.01301    0.08866      97.43


HBox(children=(FloatProgress(value=0.0, max=227.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=57.0), HTML(value='')))

   65  0.000237     0.49    0.007999    0.08897      97.45


HBox(children=(FloatProgress(value=0.0, max=227.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=57.0), HTML(value='')))

   66  0.000237     0.49     0.01272    0.08906      97.44


HBox(children=(FloatProgress(value=0.0, max=227.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=57.0), HTML(value='')))

   67  0.000237     0.49     0.01498    0.08954      97.43


HBox(children=(FloatProgress(value=0.0, max=227.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=57.0), HTML(value='')))

   68  0.000237     0.49     0.01109    0.09040      97.46


HBox(children=(FloatProgress(value=0.0, max=227.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=57.0), HTML(value='')))

   69  0.000237     0.49     0.01288    0.09058      97.43


HBox(children=(FloatProgress(value=0.0, max=227.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=57.0), HTML(value='')))

   70  0.000178     0.49     0.01599    0.09037      97.43


HBox(children=(FloatProgress(value=0.0, max=227.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=57.0), HTML(value='')))

   71  0.000178     0.49     0.01543    0.09150      97.43


HBox(children=(FloatProgress(value=0.0, max=227.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=57.0), HTML(value='')))

   72  0.000178     0.49     0.02014    0.09114      97.46


HBox(children=(FloatProgress(value=0.0, max=227.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=57.0), HTML(value='')))

   73  0.000178     0.49    0.008304    0.09163      97.44


HBox(children=(FloatProgress(value=0.0, max=227.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=57.0), HTML(value='')))

   74  0.000178     0.49     0.01466    0.09182      97.42


HBox(children=(FloatProgress(value=0.0, max=227.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=57.0), HTML(value='')))

   75  0.000133     0.49     0.01081    0.09094      97.44


HBox(children=(FloatProgress(value=0.0, max=227.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=57.0), HTML(value='')))

   76  0.000133     0.49     0.01631    0.09144      97.45


HBox(children=(FloatProgress(value=0.0, max=227.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=57.0), HTML(value='')))

   77  0.000133     0.49     0.01172    0.09209      97.44


HBox(children=(FloatProgress(value=0.0, max=227.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=57.0), HTML(value='')))

   78  0.000133     0.49     0.01096    0.09208      97.46


HBox(children=(FloatProgress(value=0.0, max=227.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=57.0), HTML(value='')))

   79  0.000133     0.49    0.009646    0.09284      97.46


HBox(children=(FloatProgress(value=0.0, max=227.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=57.0), HTML(value='')))

   80  0.000100     0.49     0.01522    0.09250      97.45


HBox(children=(FloatProgress(value=0.0, max=227.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=57.0), HTML(value='')))

   81  0.000100     0.49     0.01409    0.09262      97.45


HBox(children=(FloatProgress(value=0.0, max=227.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=57.0), HTML(value='')))

   82  0.000100     0.49    0.007342    0.09307      97.45


HBox(children=(FloatProgress(value=0.0, max=227.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=57.0), HTML(value='')))

   83  0.000100     0.50    0.009062    0.09332      97.45
Stopping early. Best epoch: 72. Best accuracy:     97.4631
