diff --git a/smpplib/client.py b/smpplib/client.py index 5b10fcd..4d0dda5 100644 --- a/smpplib/client.py +++ b/smpplib/client.py @@ -226,20 +226,32 @@ def send_pdu(self, p): return True + def _recv_exact(self, exact_size): + """ + Keep reading from self._socket until exact_size bytes have been read + """ + parts = [] + received = 0 + while received < exact_size: + try: + part = self._socket.recv(exact_size - received) + except socket.timeout: + raise + except socket.error as e: + self.logger.warning(e) + raise exceptions.ConnectionError() + if not part: + raise exceptions.ConnectionError() + received += len(part) + parts.append(part) + return b"".join(parts) + def read_pdu(self): """Read PDU from the SMSC""" self.logger.debug('Waiting for PDU...') - try: - raw_len = self._socket.recv(4) - except socket.timeout: - raise - except socket.error as e: - self.logger.warning(e) - raise exceptions.ConnectionError() - if not raw_len: - raise exceptions.ConnectionError() + raw_len = self._recv_exact(4) try: length = struct.unpack('>L', raw_len)[0] @@ -247,18 +259,7 @@ def read_pdu(self): self.logger.warning('Receive broken pdu... %s', repr(raw_len)) raise exceptions.PDUError('Broken PDU') - raw_pdu = raw_len - while len(raw_pdu) < length: - try: - raw_pdu_part = self._socket.recv(length - len(raw_pdu)) - except socket.timeout: - raise - except socket.error as e: - self.logger.warning(e) - raise exceptions.ConnectionError() - if not raw_pdu_part: - raise exceptions.ConnectionError() - raw_pdu += raw_pdu_part + raw_pdu = raw_len + self._recv_exact(length - 4) self.logger.debug('<<%s (%d bytes)', binascii.b2a_hex(raw_pdu), len(raw_pdu))