Skip to content

Commit

Permalink
[fix] BrokenpipeError in socket
Browse files Browse the repository at this point in the history
  • Loading branch information
killerdbob committed Aug 14, 2023
1 parent e7cf5b8 commit 442d972
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 65 deletions.
39 changes: 22 additions & 17 deletions skyplane/gateway/operators/gateway_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,23 @@ def make_socket(self, dst_host):
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
return sock

def send_data(self, dst_host, header, data):
# contact server to set up socket connection
if self.destination_ports.get(dst_host) is None:
self.destination_sockets[dst_host] = self.make_socket(dst_host)
sock = self.destination_sockets[dst_host]

try:
header.to_socket(sock)
sock.sendall(data)
except socket.error as e:
print(e)
del self.destination_ports[dst_host]
return False
# if successful, return True
return True


# send chunks to other instances
def process(self, chunk_req: ChunkRequest, dst_host: str):
"""Send list of chunks to gateway server, pipelining small chunks together into a single socket stream."""
Expand Down Expand Up @@ -307,15 +324,6 @@ def process(self, chunk_req: ChunkRequest, dst_host: str):
print(f"[{self.handle}:{self.worker_id}] Error registering chunks {chunk_ids} to {dst_host}: {e}")
raise e

# contact server to set up socket connection
if self.destination_ports.get(dst_host) is None:
print(f"[sender-{self.worker_id}]:{chunk_ids} creating new socket")
self.destination_sockets[dst_host] = retry_backoff(
partial(self.make_socket, dst_host), max_retries=3, exception_class=socket.timeout
)
print(f"[sender-{self.worker_id}]:{chunk_ids} created new socket")
sock = self.destination_sockets[dst_host]

# TODO: cleanup so this isn't a loop
for idx, chunk_req in enumerate(chunk_reqs):
# self.chunk_store.state_start_upload(chunk_id, f"sender:{self.worker_id}")
Expand Down Expand Up @@ -347,16 +355,13 @@ def process(self, chunk_req: ChunkRequest, dst_host: str):
raw_wire_length=raw_wire_length,
is_compressed=(compressed_length is not None),
)
# print(f"[sender-{self.worker_id}]:{chunk_id} sending chunk header {header}")
header.to_socket(sock)
# print(f"[sender-{self.worker_id}]:{chunk_id} sent chunk header")

# send chunk data
assert chunk_file_path.exists(), f"chunk file {chunk_file_path} does not exist"
# file_size = os.path.getsize(chunk_file_path)

with Timer() as t:
sock.sendall(data)

while True:
with Timer() as t:
is_suc = self.send_data(dst_host=dst_host, header=header, data=data)
if is_suc: break

# logger.debug(f"[sender:{self.worker_id}]:{chunk_id} sent at {chunk.chunk_length_bytes * 8 / t.elapsed / MB:.2f}Mbps")
print(f"[sender:{self.worker_id}]:{chunk_id} sent at {wire_length * 8 / t.elapsed / MB:.2f}Mbps")
Expand Down
102 changes: 54 additions & 48 deletions skyplane/gateway/operators/gateway_receiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,31 +145,31 @@ def recv_chunks(self, conn: socket.socket, addr: Tuple[str, int]):
init_space = self.chunk_store.remaining_bytes()
print("Init space", init_space)
while True:
# receive header and write data to file
logger.debug(f"[receiver:{server_port}] Blocking for next header")
chunk_header = WireProtocolHeader.from_socket(conn)
logger.debug(f"[receiver:{server_port}]:{chunk_header.chunk_id} Got chunk header {chunk_header}")

# TODO: this wont work
# chunk_request = self.chunk_store.get_chunk_request(chunk_header.chunk_id)

should_decrypt = self.e2ee_secretbox is not None # and chunk_request.dst_region == self.region
should_decompress = chunk_header.is_compressed # and chunk_request.dst_region == self.region

# wait for space
# while self.chunk_store.remaining_bytes() < chunk_header.data_len * self.max_pending_chunks:
# print(
# f"[receiver:{server_port}]: No remaining space with bytes {self.chunk_store.remaining_bytes()} data len {chunk_header.data_len} max pending {self.max_pending_chunks}, total space {init_space}"
# )
# time.sleep(0.1)

# get data
# self.chunk_store.state_queue_download(chunk_header.chunk_id)
# self.chunk_store.state_start_download(chunk_header.chunk_id, f"receiver:{self.worker_id}")
logger.debug(f"[receiver:{server_port}]:{chunk_header.chunk_id} wire header length {chunk_header.data_len}")
with Timer() as t:
fpath = self.chunk_store.get_chunk_file_path(chunk_header.chunk_id)
with fpath.open("wb") as f:
try:
# receive header and write data to file
logger.debug(f"[receiver:{server_port}] Blocking for next header")
chunk_header = WireProtocolHeader.from_socket(conn)
logger.debug(f"[receiver:{server_port}]:{chunk_header.chunk_id} Got chunk header {chunk_header}")

# TODO: this wont work
# chunk_request = self.chunk_store.get_chunk_request(chunk_header.chunk_id)

should_decrypt = self.e2ee_secretbox is not None # and chunk_request.dst_region == self.region
should_decompress = chunk_header.is_compressed # and chunk_request.dst_region == self.region

# wait for space
# while self.chunk_store.remaining_bytes() < chunk_header.data_len * self.max_pending_chunks:
# print(
# f"[receiver:{server_port}]: No remaining space with bytes {self.chunk_store.remaining_bytes()} data len {chunk_header.data_len} max pending {self.max_pending_chunks}, total space {init_space}"
# )
# time.sleep(0.1)

# get data
# self.chunk_store.state_queue_download(chunk_header.chunk_id)
# self.chunk_store.state_start_download(chunk_header.chunk_id, f"receiver:{self.worker_id}")
logger.debug(f"[receiver:{server_port}]:{chunk_header.chunk_id} wire header length {chunk_header.data_len}")
with Timer() as t:

socket_data_len = chunk_header.data_len
chunk_received_size, chunk_received_size_decompressed = 0, 0
to_write = bytearray(socket_data_len)
Expand Down Expand Up @@ -199,29 +199,35 @@ def recv_chunks(self, conn: socket.socket, addr: Tuple[str, int]):
print(
f"[receiver:{server_port}]:{chunk_header.chunk_id} Decompressing {len(to_write)} bytes to {chunk_received_size_decompressed} bytes"
)

# try to write data until successful
while True:
try:
f.seek(0, 0)
f.write(to_write)
f.flush()

# check write succeeds
assert os.path.exists(fpath)

# check size
file_size = os.path.getsize(fpath)
if file_size == chunk_header.raw_data_len:
break
elif file_size >= chunk_header.raw_data_len:
raise ValueError(f"[Gateway] File size {file_size} greater than chunk size {chunk_header.raw_data_len}")
except Exception as e:
print(e)
print(
f"[receiver:{server_port}]: No remaining space with bytes {self.chunk_store.remaining_bytes()} data len {chunk_header.data_len} max pending {self.max_pending_chunks}, total space {init_space}"
)
time.sleep(1)
except socket.error as e:
print(e)
# This may have pipeline broken error, if happened then restart receiver.
continue

fpath = self.chunk_store.get_chunk_file_path(chunk_header.chunk_id)
with fpath.open("wb") as f:
# try to write data until successful
while True:
try:
f.seek(0, 0)
f.write(to_write)
f.flush()

# check write succeeds
assert os.path.exists(fpath)

# check size
file_size = os.path.getsize(fpath)
if file_size == chunk_header.raw_data_len:
break
elif file_size >= chunk_header.raw_data_len:
raise ValueError(f"[Gateway] File size {file_size} greater than chunk size {chunk_header.raw_data_len}")
except Exception as e:
print(e)
print(
f"[receiver:{server_port}]: No remaining space with bytes {self.chunk_store.remaining_bytes()} data len {chunk_header.data_len} max pending {self.max_pending_chunks}, total space {init_space}"
)
time.sleep(1)
assert (
socket_data_len == 0 and chunk_received_size == chunk_header.data_len
), f"Size mismatch: got {chunk_received_size} expected {chunk_header.data_len} and had {socket_data_len} bytes remaining"
Expand Down

0 comments on commit 442d972

Please sign in to comment.