In [15]:
import json
import tensorflow as tf
import numpy as np
import random

# Load the data
with open('./preproced_new_data.json') as f:
    data = json.load(f)

In [16]:
# Extract the prompts and negative prompts
prompts = [d['prompt'] for d in data['items']]
neg_prompts = [d['negativePrompt'] for d in data['items']]
labels = [d['nsfw'] for d in data['items']]

# Define the vocabulary size and embedding dimensions
vocab_size = 10000
embedding_dim = 64

# Tokenize the prompts and negative prompts
tokenizer = tf.keras.preprocessing.text.Tokenizer(num_words=vocab_size, oov_token='<OOV>')
tokenizer.fit_on_texts(prompts + neg_prompts)
prompt_sequences = tokenizer.texts_to_sequences(prompts)
neg_prompt_sequences = tokenizer.texts_to_sequences(neg_prompts)

# Pad the prompt and negative prompt sequences
max_sequence_length = 50
prompt_padded = tf.keras.preprocessing.sequence.pad_sequences(prompt_sequences, maxlen=max_sequence_length, truncating='post', padding='post')
neg_prompt_padded = tf.keras.preprocessing.sequence.pad_sequences(neg_prompt_sequences, maxlen=max_sequence_length, truncating='post', padding='post')

# Convert the labels to numpy arrays
labels = np.array(labels)

# Shuffle the data
indices = np.arange(len(prompts))
np.random.shuffle(indices)
prompt_padded = prompt_padded[indices]
neg_prompt_padded = neg_prompt_padded[indices]
labels = labels[indices]

# Split the data into training and validation sets
split = 0.8
split_index = int(len(prompts) * split)
x_train_prompt = prompt_padded[:split_index]
x_train_neg_prompt = neg_prompt_padded[:split_index]
y_train = labels[:split_index]
x_val_prompt = prompt_padded[split_index:]
x_val_neg_prompt = neg_prompt_padded[split_index:]
y_val = labels[split_index:]

In [17]:
# Define the input shapes
prompt_input_shape = (max_sequence_length,)
neg_prompt_input_shape = (max_sequence_length,)

# Define the input layers
prompt_input_layer = tf.keras.layers.Input(shape=prompt_input_shape, name='prompt_input')
neg_prompt_input_layer = tf.keras.layers.Input(shape=neg_prompt_input_shape, name='neg_prompt_input')

# Define the embedding layers
embedding_layer = tf.keras.layers.Embedding(input_dim=vocab_size, output_dim=embedding_dim, mask_zero=True, name='embedding_layer')

# Define the LSTM layer
lstm_units = 64
lstm_layer = tf.keras.layers.LSTM(units=lstm_units, name='lstm_layer')

# Define the output layer
output_layer = tf.keras.layers.Dense(units=1, activation='sigmoid', name='output_layer')

# Pass the prompt and neg_prompt inputs through the embedding layer and LSTM layer
prompt_embedded = embedding_layer(prompt_input_layer)
neg_prompt_embedded = embedding_layer(neg_prompt_input_layer)

prompt_lstm_output = lstm_layer(prompt_embedded)
neg_prompt_lstm_output = lstm_layer(neg_prompt_embedded)

# Concatenate the LSTM outputs
concatenated_output = tf.keras.layers.concatenate([prompt_lstm_output, neg_prompt_lstm_output], axis=-1)

# Pass the concatenated output through the output layer
model_output = output_layer(concatenated_output)


In [18]:
# Define the model inputs and outputs
model_inputs = [prompt_input_layer, neg_prompt_input_layer]
model_outputs = model_output

# Define the model
model = tf.keras.models.Model(inputs=model_inputs, outputs=model_outputs)

# Compile the model
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])

In [19]:
# Train the model
history = model.fit([x_train_prompt, x_train_neg_prompt], y_train, validation_split=0.2, epochs=10, batch_size=32)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


In [20]:
# Evaluate the model on test set
loss, accuracy = model.evaluate([x_val_prompt, x_val_neg_prompt], y_val, batch_size=32)
print("Test loss:", loss)
print("Test accuracy:", accuracy)

Test loss: 0.454784095287323
Test accuracy: 0.8308905959129333


In [21]:
# Save model
model.save('nsfw_classifier.h5')

In [164]:
import pickle
with open('nsfw_classifier_tokenizer.pickle', 'wb') as f:
    pickle.dump(tokenizer, f, protocol=pickle.HIGHEST_PROTOCOL)

In [None]:
with open('nsfw_classifier.pickle', 'wb') as f:
    pickle.dump(model, f, protocol=pickle.HIGHEST_PROTOCOL)

In [22]:
import re
def preprocess(text, isfirst = True):
    if isfirst:
        if type(text) == str: pass
        elif type(text) == list:
            output = []
            for i in text:
                output.append(preprocess(i))
            return(output)
            

    text = re.sub('<.*?>', '', text)
    text = re.sub('\(+', '(', text)
    text = re.sub('\)+', ')', text)
    matchs = re.findall('\(.*?\)', text)
    
    for _ in matchs:
        text = text.replace(_, preprocess(_[1:-1], isfirst=False) )

    text = text.replace('\n', ',').replace('|',',')

    if isfirst: 
        output = text.split(',')
        output = list(map(lambda x: x.strip(), output))
        output = [x for x in output if x != '']
        return ', '.join(output)
        # return output

    return text

In [162]:
def postprocess(prompts, negative_prompts, outputs, print_percentage = True):
    for idx, i in enumerate(prompts):
        print('*****************************************************************')
        if print_percentage:
            print(f"prompt: {i}\nnegative_prompt: {negative_prompts[idx]}\npredict: {outputs[idx][0]} --{outputs[idx][1]}%")
        else:
            print(f"prompt: {i}\nnegative_prompt: {negative_prompts[idx]}\npredict: {outputs[idx][0]}")


In [163]:
# Make predictions on new data
prompt = ["a landscape with trees and mountains in the background", 'nude, sexy, 1girl, nsfw']
negative_prompt = ["nsfw", 'worst quality']

x_new = tokenizer.texts_to_sequences( preprocess(prompt) )
z_new = tokenizer.texts_to_sequences( preprocess(negative_prompt) )
x_new = tf.keras.preprocessing.sequence.pad_sequences(x_new, maxlen=max_sequence_length)
z_new = tf.keras.preprocessing.sequence.pad_sequences(z_new, maxlen=max_sequence_length)
y_new = model.predict([x_new, z_new])
y_new = list(map(lambda x:("NSFW", float("{:.2f}".format(x[0]*100)) ) if x[0]>0.5 else ("SFW", float("{:.2f}".format(100-x[0]*100))), y_new))


print("Prediction:", y_new)
postprocess(prompt, negative_prompt, y_new, print_percentage=True)

Prediction: [('SFW', 100.0), ('NSFW', 99.44)]
*****************************************************************
prompt: a landscape with trees and mountains in the background
negative_prompt: nsfw
predict: SFW --100.0%
*****************************************************************
prompt: nude, sexy, 1girl, nsfw
negative_prompt: worst quality
predict: NSFW --99.44%
