In [2]:
import torch

from typing import Protocol, runtime_checkable

In [3]:
@runtime_checkable
class AddMul(Protocol):
    def __add__(self, other: "AddMul") -> "AddMul": ...
    def __mul__(self, factor: float) -> "AddMul": ...

In [4]:
t = torch.Tensor(0)

In [5]:
isinstance(t, AddMul)

True

In [6]:
isinstance("a", AddMul)

True

In [10]:
isinstance([1], AddMul)

True

In [11]:
def process(value: AddMul) -> AddMul:
    return (2*value) + (3*value) 

In [20]:
process((1,))

(1, 1, 1, 1, 1)

In [21]:
isinstance((1,2), AddMul)

True

In [22]:
isinstance({1,2}, AddMul)

False

In [23]:
process({1,2})

TypeError: unsupported operand type(s) for *: 'int' and 'set'

In [24]:
process({1:2})

TypeError: unsupported operand type(s) for *: 'int' and 'dict'

In [25]:
process(t)

tensor([])

In [26]:
t = torch.tensor(1)
t

tensor(1)

In [27]:
process(t)

tensor(5)

In [28]:
from world_machine.train.scheduler import ParameterScheduler, LinearScheduler

In [29]:
l = LinearScheduler(0, 1, 5)

In [31]:
l(0), l(1), l(2), l(3), l(4), l(5)

(0.0, 0.25, 0.5, 0.75, 1.0, 1.25)

In [32]:
isinstance(l, ParameterScheduler)

True

In [33]:
from world_machine_experiments.shared.save_parameters import make_model

In [58]:
import pydantic_core
from pydantic_core import CoreSchema, core_schema

from pydantic import GetCoreSchemaHandler, TypeAdapter

import json





class LinearScheduler(ParameterScheduler):
    def __init__(self, initial_value, final_value, n_epoch: int):
        super().__init__(n_epoch)

        self._initial_value = initial_value
        self._final_value = final_value

    def __call__(self, epoch_index):
        t = epoch_index/(self._n_epoch-1)

        result = (t-1)*self._initial_value
        result += t*self._final_value

        return result
    
    @classmethod
    def __get_pydantic_core_schema__(cls, source_type, handler:GetCoreSchemaHandler) -> CoreSchema:
        def serialize(value:"LinearScheduler") -> str:
            return json.dumps({"type":value.__class__.__name__, "initial_value":value._initial_value, "final_value":value._final_value, "n_epoch":value._n_epoch})

        def validate(value:str) -> "LinearScheduler":
            return value

        schema = core_schema.union_schema([
            core_schema.is_instance_schema(cls),
        ])

        return pydantic_core.core_schema.no_info_after_validator_function(
            validate,
            schema,
            serialization=pydantic_core.core_schema.plain_serializer_function_ser_schema(
                serialize, when_used="json"
            ),
        )

In [59]:
from pydantic import BaseModel, ImportString, create_model

log = {"parameters":{"scheduler":LinearScheduler(0, 1, 5)}}

def _make_model(v, name):
    if type(v) is dict:
        return create_model(name, **{k: _make_model(v, k) for k, v in v.items()}), ...
    elif type(v) is type:
        return ImportString, v
    return type(v), v


def make_model(v: dict, name: str):
    return _make_model(v, name)[0]


model = make_model(log, "Parameters").model_validate(log)
model_json = model.model_dump_json(indent=4)


print(model_json)

{
    "parameters": {
        "scheduler": "{\"type\": \"LinearScheduler\", \"initial_value\": 0, \"final_value\": 1, \"n_epoch\": 5}"
    }
}
