In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from sklearn.metrics import accuracy_score
import matplotlib.pyplot as plt

# Set dataset paths
train_dir = r"E:\Thesis-8th sem\Dataset\archive\breast-cancer-dataset\Train"
test_dir = r"E:\Thesis-8th sem\Dataset\archive\breast-cancer-dataset\Test"

# Data transformations with augmentation
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(20),
    transforms.RandomAffine(degrees=15, translate=(0.1, 0.1), shear=10),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.GaussianBlur(3),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

# Custom Dataset class
class BreastCancerDataset(Dataset):
    def __init__(self, directory, transform=None):
        self.directory = directory
        self.transform = transform
        self.image_paths, self.labels = self.load_images()

    def load_images(self):
        image_paths = []
        labels = []
        for label_folder in os.listdir(self.directory):
            label_path = os.path.join(self.directory, label_folder)
            if os.path.isdir(label_path):
                label = 0 if 'benign' in label_folder.lower() else 1
                for file in os.listdir(label_path):
                    if file.endswith(('.png', '.jpg', '.jpeg')):
                        image_paths.append(os.path.join(label_path, file))
                        labels.append(label)
        return image_paths, labels

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        image = Image.open(image_path).convert('RGB')
        label = self.labels[idx]
        if self.transform:
            image = self.transform(image)
        return image, label

# Load datasets
train_dataset = BreastCancerDataset(train_dir, transform=transform)
test_dataset = BreastCancerDataset(test_dir, transform=transform)

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)

# Define Hybrid CNN + ViT Model
class CNN_ViT_Model(nn.Module):
    def __init__(self):
        super(CNN_ViT_Model, self).__init__()
        
        # CNN (ResNet-18) as Feature Extractor
        self.cnn = models.resnet18(pretrained=True)
        self.cnn.fc = nn.Identity()  # Remove final classification layer
        self.cnn_out_features = 512  # ResNet-18 outputs 512 features

        # ViT (Vision Transformer) for Global Context Learning
        self.vit = models.vit_b_16(pretrained=True)
        self.vit.heads.head = nn.Identity()  # Remove ViT classification head
        self.vit_out_features = 768  # ViT outputs 768 features

        # Fusion Layer + Final Classifier
        self.fc = nn.Sequential(
            nn.Linear(self.cnn_out_features + self.vit_out_features, 512),  # Corrected input size
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, 2)  # Binary classification (Benign/Malignant)
        )

    def forward(self, x):
        cnn_features = self.cnn(x)  # Extract local features using CNN (ResNet-18) -> (Batch, 512)
        vit_features = self.vit(x)  # Extract global features using ViT (ViT-B/16) -> (Batch, 768)
        combined = torch.cat((cnn_features, vit_features), dim=1)  # Concatenate features -> (Batch, 1280)
        return self.fc(combined)

# Initialize model
model = CNN_ViT_Model()

# Custom Focal Loss (Handles Class Imbalance Better)
class FocalLoss(nn.Module):
    def __init__(self, alpha=0.25, gamma=2.0):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma

    def forward(self, logits, targets):
        ce_loss = nn.functional.cross_entropy(logits, targets, reduction='none')
        p_t = torch.exp(-ce_loss)
        loss = self.alpha * (1 - p_t) ** self.gamma * ce_loss
        return loss.mean()

criterion = FocalLoss()

# Optimizer: AdamW (Better Generalization)
optimizer = optim.AdamW(model.parameters(), lr=0.0005, weight_decay=1e-4)

# Learning rate scheduler
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)

# Move model to device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Training with Early Stopping (Patience = 5)
num_epochs = 50
best_loss = float("inf")
patience = 5
counter = 0

train_losses = []
test_losses = []

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0

    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    avg_train_loss = running_loss / len(train_loader)
    train_losses.append(avg_train_loss)

    # Evaluate test loss
    model.eval()
    test_loss = 0.0
    all_preds, all_labels = [], []

    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            test_loss += loss.item()

            _, preds = torch.max(outputs, 1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    avg_test_loss = test_loss / len(test_loader)
    test_losses.append(avg_test_loss)
    accuracy = accuracy_score(all_labels, all_preds) * 100

    print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {avg_train_loss:.4f}, Test Loss: {avg_test_loss:.4f}, Accuracy: {accuracy:.2f}%')

    # Early stopping condition
    if avg_test_loss < best_loss:
        best_loss = avg_test_loss
        counter = 0
        torch.save(model.state_dict(), "best_hybrid_model.pth")  # Save the best model
    else:
        counter += 1
        if counter >= patience:
            print("Early stopping triggered.")
            break

    scheduler.step()

# Load the best model and evaluate
model.load_state_dict(torch.load("best_hybrid_model.pth"))
model.eval()
all_preds = []
all_labels = []

with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, preds = torch.max(outputs, 1)

        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

final_accuracy = accuracy_score(all_labels, all_preds) * 100
print(f'Final Test Accuracy: {final_accuracy:.2f}%')

# Plot loss curves
plt.figure(figsize=(10, 5))
plt.plot(train_losses, label="Train Loss")
plt.plot(test_losses, label="Test Loss")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.legend()
plt.title("Training and Testing Loss")
plt.show()


Epoch [1/50], Train Loss: 0.0371, Test Loss: 0.0353, Accuracy: 75.83%
Epoch [2/50], Train Loss: 0.0264, Test Loss: 0.0089, Accuracy: 96.67%
Epoch [3/50], Train Loss: 0.0215, Test Loss: 0.0593, Accuracy: 67.50%
Epoch [4/50], Train Loss: 0.0177, Test Loss: 0.0089, Accuracy: 94.58%
Epoch [5/50], Train Loss: 0.0159, Test Loss: 0.0185, Accuracy: 89.17%
Epoch [6/50], Train Loss: 0.0096, Test Loss: 0.0024, Accuracy: 98.75%
