In [None]:
import socket
import argparse
from PIL import Image
import numpy as np
from lib import dwt, ezw, size_amp, huffman, enc_dec
import sys # For sys.exit
import matplotlib.pyplot as plt # Added import

In [None]:
def imageEncode_socket(image_path, quant_step, host, port):
    
    img = Image.open(image_path).convert('L')
    img_array = np.array(img)
    height, width = img_array.shape
    if height != 512 or width != 512:
        print(f"Warning: Image size is {width}x{height}. Resizing to 512x512.")
        img = img.resize((512, 512), Image.Resampling.LANCZOS) # Or Image.Resampling.BILINEAR, Image.Resampling.BICUBIC
        img_array = np.array(img)
        height, width = img_array.shape # Update dimensions after resizing
        print(f"Image resized to {width}x{height}.")
    # Display the image after encoding
    plt.imshow(img_array, cmap='gray') # Displaying the grayscale resized image
    plt.title(f"Resized Image (512x512)")
    plt.axis('off')
    plt.show()

    l5_decomp = dwt.decomposition_to_specify_level(img_array, 5)
    l5_decomp_quant = np.round(l5_decomp / quant_step)
    top_16_16 = l5_decomp_quant[0:16, 0:16].flatten().astype(int)
    top_16_16_diff = np.diff(top_16_16)
    top_16_16_diff = np.insert(top_16_16_diff, 0, top_16_16[0]).tolist()

    root_nodes = [[None for _ in range(16)] for _ in range(16)]
    for i in range(0, 16):
        for j in range(0, 16):
            child_1 = ezw.build_tree(l5_decomp_quant, 1, (i,j+16), 1, None)
            child_3 = ezw.build_tree(l5_decomp_quant, 1, (i+16,j), 3, None)
            child_4 = ezw.build_tree(l5_decomp_quant, 1, (i+16,j+16), 4, None)
            root_nodes[i][j] = ezw.EZWTree(l5_decomp_quant[i, j], 0, 2, (i,j), [child_1, child_3, child_4], None) # type: ignore
            if child_1: child_1.parent = root_nodes[i][j] # type: ignore
            if child_3: child_3.parent = root_nodes[i][j] # type: ignore
            if child_4: child_4.parent = root_nodes[i][j] # type: ignore

    dpr_list= ezw.enc_dp_sp(root_nodes)
    rs_16_16, ra_16_16 = size_amp.size_amplitude_single_list(top_16_16_diff)
    ra_16_16_str = ''.join(str(k) for k in ra_16_16)

    result_size = []
    result_amplitude = []
    for i in range(16):
        for j in range(16):
            rs, ra = size_amp.size_amplitude_single_list(dpr_list[i][j])
            result_size = result_size + rs
            result_amplitude = result_amplitude + ra

    result_amplitude_str_lf = ''.join(str(k) for k in result_amplitude)
    huffman_dict = huffman.huffman_encode(result_size + rs_16_16)
    huffman_encoded_result_size = huffman.encode_data(result_size, huffman_dict)
    huffman_encoded_rs_16_16 = huffman.encode_data(rs_16_16, huffman_dict)

    code_table = huffman_dict
    code_table_binary = enc_dec.convert_code_table_to_binary(code_table)
    quant_step_binary = enc_dec.float_to_binary_str(float(quant_step))

    # Define the number of bits for the length prefix
    # (e.g., 32 bits allows for segment lengths up to 2^32 - 1 bits)
    LENGTH_PREFIX_BITS = 32 

    # Apply enc_dec.pad_encoded_data to relevant segments first
    # If enc_dec.pad_encoded_data was solely for separator avoidance, 
    # these calls might be removable. Assuming it might have other purposes for now.
    padded_huffman_encoded_result_size = enc_dec.pad_encoded_data(huffman_encoded_result_size)
    padded_result_amplitude_str_lf = enc_dec.pad_encoded_data(result_amplitude_str_lf)
    padded_huffman_encoded_rs_16_16 = enc_dec.pad_encoded_data(huffman_encoded_rs_16_16)
    padded_ra_16_16_str = enc_dec.pad_encoded_data(ra_16_16_str)

    segments_data = [
        code_table_binary,
        quant_step_binary,
        padded_huffman_encoded_result_size,
        padded_result_amplitude_str_lf,
        padded_huffman_encoded_rs_16_16,
        padded_ra_16_16_str
    ]

    combined_data_list = []
    for segment_str in segments_data:
        length_of_segment = len(segment_str)
        # Format length as a fixed-width binary string
        length_prefix_str = format(length_of_segment, f'0{LENGTH_PREFIX_BITS}b')
        combined_data_list.append(length_prefix_str)
        combined_data_list.append(segment_str)
    
    combined_data_str = "".join(combined_data_list)
    
    original_length = len(combined_data_str)


    # Pad the final combined_data_str to be a multiple of 8 bits for byte conversion
    num_padding_bits = (8 - original_length % 8) % 8
    padded_combined_data_str = combined_data_str + '0' * num_padding_bits

    byte_array = bytearray()
    for i in range(0, len(padded_combined_data_str), 8):
        byte = padded_combined_data_str[i:i+8]
        byte_array.append(int(byte, 2))

    try:
        with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
            print(f"Encoder: Connecting to {host}:{port}...")
            s.connect((host, port))
            print("Encoder: Connected.")
            # Send original length of the bit string (as 8-byte unsigned long long, big-endian)
            s.sendall(original_length.to_bytes(8, 'big'))
            # Send the actual data bytes
            s.sendall(byte_array)
            print(f"Encoder: Sent {original_length} bits ({len(byte_array)} bytes) of image data.")
    except ConnectionRefusedError:
        print(f"Error: Connection refused. Ensure decoder is running on {host}:{port}.")
        sys.exit(1)
    except Exception as e:
        print(f"Socket communication error: {e}")
        sys.exit(1)

    bpp = original_length / (height * width) # Use actual image dimensions
    return bpp


In [None]:
# 请根据您的实际情况修改这些值
image_path_param = "IMG_1601.JPG"  # 例如: "images/lena_512.png"
quant_step_param = 10.0
host_param = "localhost"       # 解码器的主机名或IP地址
port_param = 65432             # 解码器监听的端口号

# 确保 imageEncode_socket 函数已在前面的单元格中定义

print(f"Encoder: Starting encoding for image '{image_path_param}' with quant_step={quant_step_param}")
print(f"Encoder: Attempting to send to {host_param}:{port_param}")
    
# 假设 imageEncode_socket 函数定义在之前的单元格中
# 例如，它可能在 id="bea07ec0" 的单元格中
bpp = imageEncode_socket(image_path_param, quant_step_param, host_param, port_param)
print(f"Encoding complete. BPP: {bpp:.4f}")
