In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchtune

from typing import Literal
from matplotlib import pyplot as plt
from tqdm import tqdm
from torchvision.transforms import v2
from torchsummary import summary
from torchmetrics.classification import MulticlassAccuracy, MulticlassPrecision, MulticlassRecall, MulticlassF1Score

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Data Loader

In [None]:
IMG_DIM = 256
NUM_CHANNELS = 3
BATCH_SIZE = 128
NORMALIZE_MEAN = (0.485,0.456,0.406)
NORMALIZE_STD = (0.229,0.224,0.225)
NUM_CLASSES = 90
NUM_REAL_IMG_PER_CLASS = 60
REAL_IMG_TRAIN_PERCENTAGE = 0.5
REAL_IMG_TEST_PERCENTAGE = 0.5

In [None]:
def get_subset_indices(
    percent: float = 1.0,
    side: Literal["left", "right"] = "left",
) -> np.ndarray:
    indices = []
    for i in range(NUM_CLASSES):
        base = i * NUM_REAL_IMG_PER_CLASS
        class_size = int(NUM_REAL_IMG_PER_CLASS * percent)
        start = 0 if side == "left" else NUM_REAL_IMG_PER_CLASS - class_size
        indices.extend(
            list(np.arange(base + start, base + start + class_size))
        )
    return np.array(indices, dtype=np.int32)

In [None]:
def get_loader(
    real_img_dir: str = "./real_animals",
    ai_img_dir: str = "./ai_animals",
    real_img_percent: float = 1.0,
    ai_img_percent: float = 0.0,
    batch_size: int = BATCH_SIZE,
    num_workers: int = 0,
    shuffle: bool = True,
) -> tuple[torch.utils.data.DataLoader, torch.utils.data.DataLoader]:
    """
    Get train/test dataloaders for real and AI-generated images.
    """
    transform = v2.Compose([
        v2.Resize((IMG_DIM, IMG_DIM)),
        v2.ToImage(),
        v2.ToDtype(torch.float32, scale=True),
        v2.Normalize(mean=NORMALIZE_MEAN, std=NORMALIZE_STD),
    ])

    train_real_img_subset = torch.utils.data.Subset(
        torchvision.datasets.ImageFolder(
            root=real_img_dir,
            transform=transform,
            allow_empty=True,
        ),
        indices=get_subset_indices(percent=REAL_IMG_TRAIN_PERCENTAGE * real_img_percent, side="left"),
    )
    test_real_img_subset = torch.utils.data.Subset(
        torchvision.datasets.ImageFolder(
            root=real_img_dir,
            transform=transform,
            allow_empty=True,
        ),
        indices=get_subset_indices(percent=REAL_IMG_TEST_PERCENTAGE, side="right"),
    )
    ai_img_subset = torch.utils.data.Subset(
        torchvision.datasets.ImageFolder(
            root=ai_img_dir,
            transform=transform,
            allow_empty=True,
        ),
        indices=get_subset_indices(percent=ai_img_percent, side="left"),
    )

    train_dataset = torchtune.datasets.ConcatDataset(datasets=[train_real_img_subset, ai_img_subset]) if ai_img_percent > 0.0 else train_real_img_subset
    train_img_dataloader = torch.utils.data.DataLoader(
        dataset=train_dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=num_workers,
    )
    test_img_dataloader = torch.utils.data.DataLoader(
        dataset=test_real_img_subset,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=num_workers,
    )

    return train_img_dataloader, test_img_dataloader

In [None]:
train_dl, test_dl = get_loader(real_img_percent=1.0, ai_img_percent=0.0)

In [None]:
i = 0
class_counts = {}
for img, label in train_dl:
    i += 1
    if i > 4:
        break
    print(img.shape, label.shape)
#     for l in label:
#         if l.item() not in class_counts:
#             class_counts[l.item()] = 0
#         class_counts[l.item()] += 1
# print(class_counts)

# Models

In [None]:
class CNN(nn.Module):
    def __init__(self, input_channels, n_classes):
        super(CNN, self).__init__()

        # set metadata
        self.input_channels = input_channels
        self.n_classes = n_classes
        self.FINAL_LAYER_SIZE = 8
        self.final_layer_channels = 40
        self.flatten_layer_size = self.final_layer_channels * self.FINAL_LAYER_SIZE * self.FINAL_LAYER_SIZE

        # dropout layer
        self.dropout50 = nn.Dropout(p=0.5)
        self.dropout10 = nn.Dropout(p=0.1)

        # set up layers
        self.conv1 = nn.Conv2d(in_channels=input_channels, out_channels=8, kernel_size=3, padding=1)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(in_channels=8, out_channels=16, kernel_size=3, padding=1)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv3 = nn.Conv2d(in_channels=16, out_channels=24, kernel_size=3, padding=1)
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv4 = nn.Conv2d(in_channels=24, out_channels=32, kernel_size=3, padding=1)
        self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)
        # self.conv5 = nn.Conv2d(in_channels=32, out_channels=40, kernel_size=3, padding=1)
        # self.pool5 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv5 = nn.Conv2d(in_channels=32, out_channels=self.final_layer_channels, kernel_size=3, padding=1)
        self.pool5 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc1 = nn.Linear(self.flatten_layer_size, 128)
        self.fc2 = nn.Linear(128, n_classes)

    def forward(self, x):
        # 1: conv -> pool
        x = self.conv1(x)
        x = self.dropout10(torch.nn.functional.leaky_relu(x))
        x = self.pool1(x)

        # 2: conv -> pool
        x = self.conv2(x)
        x = self.dropout10(torch.nn.functional.leaky_relu(x))
        x = self.pool2(x)

        # 3: conv -> pool
        x = self.conv3(x)
        x = self.dropout10(torch.nn.functional.leaky_relu(x))
        x = self.pool3(x)

        # 4: conv -> pool
        x = self.conv4(x)
        x = self.dropout10(torch.nn.functional.leaky_relu(x))
        x = self.pool4(x)

        # 5: conv -> pool
        x = self.conv5(x)
        x = self.dropout10(torch.nn.functional.leaky_relu(x))
        x = self.pool5(x)

        # # 6: conv -> pool
        # x = self.conv6(x)
        # x = self.dropout10(torch.nn.functional.leaky_relu(x))
        # x = self.pool6(x)

        # flatten the features (the first dimension is batch size)
        x = x.view(-1, self.flatten_layer_size)

        # fc layers
        x = self.dropout50(torch.nn.functional.leaky_relu(self.fc1(x)))
        x = self.fc2(x)
        return x

In [None]:
summary(CNN(input_channels=3, n_classes=NUM_CLASSES).to(device), (3, IMG_DIM, IMG_DIM), batch_size=BATCH_SIZE, device=device.type)

In [None]:
class CombinedResNet50(nn.Module):
    def __init__(self, input_channels, n_classes):
        super(CNN, self).__init__()

        # set metadata
        self.input_channels = input_channels
    
    def forward(self, x):
        return x

In [None]:
summary(CombinedResNet50(input_channels=3, n_classes=NUM_CLASSES).to(device), (3, IMG_DIM, IMG_DIM), batch_size=BATCH_SIZE, device=device.type)

# Hyper Parameter Tuning

In [None]:
def plot_loss_accuracy(results) -> None:
    fig, ax = plt.subplots(nrows=1, ncols=3, figsize=(16, 4))
    fig.subplots_adjust(wspace=0.4)

    # ax[0].set_title("Loss")
    # ax[0].set_xlabel("Epoch")
    # ax[0].plot(results[0][0], label="Train Loss", color="blue")
    # ax[0].plot(results[1][0], label="Test Loss", color="red")
    # ax[0].legend()
    # ax[0].set_ylabel("Cross Entropy Loss")

    # ax[1].set_title("Accuracy")
    # ax[1].set_xlabel("Epoch")
    # ax[1].plot(results[0][1], label="Train Accuracy", color="blue")
    # ax[1].plot(results[1][1], label="Test Accuracy", color="red")
    # ax[1].legend()
    # ax[1].set_ylabel("Accuracy (%)")

    # ax[2].set_title("Precision and Recall")
    # ax[2].set_xlabel("Epoch")
    # ax[2].plot(results[2][0], label="Test Precision", color="cyan")
    # ax[2].plot(results[2][1], label="Test Recall", color="orange")
    # ax[2].legend()
    # ax[2].set_ylabel("Precision/Recall Score (0-1)")

In [None]:
def plot_improvement_by_class(results) -> None:
    pass

In [None]:
def plot_diff_barchart(results) -> None:
    pass

In [None]:
def train_and_test_model(
        model: nn.Module,
        optimizer: optim.Optimizer,
        train_loader: torch.utils.data.DataLoader,
        test_loader: torch.utils.data.DataLoader,
        E: int,
        verbose: bool = True,
    ):
    """
    Train and test the given model with the given parameters.
    """
    # TODO: see if worth it
    # model = torch.compile(model)
    loss_function = nn.CrossEntropyLoss().to(device)

    # TODO: switch accuracy to none
    accuracy_metric = MulticlassAccuracy(average='micro', num_classes=NUM_CLASSES).to(device)
    precision_metric = MulticlassPrecision(average='none', num_classes=NUM_CLASSES).to(device)
    recall_metric = MulticlassRecall(average='none', num_classes=NUM_CLASSES).to(device)
    f1_metric = MulticlassF1Score(average='none', num_classes=NUM_CLASSES).to(device)

    train_losses = []
    train_accuracies = []
    test_losses = []
    test_accuracies = []
    test_precision = 0
    test_recall = 0
    test_f1score = 0

    for epoch in range(E):
        # TRAINING
        model.train()
        batch_losses = []
        accuracy_metric.reset()
        for images, labels in tqdm(train_loader, total=len(train_loader)):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = loss_function(outputs, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            batch_losses.append(loss.item())
            accuracy_metric.update(outputs, labels)
        train_loss = np.mean(np.array(batch_losses))
        train_acc = accuracy_metric.compute().item()
        train_losses.append(train_loss)
        train_accuracies.append(train_acc)

        # TESTING
        model.eval()
        test_batch_losses = []
        accuracy_metric.reset()
        precision_metric.reset()
        recall_metric.reset()
        for images, labels in tqdm(test_loader, total=len(test_loader)):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            test_batch_losses.append(loss_function(outputs, labels).item())
            accuracy_metric.update(outputs, labels)
            # if epoch >= E - 1:
            #     precision_metric.update(outputs, labels)
            #     recall_metric.update(outputs, labels)
            #     f1_metric.update(outputs, labels)
        test_loss = np.mean(np.array(test_batch_losses))
        test_acc = accuracy_metric.compute().item()
        test_losses.append(test_loss)
        test_accuracies.append(test_acc)
        # if epoch >= E - 1:
        #     test_precision = precision_metric.compute().item()
        #     test_recall = recall_metric.compute().item()
        #     test_f1score = f1_metric.compute().item()

        if verbose:
            print(f"Epoch [{epoch+1}/{E}]: Train Accuracy: {train_acc*100:.2f}%, Train Loss: {train_loss:.4f}, Test Accuracy: {test_acc*100:.2f}%, Test Loss: {test_loss:.4f}")

    print(f"\nEvaluation results:\nTrain Accuracy: {train_acc*100:.2f}%, Train Loss: {train_loss:.4f}\nTest Accuracy: {test_acc*100:.2f}%, Test Loss: {test_loss:.4f}")

    return 27

In [None]:
train_loader, test_loader = get_loader(real_img_percent=1.0, ai_img_percent=0.0, num_workers=0)

In [None]:
model = CNN(input_channels=NUM_CHANNELS, n_classes=NUM_CLASSES).to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=0.0001)
results_1 = train_and_test_model(
    model=model,
    optimizer=optimizer,
    train_loader=train_loader,
    test_loader=test_loader,
    E=1,
    verbose=True,
)

# Evaluation

# Plot Results