In [None]:
import abc
from typing import Any, List, Optional, Sequence, Tuple, TypeVar, Union

import numpy as np
from sklearn.utils.multiclass import type_of_target as _type_of_target

__all__ = [
    "type_of_target",
    "TargetType",
    "ContinuousTargetType",
    "ContinuousMultioutputTargetType",
    "BinaryTargetType",
    "MulticlassTargetType",
    "MulticlassMultioutputTargetType",
    "MultilabelIndicatorTargetType",
    "UnknownTargetType",
]

T = TypeVar("T")


class TargetType(abc.ABC):
    label: str
    description: Optional[str]

    @abc.abstractmethod
    def is_multioutput(self) -> bool:
        pass

    @abc.abstractmethod
    def is_continuous(self) -> bool:
        pass

    @abc.abstractmethod
    def validate(self, y: Union[np.ndarray, list], **kwargs) -> np.ndarray:
        pass

    def __repr__(self) -> str:
        return f'TargetType(label="{self.label}", description="{self.description}")'


class ContinuousTargetType(TargetType):
    label = "continuous"
    description = "Regression (1D continuous targets)"

    def is_multioutput(self) -> bool:
        return False

    def is_continuous(self) -> bool:
        return True

    def validate(self, y: Union[np.ndarray, list]) -> np.ndarray:
        y = np.asarray(y)
        if y.ndim != 1:
            raise ValueError("continuous target should be 1-dimensional")
        if not np.issubdtype(y.dtype, np.number):
            raise ValueError("continuous target should contain numeric values")
        return y


class ContinuousMultioutputTargetType(TargetType):
    label = "continuous-multioutput"
    description = "Multioutput Regression (2D continuous targets)"

    def is_multioutput(self) -> bool:
        return True

    def is_continuous(self) -> bool:
        return True

    def validate(self, y: Union[np.ndarray, list]) -> np.ndarray:
        y = np.asarray(y)
        if y.ndim != 2:
            raise ValueError("continuous multioutput target should be 2-dimensional")
        if not np.issubdtype(y.dtype, np.number):
            raise ValueError(
                "continuous multioutput target should contain numeric values"
            )
        return y


class BinaryTargetType(TargetType):
    label = "binary"
    description = "Binary Classification (1D binary targets)"

    def is_multioutput(self) -> bool:
        return False

    def is_continuous(self) -> bool:
        return False

    def validate(
        self,
        y: Union[np.ndarray, list],
        classes_: Union[List[str], Tuple[str], None] = None,
    ) -> np.ndarray:
        y = np.asarray(y)
        if y.ndim != 1:
            raise ValueError("binary target should be 1-dimensional")
        unique_values = np.unique(y)
        if len(unique_values) != 2:
            raise ValueError("binary target should have exactly 2 unique values")
        if classes_ is None:
            return y
        if len(classes_) != 2:
            raise ValueError("classes_ must contain exactly 2 unique values")
        elif not set(unique_values).issubset(set(classes_)):
            raise ValueError("y contains classes not in classes_")
        cls_to_idx = {cls: i for i, cls in enumerate(classes_)}
        return np.array([cls_to_idx[v] for v in y])


class MulticlassTargetType(TargetType):
    label = "multiclass"
    description = "Multiclass Classification (>2 discrete classes, 1D)"

    def is_multioutput(self) -> bool:
        return False

    def is_continuous(self) -> bool:
        return False

    def validate(
        self,
        y: Union[np.ndarray, list],
        classes_: Union[List[str], Tuple[str], None] = None,
    ) -> np.ndarray:
        y = np.asarray(y)
        if y.ndim != 1:
            raise ValueError("multiclass target should be 1-dimensional")
        unique_values = np.unique(y)
        if len(unique_values) < 2:
            raise ValueError("multiclass target should have at least 2 unique values")
        if classes_ is None:
            return y
        if len(classes_) < 2:
            raise ValueError("classes_ must contain at least 2 unique values")
        elif not set(unique_values).issubset(set(classes_)):
            raise ValueError("y contains classes not in classes_")
        cls_to_idx = {cls: i for i, cls in enumerate(classes_)}
        return np.array([cls_to_idx[v] for v in y])


class MulticlassMultioutputTargetType(TargetType):
    label = "multiclass-multioutput"
    description = "Multiclass-Multioutput (>2 discrete classes, 2D)"

    def is_multioutput(self) -> bool:
        return True

    def is_continuous(self) -> bool:
        return False

    def validate(
        self,
        y: Union[np.ndarray, list],
        classes_: Union[Sequence[Union[Sequence[str]]], None] = None,
    ) -> np.ndarray:
        y = np.asarray(y)
        if y.ndim != 2:
            raise ValueError("multiclass-multioutput target should be 2-dimensional")
        if classes_ is None:
            if any(len(np.unique(column)) < 2 for column in y.T):
                msg = (
                    "multiclass-multioutput target should have at "
                    "least 2 unique values in each column"
                )
                raise ValueError(msg)
            return y
        if len(classes_) != y.shape[1]:
            raise ValueError(f"classes_ must contain {y.shape[1]} tasks")
        for i, (task_y, task_classes) in enumerate(zip(y.T, classes_)):
            unique_values = np.unique(task_y)
            if len(unique_values) < 2:
                raise ValueError(f"column {i} should have more than 2 unique values")
            if not set(unique_values).issubset(set(task_classes)):
                raise ValueError(f"column {i} contains classes not in classes_")
        cls_to_idx = [
            {cls: i for i, cls in enumerate(column_classes)}
            for column_classes in classes_
        ]
        return np.array([[cls_to_idx[j][v] for j, v in enumerate(row)] for row in y])


class MultilabelIndicatorTargetType(TargetType):
    label = "multilabel-indicator"
    description = "Multilabel Classification (2D binary indicator matrix)"

    def is_multioutput(self) -> bool:
        return True

    def is_continuous(self) -> bool:
        return False

    def validate(self, y: Union[np.ndarray, list]) -> np.ndarray:
        y = np.asarray(y)
        if y.ndim != 2:
            msg = "multilabel-indicator target should be 2-dimensional"
            raise ValueError(msg)
        if not np.issubdtype(y.dtype, np.integer):
            msg = "multilabel-indicator target should contain integer values"
            raise ValueError(msg)
        if not np.all(np.isin(y, [0, 1])):
            msg = "multilabel-indicator target should contain only 0 and 1"
            raise ValueError(msg)
        return y


class UnknownTargetType(TargetType):
    label = "unknown"
    description = "Unknown (possibly 3D array, sequence of sequences, or array of non-sequence objects)"

    def is_multioutput(self) -> bool:
        msg = "cannot determine if target is multioutput for unknown target type"
        raise ValueError(msg)

    def is_continuous(self) -> bool:
        msg = "cannot determine if target is continuous for unknown target type"
        raise ValueError(msg)

    def validate(self, y: Union[np.ndarray, list], **kwargs) -> np.ndarray:
        return y


def type_of_target(y: Any) -> TargetType:
    tt: str = _type_of_target(y)
    if tt == "continuous":
        return ContinuousTargetType()
    elif tt == "continuous-multioutput":
        return ContinuousMultioutputTargetType()
    elif tt == "binary":
        return BinaryTargetType()
    elif tt == "multiclass":
        return MulticlassTargetType()
    elif tt == "multiclass-multioutput":
        return MulticlassMultioutputTargetType()
    elif tt == "multilabel-indicator":
        return MultilabelIndicatorTargetType()
    return UnknownTargetType()


TARGET_TYPE_RANK = {
    "binary": 0,
    "multiclass": 1,
    "multilabel-indicator": 2,
    "multiclass-multioutput": 3,
    "continuous": 4,
    "continuous-multioutput": 5,
    "unknown": 6,
}


def resolve_target_type(*ps: str) -> None:
    if not ps:
        msg = "at least one target type must be provided"
        raise ValueError(msg)
    ps = [p for p in ps if p is not None]
    if len(ps) == 1:
        return ps[0]
    if len(ps) > 2:
        p = ps[0]
        for i in range(1, len(ps)):
            p = resolve_target_type(p, ps[i])
        return p
    p, q = sorted(ps, key=lambda x: TARGET_TYPE_RANK[x])
    try:
        return {
            ("binary", "multilabel-indicator"): "unknown",
            ("binary", "multiclass-multioutput"): "unknown",
            ("multilabel-indicator", "continuous"): "continuous-multioutput",
            ("multiclass-multioutput", "continuous"): "continuous-multioutput",
        }[(p, q)]
    except KeyError:
        return q


In [None]:
if __name__ == "__main__":
    test_cases = {
        "Regression": [0.1, 0.2, 0.3, 0.4],
        "Multioutput Regression": [[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]],
        "Binary Classification [str]": ["yes", "no", "yes", "no"],
        "Binary Classification [bool]": [True, False, True, True],
        "Binary Classification [int]": [1, 0, 1, 1, 1],
        "Binary Classification [int]": [5, 6, 5, 6, 6],
        "Multiclass Classification": ["cat", "dog", "fish", "dog", "cat"],
        "Multiclass-Multioutput": [
            ["sunny", "warm"],
            ["rainy", "cold"],
            ["sunny", "cold"],
        ],
        "Multilabel Classification": [[0, 1], [1, 1], [0, 0]],
        "Unknown": [[[1, 2], [3, 4]], [[5, 6], [7, 8]]],
    }
    for name, y in test_cases.items():
        target = type_of_target(y)
        print(f"\n{name}:")
        print(f"\t{target}")
        print(f"\tValidated: {target.validate(y)}")
        if hasattr(target, "classes_"):
            print(f"\tClasses: {target.classes_}")

In [None]:
y = test_cases["Multiclass-Multioutput"]
target = type_of_target(y)
print(f"\n{name}:")
print(f"\t{target}")
classes_ = [["rainy", "sunny"], ["warm", "cold"]]
print(f"\tValidated: {target.validate(y, classes_=classes_)}")
if hasattr(target, "classes_"):
    print(f"\tClasses: {target.classes_}")

In [None]:
y = test_cases["Multiclass-Multioutput"]

y = np.asarray(y)

for column in y.T:
    print(column)

In [None]:
[len(np.unique(column)) < 2 for column in y.T]

In [None]:
type_of_target([[1, 2], [2, 1]])

In [None]:
import torch

y = torch.tensor([1,2 , 3])
type(np.asarray(y).dtype)

In [None]:
# y_true = torch.Tensor([1, 2, 0, 2, 1, 1, 0])
# y_true = torch.Tensor([[0, 1, 0], [1, 0, 0]])
# y_true = torch.Tensor([[0, 1, 0], [1, 1, -1]])

In [None]:
y_pred = [
    [
        [0, 0, 1], 
        [0, 1, 0]
    ], 
    [
        [0, 1, 0], 
        [1, 0, 0]
    ], 
]

y_true = torch.Tensor(y_pred)

y_true.shape

In [None]:

def is_one_hot(arr):
    if not torch.is_tensor(arr):
        arr = torch.tensor(arr)
    is_binary = torch.all((arr == 0) | (arr == 1)).item()
    if len(arr.shape) == 1:
        return is_binary
    return is_binary and ((torch.sum(arr, dim=-1) == 1).all().item())

In [None]:
y_pred = [
    [
        [0, 0, 1], 
        [0, 1, 0]
    ], 
    [
        [0, 1, 0], 
        [1, 0, 0]
    ], 
]

is_one_hot(y_pred)

In [None]:
y_pred = [
    [0, 0, 1], 
    [0, 1, 0],
    [0, 1, 0], 
    [1, 0, 0],
]

is_one_hot(y_pred)

In [None]:
y_pred = [
    1, 0, 1, 0
]

is_one_hot(y_pred)