# Filter inappropriate content

- https://github.com/woctezuma/discord-members-metadata

## Install packages

In [None]:
%pip install --quiet transformers mediapy

## Download the image dataset

In [None]:
%cd /content

for i in range(1, 3):
  fname = f"img_{i}.zip"

  !curl -OL https://github.com/woctezuma/discord-members-metadata/releases/download/img/{fname}
  !unzip -qq {fname}

## Download the text datasets

In [None]:
%cd /content

!curl -OL https://github.com/woctezuma/discord-members-metadata/releases/download/bio/bios.json
!curl -OL https://github.com/woctezuma/discord-members-metadata/releases/download/metadata/members.json

## Define utils

In [None]:
import json

from pathlib import Path

def save_to_json(data, fname):
  with Path(fname).open('w') as f:
    json.dump(data, f, indent=True)

def load_from_json(fname):
  with Path(fname).open() as f:
    data = json.load(f)
  return data

def safe_load_from_json(fname):
  try:
    data = load_from_json(fname)
  except FileNotFoundError:
    data = {}
  return data

In [None]:
from pathlib import Path

def get_member_id(image_path):
  return Path(image_path).stem

In [None]:
def get_output_fname(pipe, suffix = "", ext = ".json"):
  return pipe.model.name_or_path.replace('/', '_') + f'{suffix}{ext}'

## Classify images

Reference:
- https://github.com/woctezuma/stable-diffusion-safety-checker

Dataset

In [None]:
import functools
import os

import numpy as np

from pathlib import Path
from torchvision.datasets.folder import default_loader, is_image_file
from torchvision.transforms.functional import to_pil_image

@functools.lru_cache
def get_image_paths(path):
    paths = []
    for _dirpath, _dirnames, filenames in os.walk(path):
        paths.extend([str(Path(_dirpath) / filename) for filename in filenames])
    return sorted([fn for fn in paths if is_image_file(fn)])

class ImageFolder:

    def __init__(self, path, transform=None, loader=default_loader):
        self.samples = get_image_paths(path)
        self.loader = loader
        self.transform = transform

    def __getitem__(self, idx: int):
        assert 0 <= idx < len(self)
        img = self.loader(self.samples[idx])
        if self.transform:
            img = self.transform(img)
        return to_pil_image(img)

    def __len__(self):
        return len(self.samples)


Transform

In [None]:
from torchvision import transforms

def get_target_image_size(resize_size=256, keep_ratio=True):
    return resize_size if keep_ratio else (resize_size, resize_size)

def get_transform(
    resize_size=256,
    keep_ratio=True,
    interpolation=transforms.InterpolationMode.BICUBIC,
):
    transforms_list = [
        transforms.Resize(
            get_target_image_size(resize_size, keep_ratio),
            interpolation=interpolation,
        ),
        transforms.ToTensor(),
    ]
    return transforms.Compose(transforms_list)

Data loader

In [None]:
from torch.utils.data import DataLoader

def collate_fn(batch):
    """Collate function for data loader. Allows to have img of different size"""
    return batch

def get_dataloader(
    data_dir,
    transform = get_transform(),
    batch_size=8,
    collate_fn=collate_fn,
):
    dataset = ImageFolder(data_dir, transform=transform)
    dataloader = DataLoader(
        dataset,
        batch_size=batch_size,
        collate_fn=collate_fn,
    )
    return dataloader

Pipeline

In [None]:
from transformers import pipeline

# https://huggingface.co/Falconsai/nsfw_image_detection
pipe = pipeline("image-classification",
                model="Falconsai/nsfw_image_detection",
                device="cuda")

Apply the workflow

In [None]:
import torch

from tqdm.auto import tqdm

# For my use case, this cell required ~ 25 minutes.
data_path = "img/"
batch_size = 8

sample_fnames = []
safety_scores = []

loader = get_dataloader(data_path, batch_size = batch_size)

with torch.no_grad():
  for ii, images in enumerate(tqdm(loader)):
    out = pipe(images)

    sample_fnames += [
        loader.dataset.samples[ii * batch_size + jj]
        for jj in range(len(out))
    ]

    for dd in out:
      safety_scores += [ d["score"] for d in dd
                        if d["label"] == "normal" ]

Collate the IDs with the scores. At the same time, display the worst offenders.

In [None]:
import torch

print(">>> Saving safety scores...")
fname = get_output_fname(pipe, suffix="_safety_scores", ext=".pth")
torch.save(torch.asarray(safety_scores, dtype=torch.float16), fname)

In [None]:
import json

print(">>> Saving image paths...")
fname = "img_list.json"
with Path(fname).open("w") as f:
    json.dump(sample_fnames, f)

In [None]:
import mediapy as media

safety_score_threshold = 0.005
img_size = (128, 128)

aggregate = {}
for image_path, safety_score in sorted(
    zip(sample_fnames, safety_scores),
    key=lambda x: x[1]):
  member_id = get_member_id(image_path)

  aggregate[member_id] = safety_score

  if safety_score < safety_score_threshold:
    image = media.read_image(image_path)
    image = media.resize_image(image, img_size)

    print(f"{member_id} {safety_score:.2}")
    media.show_image(image)

save_to_json(aggregate,
             get_output_fname(pipe))

## Check Stable Diffusion Safety Checker

- https://github.com/woctezuma/stable-diffusion-safety-checker

In [None]:
!curl -OL https://github.com/woctezuma/discord-members-metadata/releases/download/img/bad_concepts.json

In [None]:
import mediapy as media

NUM_CONCEPTS_THRESHOLD = 5
DISPLAY_SIZE = (128, 128)

data = load_from_json("bad_concepts.json")
d = {k:v for k,v in data.items() if len(v)>0}

for k,v in sorted(d.items(), key=lambda x: len(x[1]), reverse=True):
  print(f"{k} {len(v)}")
  media.show_image(media.resize_image(media.read_image(k), DISPLAY_SIZE))
  if len(v)<NUM_CONCEPTS_THRESHOLD:
    break

## Classify texts

TODO