In [None]:
train_real_img_dir = '/mnt/d/data/learning/ML/defake/mscoco/train2014'
train_fake_img_dir = '/mnt/d/data/learning/ML/defake/mscoco/train2014'   # generated

real_prompt_file = '/mnt/d/data/learning/ML/defake/mscoco/real_prompts.txt' 
real_img_fake_prompt_file = '/mnt/d/data/learning/ML/defake/mscoco/real_img_fake_prompts.txt' # blip generated
fake_img_fake_prompt_file = '/mnt/d/data/learning/ML/defake/mscoco/fake_img_fake_prompts.txt' # blip generated

model_save_dir = '/home/b11115030/cybersecurity/De-Fake/checkpoints/hybrid_generated'

In [None]:
# %%
import torch
import torch.nn as nn
from torchvision.models import resnet18
import clip

class ImageOnlyDetector(nn.Module):
	def __init__(self, pretrained=False, num_classes=2):
		super().__init__()
		self.model = resnet18(pretrained=False, num_classes=2)
	
	# input shape: (batch_size, num_channels, height, width)
	def forward(self, img):
		return self.model(img)


class HybridDetector(nn.Module):
	def __init__(self, device='cuda' if torch.cuda.is_available() else 'cpu'):
		super().__init__()
		self.clip_encoder, self.preprocess = clip.load("ViT-B/32", device=device)
		self.mlp = nn.Sequential(
			nn.Linear(1024, 512),
			nn.ReLU(),
   			nn.Linear(512, 256),
      		nn.ReLU(),
			nn.Linear(256, 2)
		)
	
	def forward(self, img, text):
		# img = self.preprocess(img)
		with torch.no_grad():
			text = clip.tokenize(text).to(img.device)
			img_emb = self.clip_encoder.encode_image(img)
			text_emb = self.clip_encoder.encode_text(text)	
		emb = torch.cat((img_emb, text_emb), dim=1).float()
		return self.mlp(emb)


In [None]:
import os
from natsort import natsorted
from torch.utils.data import Dataset
from PIL import Image

class RealDataset(Dataset):
    def __init__(self, image_dir, prompts_file, transform=None):
        self.image_dir = image_dir
        self.transform = transform
        self.image_filenames = natsorted(os.listdir(image_dir))
        self.prompts = self.load_prompts(prompts_file)

    def load_prompts(self, prompts_file):
        with open(prompts_file, 'r') as file:
            prompts = file.readlines()
        prompts = [prompt.strip() for prompt in prompts]  # Remove any extra whitespace
        return prompts

    def __len__(self):
        return len(self.image_filenames)

    def __getitem__(self, idx):
        filename = self.image_filenames[idx]
        image_path = os.path.join(self.image_dir, filename)
        image = Image.open(image_path).convert("RGB")
        
        if self.transform:
            image = self.transform(image)
        
        label = 0
        prompt = self.prompts[idx] if idx < len(self.prompts) else ""

        return image, prompt, label

class FakeDataset(Dataset):
    def __init__(self, image_dir, prompts_file, transform=None):
        self.image_dir = image_dir
        self.transform = transform
        self.image_filenames = natsorted(os.listdir(image_dir))
        self.prompts = self.load_prompts(prompts_file)

    def load_prompts(self, prompts_file):
        with open(prompts_file, 'r') as file:
            prompts = file.readlines()
        prompts = [prompt.strip() for prompt in prompts]
        return prompts

    def __len__(self):
        return len(self.image_filenames)

    def __getitem__(self, idx):
        filename = self.image_filenames[idx]
        image_path = os.path.join(self.image_dir, filename)
        image = Image.open(image_path).convert("RGB")
        
        if self.transform:
            image = self.transform(image)
        
        label = 1
        prompt = self.prompts[idx] if idx < len(self.prompts) else ""

        return image, prompt, label

In [None]:
device = 'cuda'

seed = 42
train_ratio = 0.8
val_ratio = 0.1
batch_size = 32
n_epochs = 50
learning_rate = 3e-4

save_interval = 10 # epoch

In [None]:
import torch
torch.manual_seed(seed)

In [None]:
from torchvision import transforms

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

real_imgs = RealDataset(train_real_img_dir, real_prompt_file, transform)
fake_imgs = FakeDataset(train_real_img_dir, real_prompt_file, transform)

In [None]:
import torch
from torch.utils.data import DataLoader, random_split, ConcatDataset
from tqdm import tqdm
import os
import numpy as np


dataset = ConcatDataset([real_imgs, fake_imgs]) # combine 
size = len(dataset)
train_size = int(size * train_ratio)
val_size = int(size * val_ratio)
test_size = size - train_size - val_size

train_set, val_set, test_set = random_split(dataset, lengths=[train_size, val_size, test_size])
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False)


model = ImageOnlyDetector().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss()

best_val_loss = float('inf')

model.train()
for epoch in range(1, n_epochs + 1):
	print(f'Epoch: {epoch}/{n_epochs}: ')
	model.train()
 	
	train_bar = tqdm(train_loader, desc='Training')
	for batch in train_bar:
		images, prompts, labels = batch
		images, labels = images.to(device), labels.to(device)

		outputs = model(images)
		optimizer.zero_grad()
		loss = criterion(outputs, labels)
		loss.backward()
		optimizer.step()

		train_bar.set_postfix(loss=loss.item())

		if epoch % save_interval == 0: 
			torch.save(model.state_dict(), os.path.join(model_save_dir, f'checkpoint_{epoch}.pt'))


	model.eval()
	with torch.no_grad():
		val_bar = tqdm(val_loader, desc='Validation')
		for batch in val_bar:
			images, prompts, labels = batch
			images, labels = images.to(device), labels.to(device)
			outputs = model(images)
			loss = criterion(outputs, labels)
			val_bar.set_postfix(loss=loss.item())
   
			if loss < best_val_loss:
				best_val_loss = loss
				torch.save(model.state_dict(), os.path.join(model_save_dir, 'best_model.pt'))
	

In [None]:
from sklearn.metrics import accuracy_score, recall_score, precision_score, f1_score, confusion_matrix

test_loss = 0
correct = 0
model.eval()
all_preds = []
all_targets = []
with torch.no_grad():
	test_bar = tqdm(test_loader, desc='Testing')
	for batch in test_bar:
		images, prompts, labels = batch
		images, labels = images.to(device), labels.to(device)
		outputs = model(images)
		_, preds = torch.max(outputs, 1)

		all_preds.extend(preds.cpu().numpy())
		all_targets.extend(labels.cpu().numpy())

all_preds = np.array(all_preds)
all_targets = np.array(all_targets)

accuracy = accuracy_score(all_targets, all_preds)
recall = recall_score(all_targets, all_preds, average='weighted')
precision = precision_score(all_targets, all_preds, average='weighted')
f1 = f1_score(all_targets, all_preds, average='weighted')

print(f'Accuracy: {accuracy:.4f}, Recall: {recall:.4f}, Precision: {precision:.4f}, F1 Score: {f1:.4f}')