In [None]:
import json
from PIL import Image

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

import clip
from transformers import CLIPImageProcessor, CLIPModel
import pandas as pd
from tqdm import tqdm

from pathlib import Path

In [None]:
fake = [{'image' : str(s), 'label' : 1} for s in Path("train/fake").glob("*")]
real = [{'image' : str(s), 'label' : 0} for s in Path("train/real").glob("*")]
data = fake + real
pd.DataFrame(data).to_json("train.json", orient="records", lines=True)

adversarial = [{'image' : str(s), 'label' : 0} for s in Path("train/adversarial").glob("*")]
pd.DataFrame(adversarial).to_json("train_adv.json", orient="records", lines=True)

In [None]:
class RealFakeDataset(torch.utils.data.Dataset):
    def __init__(self, json_path, processor):
        # Initialize image paths and corresponding texts
        self.jsons = pd.read_json(json_path, lines=True)
        self.processor = processor

    def __len__(self):
        return len(self.jsons)

    def __getitem__(self, idx):
        # Preprocess image using CLIP's preprocessing function
        image = self.processor(Image.open(self.jsons.iloc[idx]['image']))
        label = torch.tensor(self.jsons.iloc[idx]['label'])
        return image, label

class Classifier(nn.Module):
    def __init__(self, ckpt="openai/clip-vit-base-patch32"):
        super().__init__()
        self.model = CLIPModel.from_pretrained(ckpt).to(torch.float16)
        self.linear = nn.Linear(512, 2, dtype=torch.float16)
    def forward(self, x):
        image_features = self.model.get_image_features(**x)
        x = self.linear(image_features)
        return x
    def to(self, device):
        self.model.to(device)
        self.linear.to(device)
        return self
    
def collate_fn(batch):
    images = [torch.tensor(x[0]['pixel_values'][0]) for x in batch]
    labels = [x[1] for x in batch]
    images = {'pixel_values': torch.stack(images)}
    labels = torch.stack(labels)
    return images, labels

In [None]:
model = Classifier().to("cuda")
processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32")

optimizer = torch.optim.Adam(model.parameters(), lr=5e-5,betas=(0.9,0.98),eps=1e-6,weight_decay=0.2)
loss = nn.CrossEntropyLoss()


json_path = "train.json" 
train_dataset = RealFakeDataset(json_path, processor)
train_dataloader = DataLoader(train_dataset, batch_size=2, shuffle=True, collate_fn=collate_fn)

In [None]:
num_epochs = 30
for epoch in range(num_epochs):
    pbar = tqdm(train_dataloader, total=len(train_dataloader))
    for batch in pbar:
        optimizer.zero_grad()
        images, labels = batch 
        images['pixel_values'] = images['pixel_values'].to("cuda")
        labels = labels.to("cuda")
        # Forward pass
        output = model(images)

        # Compute loss
        cls_loss = loss(output, labels)

        # Backward pass
        cls_loss.backward()
        optimizer.step()

        pbar.set_description(f"Epoch {epoch}/{num_epochs}, Loss: {cls_loss.item():.4f}")

torch.save(model.state_dict(), "model.pth")

In [None]:
classifier = Classifier()
classifier.load_state_dict(torch.load("model.pth"))
classifier = classifier.to("cuda")

json_path = "train_adv.json" 
processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32")
train_dataset = RealFakeDataset(json_path, processor)
train_dataloader = DataLoader(train_dataset, batch_size=5, shuffle=True, collate_fn=collate_fn)

class Noise(nn.Module):
    def __init__(self, zeros, scale=1):
        super().__init__()
        self.scale = scale
        self.noise = nn.Parameter(zeros)
    def forward(self, x):
        return x + self.scale * self.noise
    
noise = Noise(torch.zeros([1,3] + 2*[processor.size['shortest_edge']])).to("cuda")
optimizer = torch.optim.Adam(noise.parameters(), lr=5e-5,betas=(0.9,0.98),eps=1e-6, weight_decay=0.2)
loss = nn.CrossEntropyLoss()
num_epochs = 1000

In [None]:
for epoch in range(num_epochs):
    pbar = tqdm(train_dataloader, total=len(train_dataloader))
    for batch in pbar:
        optimizer.zero_grad()
        images, labels = batch 
        images['pixel_values'] = images['pixel_values'].to("cuda")
        labels = labels.to("cuda")
        images['pixel_values'] += noise(images['pixel_values'])

        # Forward pass
        output = classifier(images)

        # Compute loss
        cls_loss = loss(output, labels)

        # Backward pass
        cls_loss.backward()
        optimizer.step()
        # update noise with grad
        pbar.set_description(f"Epoch {epoch}/{num_epochs}, Loss: {cls_loss.item():.4f}")

In [None]:
# display the noise
import matplotlib.pyplot as plt
plt.imshow(noise.noise[0].detach().cpu().numpy())

In [51]:
# test on fake image

fake = Image.open("train/adversarial/0_7m_5oHoyfSdSYFDB.png")
fake = processor(fake)
fake['pixel_values'] = noise(torch.tensor(fake['pixel_values'][0]).to("cuda"))

output = classifier(fake)
pfake = torch.softmax(output, 1)[0, 1].item()
pfake

0.98681640625