# **Useless Notebook - Only used for testing random bits of code**

In [1]:
import cv2
import numpy as np
from sklearn.feature_extraction.image import grid_to_graph
from skimage.segmentation import slic
from skimage.util import img_as_float
from scipy.sparse import coo_matrix

def image_to_graph(image_path):
    # Load and normalize the image
    image = cv2.imread(image_path)
    image = img_as_float(image)
    
    # Segment the image using SLIC
    segments = slic(image, n_segments=100, compactness=10, sigma=3)
    
    # Create a graph from the segmented image
    graph = grid_to_graph(*segments.shape)
    
    # Extract features for each segment (mean color here)
    features = []
    for segment_id in np.unique(segments):
        mask = segments == segment_id
        features.append(np.mean(image[mask], axis=0))
    features = np.array(features)
    
    # Adjust edge_index to ensure indices are within bounds
    max_node_idx = features.shape[0] - 1
    graph.row = np.clip(graph.row, 0, max_node_idx)
    graph.col = np.clip(graph.col, 0, max_node_idx)
    
    return graph, features


In [2]:
from torch_geometric.data import Data, Dataset
from torch_geometric.loader import DataLoader
from torch.utils.data import random_split
import torch
import os

class HistopathologyDataset(Dataset):
    def __init__(self, root_dir):
        self.root_dir = root_dir
        self.graphs = []
        self.labels = []

        # Assuming your directory structure is as described
        for label in ['0', '1']:
            label_dir = os.path.join(root_dir, label)
            for image_path in os.listdir(label_dir)[:1000]:
                graph, features = image_to_graph(os.path.join(label_dir, image_path))
                edge_index = torch.tensor(graph.nonzero(), dtype=torch.long)
                x = torch.tensor(features, dtype=torch.float)
                y = torch.tensor(int(label), dtype=torch.long)
                data = Data(x=x, edge_index=edge_index, y=y)
                self.graphs.append(data)
                self.labels.append(y)

    def __len__(self):
        return len(self.graphs)

    def __getitem__(self, idx):
        return self.graphs[idx]

# Initialize the dataset
full_dataset = HistopathologyDataset('Root')

train_size = int(0.8 * len(full_dataset)) # Split data count into training and validation splits in the ratio 80% to 20%
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size]) # Split the data

batch_size = 32  #Set batch size
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4) 
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True, num_workers=4) 


  edge_index = torch.tensor(graph.nonzero(), dtype=torch.long)


In [3]:
import torch
from torch import nn
from torch_geometric.nn import GCNConv, global_mean_pool

class GCN(torch.nn.Module):
    def __init__(self, in_channels: int = 3, hidden_channels: int = 152, num_classes: int = 2):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels*4)
        self.conv2 = GCNConv(hidden_channels*4, hidden_channels*2)
        self.conv3 = GCNConv(hidden_channels*2, hidden_channels)
        self.classifier = nn.Sequential(
            nn.Linear(hidden_channels, 256),
            nn.ReLU(True),
            nn.Linear(256, 128),
            nn.ReLU(True),
            nn.Linear(128, 64),
            nn.ReLU(True),
            nn.Linear(64, num_classes),
        )

    def forward(self, x, edge_index, batch):
        x = self.conv1(x, edge_index).relu()
        x = self.conv2(x, edge_index).relu()
        x = self.conv3(x, edge_index).relu()
        x = global_mean_pool(x, batch)  # Pooling
        x = self.classifier(x)
        return x

In [4]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model = GCN().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.CrossEntropyLoss()

In [5]:
for epoch in range(100):
    total_train_loss = 0
    total_val_loss = 0

    correct_train_predictions = 0
    correct_val_predictions = 0

    total_train_predictions = 0
    total_val_predictions = 0
    
    model.train()
    for batch in train_dataloader:
        batch = batch.to(device)
        optimizer.zero_grad()
        out = model(batch.x, batch.edge_index, batch.batch)
        _, preds = torch.max(out, dim=1)  # Get the predicted class labels
        loss = criterion(out, batch.y)
        loss.backward()
        optimizer.step()
        total_train_loss += loss.item()

        # Calculate accuracy
        correct_train_predictions += (preds == batch.y).sum().item()
        total_train_predictions += batch.y.size(0)

    model.eval()
    for batch in val_dataloader:
        batch = batch.to(device)
        out = model(batch.x, batch.edge_index, batch.batch)
        _, preds = torch.max(out, dim=1)  # Get the predicted class labels
        loss = criterion(out, batch.y)
        total_val_loss += loss.item()

        # Calculate accuracy
        correct_val_predictions += (preds == batch.y).sum().item()
        total_val_predictions += batch.y.size(0)
    
    avg_train_loss = total_train_loss / len(train_dataloader)
    train_accuracy = correct_train_predictions / total_train_predictions
    avg_val_loss = total_val_loss / len(val_dataloader)
    val_accuracy = correct_val_predictions / total_val_predictions
    print(f'Epoch {epoch}, Train Loss: {avg_train_loss}, Train Accuracy: {train_accuracy}, Val Loss: {avg_val_loss}, Val Accuracy: {val_accuracy}')


# Save the trained model
# torch.save(model.state_dict(), 'gcn_model.pth')
