In [None]:
import torch
import numpy as np
from torchvision import transforms, datasets
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
from torch.utils.data import random_split


In [None]:
transform = transforms.Compose([
    transforms.Resize((32, 32)), # Resize images to 32x32
    transforms.ToTensor(), # Convert to tensor of shape:(3, 32, 32) and values: [0, 1]
    transforms.Lambda(lambda x: x.view(-1))  # Flatten 3D temsor to 1D vector for sklearn
])

#Path variables to data directory
data_path = 'SENG_474_Dataset_Filtered/SENG_474_Dataset_Filtered'
full_dataset = datasets.ImageFolder(data_path, transform=transform)

# 3-way split split
train_size = int(0.6 * len(full_dataset))
valid_size = int(0.2 * len(full_dataset))
test_size = int(0.2 * len(full_dataset))


# used same seed as CNN AI art detector
seed = 42
generator = torch.Generator().manual_seed(seed)

#split the data
train_dataset, valid_dataset, test_dataset = random_split(full_dataset, [train_size, valid_size, test_size], generator=generator)

# Convert datasets to numpy arrays for sklearn compatibility
def dataset_to_numpy(dataset):
    X, y = [], []
    for img, label in dataset:
        X.append(img.numpy())  # Flattened tensor to numpy array
        y.append(label)
    return np.stack(X), np.array(y)

# Note: we dont need to use DataLoader here since sklearn uses numpy arrays
X_train, y_train = dataset_to_numpy(train_dataset)
X_valid, y_valid = dataset_to_numpy(valid_dataset)
X_test,  y_test  = dataset_to_numpy(test_dataset)




Training Inspiration from: https://medium.com/@MudSnail/the-importance-of-logistic-regression-in-image-classification-1966d07e7a0c

In [None]:
# Train the logistic regression model using sklearn
# increase default max_iter=100 to ensure convergence
clf = LogisticRegression(solver='saga', max_iter=1000)
clf.fit(X_train, y_train)

In [None]:
# Evaluate the model on validation and test sets
val_acc = clf.score(X_valid, y_valid)
test_acc = clf.score(X_test, y_test)

print(f"\n✅ Logistic Regression Baseline")
print(f"Validation Accuracy: {val_acc:.4f}")
print(f"Test Accuracy:       {test_acc:.4f}")


✅ Logistic Regression Baseline
Validation Accuracy: 0.5920
Test Accuracy:       0.5803
