In [1]:
import torch
from torch_geometric.data import Data, DataLoader
from torch_geometric.nn import GATConv
import torch.nn.functional as F
import os
from tqdm import tqdm

In [8]:
import torch
from torch_geometric.data import Data, DataLoader
from torch_geometric.nn import GATConv
import torch.nn.functional as F
import os
from tqdm import tqdm

# Função para carregar os grafos salvos
def carregar_grafos_visibilidade(file_path):
    data = torch.load(file_path)
    grafos = data['grafos']
    dataset = []
    for grafo in grafos:
        src = grafo['src']
        dst = grafo['dst']
        edge_index = torch.stack([src, dst], dim=0)  # Constrói o edge_index
        num_nodes = max(torch.max(src), torch.max(dst)) + 1  # Calcula o número de nós
        x = torch.rand((num_nodes, 1))  # Inicializa features dos nós aleatoriamente
        dataset.append(Data(x=x, edge_index=edge_index))
    return dataset

# Modelo GAT-AE
class GATAutoencoder(torch.nn.Module):
    def __init__(self, in_channels, out_channels, heads=2):
        super(GATAutoencoder, self).__init__()
        self.heads = heads
        self.out_channels = out_channels
        self.encoder = GATConv(in_channels, out_channels, heads=heads, concat=True)
        self.decoder = torch.nn.Linear(out_channels * heads * 2, 1)  # Ajuste para concatenação de embeddings

    def forward(self, x, edge_index):
        z = self.encoder(x, edge_index)
        z = F.relu(z)
        row, col = edge_index
        edge_features = torch.cat([z[row], z[col]], dim=1)  # Concatena embeddings das arestas
        adj_reconstructed = self.decoder(edge_features).squeeze()
        return z, adj_reconstructed

# Função de perda
def loss_function(reconstructed, edge_index, num_nodes):
    true_adj = torch.zeros((num_nodes, num_nodes), device=reconstructed.device)
    true_adj[edge_index[0], edge_index[1]] = 1
    pred_adj = torch.sigmoid(reconstructed)
    return F.binary_cross_entropy(pred_adj, true_adj[edge_index[0], edge_index[1]])

# Métricas de avaliação
def calcular_metricas(reconstructed, edge_index, embeddings):
    # Arestas verdadeiras reconstruídas corretamente
    true_positive = (reconstructed > 0.5).sum().item()  # Soma os valores reconstruídos como positivos
    total_edges = edge_index.size(1)  # Total de arestas
    precision = true_positive / total_edges  # Precisão
    # Dispersão das features aprendidas (variância média)
    feature_variance = embeddings.var(dim=0).mean().item()
    return precision, feature_variance, total_edges, (reconstructed > 0.5).sum().item()


# Função para treinar o modelo com métricas
def treinar_gat_ae(train_loader, val_loader, in_channels=1, out_channels=4, epochs=10, batch_size=32, lr=0.01, save_path="model/gat_ae_model.pth"):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = GATAutoencoder(in_channels=in_channels, out_channels=out_channels).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    model.train()
    for epoch in range(epochs):
        total_loss = 0
        precisions = []
        variances = []
        for data in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}"):
            data = data.to(device)
            optimizer.zero_grad()
            z, reconstructed_adj = model(data.x, data.edge_index)
            loss = loss_function(reconstructed_adj, data.edge_index, data.num_nodes)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            
            # Calcular métricas no conjunto de treino
            precision, feature_variance, _, _ = calcular_metricas(reconstructed_adj, data.edge_index, z)
            precisions.append(precision)
            variances.append(feature_variance)

        # Agregar métricas por época
        avg_precision = sum(precisions) / len(precisions)
        avg_variance = sum(variances) / len(variances)
        print(f"Epoch {epoch + 1}/{epochs} - Train Loss: {total_loss:.4f}, Precision: {avg_precision:.4f}, Variance: {avg_variance:.4f}")

        # Avaliação no conjunto de validação
        model.eval()
        with torch.no_grad():
            val_loss = 0
            val_precisions = []
            val_variances = []
            for data in val_loader:
                data = data.to(device)
                z, reconstructed_adj = model(data.x, data.edge_index)
                loss = loss_function(reconstructed_adj, data.edge_index, data.num_nodes)
                val_loss += loss.item()
                precision, feature_variance, _, _ = calcular_metricas(reconstructed_adj, data.edge_index, z)
                val_precisions.append(precision)
                val_variances.append(feature_variance)

            avg_val_precision = sum(val_precisions) / len(val_precisions)
            avg_val_variance = sum(val_variances) / len(val_variances)
            print(f"Epoch {epoch + 1}/{epochs} - Val Loss: {val_loss:.4f}, Precision: {avg_val_precision:.4f}, Variance: {avg_val_variance:.4f}")
        model.train()

    # Salvar o modelo treinado
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    torch.save(model.state_dict(), save_path)
    print(f"Modelo salvo em: {save_path}")
    return model

# Função para carregar o modelo salvo
def carregar_modelo(path, in_channels, out_channels):
    model = GATAutoencoder(in_channels, out_channels)
    model.load_state_dict(torch.load(path))
    model.eval()
    return model

# Exemplo de uso
if __name__ == "__main__":
    # Caminho do arquivo de grafos gerados
    file_path = '/scratch/arturxavier/Clustering-Paper/Grafo/af.pt'

    # Carregar os grafos de visibilidade
    print("Carregando grafos de visibilidade...")
    dataset = carregar_grafos_visibilidade(file_path)
    print(f"Total de grafos carregados: {len(dataset)}")

    # Divisão dos dados em treino e validação
    train_dataset = dataset[:int(len(dataset) * 0.8)]
    val_dataset = dataset[int(len(dataset) * 0.8):]

    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

    # Treinar o modelo
    print("Treinando o modelo GAT-AE...")
    modelo_treinado = treinar_gat_ae(train_loader, val_loader, in_channels=1, out_channels=4, epochs=10, batch_size=32, lr=0.01)

    # Carregar o modelo salvo
    modelo_salvo = carregar_modelo("model/gat_ae_model.pth", in_channels=1, out_channels=4)
    print("Modelo carregado com sucesso.")


Carregando grafos de visibilidade...


  data = torch.load(file_path)


Total de grafos carregados: 9886
Treinando o modelo GAT-AE...


Epoch 1/10: 100%|██████████| 248/248 [00:02<00:00, 84.23it/s]


Epoch 1/10 - Train Loss: 17.7718, Precision: 0.9444, Variance: 0.0031
Epoch 1/10 - Val Loss: 0.0800, Precision: 1.0000, Variance: 0.0052


Epoch 2/10: 100%|██████████| 248/248 [00:02<00:00, 85.26it/s]


Epoch 2/10 - Train Loss: 0.1717, Precision: 1.0000, Variance: 0.0059
Epoch 2/10 - Val Loss: 0.0234, Precision: 1.0000, Variance: 0.0066


Epoch 3/10: 100%|██████████| 248/248 [00:02<00:00, 84.95it/s]


Epoch 3/10 - Train Loss: 0.0649, Precision: 1.0000, Variance: 0.0069
Epoch 3/10 - Val Loss: 0.0113, Precision: 1.0000, Variance: 0.0074


Epoch 4/10: 100%|██████████| 248/248 [00:02<00:00, 86.02it/s]


Epoch 4/10 - Train Loss: 0.0347, Precision: 1.0000, Variance: 0.0076
Epoch 4/10 - Val Loss: 0.0067, Precision: 1.0000, Variance: 0.0080


Epoch 5/10: 100%|██████████| 248/248 [00:02<00:00, 86.98it/s]


Epoch 5/10 - Train Loss: 0.0216, Precision: 1.0000, Variance: 0.0081
Epoch 5/10 - Val Loss: 0.0044, Precision: 1.0000, Variance: 0.0085


Epoch 6/10: 100%|██████████| 248/248 [00:02<00:00, 86.85it/s]


Epoch 6/10 - Train Loss: 0.0147, Precision: 1.0000, Variance: 0.0086
Epoch 6/10 - Val Loss: 0.0031, Precision: 1.0000, Variance: 0.0089


Epoch 7/10: 100%|██████████| 248/248 [00:02<00:00, 86.88it/s]


Epoch 7/10 - Train Loss: 0.0106, Precision: 1.0000, Variance: 0.0090
Epoch 7/10 - Val Loss: 0.0023, Precision: 1.0000, Variance: 0.0093


Epoch 8/10: 100%|██████████| 248/248 [00:02<00:00, 87.18it/s]


Epoch 8/10 - Train Loss: 0.0080, Precision: 1.0000, Variance: 0.0093
Epoch 8/10 - Val Loss: 0.0017, Precision: 1.0000, Variance: 0.0096


Epoch 9/10: 100%|██████████| 248/248 [00:02<00:00, 83.65it/s]


Epoch 9/10 - Train Loss: 0.0061, Precision: 1.0000, Variance: 0.0096
Epoch 9/10 - Val Loss: 0.0014, Precision: 1.0000, Variance: 0.0099


Epoch 10/10: 100%|██████████| 248/248 [00:02<00:00, 86.09it/s]


Epoch 10/10 - Train Loss: 0.0048, Precision: 1.0000, Variance: 0.0099
Epoch 10/10 - Val Loss: 0.0011, Precision: 1.0000, Variance: 0.0101
Modelo salvo em: model/gat_ae_model.pth
Modelo carregado com sucesso.


  model.load_state_dict(torch.load(path))
