In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models, transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from tqdm import tqdm
import timm
import random
import numpy as np
import cv2
from skimage.feature import hog, local_binary_pattern
from skimage import color

BASE_DIR = "cifakemini"  
BATCH_SIZE = 32
EPOCHS = 15
NUM_CLASSES = 2  
IMG_SIZE = 224
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5]*3, std=[0.5]*3)
])

def extract_hog_features(image):
    image_gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    fd, hog_image = hog(image_gray, pixels_per_cell=(8, 8), cells_per_block=(2, 2), visualize=True)
    return fd

def extract_lbp_features(image):
    image_gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    lbp = local_binary_pattern(image_gray, P=8, R=1, method='uniform')
    return lbp.ravel()

def extract_color_moments(image):
    image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    moments = []
    for i in range(3):  
        channel = image_rgb[:, :, i]
        mean = np.mean(channel)
        std = np.std(channel)
        
        skew = np.mean((channel - mean) ** 3) / (std ** 3) if std != 0 else 0
        moments.extend([mean, std, skew])
    return moments


class CustomImageFolder(ImageFolder):
    def __getitem__(self, index):
        img_path, label = self.imgs[index]
        pil_img = self.loader(img_path)
        img_tensor = self.transform(pil_img)

        image = cv2.imread(img_path)
        if image is None:
           
            dummy_hog = extract_hog_features(np.zeros((224, 224, 3), dtype=np.uint8))
            dummy_lbp = extract_lbp_features(np.zeros((224, 224, 3), dtype=np.uint8))
            dummy_color = extract_color_moments(np.zeros((224, 224, 3), dtype=np.uint8))

            hog_f = np.zeros_like(dummy_hog)
            lbp_f = np.zeros_like(dummy_lbp)
            color_f = np.zeros_like(dummy_color)
        else:
            hog_f = extract_hog_features(image)
            lbp_f = extract_lbp_features(image)
            color_f = extract_color_moments(image)

        handcrafted = np.concatenate([hog_f, lbp_f, color_f])
        handcrafted = torch.tensor(handcrafted, dtype=torch.float32)

        return img_tensor, handcrafted, label


train_dataset = CustomImageFolder(root=os.path.join(BASE_DIR, "train"), transform=transform)
val_dataset = CustomImageFolder(root=os.path.join(BASE_DIR, "test"), transform=transform)
test_dataset = CustomImageFolder(root=os.path.join(BASE_DIR, "test"), transform=transform)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

_, sample_handcrafted, _ = next(iter(train_loader))
handcrafted_dim = sample_handcrafted.shape[1]

print(f"Actual handcrafted dimension from DataLoader: {handcrafted_dim}")

class CNNViT(nn.Module):
    def __init__(self, num_classes=2, handcrafted_dim=0):
        super(CNNViT, self).__init__()
        self.cnn = models.resnet18(pretrained=True)
        self.cnn = nn.Sequential(*list(self.cnn.children())[:-2])  
        self.vit = timm.create_model('vit_base_patch16_224', pretrained=True, num_classes=0)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))

        cnn_out_dim = 512  
        vit_out_dim = self.vit.num_features  

       
        print(f"CNN output dim: {cnn_out_dim}")
        print(f"ViT output dim: {vit_out_dim}")
        print(f"Handcrafted features dim: {handcrafted_dim}")

        total_feature_dim = cnn_out_dim + vit_out_dim + handcrafted_dim
        print(f"Total feature dim: {total_feature_dim}")

        self.fc = nn.Sequential(
            nn.Linear(total_feature_dim, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, num_classes)  
        )

    def forward(self, x, handcrafted=None):
   
        cnn_feat = self.cnn(x)
        cnn_feat = self.avgpool(cnn_feat)
        cnn_feat = cnn_feat.view(cnn_feat.size(0), -1)  

       
        vit_feat = self.vit(x)  
     
        if handcrafted is not None:
            combined = torch.cat((cnn_feat, vit_feat, handcrafted), dim=1)
        else:
            combined = torch.cat((cnn_feat, vit_feat), dim=1)

        return self.fc(combined)


def evaluate(loader, model):
    model.eval()
    all_preds, all_labels = [], []
    with torch.no_grad():
        for imgs, handcrafted, labels in loader:
            imgs, handcrafted, labels = imgs.to(DEVICE), handcrafted.to(DEVICE), labels.to(DEVICE)
            outputs = model(imgs, handcrafted)
            preds = outputs.argmax(dim=1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    acc = accuracy_score(all_labels, all_preds)
    prec = precision_score(all_labels, all_preds, average='weighted')
    rec = recall_score(all_labels, all_labels, average='weighted') 
    f1 = f1_score(all_labels, all_preds, average='weighted')
    return acc, prec, rec, f1

def boa_optimize(model_class, train_loader, val_loader, population_size=5, iterations=5):
    lb, ub = 1e-6, 1e-3  

    
    population = np.random.uniform(lb, ub, population_size)
    fitness = np.zeros(population_size)

    for i in range(population_size):
        lr = population[i]
        
        model = model_class(num_classes=NUM_CLASSES, handcrafted_dim=handcrafted_dim).to(DEVICE)
        optimizer = torch.optim.Adam(model.parameters(), lr=lr)
        criterion = nn.CrossEntropyLoss()

      
        model.train()
        imgs, handcrafted, labels = next(iter(train_loader))
        imgs, handcrafted, labels = imgs.to(DEVICE), handcrafted.to(DEVICE), labels.to(DEVICE)
        optimizer.zero_grad()
        outputs = model(imgs, handcrafted)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        _, _, _, f1 = evaluate(val_loader, model)
        fitness[i] = f1

    best_lr = population[np.argmax(fitness)]
    best_f1 = np.max(fitness)

   
    for t in range(iterations):
        fmin = np.min(fitness)
        fmax = np.max(fitness)

        for i in range(population_size):
            r = np.random.rand()
            power_exponent = 0.5  
            a = 0.1              

            fragrance = a * (fitness[i]**power_exponent)

            if r < 0.8:  
                j = np.argmax(fitness)
                population[i] = population[i] + fragrance * (fitness[j] - fitness[i]) * np.random.rand()
            else:  
                eps = np.random.uniform(-1, 1)
                k = np.random.randint(population_size)
                l = np.random.randint(population_size)
                population[i] = population[i] + eps * (fitness[k] - fitness[l]) * fragrance

            
            population[i] = np.clip(population[i], lb, ub)

            
            lr = population[i]
            
            model = model_class(num_classes=NUM_CLASSES, handcrafted_dim=handcrafted_dim).to(DEVICE)
            optimizer = torch.optim.Adam(model.parameters(), lr=lr)
            criterion = nn.CrossEntropyLoss()
            imgs, handcrafted, labels = next(iter(train_loader))
            imgs, handcrafted, labels = imgs.to(DEVICE), handcrafted.to(DEVICE), labels.to(DEVICE)
            optimizer.zero_grad()
            outputs = model(imgs, handcrafted)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            _, _, _, f1 = evaluate(val_loader, model)
            fitness[i] = f1

        if np.max(fitness) > best_f1:
            best_f1 = np.max(fitness)
            best_lr = population[np.argmax(fitness)]

    return best_lr


print("Starting BOA for learning rate tuning...")
best_lr = boa_optimize(CNNViT, train_loader, val_loader)
print(f"\ud83d\udccc Best LR from BOA: {best_lr:.8f}")


model = CNNViT(num_classes=NUM_CLASSES, handcrafted_dim=handcrafted_dim).to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=best_lr)
criterion = nn.CrossEntropyLoss()

for epoch in range(EPOCHS):
    model.train()
    all_preds, all_labels = [], []
    epoch_loss = 0
    for imgs, handcrafted, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}"):
        imgs, handcrafted, labels = imgs.to(DEVICE), handcrafted.to(DEVICE), labels.to(DEVICE)
        optimizer.zero_grad()
        outputs = model(imgs, handcrafted)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        all_preds.extend(outputs.argmax(dim=1).cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

    train_acc = accuracy_score(all_labels, all_preds)
    val_acc, val_prec, val_rec, val_f1 = evaluate(val_loader, model)

    print(f" Epoch {epoch+1} | Loss: {epoch_loss:.4f} | Train Acc: {train_acc*100:.2f}% | "
          f"Val Acc: {val_acc*100:.2f}% | Precision: {val_prec:.2f} | Recall: {val_rec:.2f} | F1: {val_f1:.2f}")


test_acc, test_prec, test_rec, test_f1 = evaluate(test_loader, model)
print(f"\n Final Test Accuracy: {test_acc*100:.2f}% | Precision: {test_prec:.2f} | Recall: {test_rec:.2f} | F1-Score: {test_f1:.2f}")


torch.save(model.state_dict(), "cnn_vit_classifier.pth")
print(" Model saved as cnn_vit_classifier.pth")