In [61]:
import typing
from typing import Optional, Union

from marshmallow import INCLUDE, Schema, ValidationError, fields, post_load


class BaseSchema(Schema):
    _data_class = None  # deserialize Python object's class

    @post_load
    def make_obj(self, data, **kwargs):
        """if `_data_class` is not None, marshamallow will try to make a `_data_class` instance, otherwise
        return dict directed.
        """
        if self._data_class is None:
            return data
        else:
            return self._data_class(**data)


class PolymorphicField(fields.Field):
    """


    根据 Python 对象中的某个属性值使用不同的 JSON Schema 序列化对象，
    或者根据 JSON 文本中的某个字段值使用不同的 JSON Schema 反序列为 Python 对象。

    Args:
        fields (_type_): _description_
        将 `value` 序列化为基本的 Python 数据类型

        尝试通过 `value.serialize_attr` 匹配 Schema，并使用匹配到的 Schema 序列化；
        如果没有匹配到 Schema 则会使用 `default_schema` 序列化；
        如果 `default_schema` 为空，则会抛出 ValidationError 异常

    """

    def __init__(
        self,
        schemas: dict[str, Union[type[Schema], typing.Callable[[], Schema]]],
        seri_attr: str = "",
        deseri_attr: str = "",
        deseri_schemas: Optional[dict[str, Union[type[Schema], typing.Callable[[], Schema]]]] = None,
        deseri_as_dict: bool = True,
        **kwargs,
    ):
        super().__init__(**kwargs)
        if schemas is None or not schemas:
            raise ValueError("serialize schemas can't be empty.")
        self.schemas = {key.lower(): value if isinstance(value, Schema) else value() for key, value in schemas.items()}
        self.seri_attr = seri_attr
        self.deseri_attr = deseri_attr
        if deseri_schemas is None or not deseri_schemas:
            self.deseri_schemas = {}
        else:
            self.deseri_schemas = {key.lower(): value if isinstance(value, Schema) else value() for key, value in deseri_schemas.items()}
        self.deseri_as_dict = deseri_as_dict

    def _serialize(self, value, attr, obj, **kwargs):
        if value is None:
            return None
        if self.seri_attr:  # 获取匹配 Schema 的键
            schema_key = getattr(value, self.seri_attr).lower()
        else:
            schema_key = type(value).__name__.lower()
        if schema_key not in self.schemas:
            raise ValidationError(f"unknown serializes type: {schema_key}")
        else:
            return self.schemas[schema_key].dump(value)

    def _deserialize(self, value, attr, data, **kwargs):
        if not isinstance(value, dict):  # 非字典类型，直接返回
            return value
        schema_key = value.get(self.deseri_attr, "").lower()
        if schema_key is not None and (schema_key in self.deseri_schemas or schema_key in self.schemas):
            schema = self.deseri_schemas.get(schema_key, self.schemas.get(schema_key))
            try:
                return schema.load(value)
            except TypeError:
                if self.deseri_as_dict:
                    return Schema.from_dict({})(unknown=INCLUDE).load(value)
                raise
        elif self.deseri_as_dict:
            return Schema.from_dict({})(unknown=INCLUDE).load(value)
        else:
            raise ValidationError(f"unknown deserializes type: {schema_key}")


In [64]:
from dataclasses import dataclass

@dataclass
class Teacher:
    name: str
    age: int
    role: str = "teacher"
@dataclass
class Student:
    name: str
    role: str = "student"
    class_: str = "class 1"

t1 = Teacher(name="t1", age=35)
s1 = Student(name="s1", class_="class 12")
s2 = Student(name="s2", class_="class 12")


class TeacherSchema(BaseSchema):
    _data_class = Teacher

    name = fields.String()
    age = fields.Integer()
    role = fields.String()
class StudentSchema(BaseSchema):
    _data_class = Student

    name = fields.String()
    role = fields.String()
    class_ = fields.String(attribute="class_", data_key="class")

class TSSchema(BaseSchema):
    list_ = fields.List(
        PolymorphicField(
            schemas={
                "student": StudentSchema(only=("name", "role"), unknown=INCLUDE),
                "teacher": TeacherSchema(only=("name",), unknown=INCLUDE),
            },
            seri_attr="role",
            deseri_attr="role",
            deseri_schemas={
                "student": StudentSchema
            },
            deseri_as_dict=True,
        )
    )

TSSchema().dumps({"list_": [s1, t1, s2]})

'{"list_": [{"name": "s1", "role": "student"}, {"name": "t1"}, {"name": "s2", "role": "student"}]}'

In [65]:
json_str = '{"list_": [{"name": "s1", "role": "none", "class": "class 12"}, {"name": "s1", "role": "student", "class": "class 12"}, {"name": "t1", "age": 35, "role": "teacher"}, {"name": "s2", "role": "student", "class": "class 12"}]}'
TSSchema().loads(json_str)


{'list_': [{'role': 'none', 'name': 's1', 'class': 'class 12'},
  Student(name='s1', role='student', class_='class 12'),
  Teacher(name='t1', age=35, role='teacher'),
  Student(name='s2', role='student', class_='class 12')]}