In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch_geometric
from torch_geometric.data import Data, DataLoader
import torchvision.models as models
import torchvision.transforms as transforms
import numpy as np
import networkx as nx
import cv2
from PIL import Image
from skimage.segmentation import slic
from skimage.color import rgb2lab
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

# Load dataset paths
def load_dataset_paths(root_dir):
    image_paths = []
    class_names = sorted(os.listdir(root_dir))  
    class_to_label = {class_name: idx for idx, class_name in enumerate(class_names)}
    label_to_class = {idx: class_name for class_name, idx in class_to_label.items()}

    for class_name in class_names:
        class_dir = os.path.join(root_dir, class_name)
        if os.path.isdir(class_dir):
            for filename in os.listdir(class_dir):
                if filename.endswith((".jpg", ".png", ".jpeg")):
                    image_paths.append((os.path.join(class_dir, filename), class_to_label[class_name]))

    return image_paths, class_to_label, label_to_class

# Build Tissue Graph
def build_tissue_graph(image, segments):
    rag = nx.Graph()
    image_lab = rgb2lab(image)
    regions = sorted(np.unique(segments))
    node_map = {region: idx for idx, region in enumerate(regions)}

    for region, idx in node_map.items():
        mask = (segments == region)
        if np.sum(mask) == 0:
            continue
        mean_color = np.mean(image_lab[mask, :], axis=0)
        rag.add_node(idx, mean_color=mean_color)

    height, width = segments.shape
    for i in range(height):
        for j in range(width):
            current = segments[i, j]
            if j < width - 1 and current != segments[i, j + 1]:
                if current in node_map and segments[i, j + 1] in node_map:
                    rag.add_edge(node_map[current], node_map[segments[i, j + 1]])
            if i < height - 1 and current != segments[i + 1, j]:
                if current in node_map and segments[i + 1, j] in node_map:
                    rag.add_edge(node_map[current], node_map[segments[i + 1, j]])
    return rag, node_map

# Extract Node Features
def extract_node_features(image, segments, node_map):
    resnet = models.resnet34(weights="IMAGENET1K_V1")
    resnet = nn.Sequential(*list(resnet.children())[:-1])  # Remove FC layer
    resnet.eval()
    preprocess = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    
    node_features = {}
    for region, idx in node_map.items():
        mask = (segments == region).astype(np.uint8)
        masked_image = cv2.bitwise_and(image, image, mask=mask)
        feature = preprocess(masked_image)
        feature = resnet(feature.unsqueeze(0)).squeeze().detach().numpy()
        node_features[idx] = feature
    return node_features

# Dataset Class
class BreastCancerDataset(torch.utils.data.Dataset):
    def __init__(self, image_paths):
        self.image_paths = image_paths
    
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        image_path, label = self.image_paths[idx]
        image = np.array(Image.open(image_path).convert('RGB'))

        segments = slic(image, n_segments=100, compactness=10)
        tissue_graph, node_map = build_tissue_graph(image, segments)
        node_features = extract_node_features(image, segments, node_map)

        x = torch.tensor([node_features[i] for i in sorted(node_features.keys())], dtype=torch.float)
        edges = [(node_map[u], node_map[v]) for u, v in tissue_graph.edges() if u in node_map and v in node_map]
        edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous() if edges else torch.empty((2, 0), dtype=torch.long)

        return Data(x=x, edge_index=edge_index, y=torch.tensor(label))

    @staticmethod
    def collate_fn(batch):
        return batch

# Graph Neural Network Model
class CGTModel(torch.nn.Module):
    def __init__(self, in_features=512, hidden_dim=256, num_classes=9):
        super(CGTModel, self).__init__()
        self.conv1 = torch_geometric.nn.GCNConv(in_features, hidden_dim)
        self.conv2 = torch_geometric.nn.GCNConv(hidden_dim, hidden_dim)
        self.fc = nn.Linear(hidden_dim, num_classes)
    
    def forward(self, x, edge_index, batch):
        x = self.conv1(x, edge_index).relu()
        x = self.conv2(x, edge_index).relu()
        x = torch_geometric.nn.global_mean_pool(x, batch)
        x = self.fc(x)
        return x

# Train Model
def train_model(train_loader, save_dir, num_epochs=10):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = CGTModel().to(device)
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    criterion = nn.CrossEntropyLoss()

    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        for batch in tqdm(train_loader):
            batch = batch.to(device)
            optimizer.zero_grad()
            outputs = model(batch.x, batch.edge_index, batch.batch)
            loss = criterion(outputs, batch.y)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        
        torch.save(model.state_dict(), os.path.join(save_dir, f"model_epoch_{epoch+1}.pth"))
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss/len(train_loader)}")

# Load Dataset and Split into Train/Test
dataset_path = r"E:\generate_images"
image_paths, class_to_label, label_to_class = load_dataset_paths(dataset_path)

train_paths, test_paths = train_test_split(image_paths, test_size=0.2, random_state=42)

train_dataset = BreastCancerDataset(train_paths)
test_dataset = BreastCancerDataset(test_paths)

train_loader = DataLoader(train_dataset, batch_size=4, collate_fn=BreastCancerDataset.collate_fn)
test_loader = DataLoader(test_dataset, batch_size=4, collate_fn=BreastCancerDataset.collate_fn)

# Train the Model
train_model(train_loader, save_dir="E:/gan/checkpoints")  

In [None]:
# Load Trained Model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = CGTModel().to(device)
model.load_state_dict(torch.load("E:/gan/checkpoints/model_epoch_20.pth"))
model.eval()

# Classify Test Samples & Calculate Accuracy
def evaluate_model(model, test_loader, label_to_class):
    true_labels = []
    predicted_labels = []

    print("\n Predictions:")
    with torch.no_grad():
        for batch in test_loader:
            batch = batch.to(device)
            outputs = model(batch.x, batch.edge_index, batch.batch)
            predicted = torch.argmax(outputs, dim=1).cpu().numpy()
            true = batch.y.cpu().numpy()
            
            true_labels.extend(true)
            predicted_labels.extend(predicted)

            for pred, actual in zip(predicted, true):
                print(f"Predicted: {label_to_class[pred]} | Actual: {label_to_class[actual]}")

    accuracy = accuracy_score(true_labels, predicted_labels)
    print(f"\nAccuracy: {accuracy * 100:.2f}%")

# Run Evaluation
evaluate_model(model, test_loader, label_to_class)