Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: scaffolding to support custom context in extensions #816

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

noahnu
Copy link
Collaborator

@noahnu noahnu commented Sep 27, 2023

NOTE: Since syrupy v4 migrated from instance methods to classmethods, this new context is not actual usable. This lays the groundwork for a switch back to instance methods though (if we continue along this path).

Related to #814, this PR lays the groundwork to switch back to instance-based extensions (reverting an earlier decision to move to class methods for easier pytest-xdist compatibility).

NOTE: Since syrupy v4 migrated from instance methods to classmethods, this new context is not actual usable. This lays the groundwork for a switch back to instance methods though (if we continue along this path).
@atharva-2001
Copy link
Contributor

Thank you @noahnu for this PR! Would you mind letting me know how long this PR would take to merge? This feature was critical in some of our tests in TARDIS. If there is something I can do to help, please let me know!!

@noahnu
Copy link
Collaborator Author

noahnu commented Sep 29, 2023

@atharva-2001 I can't give an ETA. Could you describe what you're trying to do in your project (possibly with an example)? I may be able to recommend a workaround.

@atharva-2001
Copy link
Contributor

atharva-2001 commented Oct 2, 2023

I see. Here is some example code. Since most of my code deals with NumPy arrays and Pandas dataframes, I want to send in additional assertion options, for example, rtol, the assertion function etc. I don't want to create multiple fixtures by typing them out. Is it possible to at least create them programatically? Thanks for all the help!

from typing import Any

import numpy as np
import pytest

from syrupy.data import SnapshotCollection
from syrupy.extensions.single_file import SingleFileSnapshotExtension
from syrupy.types import SerializableData


class NumpySnapshotExtenstion(SingleFileSnapshotExtension):
    _file_extension = "dat"

    def matches(self, *, serialized_data, snapshot_data, **kwargs):
        print(kwargs, "kwargs inside matches")
        try:
            if (
                # Allow multiple assertion methodds here, for example- assert_almost_equal
                # allow relative and default tolerance
                np.testing.assert_allclose(
                    np.array(snapshot_data), np.array(serialized_data), **kwargs
                )
                is not None
            ):
                return False
            else:
                return True

        except:
            return False

    def _read_snapshot_data_from_location(
        self, *, snapshot_location: str, snapshot_name: str, session_id: str
    ):
        # see https://github.com/tophat/syrupy/blob/f4bc8453466af2cfa75cdda1d50d67bc8c4396c3/src/syrupy/extensions/base.py#L139
        try:
            return np.loadtxt(snapshot_location).tolist()
        except OSError:
            return None

    @classmethod
    def _write_snapshot_collection(
        cls, *, snapshot_collection: SnapshotCollection
    ) -> None:
        # see https://github.com/tophat/syrupy/blob/f4bc8453466af2cfa75cdda1d50d67bc8c4396c3/src/syrupy/extensions/base.py#L161

        filepath, data = (
            snapshot_collection.location,
            next(iter(snapshot_collection)).data,
        )
        np.savetxt(filepath, data)

    def serialize(self, data: SerializableData, **kwargs: Any) -> str:
        return data


@pytest.fixture
def snapshot_numpy(snapshot):
    options = dict(matcher_options=dict(rtol=1, atol=0))
    return snapshot.with_defaults(extension_class=NumpySnapshotExtenstion)


def test_np(snapshot_numpy):
    x = [1e-5, 1e-3, 1e-1]
    # ideally-
    # from numpy.testing import assert_allclose
    # assert snapshot_numpy(matcher = assert_allclose, rtol=1e6...)
    assert snapshot_numpy == x

@noahnu
Copy link
Collaborator Author

noahnu commented Oct 10, 2023

@atharva-2001 Does something like this work until syrupy has built-in support?

import pytest

class NumpySnapshotExtension(SingleFileSnapshotExtension):
    _file_extension = "dat"

    rtol = 0
    atol = 0

    def with_kwargs(**kwargs):
        class MyCopy(NumpySnapshotExtension):
            rtol = kwargs["rtol"]
            atol = kwargs["atol"]
        return MyCopy

@pytest.fixture
def snapshot(snapshot):
    def factory(**kwargs):
        _class = NumpySnapshotExtension.with_kwargs(**kwargs)
        return snapshot.with_defaults(extension_class=_class)
    return factory

def test_np(snapshot):
    assert snapshot(rtol=1, atol=0) == [1e-5, 1e-3]

(not tested)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants