In [113]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchtune
import pickle as pkl

from typing import Literal
from matplotlib import pyplot as plt
from tqdm import tqdm
from torchvision.transforms import v2
from torchsummary import summary
from torchmetrics.classification import MulticlassAccuracy, MulticlassPrecision, MulticlassRecall, MulticlassF1Score

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


# Data Loader

In [56]:
IMG_DIM = 128
NUM_CHANNELS = 3
BATCH_SIZE = 512
NORMALIZE_MEAN = (0.485,0.456,0.406)
NORMALIZE_STD = (0.229,0.224,0.225)
NUM_CLASSES = 90
NUM_REAL_IMG_PER_CLASS = 60
NUM_AI_IMG_PER_CLASS = 30
REAL_IMG_TRAIN_PERCENTAGE = 0.5
REAL_IMG_TEST_PERCENTAGE = 0.5

In [57]:
def get_subset_indices(
    num_img_per_class: int,
    percent: float = 1.0,
    side: Literal["left", "right"] = "left",
) -> np.ndarray:
    indices = []
    for i in range(NUM_CLASSES):
        base = i * num_img_per_class
        class_size = int(num_img_per_class * percent)
        start = 0 if side == "left" else num_img_per_class - class_size
        indices.extend(
            list(np.arange(base + start, base + start + class_size))
        )
    return np.array(indices, dtype=np.int32)

In [58]:
def get_loader(
    real_img_dir: str = "./real_animals",
    ai_img_dir: str = "./ai_animals",
    real_img_percent: float = 1.0,
    ai_img_percent: float = 0.0,
    batch_size: int = BATCH_SIZE,
    num_workers: int = 0,
    shuffle: bool = True,
) -> tuple[torch.utils.data.DataLoader, torch.utils.data.DataLoader]:
    """
    Get train/test dataloaders for real and AI-generated images.
    """
    transform = v2.Compose([
        v2.Resize((IMG_DIM, IMG_DIM)),
        v2.ToImage(),
        v2.ToDtype(torch.float32, scale=True),
        v2.Normalize(mean=NORMALIZE_MEAN, std=NORMALIZE_STD),
    ])

    train_real_img_subset = torch.utils.data.Subset(
        torchvision.datasets.ImageFolder(
            root=real_img_dir,
            transform=transform,
            allow_empty=True,
        ),
        indices=get_subset_indices(NUM_REAL_IMG_PER_CLASS, percent=REAL_IMG_TRAIN_PERCENTAGE * real_img_percent, side="left"),
    )
    test_real_img_subset = torch.utils.data.Subset(
        torchvision.datasets.ImageFolder(
            root=real_img_dir,
            transform=transform,
            allow_empty=True,
        ),
        indices=get_subset_indices(NUM_REAL_IMG_PER_CLASS, percent=REAL_IMG_TEST_PERCENTAGE, side="right"),
    )
    ai_img_subset = torch.utils.data.Subset(
        torchvision.datasets.ImageFolder(
            root=ai_img_dir,
            transform=transform,
            allow_empty=True,
        ),
        indices=get_subset_indices(NUM_AI_IMG_PER_CLASS, percent=ai_img_percent, side="left"),
    )

    train_dataset = torchtune.datasets.ConcatDataset(datasets=[train_real_img_subset, ai_img_subset]) if ai_img_percent > 0.0 else train_real_img_subset
    train_img_dataloader = torch.utils.data.DataLoader(
        dataset=train_dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=num_workers,
    )
    test_img_dataloader = torch.utils.data.DataLoader(
        dataset=test_real_img_subset,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=num_workers,
    )

    return train_img_dataloader, test_img_dataloader

In [59]:
train_dl, test_dl = get_loader(real_img_percent=1.0, ai_img_percent=1.0)

In [60]:
len(train_dl)

11

In [61]:
for img, label in tqdm(train_dl, total=len(train_dl)):
    pass

100%|██████████| 11/11 [00:50<00:00,  4.60s/it]


In [50]:
i = 0
class_counts = {}
for img, label in tqdm(train_dl, total=len(train_dl)):
    i += 1
    # if i > 4:
    #     break
    # print(img.shape, label.shape)
    for l in label:
        if l.item() not in class_counts:
            class_counts[l.item()] = 0
        class_counts[l.item()] += 1
print(class_counts)

100%|██████████| 22/22 [00:50<00:00,  2.30s/it]

{54: 60, 63: 60, 64: 60, 89: 60, 33: 60, 51: 60, 25: 60, 77: 60, 19: 60, 41: 60, 22: 60, 49: 60, 74: 60, 78: 60, 17: 60, 79: 60, 86: 60, 23: 60, 50: 60, 73: 60, 52: 60, 32: 60, 26: 60, 31: 60, 76: 60, 60: 60, 45: 60, 46: 60, 35: 60, 67: 60, 12: 60, 75: 60, 37: 60, 55: 60, 62: 60, 7: 60, 5: 60, 40: 60, 6: 60, 47: 60, 30: 60, 65: 60, 53: 60, 85: 60, 71: 60, 3: 60, 83: 60, 43: 60, 72: 60, 10: 60, 61: 60, 36: 60, 27: 60, 87: 60, 21: 60, 8: 60, 20: 60, 84: 60, 9: 60, 70: 60, 81: 60, 13: 60, 42: 60, 82: 60, 88: 60, 34: 60, 15: 60, 28: 60, 44: 60, 66: 60, 29: 60, 48: 60, 58: 60, 1: 60, 0: 60, 68: 60, 59: 60, 4: 60, 18: 60, 16: 60, 69: 60, 24: 60, 11: 60, 39: 60, 56: 60, 2: 60, 80: 60, 14: 60, 38: 60, 57: 60}





# Models

In [62]:
class CNN(nn.Module):
    def __init__(self, input_channels, n_classes):
        super(CNN, self).__init__()

        # set metadata
        self.input_channels = input_channels
        self.n_classes = n_classes
        self.FINAL_LAYER_SIZE = 4
        self.final_layer_channels = 40
        self.flatten_layer_size = self.final_layer_channels * self.FINAL_LAYER_SIZE * self.FINAL_LAYER_SIZE

        # dropout layer
        self.dropout50 = nn.Dropout(p=0.5)
        self.dropout10 = nn.Dropout(p=0.1)

        # set up layers
        self.conv1 = nn.Conv2d(in_channels=input_channels, out_channels=8, kernel_size=3, padding=1)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(in_channels=8, out_channels=16, kernel_size=3, padding=1)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv3 = nn.Conv2d(in_channels=16, out_channels=24, kernel_size=3, padding=1)
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv4 = nn.Conv2d(in_channels=24, out_channels=32, kernel_size=3, padding=1)
        self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv5 = nn.Conv2d(in_channels=32, out_channels=self.final_layer_channels, kernel_size=3, padding=1)
        self.pool5 = nn.MaxPool2d(kernel_size=2, stride=2)
        # self.conv6 = nn.Conv2d(in_channels=32, out_channels=self.final_layer_channels, kernel_size=3, padding=1)
        # self.pool6 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc1 = nn.Linear(self.flatten_layer_size, 128)
        self.fc2 = nn.Linear(128, n_classes)

    def forward(self, x):
        # 1: conv -> pool
        x = self.conv1(x)
        x = self.dropout10(torch.nn.functional.leaky_relu(x))
        x = self.pool1(x)

        # 2: conv -> pool
        x = self.conv2(x)
        x = self.dropout10(torch.nn.functional.leaky_relu(x))
        x = self.pool2(x)

        # 3: conv -> pool
        x = self.conv3(x)
        x = self.dropout10(torch.nn.functional.leaky_relu(x))
        x = self.pool3(x)

        # 4: conv -> pool
        x = self.conv4(x)
        x = self.dropout10(torch.nn.functional.leaky_relu(x))
        x = self.pool4(x)

        # 5: conv -> pool
        x = self.conv5(x)
        x = self.dropout10(torch.nn.functional.leaky_relu(x))
        x = self.pool5(x)

        # # 6: conv -> pool
        # x = self.conv6(x)
        # x = self.dropout10(torch.nn.functional.leaky_relu(x))
        # x = self.pool6(x)

        # flatten the features (the first dimension is batch size)
        x = x.view(-1, self.flatten_layer_size)

        # fc layers
        x = self.dropout50(torch.nn.functional.leaky_relu(self.fc1(x)))
        x = self.fc2(x)
        return x

In [63]:
summary(CNN(input_channels=3, n_classes=NUM_CLASSES).to(device), (3, IMG_DIM, IMG_DIM), batch_size=BATCH_SIZE, device=device.type)

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [512, 8, 128, 128]             224
           Dropout-2         [512, 8, 128, 128]               0
         MaxPool2d-3           [512, 8, 64, 64]               0
            Conv2d-4          [512, 16, 64, 64]           1,168
           Dropout-5          [512, 16, 64, 64]               0
         MaxPool2d-6          [512, 16, 32, 32]               0
            Conv2d-7          [512, 24, 32, 32]           3,480
           Dropout-8          [512, 24, 32, 32]               0
         MaxPool2d-9          [512, 24, 16, 16]               0
           Conv2d-10          [512, 32, 16, 16]           6,944
          Dropout-11          [512, 32, 16, 16]               0
        MaxPool2d-12            [512, 32, 8, 8]               0
           Conv2d-13            [512, 40, 8, 8]          11,560
          Dropout-14            [512, 4

In [None]:
class CombinedResNet50(nn.Module):
    def __init__(self, input_channels, n_classes):
        super(CNN, self).__init__()

        # set metadata
        self.input_channels = input_channels
    
    def forward(self, x):
        return x

In [None]:
summary(CombinedResNet50(input_channels=3, n_classes=NUM_CLASSES).to(device), (3, IMG_DIM, IMG_DIM), batch_size=BATCH_SIZE, device=device.type)

# Hyper Parameter Tuning

In [None]:
def plot_loss_accuracy(results) -> None:
    fig, ax = plt.subplots(nrows=1, ncols=3, figsize=(16, 4))
    fig.subplots_adjust(wspace=0.4)

    # ax[0].set_title("Loss")
    # ax[0].set_xlabel("Epoch")
    # ax[0].plot(results[0][0], label="Train Loss", color="blue")
    # ax[0].plot(results[1][0], label="Test Loss", color="red")
    # ax[0].legend()
    # ax[0].set_ylabel("Cross Entropy Loss")

    # ax[1].set_title("Accuracy")
    # ax[1].set_xlabel("Epoch")
    # ax[1].plot(results[0][1], label="Train Accuracy", color="blue")
    # ax[1].plot(results[1][1], label="Test Accuracy", color="red")
    # ax[1].legend()
    # ax[1].set_ylabel("Accuracy (%)")

    # ax[2].set_title("Precision and Recall")
    # ax[2].set_xlabel("Epoch")
    # ax[2].plot(results[2][0], label="Test Precision", color="cyan")
    # ax[2].plot(results[2][1], label="Test Recall", color="orange")
    # ax[2].legend()
    # ax[2].set_ylabel("Precision/Recall Score (0-1)")

In [None]:
def plot_improvement_by_class(results) -> None:
    pass

In [None]:
def plot_diff_barchart(results) -> None:
    pass

In [117]:
def train_and_test_model(
        model: nn.Module,
        optimizer: optim.Optimizer,
        train_loader: torch.utils.data.DataLoader,
        test_loader: torch.utils.data.DataLoader,
        E: int,
        verbose: Literal["none", "prints", "epoch_tqdm", "loader_tqdm"] = "epoch_tqdm",
    ):
    """
    Train and test the given model with the given parameters.
    """
    # model = torch.compile(model)
    loss_function = nn.CrossEntropyLoss().to(device)

    accuracy_metric = MulticlassAccuracy(average='none', num_classes=NUM_CLASSES).to(device)
    precision_metric = MulticlassPrecision(average='none', num_classes=NUM_CLASSES).to(device)
    recall_metric = MulticlassRecall(average='none', num_classes=NUM_CLASSES).to(device)
    f1_metric = MulticlassF1Score(average='none', num_classes=NUM_CLASSES).to(device)

    train_losses = []
    avg_train_accuracies = []
    test_losses = []
    test_accuracies = 0
    avg_test_accuracies = []
    test_precision = 0
    test_recall = 0
    test_f1score = 0
    avg_test_precision = 0
    avg_test_recall = 0
    avg_test_f1score = 0

    for epoch in tqdm(range(E), total=E, disable=verbose!="epoch_tqdm"):
        # TRAINING
        model.train()
        batch_losses = []
        accuracy_metric.reset()
        for images, labels in tqdm(train_loader, total=len(train_loader), disable=verbose!="loader_tqdm"):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = loss_function(outputs, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            batch_losses.append(loss.item())
            accuracy_metric.update(outputs, labels)
        train_loss = np.mean(np.array(batch_losses))
        train_losses.append(train_loss)
        train_acc = accuracy_metric.compute()
        avg_train_accuracies.append(train_acc.mean().item())

        # TESTING
        model.eval()
        test_batch_losses = []
        accuracy_metric.reset()
        precision_metric.reset()
        recall_metric.reset()
        for images, labels in tqdm(test_loader, total=len(test_loader), disable=verbose!="loader_tqdm"):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            test_batch_losses.append(loss_function(outputs, labels).item())
            accuracy_metric.update(outputs, labels)
            if epoch >= E - 1:
                precision_metric.update(outputs, labels)
                recall_metric.update(outputs, labels)
                f1_metric.update(outputs, labels)
        test_loss = np.mean(np.array(test_batch_losses))
        test_losses.append(test_loss)
        test_acc = accuracy_metric.compute()
        avg_test_accuracies.append(test_acc.mean().item())
        if epoch >= E - 1:
            test_accuracies = test_acc.cpu().numpy()
            test_precision = precision_metric.compute().cpu().numpy()
            test_recall = recall_metric.compute().cpu().numpy()
            test_f1score = f1_metric.compute().cpu().numpy()
            avg_test_precision = test_precision.mean().item()
            avg_test_recall = test_recall.mean().item()
            avg_test_f1score = test_f1score.mean().item()

        if verbose=="prints":
            print(f"Epoch [{epoch+1}/{E}]: Train Accuracy: {avg_train_accuracies[-1]*100:.2f}%, Train Loss: {train_loss:.4f}, Test Accuracy: {avg_test_accuracies[-1]*100:.2f}%, Test Loss: {test_loss:.4f}")

    print(f"\nEvaluation results:\nTrain Accuracy: {avg_train_accuracies[-1]*100:.2f}%, Train Loss: {train_loss:.4f}\nTest Accuracy: {avg_test_accuracies[-1]*100:.2f}%, Test Loss: {test_loss:.4f}")

    return {
        "train_losses": train_losses,
        "avg_train_accuracies": avg_train_accuracies,
        "test_losses": test_losses,
        "test_accuracies": test_accuracies,
        "avg_test_accuracies": avg_test_accuracies,
        "test_precision": test_precision,
        "test_recall": test_recall,
        "test_f1score": test_f1score,
        "avg_test_precision": avg_test_precision,
        "avg_test_recall": avg_test_recall,
        "avg_test_f1score": avg_test_f1score,
    }

In [118]:
train_loader, test_loader = get_loader(real_img_percent=1.0, ai_img_percent=0.0, num_workers=0)

In [119]:
model = CNN(input_channels=NUM_CHANNELS, n_classes=NUM_CLASSES).to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=0.0001)
results_1 = train_and_test_model(
    model=model,
    optimizer=optimizer,
    train_loader=train_loader,
    test_loader=test_loader,
    E=1,
    verbose="epoch_tqdm",
)

100%|██████████| 1/1 [00:54<00:00, 54.32s/it]


Evaluation results:
Train Accuracy: 1.11%, Train Loss: 4.5019
Test Accuracy: 1.11%, Test Loss: 4.4992





In [120]:
with open("results_1.pkl", "wb") as f:
    pkl.dump(results_1, f)

# Evaluation

# Plot Results