In [None]:
from tensorflow.keras.layers import Input, Embedding, Flatten, Reshape, LSTM, Concatenate, TimeDistributed, Dense, RepeatVector
from tensorflow.keras.losses import MeanAbsoluteError, CategoricalCrossentropy
from tensorflow.keras.callbacks import ModelCheckpoint
from tensorflow.keras.models import Model, Sequential
from tensorflow.keras.metrics import CategoricalAccuracy
from tensorflow.keras.preprocessing.image import img_to_array
from tensorflow.data import Dataset
from tensorflow import TensorSpec, float32, int32
from PIL import Image
import numpy as np
from matplotlib import pyplot as plt
import tkinter as tk
from tkinter import filedialog
from tkinter import messagebox
import tensorflow as tf


def rand_img(size):
    return np.random.randint(0, 256, size).astype(np.float32) / 255.0

def rand_sentence(length, max_words):
    return np.random.randint(0, max_words, length).astype(np.int32)

def one_hot_encoding(sentence, max_words):
    msg = np.zeros((len(sentence), max_words), dtype=np.float32)
    for i, v in enumerate(sentence):
        msg[i, v] = 1.0
    return msg

def data_generator(image_size, sentence_length, sentence_max_word, batch_size=32):
    while True:
        x_img = np.zeros((batch_size, image_size[0], image_size[1], image_size[2]), dtype=np.float32)
        x_sen = np.zeros((batch_size, sentence_length), dtype=np.int32)
        y_img = np.zeros((batch_size, image_size[0], image_size[1], image_size[2]), dtype=np.float32)
        y_sen = np.zeros((batch_size, sentence_length, sentence_max_word), dtype=np.float32)
        for i in range(batch_size):
            img = rand_img(image_size)
            sentence = rand_sentence(sentence_length, sentence_max_word)
            sentence_one_hot_encoded = one_hot_encoding(sentence, sentence_max_word)
            x_img[i] = img
            x_sen[i] = sentence
            y_img[i] = img
            y_sen[i] = sentence_one_hot_encoded
        yield (x_img, x_sen), (y_img, y_sen)

def get_model(image_shape, sentence_length, max_word):
    input_img = Input(image_shape)
    input_sen = Input((sentence_length,))
    
    embed_sen = Embedding(max_word, 50)(input_sen)  # Reduced embedding size
    lstm_emb_sen = LSTM(64, return_sequences=False)(embed_sen)  # Reduced LSTM units
    flat_emb_sen = Dense(image_shape[0] * image_shape[1] * 3)(lstm_emb_sen)
    flat_emb_sen = Reshape((image_shape[0], image_shape[1], 3))(flat_emb_sen)
    
    # Concatenate image and sentence embeddings
    enc_input = Concatenate(axis=-1)([flat_emb_sen, input_img])
    
    # Use LSTM for image reconstruction
    lstm_out = LSTM(64, return_sequences=False)(Reshape((-1, 3))(enc_input))  # Reduced LSTM units
    out_img = Dense(image_shape[0] * image_shape[1] * 3, activation='sigmoid')(lstm_out)
    out_img = Reshape(image_shape)(out_img)
    
    # Decoder model for sentence reconstruction
    decoder_model = Sequential(name="sentence_reconstruction")
    decoder_model.add(Input(shape=(sentence_length, 64)))  # Reduced to match LSTM output
    decoder_model.add(LSTM(64, return_sequences=True))  # Reduced LSTM units
    decoder_model.add(TimeDistributed(Dense(max_word, activation="softmax")))
    
    # Use lstm_out directly in the decoder model
    out_sen = decoder_model(RepeatVector(sentence_length)(lstm_out))  # Use RepeatVector for tiling

    model = Model(inputs=[input_img, input_sen], outputs=[out_img, out_sen])
    model.compile(optimizer='adam', loss=[MeanAbsoluteError(), CategoricalCrossentropy()],
                  metrics={'sentence_reconstruction': CategoricalAccuracy()})
    encoder_model = Model(inputs=[input_img, input_sen], outputs=[out_img])
    return model, encoder_model, decoder_model

def ascii_encode(message, sentence_length):
    sen = np.zeros((1, sentence_length), dtype=np.int32)
    for i, a in enumerate(message.encode("ascii")):
        sen[0, i] = a
    return sen

def ascii_decode(message):
    return ''.join(chr(int(a)) for a in message[0].argmax(-1))

def main():
    root = tk.Tk()
    root.title("Image Encoder/Decoder")

    def load_image():
        file_path = filedialog.askopenfilename()
        if file_path:
            global img
            img_temp = Image.open(file_path).convert("RGB")
            res = img_temp.resize((100, 100))
            img = np.expand_dims(img_to_array(res) / 255.0, axis=0)
            messagebox.showinfo("Image Loaded", "Image loaded successfully!")
            plt.imshow(img[0], interpolation='nearest')
            plt.show()

    def encode_message():
        global img
        if img is None:
            messagebox.showwarning("No Image", "Please load an image first!")
            return
        
        text_to_encode = message_entry.get()
        sen = ascii_encode(text_to_encode, sentence_len)
        
        y = encoder.predict([img, sen])
        plt.imshow(y[0], interpolation='nearest')
        plt.title("Encoded Image")
        plt.show()

        y_hat = decoder.predict(y)
        decoded_message = ascii_decode(y_hat)
        messagebox.showinfo("Decoded Message", f"The decoded message is: {decoded_message}")

    global img
    img = None
    image_shape = (100, 100, 3)
    sentence_len = 100
    max_word = 256
    batch_size = 64
    
    gen = data_generator(image_shape, sentence_len, max_word, batch_size)
    
    output_signature = (
        (TensorSpec(shape=(batch_size, *image_shape), dtype=float32),
         TensorSpec(shape=(batch_size, sentence_len), dtype=int32)),
        (TensorSpec(shape=(batch_size, *image_shape), dtype=float32),
         TensorSpec(shape=(batch_size, sentence_len, max_word), dtype=float32))
    )
    
    dataset = Dataset.from_generator(lambda: gen, output_signature=output_signature)
    
    global model, encoder, decoder
    model, encoder, decoder = get_model(image_shape, sentence_len, max_word)
    
    try:
        model.load_weights("best_weights.weights.h5")
    except:
        model.fit(dataset, epochs=5, steps_per_epoch=100, callbacks=[
            ModelCheckpoint("best_weights.weights.h5", monitor="loss",
                            verbose=1,
                            save_weights_only=True,
                            save_best_only=True)]
        )
    
    tk.Label(root, text="Enter the message you want to encode:").pack()
    message_entry = tk.Entry(root, width=50)
    message_entry.pack()
    
    load_button = tk.Button(root, text="Load Image", command=load_image)
    load_button.pack()
    
    encode_button = tk.Button(root, text="Encode and Decode Message", command=encode_message)
    encode_button.pack()
    
    root.mainloop()

if _name_ == "_main_":
    main()