In [1]:
import psycopg2, os, subprocess
import pandas as pd

## Notebook Purpose
<b> This is notebook number 1 </b>

This notebook looks into the postgres tables public.Image, public.Tags, and a few others to gather training data for all models in the mixture including the NLP transformers, the ViT transformers, and the traditional CV models

### Notebook Order
1. getData
2. downloadData
3. trainResNetModel | trainPromptTransformerClassifier | trainViTClassifier
4. localMixtureEval

In [2]:
# Path to the shell script
script_path = './load_env.sh'

# Run the script and capture the output
proc = subprocess.Popen(['/bin/bash', script_path], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
stdout, stderr = proc.communicate()

if proc.returncode != 0:
    print(f"Error sourcing .zshrc: {stderr.decode('utf-8')}")
else:
    # Parse the output and set the environment variables
    for line in stdout.decode('utf-8').splitlines():
        key, _, value = line.partition("=")
        # Remove the surrounding quotes from the value
        if value.startswith('"') and value.endswith('"'):
            value = value[1:-1]
        os.environ[key] = value

# Verify the environment variable is loaded
URL = os.getenv('REMOTE_POSTGRES_URL')  # Replace 'MY_VARIABLE' with your variable name to check

# print(URL)

In [3]:
conn = psycopg2.connect(URL)
cur = conn.cursor()

In [7]:
##tags we want
styles = ["anime", "photorealistic", "cartoon", "modern art", "realistic"]
subjects = ["man", "woman", "animal", "child"]


sql_query = """
WITH ImageTags AS (
  SELECT
    toi."imageId",
    string_agg(t.name, ', ') AS tags
  FROM "TagsOnImage" toi
  JOIN "Tag" t ON t.id = toi."tagId"
  WHERE NOT toi.disabled
    AND toi.source != 'Rekognition'
  GROUP BY toi."imageId"
)

SELECT
  CONCAT('https://image.civitai.com/xG1nkqKTMzGDvpLrqFT7WA/', i.url, '/width=450/', i.id, '.jpg') AS download_url,
  i."url",
  i."id",
  CASE
    WHEN i."nsfwLevel" = 1 THEN 'PG'
    WHEN i."nsfwLevel" = 2 THEN 'PG13'
    WHEN i."nsfwLevel" = 4 THEN 'R'
    WHEN i."nsfwLevel" = 8 THEN 'X'
    WHEN i."nsfwLevel" = 16 THEN 'XXX'
  END AS label,
  it.tags,
  i.meta->>'prompt' AS prompt
FROM "Image" i
JOIN ImageTags it ON i."id" = it."imageId"
WHERE i.meta->>'prompt' IS NOT NULL
  AND (
"""

for style in styles:
    for subject in subjects:
        if subject == 'child':
            where_clause = f"i.meta->>'prompt' LIKE '%{subject}%' AND it.tags LIKE '%{style}%'"
        else:
            where_clause = f"it.tags LIKE '%{style}%' AND it.tags LIKE '%{subject}%'"
        
        sql_query += f"({where_clause}) OR "

# Remove the last 'OR' and close the WHERE clause
sql_query = sql_query[:-4] + ")"

In [8]:
sql_query

'\nWITH ImageTags AS (\n  SELECT\n    toi."imageId",\n    string_agg(t.name, \', \') AS tags\n  FROM "TagsOnImage" toi\n  JOIN "Tag" t ON t.id = toi."tagId"\n  WHERE NOT toi.disabled\n    AND toi.source != \'Rekognition\'\n  GROUP BY toi."imageId"\n)\n\nSELECT\n  CONCAT(\'https://image.civitai.com/xG1nkqKTMzGDvpLrqFT7WA/\', i.url, \'/width=450/\', i.id, \'.jpg\') AS download_url,\n  i."url",\n  i."id",\n  CASE\n    WHEN i."nsfwLevel" = 1 THEN \'PG\'\n    WHEN i."nsfwLevel" = 2 THEN \'PG13\'\n    WHEN i."nsfwLevel" = 4 THEN \'R\'\n    WHEN i."nsfwLevel" = 8 THEN \'X\'\n    WHEN i."nsfwLevel" = 16 THEN \'XXX\'\n  END AS original_level,\n  it.tags,\n  i.meta->>\'prompt\' AS prompt\nFROM "Image" i\nJOIN ImageTags it ON i."id" = it."imageId"\nWHERE i.meta->>\'prompt\' IS NOT NULL\n  AND (\n(it.tags LIKE \'%anime%\' AND it.tags LIKE \'%man%\') OR (it.tags LIKE \'%anime%\' AND it.tags LIKE \'%woman%\') OR (it.tags LIKE \'%anime%\' AND it.tags LIKE \'%animal%\') OR (i.meta->>\'prompt\' LIKE \'

In [9]:
# Execute each query and store results in a list
image_prompt_tag_data = pd.read_sql_query(sql_query, conn)


  image_prompt_tag_data = pd.read_sql_query(sql_query, conn)


### Notes about queries

- Use below with original or updated_query -> the query we have in here is advanced, so we don't use this code <br>
`image_with_ids = pd.read_sql_query(sql_query, conn)`

In [10]:
image_prompt_tag_data.head()

Unnamed: 0,download_url,url,id,original_level,tags,prompt
0,https://image.civitai.com/xG1nkqKTMzGDvpLrqFT7...,3941c5f2-d694-44bf-a6e3-b2158f7fd038,7482165,XXX,"woman, brown hair, chair, closed eyes, coverin...",iphone photograph of sleepy tired pale freckle...
1,https://image.civitai.com/xG1nkqKTMzGDvpLrqFT7...,15f5d724-8831-4b5f-8e07-491696366155,6429527,PG,"woman, bow, bracelet, breasts, jewelry, long h...","1girl, solo, portrait, simple background, dark..."
2,https://image.civitai.com/xG1nkqKTMzGDvpLrqFT7...,3d64c9d8-d812-4e93-97e1-4a318d01afa4,7483116,X,"woman, breasts, brown eyes, brown hair, long h...","((nude, topless, nipples, white panties, white..."
3,https://image.civitai.com/xG1nkqKTMzGDvpLrqFT7...,de196b97-127f-4616-8133-52bd3b8f8494,7481494,PG13,"woman, black hair, cowboy shot, curly hair, fe...","(sfw:1.2) romanian 1girl, as a curious killer ..."
4,https://image.civitai.com/xG1nkqKTMzGDvpLrqFT7...,7905bc95-2c3b-4c1e-8a3c-c0ab6295d2bf,6424802,R,"partial nudity, black hair, bracelet, dark ski...","cover art by angus mckie, scifi, futuristic, h..."


In [12]:
image_prompt_tag_data.groupby('original_level')['id'].count()

original_level
PG      500512
PG13    152409
R       253634
X       339941
XXX     209348
Name: id, dtype: int64

In [13]:
image_prompt_tag_data.to_csv('./data/image_prompt_tag_data.csv', index=False)