In [1]:
import json
from PIL import Image

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

import clip
from transformers import CLIPImageProcessor, CLIPModel
import pandas as pd
from tqdm import tqdm

from pathlib import Path

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
fake = [{'image' : str(s), 'label' : 1} for s in Path("train/fake").glob("*")]
real = [{'image' : str(s), 'label' : 0} for s in Path("train/real").glob("*")]
data = fake + real
pd.DataFrame(data).to_json("train.json", orient="records", lines=True)

In [128]:
class RealFakeDataset(torch.utils.data.Dataset):
    def __init__(self, json_path, processor):
        # Initialize image paths and corresponding texts
        self.jsons = pd.read_json(json_path, lines=True)
        self.processor = processor

    def __len__(self):
        return len(self.jsons)

    def __getitem__(self, idx):
        # Preprocess image using CLIP's preprocessing function
        image = self.processor(Image.open(self.jsons.iloc[idx]['image']))
        label = torch.tensor(self.jsons.iloc[idx]['label'])
        return image, label

class Classifier(nn.Module):
    def __init__(self, ckpt="openai/clip-vit-base-patch32"):
        super().__init__()
        self.model = CLIPModel.from_pretrained(ckpt).to(torch.float16)
        self.linear = nn.Linear(512, 2, dtype=torch.float16)
    def forward(self, x):
        image_features = self.model.get_image_features(**x)
        x = self.linear(image_features)
        return x
    def to(self, device):
        self.model.to(device)
        self.linear.to(device)
        return self

In [129]:
model = Classifier().to("cuda")
processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32")

In [130]:
optimizer = torch.optim.Adam(model.parameters(), lr=5e-5,betas=(0.9,0.98),eps=1e-6,weight_decay=0.2)
loss = nn.CrossEntropyLoss()
json_path = "train.json" 
train_dataset = RealFakeDataset(json_path, processor)
def collate_fn(batch):
    images = [torch.tensor(x[0]['pixel_values'][0]) for x in batch]
    labels = [x[1] for x in batch]
    images = {'pixel_values': torch.stack(images)}
    labels = torch.stack(labels)
    return images, labels

train_dataloader = DataLoader(train_dataset, batch_size=2, shuffle=True, collate_fn=collate_fn)

In [131]:
num_epochs = 30
for epoch in range(num_epochs):
    pbar = tqdm(train_dataloader, total=len(train_dataloader))
    for batch in pbar:
        optimizer.zero_grad()
        images, labels = batch 
        images['pixel_values'] = images['pixel_values'].to("cuda")
        labels = labels.to("cuda")
        # Forward pass
        linear_output = model(images)

        # Compute loss
        cls_loss = loss(linear_output, labels)

        # Backward pass
        cls_loss.backward()
        optimizer.step()

        pbar.set_description(f"Epoch {epoch}/{num_epochs}, Loss: {cls_loss.item():.4f}")

Epoch 0/30, Loss: 0.3193: 100%|██████████| 22/22 [00:01<00:00, 11.04it/s]
Epoch 1/30, Loss: 0.8389: 100%|██████████| 22/22 [00:02<00:00, 10.24it/s]
Epoch 2/30, Loss: 0.7407: 100%|██████████| 22/22 [00:02<00:00, 10.85it/s]
Epoch 3/30, Loss: 0.5903: 100%|██████████| 22/22 [00:02<00:00, 10.67it/s]
Epoch 4/30, Loss: 0.3643: 100%|██████████| 22/22 [00:02<00:00, 10.76it/s]
Epoch 5/30, Loss: 0.6782: 100%|██████████| 22/22 [00:02<00:00, 10.82it/s]
Epoch 6/30, Loss: 0.9692: 100%|██████████| 22/22 [00:02<00:00, 10.62it/s]
Epoch 7/30, Loss: 0.6157: 100%|██████████| 22/22 [00:02<00:00, 10.73it/s]
Epoch 8/30, Loss: 0.3064:  82%|████████▏ | 18/22 [00:01<00:00, 11.35it/s]