In [66]:
import pandas as pd
import numpy as np
import torch
import torchvision.transforms as torchvision_transforms
from torchvision.models import resnet18
from torchvision import datasets
from torch.utils.data import Dataset, DataLoader, random_split
from transformers import BertModel, BertTokenizer
import torch.nn as nn
import torch.optim as optim
import os
from PIL import Image
from tqdm import tqdm

In [67]:
class MultiModalModel(nn.Module):
    def __init__(self):
        super(MultiModalModel, self).__init__()
        # Load pre-trained models
        self.resnet = resnet18(pretrained=True)
        self.bert = BertModel.from_pretrained('bert-base-uncased')

        # Define concatonated layers
        self.multi_modal_layers = nn.Sequential(
            nn.Linear(in_features=self.resnet.fc.out_features + self.bert.config.hidden_size, out_features=512),
            nn.ReLU(),
            nn.Linear(512, 1)
        )
        self.sigmoid = nn.Sigmoid()

    def forward(self, image_inputs, text_inputs):
        # Process image input
        image_features = self.resnet(image_inputs)
        image_features = torch.flatten(image_features, 1)  # Flatten the features

        # Process text input
        text_features = self.bert(**text_inputs).last_hidden_state[:, 0, :]  # Get the [CLS] token's features

        # Concatenate features
        combined_features = torch.cat((image_features, text_features), dim=1)

        # Pass through additional layers
        output = self.multi_modal_layers(combined_features)

        output_binary = self.sigmoid(output)

        return output_binary

In [68]:
class MultiModalDataset(Dataset):
    def __init__(self, dataframe, ai_img_dir, real_img_dir, transform=None):
        self.dataframe = dataframe
        self.transform = transform
        self.text_idx = dataframe.columns.get_loc('Text')
        self.title_idx = dataframe.columns.get_loc('Title')
        self.image_idx = dataframe.columns.get_loc('Image')
        self.label_idx = dataframe.columns.get_loc('Label')
        # self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        self.ai_img_dir = ai_img_dir
        self.real_img_dir = real_img_dir

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        text = self.dataframe.iloc[idx, self.text_idx]
        title = self.dataframe.iloc[idx, self.title_idx]
        label = self.dataframe.iloc[idx, self.label_idx]

        img_folder = self.ai_img_dir if label == 1 else self.real_img_dir
        img_name = os.path.join(img_folder, str(self.dataframe.iloc[idx, self.image_idx]))
        image = Image.open(img_name).convert('RGB')

        # only tockenizing text for now
        # text = self.tokenizer(text, return_tensors="pt", padding='max_length', truncation=True, max_length=512)

        if self.transform:
            image = self.transform(image)

        return title, text, image, label

In [69]:
class ImageDataset(Dataset):
    """
    create image dataset for loading training images and calculating mean and std of normalization
    for image transforms in the MMM

    input: dataframe with Image, Label columns
    """
    def __init__(self, dataframe, ai_img_dir, real_img_dir, transform=None):
        self.dataframe = dataframe
        self.ai_img_dir = ai_img_dir
        self.real_img_dir = real_img_dir
        self.image_idx = dataframe.columns.get_loc('Image')
        self.label_idx = dataframe.columns.get_loc('Label')
        self.transform = transform

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):

        label = self.dataframe.iloc[idx, self.label_idx]
        
        img_folder = self.ai_img_dir if label == 1 else self.real_img_dir
        img_name = os.path.join(img_folder, str(self.dataframe.iloc[idx, self.image_idx]))
        image = Image.open(img_name).convert('RGB')

        if self.transform:
            image = self.transform(image)

        return image

In [70]:
def get_mean_std(loader):
    # Compute the mean and standard deviation of all pixels in the dataset
    num_pixels = 0
    mean = 0.0
    std = 0.0
    for images in loader:
        batch_size, num_channels, height, width = images.shape
        num_pixels += batch_size * height * width
        mean += images.mean(axis=(0, 2, 3)).sum()
        std += images.std(axis=(0, 2, 3)).sum()

    mean /= num_pixels
    std /= num_pixels

    return mean, std

def get_normalization_values(dataframe, ai_img_dir, real_img_dir):

    transforms = torchvision_transforms.Compose([
        torchvision_transforms.ToTensor(),
        torchvision_transforms.Resize((224, 224)),
        torchvision_transforms.CenterCrop(224),
    ])

    dataset = ImageDataset(dataframe, ai_img_dir, real_img_dir, transforms)
    
    batch_size = 32
    loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
    mean, std = get_mean_std(loader)
    return mean, std

In [72]:
# load data
def get_data(ai_dataset_path, real_dataset_path, ai_img_dir, real_img_dir):

    # get the datasets
    ai_data = pd.read_csv(ai_dataset_path)
    real_data = pd.read_csv(real_dataset_path)
    
    combined_data = pd.concat([ai_data, real_data])

    # shuffle dataset before hand
    combined_data_shuffled = combined_data.sample(frac=1).reset_index(drop=True)

    mean, std = get_normalization_values(combined_data[['Image', 'Label']], ai_img_dir, real_img_dir)

    print(mean, std)

    transform = torchvision_transforms.Compose([
        torchvision_transforms.Resize(256),
        torchvision_transforms.CenterCrop(224),
        torchvision_transforms.ToTensor(),
        torchvision_transforms.Normalize(mean=mean, std=std),
    ])

    # create dataset class
    dataset = MultiModalDataset(
        dataframe=combined_data,
        ai_img_dir=ai_img_dir,
        real_img_dir=real_img_dir,
        transform=transform
    )
    
    # Split the dataset
    train_size = int(0.8 * len(dataset))
    test_size = len(dataset) - train_size
    
    train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

    return train_dataset, test_dataset

In [2]:
def train(model, loader, criterion, optimizer, config):

    # Run training and track with wandb
    total_batches = len(loader) * config.epochs
    example_ct = 0  # number of examples seen
    batch_ct = 0
    for epoch in tqdm(range(config.epochs)):
        for title, text, image, label in loader:
            tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
            text = tokenizer(text, return_tensors="pt", padding='max_length', truncation=True, max_length=512)
            
            loss = train_batch(text, image, label, model, optimizer, criterion)
            example_ct +=  len(images)
            batch_ct += 1

            # Report metrics every 25th batch
            if ((batch_ct + 1) % 25) == 0:
                train_log(loss, example_ct, epoch)


def train_batch(text, image, label, model, optimizer, criterion):
    text, image, label = text.to(device), image.to(device), label.to(device)
    
    # Forward pass ➡
    output = model(texts, images)
    loss = criterion(output, label)
    
    # Backward pass ⬅
    optimizer.zero_grad()
    loss.backward()

    # Step with optimizer
    optimizer.step()

    return loss

In [3]:
def train_log(loss, example_ct, epoch):
    print(f"Loss after {str(example_ct).zfill(5)} examples: {loss:.3f}")

In [None]:
def test(model, test_loader):
    model.eval()

    # Run the model on some test examples
    with torch.no_grad():
        correct, total = 0, 0
        for title, text, image, label in test_loader:
            tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
            text = tokenizer(text, return_tensors="pt", padding='max_length', truncation=True, max_length=512)

            text, image, labels = text.to(device), image.to(device), labels.to(device)
            outputs = model(text, image)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        print(f"Accuracy of the model on the {total} " +
              f"test images: {correct / total:%}")
        
    # Save the model in the exchangeable ONNX format
    torch.save({
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            }, 'MMM.pth')

In [4]:
def model_pipeline(config):
    # make the model, data, and optimization problem
    
    model, train_loader, test_loader, criterion, optimizer = make(config)
    print(model)

    # and use them to train the model
    train(model, train_loader, criterion, optimizer, config)

    # and test its final performance
    test(model, test_loader)
    
    return model

In [None]:
def make(config):
    # Make the data
    train, test = get_data(
        config.ai_dataset_path,
        config.real_dataset_path,
        config.ai_img_dir,
        config.real_img_dir,
    )

    train_loader = DataLoader(train, batch_size=config.batch_size, shuffle=True)
    test_loader = DataLoader(test, batch_size=config.batch_size, shuffle=True)

    # Make the model
    model = MultiModalModel().to(device)

    criterion = nn.BCELoss()
    optimizer = optim.Adam(model.parameters(), lr=config.learning_rate)
    
    return model, train_loader, test_loader, criterion, optimizer

In [None]:
config = {
    epochs: 5,
    learning_rate: 0.001,
    batch_size: 32,
    ai_dataset_path: "../../data/ai_scraping/newsgpt_dataset.csv",
    real_dataset_path: "../../data/real_scraping/cnn_dataset.csv",
    ai_img_dir: "../../data/ai_scraping/newsgpt_images",
    real_img_dir: "../../data/real_scraping/cnn_images/",
}

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model_pipeline(config)