In [15]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.data import Data, DataLoader
import numpy as np
import matplotlib.pyplot as plt

In [8]:
# torch.cuda.is_available()
print(torch.cuda.device_count())
# print(torch.cuda.get_device_name(torch.cuda.current_device()))


1


In [9]:
print(torch.__version__)
print(torch.version.cuda)


2.4.1+cu121
12.1


In [16]:


# Step 1: Generate a 1D Lattice Graph with 3 neighbors on each side
def create_1d_lattice_graph(lattice_size=149, neighbors=3):
    edges = []
    for i in range(lattice_size):
        edges.append([i, i])  # self loop
        for j in range(1, neighbors + 1):
            edges.append([i, (i - j) % lattice_size])  # Connect to left neighbors
            edges.append([i, (i + j) % lattice_size])  # Connect to right neighbors

    edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()
    return edge_index

# Step 2: Define the GNN Model
class GCN(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(input_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, output_dim)

    def forward(self, x, edge_index):
        # First GCN layer + ReLU activation
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        # Second GCN layer
        x = self.conv2(x, edge_index)
        return x

def generate_data(lattice_size=149, sample_size=100000, prob=0.5):
    if prob < 0 or prob > 1:
        raise ValueError('Probability must be between 0 and 1.')
    
    edge_index = create_1d_lattice_graph(lattice_size=lattice_size)
    data_list = []
    
    for _ in range(sample_size):
        if prob == 0.5:
            # Generate random input states for each node
            x = torch.randint(0, 2, (lattice_size, 1), dtype=torch.float)
        elif prob == 0.0:
            # Generate all zeros
            x = torch.zeros(lattice_size, 1)
        elif prob == 1.0:
            # Generate all ones
            x = torch.ones(lattice_size, 1)
        else:
            x = torch.bernoulli(prob * torch.ones(lattice_size, 1))
        
        # Define the binary density classification based on the mean of cell states
        y = (x.mean() > 0.5).long().repeat(lattice_size)  # 0 for < 0.5 density, 1 for >= 0.5, repeated to match lattice size
        
        # Create a Data object and append to the list
        data_list.append(Data(x=x, edge_index=edge_index, y=y))
    
    return data_list

# Step 3: Create the data and model
lattice_size = 149
prob = 0.49
train_data_list = generate_data(lattice_size=lattice_size, sample_size=20000, prob=prob)
test_data_list = generate_data(lattice_size=lattice_size, sample_size=20000, prob=prob)

train_loader = DataLoader(train_data_list, batch_size=4098, shuffle=True)
test_loader = DataLoader(test_data_list, batch_size=4098, shuffle=False)

# Define model, optimizer, and loss function
model = GCN(input_dim=1, hidden_dim=32, output_dim=2)  # 2 for binary classification
optimizer = torch.optim.Adam(model.parameters(), lr=0.014)
criterion = torch.nn.CrossEntropyLoss()

# Step 4: Training Loop
def train(loader, model, optimizer, criterion, epochs=400, patience=40):
    model.train()
    recent_losses = []
    avg_loss_last=[]
    for epoch in range(epochs):
        total_loss = 0
        for data in loader:
            optimizer.zero_grad()
            out = model(data.x, data.edge_index)  # Forward pass
            loss = criterion(out, data.y)         # Compute the loss
            loss.backward()                       # Backpropagation
            optimizer.step()                      # Update the parameters
            total_loss += loss.item()
        
        avg_loss = total_loss / len(loader)
        print(f'Epoch {epoch+1}/{epochs}, Loss: {avg_loss}')
        
        # Early stopping if average loss over the last `patience` epochs doesn't reduce
        recent_losses.append(avg_loss)
        if len(recent_losses) > patience:
            avg_loss_last.append(sum(recent_losses) / len(recent_losses))
            recent_losses.pop(0)
            if len(avg_loss_last) > 1 and avg_loss_last[-2] < avg_loss_last[-1]:
                print(f'Early stopping at epoch {epoch+1}')
                break
            

# Step 5: Testing Loop
def test(loader, model):
    model.eval()
    correct = 0
    with torch.no_grad():
        for data in loader:
            out = model(data.x, data.edge_index)
            pred = out.argmax(dim=1)
            correct += (pred == data.y).sum().item()
    accuracy = correct / (len(loader.dataset) * lattice_size)
    print(f'Accuracy: {accuracy:.4f}')

# Train the model
# train(train_loader, model, optimizer, criterion, epochs=400)



In [5]:
# import matplotlib.pyplot as plt

# probabilities = [i * 0.01 for i in range(101)]
# accuracies = []

# for prob in probabilities:
#     sample_test_data = generate_data(lattice_size=lattice_size, sample_size=10000, prob=prob)
#     sample_test_loader = DataLoader(sample_test_data, batch_size=4098, shuffle=False)
#     correct = 0
#     model.eval()
#     for data in sample_test_loader:
#         out = model(data.x, data.edge_index)
#         pred = out.argmax(dim=1)
#         correct += (pred == data.y).sum().item()
#     accuracy = correct / (len(sample_test_loader.dataset) * lattice_size)
#     accuracies.append(accuracy)


In [13]:

# plt.plot(probabilities, accuracies)
# plt.xlabel('Probability')
# plt.ylabel('Accuracy')
# plt.title('Accuracy vs Probability')
# plt.grid(True)
# plt.show()

[0.0, 0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09, 0.1, 0.11, 0.12, 0.13, 0.14, 0.15, 0.16, 0.17, 0.18, 0.19, 0.2, 0.21, 0.22, 0.23, 0.24, 0.25, 0.26, 0.27, 0.28, 0.29, 0.3, 0.31, 0.32, 0.33, 0.34, 0.35000000000000003, 0.36, 0.37, 0.38, 0.39, 0.4, 0.41000000000000003, 0.42, 0.43, 0.44, 0.45, 0.46, 0.47000000000000003, 0.48, 0.49, 0.5, 0.51, 0.52, 0.53, 0.54, 0.55, 0.56, 0.5700000000000001, 0.58, 0.59, 0.6, 0.61, 0.62, 0.63, 0.64, 0.65, 0.66, 0.67, 0.68, 0.6900000000000001, 0.7000000000000001, 0.71, 0.72, 0.73, 0.74, 0.75, 0.76, 0.77, 0.78, 0.79, 0.8, 0.81, 0.8200000000000001, 0.8300000000000001, 0.84, 0.85, 0.86, 0.87, 0.88, 0.89, 0.9, 0.91, 0.92, 0.93, 0.9400000000000001, 0.9500000000000001, 0.96, 0.97, 0.98, 0.99, 1.0]
[0.   0.01 0.02 0.03 0.04 0.05 0.06 0.07 0.08 0.09 0.1  0.11 0.12 0.13
 0.14 0.15 0.16 0.17 0.18 0.19 0.2  0.21 0.22 0.23 0.24 0.25 0.26 0.27
 0.28 0.29 0.3  0.31 0.32 0.33 0.34 0.35 0.36 0.37 0.38 0.39 0.4  0.41
 0.42 0.43 0.44 0.45 0.46 0.47 0.48 0.49 0.5  0

In [17]:
# training_probabilities = [0.4, 0.41 , 0.42, 0.43, 0.44, 0.45, 0.46, 0.47, 0.48, 0.49, 0.5, 0.51, 0.52, 0.53, 0.54, 0.55, 0.56, 0.57, 0.58, 0.59, 0.6]
training_probabilities = np.linspace(0, 1, 101)
accuracies = {}
probabilities = np.linspace(0, 1, 101)

for prob in training_probabilities:
    train_data_list = generate_data(lattice_size=lattice_size, sample_size=20000, prob=prob)
    train_loader = DataLoader(train_data_list, batch_size=4098, shuffle=True)
    model = GCN(input_dim=1, hidden_dim=32, output_dim=2)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.014)
    criterion = torch.nn.CrossEntropyLoss()
    train(train_loader, model, optimizer, criterion, epochs=400)
    for probability in probabilities:
        sample_test_data = generate_data(lattice_size=lattice_size, sample_size=10000, prob=probability)
        sample_test_loader = DataLoader(sample_test_data, batch_size=4098, shuffle=False)
        correct = 0
        model.eval()
        for data in sample_test_loader:
            out = model(data.x, data.edge_index)
            pred = out.argmax(dim=1)
            correct += (pred == data.y).sum().item()
        accuracy = correct / (len(sample_test_loader.dataset) * lattice_size)
        accuracies[prob] = accuracy


Epoch 1/400, Loss: 0.6657489180564881
Epoch 2/400, Loss: 0.6003825545310975
Epoch 3/400, Loss: 0.5404214859008789
Epoch 4/400, Loss: 0.4860027253627777
Epoch 5/400, Loss: 0.43708619475364685
Epoch 6/400, Loss: 0.3934718072414398
Epoch 7/400, Loss: 0.35483341813087466
Epoch 8/400, Loss: 0.32076109051704405
Epoch 9/400, Loss: 0.2908013999462128
Epoch 10/400, Loss: 0.26449201703071595
Epoch 11/400, Loss: 0.2413867622613907
Epoch 12/400, Loss: 0.22107152342796327
Epoch 13/400, Loss: 0.2031730592250824
Epoch 14/400, Loss: 0.18736161291599274
Epoch 15/400, Loss: 0.17335015237331391
Epoch 16/400, Loss: 0.1608915477991104
Epoch 17/400, Loss: 0.14977436661720275
Epoch 18/400, Loss: 0.13981837928295135
Epoch 19/400, Loss: 0.13087020218372344
Epoch 20/400, Loss: 0.122799614071846
Epoch 21/400, Loss: 0.11549558341503144
Epoch 22/400, Loss: 0.10886370837688446
Epoch 23/400, Loss: 0.10282309800386429
Epoch 24/400, Loss: 0.09730463176965713
Epoch 25/400, Loss: 0.09224874526262283
Epoch 26/400, Loss: 

In [13]:

for prob in training_probabilities:
    plt.plot(probabilities, [accuracies[prob] for prob in probabilities], label=f'p={prob}')
plt.xlabel('Probability')
plt.ylabel('Accuracy')
plt.title('Accuracy vs Probability')
plt.legend()
plt.grid(True)
plt.show()

In [14]:
# sample = generate_data(lattice_size=149, sample_size=1)
# sample[0].y