diff --git a/rtmp_protocol.py b/rtmp_protocol.py index b681f00..efb665d 100644 --- a/rtmp_protocol.py +++ b/rtmp_protocol.py @@ -4,11 +4,11 @@ """ import pyamf.amf0 -import pyamf.util +import pyamf.util.pure import rtmp_protocol_base import socket -class FileDataTypeMixIn(pyamf.util.DataTypeMixIn): +class FileDataTypeMixIn(pyamf.util.pure.DataTypeMixIn): """ Provides a wrapper for a file object that enables reading and writing of raw data types for the file. @@ -16,7 +16,7 @@ class FileDataTypeMixIn(pyamf.util.DataTypeMixIn): def __init__(self, fileobject): self.fileobject = fileobject - pyamf.util.DataTypeMixIn.__init__(self) + pyamf.util.pure.DataTypeMixIn.__init__(self) def read(self, length): return self.fileobject.read(length) @@ -32,6 +32,7 @@ def at_eof(self): class DataTypes: """ Represents an enumeration of the RTMP message datatypes. """ + NONE = -1 USER_CONTROL = 4 WINDOW_ACK_SIZE = 5 SET_PEER_BANDWIDTH = 6 @@ -48,6 +49,16 @@ class SOEventTypes: DELETE = 9 USE_SUCCESS = 11 +class UserControlTypes: + """ Represents an enumeration of the user control event types. """ + STREAM_BEGIN = 0 + STREAM_EOF = 1 + STREAM_DRY = 2 + SET_BUFFER_LENGTH = 3 + STREAM_IS_RECORDED = 4 + PING_REQUEST = 6 + PING_RESPONSE = 7 + class RtmpReader: """ This class reads RTMP messages from a stream. """ @@ -72,6 +83,10 @@ def next(self): message_body = [] msg_body_len = 0 header = rtmp_protocol_base.header_decode(self.stream) + # FIXME: this should be really implemented inside header_decode + if header.datatype == DataTypes.NONE: + header = self.prv_header + self.prv_header = header while True: read_bytes = min(header.bodyLength - msg_body_len, self.chunk_size) message_body.append(self.stream.read(read_bytes)) @@ -124,6 +139,9 @@ def next(self): while not body_stream.at_eof(): commands.append(decoder.readElement()) ret['command'] = commands + #elif ret['msg'] == DataTypes.NONE: + # print 'WARNING: message with no datatype received.', header + # return self.next() else: assert False, header @@ -210,7 +228,7 @@ def write(self, message): for command in message['command']: encoder.writeElement(command) elif datatype == DataTypes.SHARED_OBJECT: - encoder.writeString(message['obj_name'],writeType=False) + encoder.serialiseString(message['obj_name']) body_stream.write_ulong(message['curr_version']) body_stream.write(message['flags']) @@ -236,7 +254,7 @@ def write_shared_object_event(self, event, body_stream): elif event_type == SOEventTypes.CHANGE: for attrib_name in event['data']: attrib_value = event['data'][attrib_name] - encoder.writeString(attrib_name,writeType=False) + encoder.serialiseString(attrib_name) encoder.writeElement(attrib_value) elif event['type'] == SOEventTypes.CLEAR: assert event['data'] == '', event['data'] @@ -444,7 +462,7 @@ def handle_message_pre_connect(self, msg): assert msg['window_ack_size'] == 2500000, msg assert msg['limit_type'] == 2, msg elif msg['msg'] == DataTypes.USER_CONTROL: - assert msg['event_type'] == 0, msg + assert msg['event_type'] == UserControlTypes.STREAM_BEGIN, msg assert msg['event_data'] == '\x00\x00\x00\x00', msg else: assert False, msg @@ -476,10 +494,30 @@ def handle_messages(self): """ Start the message handling loop. """ while True: msg = self.reader.next() - handled = False + + handled = self.handle_simple_message(msg) + + if handled: + continue + for so in self.shared_objects: if so.handle_message(msg): handled = True break if not handled: assert False, msg + + def handle_simple_message(self, msg): + """ Handle simple messages, e.g. ping requests. """ + if msg['msg'] == DataTypes.USER_CONTROL and msg['event_type'] == \ + UserControlTypes.PING_REQUEST: + resp = { + 'msg':DataTypes.USER_CONTROL, + 'event_type':UserControlTypes.PING_RESPONSE, + 'event_data':msg['event_data'], + } + self.writer.write(resp) + self.writer.flush() + return True + + return False