In [13]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv


In [15]:
torch.cuda.is_available()

False

In [16]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.data import Data, DataLoader

# Step 1: Generate a 1D Lattice Graph with 3 neighbors on each side
def create_1d_lattice_graph(lattice_size=149, neighbors=3):
    edges = []
    for i in range(lattice_size):
        edges.append([i, i])  # self loop
        for j in range(1, neighbors + 1):
            edges.append([i, (i - j) % lattice_size])  # Connect to left neighbors
            edges.append([i, (i + j) % lattice_size])  # Connect to right neighbors

    edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()
    return edge_index

# Step 2: Define the GNN Model
class GCN(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(input_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, output_dim)

    def forward(self, x, edge_index):
        # First GCN layer + ReLU activation
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        # Second GCN layer
        x = self.conv2(x, edge_index)
        return x

def generate_data(lattice_size=149, sample_size=100000):
    edge_index = create_1d_lattice_graph(lattice_size=lattice_size)
    data_list = []
    
    for _ in range(sample_size):
        # Generate random input states for each node
        x = torch.randint(0, 2, (lattice_size, 1), dtype=torch.float)
        
        # Define the binary density classification based on the mean of cell states
        y = (x.mean() > 0.5).long().repeat(lattice_size)  # 0 for < 0.5 density, 1 for >= 0.5, repeated to match lattice size
        
        # Create a Data object and append to the list
        data_list.append(Data(x=x, edge_index=edge_index, y=y))
    
    return data_list

# Step 3: Create the data and model
lattice_size = 149
train_data_list = generate_data(lattice_size=lattice_size, sample_size=100000)
test_data_list = generate_data(lattice_size=lattice_size, sample_size=100000)

train_loader = DataLoader(train_data_list, batch_size=32, shuffle=True)
test_loader = DataLoader(test_data_list, batch_size=32, shuffle=False)

# Define model, optimizer, and loss function
model = GCN(input_dim=1, hidden_dim=16, output_dim=2)  # 2 for binary classification
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.CrossEntropyLoss()

# Step 4: Training Loop
def train(loader, model, optimizer, criterion, epochs=400):
    model.train()
    for epoch in range(epochs):
        total_loss = 0
        for data in loader:
            optimizer.zero_grad()
            out = model(data.x, data.edge_index)  # Forward pass
            loss = criterion(out, data.y)         # Compute the loss
            loss.backward()                       # Backpropagation
            optimizer.step()                      # Update the parameters
            total_loss += loss.item()
        print(f'Epoch {epoch+1}/{epochs}, Loss: {total_loss/len(loader)}')

# Step 5: Testing Loop
def test(loader, model):
    model.eval()
    correct = 0
    for data in loader:
        out = model(data.x, data.edge_index)
        pred = out.argmax(dim=1)
        correct += (pred == data.y).sum().item()
    accuracy = correct / (len(loader.dataset) * lattice_size)
    print(f'Accuracy: {accuracy:.4f}')

# Train the model
train(train_loader, model, optimizer, criterion, epochs=10)

# Test the model
test(test_loader, model)



Epoch 1/10, Loss: 0.6733839413261413


In [29]:

test_data = generate_data(lattice_size=149)
test_accuracy = test(model, test_data)

Accuracy: 100.00%


In [59]:
edge_index = create_1d_lattice_graph(lattice_size=lattice_size)
print(f"Length of index: {len(edge_index[0])}")
edge_index[0]

Length of index: 1043


tensor([  0,   0,   0,  ..., 148, 148, 148])

In [65]:
sample = generate_data(lattice_size=149)
sample.x

tensor([[0.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [0.],
        [1.],
        [1.],
        [0.],
        [0.],
        [1.],
        [0.],
        [1.],
        [1.],
        [0.],
        [0.],
        [0.],
        [0.],
        [1.],
        [1.],
        [0.],
        [0.],
        [0.],
        [1.],
        [1.],
        [0.],
        [1.],
        [0.],
        [1.],
        [1.],
        [1.],
        [1.],
        [0.],
        [0.],
        [0.],
        [1.],
        [1.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [1.],
        [1.],
        [0.],
        [1.],
        [1.],
        [0.],
        [1.],
        [0.],
        [0.],
        [0.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [0.],
        [0.],
        [1.],
        [1.],
        [0.],
        [1.],
      