In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv


In [48]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.data import Data

# Step 1: Generate a 1D Lattice Graph with 3 neighbors on each side
def create_1d_lattice_graph(lattice_size=149, neighbors=3):
    edges = []
    for i in range(lattice_size):
        edges.append([i,i]) # self loop
        for j in range(1, neighbors + 1):
            edges.append([i, (i - j) % lattice_size])  # Connect to left neighbors
            edges.append([i, (i + j) % lattice_size])  # Connect to right neighbors

    edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()
    return edge_index

# Step 2: Define the GNN Model
class GCN(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(input_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, output_dim)

    def forward(self, x, edge_index):
        # First GCN layer + ReLU activation
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        # Second GCN layer
        x = self.conv2(x, edge_index)
        return x

def generate_data(lattice_size=149, sample_size = 100000):
    edge_index = create_1d_lattice_graph(lattice_size=lattice_size)
    
    # Generate random input states for each node
    x = torch.randint(0, 2, (lattice_size, 1), dtype=torch.float)
    
    # Define the binary density classification based on the mean of cell states
    y = (x.mean() > 0.5).long().repeat(lattice_size)  # 0 for < 0.5 density, 1 for >= 0.5, repeated to match lattice size
    
    # Return graph data
    return Data(x=x, edge_index=edge_index, y=y)

# Step 3: Create the data and model
lattice_size = 149
edge_index = create_1d_lattice_graph(lattice_size=lattice_size)

train_data = generate_data()

# Define model, optimizer, and loss function
model = GCN(input_dim=1, hidden_dim=16, output_dim=2)  # 2 for binary classification
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.CrossEntropyLoss()

# Step 4: Training Loop
def train(data, model, optimizer, criterion, epochs=400):
    model.train()
    for epoch in range(epochs):
        optimizer.zero_grad()
        out = model(data.x, data.edge_index)  # Forward pass
        loss = criterion(out, data.y)         # Compute the loss
        loss.backward()                       # Backpropagation
        optimizer.step()                      # Update the parameters

        if epoch % 100 == 0:
            print(f'Epoch {epoch}, Loss: {loss.item()}')

# Step 4: Testing Function
def test(model, test_data):
    model.eval()  # Set the model to evaluation mode
    with torch.no_grad():  # Disable gradient computation for inference
        out = model(test_data.x, test_data.edge_index)  # Forward pass
        pred = out.argmax(dim=1)  # Get the predicted class (0 or 1)
        correct = (pred == test_data.y).sum().item()  # Count correct predictions
        accuracy = correct / len(test_data.y)  # Compute accuracy
        print(f'Accuracy: {accuracy * 100:.2f}%')
        
        return accuracy
    
# Step 5: Run training
train(train_data, model, optimizer, criterion, epochs=1000)


Epoch 0, Loss: 0.6945456862449646
Epoch 100, Loss: 0.00317948660813272
Epoch 200, Loss: 0.0010363091714680195
Epoch 300, Loss: 0.0005152528756298125
Epoch 400, Loss: 0.00030919379787519574
Epoch 500, Loss: 0.00020651282102335244
Epoch 600, Loss: 0.0001477911719121039
Epoch 700, Loss: 0.0001109632576117292
Epoch 800, Loss: 8.632027311250567e-05
Epoch 900, Loss: 6.898836727486923e-05


In [29]:

test_data = generate_data(lattice_size=149)
test_accuracy = test(model, test_data)

Accuracy: 100.00%


In [59]:
edge_index = create_1d_lattice_graph(lattice_size=lattice_size)
print(f"Length of index: {len(edge_index[0])}")
edge_index[0]

Length of index: 1043


tensor([  0,   0,   0,  ..., 148, 148, 148])

In [65]:
sample = generate_data(lattice_size=149)
sample.x

tensor([[0.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [0.],
        [1.],
        [1.],
        [0.],
        [0.],
        [1.],
        [0.],
        [1.],
        [1.],
        [0.],
        [0.],
        [0.],
        [0.],
        [1.],
        [1.],
        [0.],
        [0.],
        [0.],
        [1.],
        [1.],
        [0.],
        [1.],
        [0.],
        [1.],
        [1.],
        [1.],
        [1.],
        [0.],
        [0.],
        [0.],
        [1.],
        [1.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [1.],
        [1.],
        [0.],
        [1.],
        [1.],
        [0.],
        [1.],
        [0.],
        [0.],
        [0.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [0.],
        [0.],
        [1.],
        [1.],
        [0.],
        [1.],
      