# Label changes

A simple workflow to label if a Snapchat lens has actually changed an image or not.
This can be used to build up a small dataset (a few hundred images) for train_change_clf.py.

In [None]:
import os
import json
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm.auto import tqdm
import shutil
from PIL import Image

In [2]:
input_dir = "/Volumes/MLData2/laion-icons-selfies/filtered/download"
output_dir = "/Volumes/MLData2/outputs/cartoon_kid"
label_dir = 'change_labeled_cartoon'


In [4]:
def load_images(name):
    img1 = Image.open(os.path.join(input_dir, name)).convert('RGB')
    img2 = Image.open(os.path.join(output_dir, name)).convert('RGB')
    return img1, img2

def side_by_side(name):
    img1, img2 = load_images(name)
    result = Image.fromarray(np.concatenate([np.array(img1), np.array(img2)], axis=1))
    result.thumbnail((512, 512))
    return result

In [5]:
def save_label(name, label):
    os.makedirs(os.path.join(label_dir, name))
    with open(os.path.join(label_dir, name, 'label.json'), 'w') as f:
        json.dump(label, f)
    shutil.copyfile(os.path.join(input_dir, name), os.path.join(label_dir, name, 'input.png'))
    shutil.copyfile(os.path.join(output_dir, name), os.path.join(label_dir, name, 'output.png'))

def unlabel(name):
    shutil.rmtree(os.path.join(label_dir, name))

In [23]:
out_names = os.listdir(output_dir)

In [26]:
prev_name = None
for name in out_names:
    if os.path.exists(os.path.join(label_dir, name)):
        continue
    try:
        display(side_by_side(name))
    except:
        continue
    while True:
        label = input('Did the filter work? y/n:').strip()
        if label == 'y':
            save_label(name, True)
        elif label == 'n':
            save_label(name, False)
        elif label == 'x':
            print('deleting', prev_name)
            unlabel(prev_name)
            continue
        else:
            continue
        prev_name = name
        break

In [27]:
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.in_layers = nn.Sequential(
            nn.Conv2d(6, 32, 3, stride=2),
            nn.ReLU(),
            nn.Conv2d(32, 32, 3, stride=2),
            nn.ReLU(),
        )
        self.out_layers = nn.Sequential(
            nn.Linear(32, 1),
        )

    def apply_to_pair(self, img1: Image.Image, img2: Image.Image) -> torch.Tensor:
        img1 = img1.convert('RGB')
        img2 = img2.convert('RGB')
        img1.thumbnail((256, 256))
        img2.thumbnail((256, 256))
        arr = np.concatenate([np.array(img1), np.array(img2)], axis=-1)
        tensor = torch.from_numpy(arr).permute(2, 0, 1)[None].float() / 127.5 - 1
        return self(tensor)

    def forward(self, images: torch.Tensor) -> torch.Tensor:
        spatial_out = self.in_layers(images)
        spatial_out = spatial_out.flatten(2).mean(-1)
        return self.out_layers(spatial_out)

In [9]:
def load_pair(name: str) -> tuple[Image.Image, Image.Image, bool]:
    img1 = Image.open(os.path.join(label_dir, name, 'input.png'))
    img2 = Image.open(os.path.join(label_dir, name, 'output.png'))
    with open(os.path.join(label_dir, name, 'label.json'), 'r') as f:
        label = json.load(f)
    return img1, img2, label

In [6]:
names = [x for x in os.listdir(label_dir) if not x.startswith('.')]
random.shuffle(names)
num_test = 40
train_names = names[num_test:]
test_names = names[:num_test]

In [None]:
def eval_loss(model, name: str) -> torch.Tensor:
    img1, img2, label = load_pair(name)
    pred = model.apply_to_pair(img1, img2)
    return (
        F.binary_cross_entropy_with_logits(pred[0], torch.tensor([float(label)])),
        (pred.item() > 0) == label
    )

In [None]:
model = Model()
opt = torch.optim.Adam(model.parameters(), lr=1e-3)
os.makedirs('ckpt_cartoon', exist_ok=True)

for epoch in range(100):
    train_losses, train_corr = [], []
    for name in tqdm(train_names):
        img1, img2, label = load_pair(name)
        loss, correct = eval_loss(model, name)
        loss.backward()
        opt.step()
        opt.zero_grad()
        train_losses.append(loss.item())
        train_corr.append(correct)
    test_losses, test_corr = [], []
    for name in tqdm(test_names):
        with torch.no_grad():
            loss, correct = eval_loss(model, name)
            test_losses.append(loss.item())
            test_corr.append(correct)
    print(f'epoch {epoch}: test_acc={np.mean(test_corr)} train_acc={np.mean(train_corr)} test_loss={np.mean(test_losses)} train_loss={np.mean(train_losses)}')
    with open(f'ckpt_cartoon/model_{epoch}.pt', 'wb') as f:
        torch.save(model.state_dict(), f)

In [None]:
model = Model()
model.load_state_dict(torch.load('ckpt/model_99.pt'))

In [None]:
for name in tqdm(names):
    _, correct = eval_loss(name)
    if not correct:
        print('name', name)