from micropython import const
import time

# === Costanti dei tipi di messaggio ===
NOTE_OFF = const(0x80)
NOTE_ON = const(0x90)
AFTERTOUCH = const(0xA0)
CONTROLLER_CHANGE = CC = const(0xB0)
PROGRAM_CHANGE = const(0xC0)
CHANNEL_PRESSURE = const(0xD0)
PITCH_BEND = const(0xE0)
SYSTEM_EXCLUSIVE = SYSEX = const(0xF0)
SONG_POSITION = const(0xF2)
SONG_SELECT = const(0xF3)
BUS_SELECT = const(0xF5)
TUNE_REQUEST = const(0xF6)
SYSEX_END = const(0xF7)
CLOCK = const(0xF8)
TICK = const(0xF9)
START = const(0xFA)
CONTINUE = const(0xFB)
STOP = const(0xFC)
ACTIVE_SENSING = const(0xFE)
SYSTEM_RESET = const(0xFF)

_LEN_0_MESSAGES = set([
    TUNE_REQUEST, SYSEX, SYSEX_END, CLOCK, TICK,
    START, CONTINUE, STOP, ACTIVE_SENSING, SYSTEM_RESET
])
_LEN_1_MESSAGES = set([PROGRAM_CHANGE, CHANNEL_PRESSURE, SONG_SELECT, BUS_SELECT])
_LEN_2_MESSAGES = set([NOTE_OFF, NOTE_ON, AFTERTOUCH, CC, PITCH_BEND, SONG_POSITION])

_MSG_TYPE_NAMES = {
    NOTE_OFF: "NoteOff",
    NOTE_ON: "NoteOn",
    AFTERTOUCH: "Aftertouch",
    CC: "CC",
    PROGRAM_CHANGE: "ProgramChange",
    CHANNEL_PRESSURE: "ChannelPressure",
    PITCH_BEND: "PitchBend",
    SYSEX: "Sysex",
    SONG_POSITION: "SongPosition",
    SONG_SELECT: "SongSelect",
    BUS_SELECT: "BusSelect",
    TUNE_REQUEST: "TuneRequest",
    SYSEX_END: "SysexEnd",
    CLOCK: "Clock",
    TICK: "Tick",
    START: "Start",
    CONTINUE: "Continue",
    STOP: "Stop",
    ACTIVE_SENSING: "ActiveSensing",
    SYSTEM_RESET: "SystemReset",
}

def _is_channel_message(status_byte):
    return NOTE_OFF <= status_byte < SYSEX

# --- funzione di lettura con timeout ---
def _read_byte(port, timeout_ms=2):
    """Legge 1 byte dal port, non blocca mai oltre il timeout."""
    t0 = time.monotonic()
    while True:
        buf = port.read(1)
        if buf:
            return buf[0]
        if (time.monotonic() - t0) > (timeout_ms / 1000.0):
            return None


# === Classe Message ===
class Message:
    def __init__(self, mtype=SYSTEM_RESET, data0=0, data1=0, channel=0):
        self.type = mtype
        self.channel = channel
        self.data0 = data0
        self.data1 = data1
        if mtype == PITCH_BEND and data1 == 0:
            self.pitch_bend = data0

    def __bytes__(self):
        status_byte = self.type
        if _is_channel_message(status_byte):
            status_byte |= self.channel
        if self.type in _LEN_2_MESSAGES:
            return bytes([status_byte, self.data0, self.data1])
        elif self.type in _LEN_1_MESSAGES:
            return bytes([status_byte, self.data0])
        return bytes([status_byte])

    def __str__(self):
        t = _MSG_TYPE_NAMES.get(self.type, "Unknown")
        if _is_channel_message(self.type):
            return f"Message({t} ch:{self.channel} {self.data0} {self.data1})"
        return f"Message({t})"


# === Classe MIDI ===
class MIDI:
    def __init__(self, midi_in=None, midi_out=None, enable_running_status=False):
        self._in_port = midi_in
        self._out_port = midi_out
        self._running_status_enabled = enable_running_status
        self._running_status = None
        self._error_count = 0
        self._read_buf = bytearray(1)
        self._partial = None   # memorizza messaggi incompleti
        self._rxbuf = bytearray()

    @property
    def error_count(self):
        return self._error_count

    def receive(self):
        """Riceve un messaggio MIDI completo o None, senza bloccare e senza perdita."""
        # 1️⃣ Aggiungi al buffer tutti i bytes disponibili
        chunk = self._in_port.read(64)
        if chunk:
            self._rxbuf.extend(chunk)

        # 2️⃣ Se buffer vuoto, niente da fare
        if not self._rxbuf:
            return None


        # 3️⃣ Cerca l’inizio di un messaggio (status byte)
        # elimina eventuali garbage iniziali
        while len(self._rxbuf) and not (self._rxbuf[0] & 0x80):
            if self._running_status_enabled and self._running_status:
                break
            else:
                # scarta il primo byte
                self._rxbuf = self._rxbuf[1:]

        if not self._rxbuf:
            return None

        status = self._rxbuf[0]
        is_status = status & 0x80
        if is_status:
            # nuovo messaggio, consuma status
            self._rxbuf = self._rxbuf[1:]
        elif self._running_status_enabled and self._running_status:
            status = self._running_status
        else:
            return None

        msg = Message(status)
        if _is_channel_message(status):
            self._running_status = status
            msg.type = status & 0xF0
            msg.channel = status & 0x0F

        # Determina quanti data byte servono
        need = 2 if msg.type in _LEN_2_MESSAGES else 1 if msg.type in _LEN_1_MESSAGES else 0

        # Se non ci sono ancora abbastanza byte, aspetta al prossimo giro
        if len(self._rxbuf) < need:
            if is_status:
                self._rxbuf = bytes([status]) + self._rxbuf
            return None

        # Estrai i data bytes
        data = list(self._rxbuf[:need])
        self._rxbuf = self._rxbuf[need:]



        msg.data0 = data[0] if data else 0
        msg.data1 = data[1] if len(data) > 1 else 0
        return msg

    def send(self, msg, channel=None):
        """Invia un messaggio MIDI o una lista di messaggi."""
        if isinstance(msg, Message):
            if channel is not None:
                msg.channel = channel
            data = msg.__bytes__()
        else:
            data = bytearray()
            for each in msg:
                if channel is not None:
                    each.channel = channel
                data.extend(each.__bytes__())
        self._out_port.write(data, len(data))