In [3]:
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split

In [2]:
df = pd.read_csv('./datasets/merged_dataset.csv')
df.head(), df.tail()

(                                              prompt label
 0  Tell me about the hundred years war like you'r...  text
 1  "Suggest in detail three creative activities t...  text
 2  Please articulate the benefit of GPT-4 over Ch...  text
 3  Using C#, write me code for an Android applica...  text
 4  How can I, an average person, realistically ta...  text,
                                                 prompt  label
 370  representation of power and domination in the ...  image
 371  bisley simon mullins, craig bussiere, gaston r...  image
 372  a highly detailed 4 k close up render of bella...  image
 373  jorginho on stage making up the lyrics to the ...  image
 374  kanye west in super mario oufit as mario in ge...  image)

In [24]:
X_train, X_test, y_train, y_test = train_test_split(df["prompt"], df["label"], test_size=0.2)
# X_train, y_train

In [65]:
from nltk.stem import SnowballStemmer
from nltk.tokenize import RegexpTokenizer
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.naive_bayes import MultinomialNB

In [77]:
vectorizer = CountVectorizer()
X_train_counts = vectorizer.fit_transform(X_train)

nb = MultinomialNB()
nb.fit(X_train_counts, y_train)

In [83]:
def stem_word_list(l):
    stemmer = SnowballStemmer('english')
    return [stemmer.stem(w) for w in l]

image_related_words = set(stem_word_list(['image', 'picture', 'photo', 'visual', 'generate', 'generate', 'create']))

image_style_words = set(stem_word_list([
    'photograph', 'graphic', 'illustration',
    'snapshot', 'rendering', 'drawing', 'diagram', 'sketch', 'portrait', 'canvas',
    'photography', 'visualization', 'artwork', 'snapshot', 'composition', 'scene',
    'display', 'capture', 'view', 'presentation', 'design', 'figure', 'landscape',
    'depiction', 'model', 'render', 'animation', 'digital', 'artistic', 'photogenic',
    'aesthetic', 'icon', 'logo', 'infographic', 'collage', 'poster', 'cover', 'screen',
    'album', 'photojournalism', 'layout', 'frame', 'exposure', 'focus', 'filter', 'format',
    'contrast', 'hue', 'saturation', 'shading', 'resolution', 'pixel', 'blur', 'depth',
    'composition', 'angle', 'panorama', 'hdr', 'retouch', 'montage', 'composite', 'snapshot',
    'thumbnail', 'cinematography', 'visualize', 'scenery', 'scene', 'art', 'media', 'snapshot',
    'animation', 'frame', 'snapshot', 'illustration', 'collage', 'conceptual', 'composition',
    'drawing', 'fine art', 'surreal', 'abstract', 'visualize', 'digital art', 'concept art',
    'impressionism', 'realism', 'pop art', 'modern art', 'contemporary art'
]))

text_generation_related_words = set(stem_word_list([
    'generate', 'create', 'story', 'content',
    'narrate', 'produce', 'write', 'text', 'code', 'program']))

text_style_words = set(stem_word_list([
    'eloquent', 'concise', 'fluent', 'poetic', 'lyrical', 'vivid', 'evocative',
    'articulate', 'coherent', 'expressive', 'nuanced', 'profound', 'engaging',
    'persuasive', 'rhythmic', 'prosaic', 'flowing', 'eloquence', 'clarity', 
    'captivating', 'dramatic', 'melodic', 'creative', 'sophisticated', 'smooth',
    'insightful', 'compelling', 'imagery', 'descriptive', 'inspirational',
    'rich', 'melancholic', 'mellifluous', 'serene', 'captivating', 'inspiring',
    'melancholy', 'dynamic', 'harmonious', 'thoughtful', 'expressive',
    'elegant', 'magnetic', 'whimsical', 'thought-provoking', 'resonant',
    'intuitive', 'impactful', 'polished', 'lush', 'stirring', 'vibrant',
    'dreamy', 'eloquent', 'captivating', 'enchanting', 'inventive', 'sonorous',
    'insightful', 'nuanced', 'provocative', 'evocative', 'mesmerizing',
    'cogent', 'soothing', 'unforgettable', 'haunting', 'compelling', 'help', 'solve'
]))

'portrait' in image_style_words

True

In [112]:
def count_by_list(tokens: list[str], words):
    return sum(1 for token in tokens if token in words)

def count_image_related_words(tokens: list[str]):
    return count_by_list(tokens, image_related_words)

def count_image_style_words(tokens: list[str]):
    return count_by_list(tokens, image_style_words)

def count_text_related_words(tokens: list[str]):
    return count_by_list(tokens, text_generation_related_words)

def count_text_style_words(tokens: list[str]):
    return count_by_list(tokens, text_style_words)

def count_words(tokens: list[str]):
    return len(tokens)

def nb_prediction(tokens: list[str]):
    prompt = ' '.join(tokens)
    prompt_counts = vectorizer.transform([prompt])
    p = nb.predict_log_proba(prompt_counts)[0]
    result_dict = {class_name: prob for class_name, prob in zip(nb.classes_, p)}

    return result_dict

def normalize(prompt: str):
    re = RegexpTokenizer(r'\b[a-z0-9]+\b')
    prompt = prompt.lower()
    tokens = re.tokenize(prompt)
    tokens = stem_word_list(tokens)

    d = nb_prediction(tokens)

    ft = {
        "img_rel": count_image_related_words(tokens),
        "img_style": count_image_style_words(tokens),
        "text_rel": count_text_related_words(tokens),
        "text_style": count_text_related_words(tokens),
        "nb_image": d['image'],
        "nb_text": d['text'],
        "word_count": count_words(tokens)
    }
    
    return ft

In [113]:
X_train, y_train

(335    highly detailed vfx portrait of a man with a m...
 145    Write a sentence that has excatly 12 words in ...
 271    AN 8K RESOLUTION, MATTE PAINTING OF THE WISE A...
 238                        marley bob as einstein albert
 213                        hybrid goat-owl abominable an
                              ...                        
 277                 druillet philippe et mucha by castle
 141    So is the main thing with Plus that the limita...
 374    kanye west in super mario oufit as mario in ge...
 200    gordon brooks bourassa chris 4k, thirds, of ru...
 313    greeting card, love, 2 beautiful dragons, by t...
 Name: prompt, Length: 300, dtype: object,
 335    image
 145     text
 271    image
 238    image
 213    image
        ...  
 277    image
 141     text
 374    image
 200    image
 313    image
 Name: label, Length: 300, dtype: object)

In [114]:
X_train[335]

'highly detailed vfx portrait of a man with a mustache wearing a red shirt, blue overalls, and a red cap, art by capcom, digital illustration, ornate details,'

In [115]:
t = normalize("""highly detailed vfx portrait of a man with a mustache wearing a red shirt, blue overalls, and a red cap, art by capcom, digital illustration, ornate details,""")
t

{'img_rel': 0,
 'img_style': 4,
 'text_rel': 0,
 'text_style': 0,
 'nb_image': -1.1613536798904533e-09,
 'nb_text': -20.57368168303107,
 'word_count': 27}

In [116]:
df = pd.read_csv('./datasets/merged_dataset.csv')

df_features = pd.DataFrame(columns=['img_rel', 'img_style', 'text_rel', 'text_style', 'nb_image', 'nb_text', 'word_count', 'label'])
for i, r in df.iterrows():
    prompt = r['prompt']
    label = r['label']
    n = normalize(prompt)
    n['label'] = label
    df_features.loc[len(df_features)] = n

df_features.head()

Unnamed: 0,img_rel,img_style,text_rel,text_style,nb_image,nb_text,word_count,label
0,0,0,0,0,-17.798519,-1.862951e-08,19,text
1,0,0,0,0,-10.46046,-2.864747e-05,14,text
2,0,0,0,0,-26.25356,-3.964828e-12,22,text
3,0,1,4,4,-85.191018,0.0,76,text
4,0,0,0,0,-9.533721,-7.237248e-05,12,text


In [118]:
X = df_features.drop('label', axis=1)
y = df_features['label']
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)

In [119]:
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, classification_report

In [120]:
model = LogisticRegression()
model.fit(X_train, y_train)
predictions = model.predict(X_test)
accuracy = accuracy_score(y_test, predictions)
print(f"Accuracy: {accuracy}")
print("Classification Report:\n", classification_report(y_test, predictions))

Accuracy: 0.9866666666666667
Classification Report:
               precision    recall  f1-score   support

       image       1.00      0.97      0.99        40
        text       0.97      1.00      0.99        35

    accuracy                           0.99        75
   macro avg       0.99      0.99      0.99        75
weighted avg       0.99      0.99      0.99        75

