"""
CST820 driver
David Hofr, 2025
Based on https://github.com/lvgl-micropython/lvgl_micropython/blob/main/api_drivers/common_api_drivers/indev/cst328.py
"""
from micropython import const
import pointer_framework
import time
import machine

I2C_ADDR = 0x15
BITS = 8


class CST820(pointer_framework.PointerDriver):

    def _read_reg(self, reg, num_bytes):
        self._tx_buf[0] = reg & 0xFF

        self._rx_buf[:num_bytes] = bytearray([0x00] * num_bytes)

        self._device.write_readinto(self._tx_mv[:1], self._rx_mv[:num_bytes])


    def _write_reg(self, reg, value=None):
        self._tx_buf[0] = reg & 0xFF

        if value is None:
            self._device.write(self._tx_mv[:1])
        else:
            self._tx_buf[1] = value
            self._device.write(self._tx_mv[:2])


    def __init__(self, device, reset_pin=None, int_pin=None, touch_cal=None,
                 startup_rotation=pointer_framework.lv.DISPLAY_ROTATION._0,
                 debug=False):

        self._tx_buf = bytearray(3)
        self._tx_mv = memoryview(self._tx_buf)
        self._rx_buf = bytearray(6)
        self._rx_mv = memoryview(self._rx_buf)

        self._device = device
        self._int_pin = int_pin
        self._reset_pin = reset_pin

        if isinstance(int_pin, int):
            self._int_pin = machine.Pin(self._int_pin, machine.Pin.OUT, value=1)

        if isinstance(reset_pin, int):
            self._reset_pin = machine.Pin(self._reset_pin, machine.Pin.OUT, value=1)

        self.hw_reset()

        # check chip id
        self._read_reg(0xA7, 1)
        print('CST820 id: 0x{:02X} {}'.format(self._rx_buf[0], 'ok' if self._rx_buf[0] == 0xB7 else 'fail'))

        # disable sleep
        self._write_reg(0xFE, 0xFF)

        self._touch_thresh = 10

        super().__init__(touch_cal=touch_cal, startup_rotation=startup_rotation, debug=debug)


    @property
    def touch_threshold(self):
        return self._touch_thresh


    @touch_threshold.setter
    def touch_threshold(self, value):
        if value < 1:
            value = 1
        elif value > 255:
            value = 255
        self._touch_thresh = value


    def hw_reset(self):
        if self._reset_pin:
            self._reset_pin(0)
            time.sleep_ms(10)
            self._reset_pin(1)
            time.sleep_ms(50)


    def _get_coords(self):
        self._read_reg(0x02, 5)

        touch_count = self._rx_buf[0]
        x = (self._rx_buf[1] << 8) | self._rx_buf[2]
        y = (self._rx_buf[3] << 8) | self._rx_buf[4]
        print(touch_count, x, y)

        if touch_count:
            state = self.PRESSED
        else:
            state = self.RELEASED

        return state, x, y
