In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.data import Data
from transformers import BertModel, BertConfig

In [6]:
# Define the Learnable Masking Layer
class LearnableMaskingLayer(nn.Module):
    def __init__(self, input_size):
        super(LearnableMaskingLayer, self).__init__()
        self.masking_layer = nn.Sequential(
            nn.Linear(input_size, input_size),
            nn.Sigmoid()
        )

    def forward(self, x):
        mask = self.masking_layer(x)
        return x * mask

# Define the Autoencoder with Learnable Masking
class MaskedAutoencoder(nn.Module):
    def __init__(self, input_size=1024, latent_size=512):
        super(MaskedAutoencoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(input_size, latent_size),
            nn.ReLU()
        )
        self.decoder = nn.Sequential(
            nn.Linear(latent_size, input_size),
            nn.ReLU()
        )
        self.learnable_masking = LearnableMaskingLayer(input_size)

    def forward(self, x):
        x_masked = self.learnable_masking(x)
        encoded = self.encoder(x_masked)
        decoded = self.decoder(encoded)
        return decoded

In [7]:
# Example Usage
autoencoder = MaskedAutoencoder()
input_matrix = torch.rand(1, 1024)  # Example input
reconstructed = autoencoder(input_matrix)

# Loss Function
criterion = nn.MSELoss()
loss = criterion(reconstructed, input_matrix)  # Compare with the original matrix

# Print the loss
print(loss.item())


0.3056860864162445


In [8]:
class GraphMaskingModel(torch.nn.Module):
    def __init__(self, num_features, num_nodes):
        super(GraphMaskingModel, self).__init__()
        self.conv1 = GCNConv(num_features, 128)
        self.conv2 = GCNConv(128, num_nodes)  # num_nodes is the output dimension

    def forward(self, data):
        x, edge_index = data.x, data.edge_index

        x = F.relu(self.conv1(x, edge_index))
        x = torch.sigmoid(self.conv2(x, edge_index))  # Sigmoid for masking
        return x


In [None]:
# Example usage
num_nodes = 1024
num_features = 3 # Define the number of features per node
edges = 4 # Define the edges of your graph

# Create a PyG data object
data = Data(x=torch.randn(num_nodes, num_features), edge_index=edges)

model = GraphMaskingModel(num_features, num_nodes)
mask = model(data)

In [10]:
class ConnectivityMaskingTransformer(nn.Module):
    def __init__(self, num_features, seq_length):
        super(ConnectivityMaskingTransformer, self).__init__()
        config = BertConfig(hidden_size=num_features, num_attention_heads=12, num_hidden_layers=6)
        self.transformer = BertModel(config)
        self.linear = nn.Linear(num_features, seq_length)

    def forward(self, x):
        # Assuming x is of shape (batch_size, seq_length, num_features)
        transformer_output = self.transformer(inputs_embeds=x).last_hidden_state
        mask = torch.sigmoid(self.linear(transformer_output))
        return mask

In [12]:
# Example usage
seq_length = 1024
num_features = 12 # Define the number of features

input_seq = torch.randn(1, seq_length, num_features)  # Example input
model = ConnectivityMaskingTransformer(num_features, seq_length)
mask = model(input_seq)


RuntimeError: The expanded size of the tensor (1024) must match the existing size (512) at non-singleton dimension 1.  Target sizes: [1, 1024].  Tensor sizes: [1, 512]

In [22]:

# Example of converting a numpy array to a PyTorch tensor
# Replace this with your actual data loading and preprocessing
def load_and_preprocess_data():
    # Load your brain connectivity matrices here
    # For example, let's create a dummy tensor
    data = torch.rand((100, 100))  # 100 matrices of size 100x100
    return data

class BrainConnectivityTransformer(nn.Module):
    def __init__(self, matrix_size, num_layers, num_heads):
        super(BrainConnectivityTransformer, self).__init__()
        self.matrix_size = matrix_size
        self.embedding = nn.Linear(matrix_size, matrix_size)
        transformer_layer = nn.TransformerEncoderLayer(d_model=matrix_size, nhead=num_heads)
        self.transformer_encoder = nn.TransformerEncoder(transformer_layer, num_layers=num_layers)
        self.output_layer = nn.Linear(matrix_size, matrix_size)

    def forward(self, x):
        x = x.view(-1, self.matrix_size)  # Reshape to [batch_size*sequence_length, features]
        x = self.embedding(x)
        x = x.view(-1, 1, self.matrix_size)  # Reshape to [sequence_length, batch_size, features]
        x = self.transformer_encoder(x)
        x = x.view(-1, self.matrix_size)  # Flatten for the output layer
        x = self.output_layer(x)
        return x



In [23]:
def apply_mask(matrix, mask_percentage=0.1):
    # Assuming matrix is 2D (sequence length x features)
    # Create a mask with the same dimensions as the matrix
    mask = torch.rand(matrix.shape) < mask_percentage  # This creates a boolean mask

    # Apply the mask to the matrix
    masked_matrix = matrix.clone()  # Clone to avoid modifying the original matrix
    masked_matrix[mask] = 0  # Set masked elements to zero

    return masked_matrix

def train_model(model, data, num_epochs):
    criterion = nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

    for epoch in range(num_epochs):
        for matrix in data:
            # Reshape matrix to [sequence length, 1, features]
            # Assuming each matrix is square with dimensions matrix_size x matrix_size
            matrix = matrix.view(-1, model.matrix_size)

            # Apply mask
            masked_matrix = apply_mask(matrix)

            optimizer.zero_grad()

            # Forward pass
            output = model(masked_matrix.view(-1, 1, model.matrix_size))
            loss = criterion(output, matrix.view(-1, model.matrix_size))  # Compare with the original matrix
            loss.backward()
            optimizer.step()

        print(f"Epoch {epoch+1}, Loss: {loss.item()}")


# Example usage
data = load_and_preprocess_data()
model = BrainConnectivityTransformer(matrix_size=100, num_layers=2, num_heads=4)
train_model(model, data, num_epochs=10)

Epoch 1, Loss: 0.11654290556907654
Epoch 2, Loss: 0.07916422188282013
Epoch 3, Loss: 0.08018319308757782
Epoch 4, Loss: 0.07989799976348877
Epoch 5, Loss: 0.0734124407172203
Epoch 6, Loss: 0.06462837755680084
Epoch 7, Loss: 0.07987625151872635
Epoch 8, Loss: 0.05594448000192642
Epoch 9, Loss: 0.06626898795366287
Epoch 10, Loss: 0.055277567356824875


In [24]:
def predict_mask(model, matrix, threshold=0.1):
    # Assuming matrix is already in the correct shape: [sequence length, 1, features]
    model.eval()  # Set the model to evaluation mode

    with torch.no_grad():  # No need to track gradients
        # Forward pass through the model
        output = model(matrix)

    # Calculate the difference between the output and the original matrix
    difference = torch.abs(output - matrix)

    # Determine the mask based on a threshold
    mask = difference > threshold

    return mask


In [28]:
# Example usage
new_matrix = torch.rand((100, 100))  # Replace with your new matrix
new_matrix = new_matrix.view(-1, 1, 100)  # Reshape to [sequence length, 1, features]

mask = predict_mask(model, new_matrix, threshold=0.4)
mask

tensor([[[False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False],
         ...,
         [False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False]],

        [[False, False, False,  ..., False,  True,  True],
         [False, False, False,  ..., False,  True,  True],
         [False,  True, False,  ..., False, False, False],
         ...,
         [False, False, False,  ..., False,  True,  True],
         [False, False, False,  ..., False,  True,  True],
         [False, False, False,  ..., False, False, False]],

        [[False, False, False,  ...,  True, False,  True],
         [False, False, False,  ..., False, False,  True],
         [False, False, False,  ..., False, False, False],
         ...,
         [False,  True, False,  ...,  True, False,  True],
         [

In [None]:
import torch
from torch_geometric.data import Data

def matrix_to_graph(matrix):
    # Assuming `matrix` is your connectivity matrix

    # Create edge index and edge attributes
    edge_index = []
    edge_attr = []
    for i in range(matrix.shape[0]):
        for j in range(matrix.shape[1]):
            if matrix[i, j] != 0:  # assuming a non-zero value indicates an edge
                edge_index.append([i, j])
                edge_attr.append(matrix[i, j])

    edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
    edge_attr = torch.tensor(edge_attr, dtype=torch.float)

    # Create a graph data object
    data = Data(edge_index=edge_index, edge_attr=edge_attr)

    return data


In [None]:
import torch.nn.functional as F
from torch_geometric.nn import GCNConv

class GNNMaskingModel(torch.nn.Module):
    def __init__(self, num_features, hidden_dim):
        super(GNNMaskingModel, self).__init__()
        self.conv1 = GCNConv(num_features, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, 1)  # Output one value per node

    def forward(self, data):
        x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr

        # Apply GNN layers
        x = F.relu(self.conv1(x, edge_index, edge_attr=edge_attr))
        x = self.conv2(x, edge_index, edge_attr=edge_attr)

        return x


In [None]:
def train_model(model, graph_data, num_epochs):
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    criterion = torch.nn.MSELoss()  # or any other suitable loss function

    for epoch in range(num_epochs):
        model.train()
        optimizer.zero_grad()
        out = model(graph_data)
        loss = criterion(out, graph_data.y)  # Compare with target values
        loss.backward()
        optimizer.step()
        print(f'Epoch {epoch+1}, Loss: {loss.item()}')


In [None]:
def predict_mask(model, graph_data):
    model.eval()
    with torch.no_grad():
        out = model(graph_data)

    # Assuming that higher values indicate more importance
    mask = out > some_threshold  # Define some_threshold based on your requirements

    return mask
