In [2]:
import torch

In [None]:
import random   
import numpy as np
from collections import defaultdict


class mini_batches_fast:
    def __init__(self, data, unique_list, sample_size, edge_type, full_data, citation_dict, all_papers):
        self.data = data
        self.sample_size = sample_size
        self.edge_type = edge_type
        self.unique_list = unique_list
        self.full_data = full_data
        self.device = self.data.device if isinstance(self.data, torch.Tensor) else torch.device("cpu")
        self.citation_dict = citation_dict
        self.all_papers = all_papers
        self.remaining_papers = set(all_papers)
        self.set_unique_list(unique_list)  # initialize tensor once

    
    def set_unique_list(self, unique_list):
        if isinstance(unique_list, torch.Tensor):
            self.unique_tensor = unique_list.to(self.device)
        else:
            self.unique_tensor = torch.tensor(unique_list, dtype=torch.long, device=self.device)

    def get_batch(self):
        unique_tensor = self.unique_tensor

        if len(unique_tensor) < self.sample_size:
            sample_tensor = unique_tensor
            sample_tensor_sorted, _ = sample_tensor.sort()
            idx = torch.searchsorted(sample_tensor_sorted, self.data[0])
            idx = idx.clamp(max=sample_tensor_sorted.size(0) - 1)
            mask = sample_tensor_sorted[idx] == self.data[0]
            filtered_data = self.data[:, mask]
            return filtered_data, sample_tensor.tolist(), []

        # Efficient random sample and sort once
        rand_idx = torch.randperm(len(unique_tensor), device=self.device)[:self.sample_size]
        sample_tensor = unique_tensor[rand_idx]
        sample_tensor_sorted, _ = sample_tensor.sort()

        # Compute membership using searchsorted
        idx = torch.searchsorted(sample_tensor_sorted, unique_tensor)
        idx = idx.clamp(max=sample_tensor_sorted.size(0) - 1)
        isin_mask = sample_tensor_sorted[idx] == unique_tensor
        remaining_tensor = unique_tensor[~isin_mask]

        # Filter edges
        idx = torch.searchsorted(sample_tensor_sorted, self.data[0])
        idx = idx.clamp(max=sample_tensor_sorted.size(0) - 1)
        mask = sample_tensor_sorted[idx] == self.data[0]
        filtered_data = self.data[:, mask]

        return filtered_data, sample_tensor.tolist(), remaining_tensor
    
    def data_matrix(self):
        data = self.full_data
        edge_entities = {
            'paper': 0,
            'author': 1,
            'institution': 2,
            'field_of_study': 3,
            'venue': 4,
        }

        tensor, random_sample, unique_tensor = self.get_batch()

        if tensor.shape[1] == 0:
            result_tensor = torch.empty((0, 5), dtype=torch.long, device=self.device)
        else:
            edge_type1 = edge_entities[self.edge_type[0]]
            edge_type2 = edge_entities[self.edge_type[2]]
            ones = torch.ones(tensor.shape[1], device=self.device, dtype=torch.long)
            result_tensor = torch.stack([
                ones,
                tensor[0, :],
                tensor[1, :],
                torch.full((tensor.shape[1],), edge_type1, device=self.device),
                torch.full((tensor.shape[1],), edge_type2, device=self.device)
            ], dim=1).long()

        paper_venues = data['y_dict']['paper']
        random_sample_tensor = torch.tensor(random_sample, device=self.device).long()
        venue_targets = paper_venues[random_sample_tensor]

        venues_tensor = torch.stack([
            torch.ones(len(random_sample), device=self.device, dtype=torch.long),
            random_sample_tensor,
            venue_targets.flatten(),
            torch.full((len(random_sample),), edge_type1, device=self.device),
            torch.full((len(random_sample),), edge_entities['venue'], device=self.device)
        ], dim=1).long()

        if tensor.shape[1] > 0:
            unique_targets = tensor[1].unique()
            i_grid, j_grid = torch.meshgrid(random_sample_tensor, unique_targets, indexing='ij')
            i_vals = i_grid.flatten()
            j_vals = j_grid.flatten()

            existing_edges = result_tensor[:, 1:3]
            max_node_id = max(i_vals.max().item(), j_vals.max().item(), existing_edges.max().item()) + 1

            packed_existing = (existing_edges[:, 0] * max_node_id + existing_edges[:, 1])
            packed_pairs = i_vals * max_node_id + j_vals

            # Use torch.isin (vectorized, GPU) instead of slow CPU loop
            exists = torch.isin(packed_pairs, packed_existing)
            mask = ~exists & (i_vals != j_vals)
            non_edges_pairs = torch.stack((i_vals[mask], j_vals[mask]), dim=1)

            if non_edges_pairs.shape[0] > 0:
                non_edges_tensor = torch.cat([
                    torch.zeros((non_edges_pairs.shape[0], 1), device=self.device, dtype=torch.long),
                    non_edges_pairs,
                    torch.full((non_edges_pairs.shape[0], 1), edge_type1, device=self.device),
                    torch.full((non_edges_pairs.shape[0], 1), edge_type2, device=self.device)
                ], dim=1)
            else:
                non_edges_tensor = torch.empty((0, 5), dtype=torch.long, device=self.device)
        else:
            non_edges_tensor = torch.empty((0, 5), dtype=torch.long, device=self.device)

        comb_r, comb_j = torch.combinations(random_sample_tensor, r=2).unbind(1)
        r_venue = paper_venues[comb_r]
        j_venue = paper_venues[comb_j]
        unequal_mask = (r_venue != j_venue).flatten().nonzero(as_tuple=True)[0]

        if unequal_mask.numel() > 0:
            comb_r = comb_r[unequal_mask].squeeze()
            comb_j = comb_j[unequal_mask].squeeze()
            r_venue = r_venue[unequal_mask]
            j_venue = j_venue[unequal_mask]

            venue_non_edges = torch.cat([
                torch.zeros((comb_r.shape[0]*2, 1), device=self.device, dtype=torch.long),
                torch.cat([comb_r.unsqueeze(1), comb_j.unsqueeze(1)], dim=0),
                torch.cat([j_venue, r_venue], dim=0),
                torch.full((comb_r.shape[0]*2, 1), edge_entities['paper'], device=self.device),
                torch.full((comb_r.shape[0]*2, 1), edge_entities['venue'], device=self.device)
            ], dim=1)
        else:
            venue_non_edges = torch.empty((0, 5), dtype=torch.long, device=self.device)

        data_matrix = torch.cat((result_tensor, non_edges_tensor, venues_tensor, venue_non_edges), dim=0)
        return data_matrix, unique_tensor, random_sample

In [1]:
import torch

class LossFunction:
    def __init__(self, alpha=1.0, eps=1e-8, use_regularization=False, lam=0.01, weight = 0.01):
        """
        Initialize the loss function with given parameters.
        
        Args:
            alpha (float): Scaling parameter for edge probability.
            eps (float): Small value to prevent log(0).
            use_regularization (bool): Whether to include Gaussian regularization.
        """
        self.alpha = alpha
        self.eps = eps
        self.use_regularization = use_regularization
        self.lam = lam
        self.weight = weight

    def edge_probability(self, z_i, z_j):
        """Compute the probability of an edge existing between two embeddings."""
        dist_sq = torch.sum((z_i - z_j) ** 2, dim=1)  # Squared Euclidean distance (batch-wise)
        # return 1 / (1 + torch.exp(-self.alpha + dist_sq))  # Logistic function, element-wise
        return torch.sigmoid(-dist_sq + self.alpha)

    def link_loss(self, label, z_u, z_v):
        """Compute the loss for a single edge."""
        prob = self.edge_probability(z_u, z_v)  # Compute edge probabilities (batch-wise)
        prob = torch.clamp(prob, self.eps, 1 - self.eps)  # Numerical stability

        # Compute the loss for each edge
        return label * torch.log(prob) + self.weight * (1 - label) * torch.log(1 - prob)

    def compute_loss(self, z, datamatrix_tensor):
        """Compute the total loss for the dataset."""
        # Extract labels, u_idx, and v_idx in a vectorized way
        labels = datamatrix_tensor[:, 0].float()
        u_idx = datamatrix_tensor[:, 1].long()
        v_idx = datamatrix_tensor[:, 2].long()
        pv1_idx = datamatrix_tensor[:, 3].long()
        pv2_idx = datamatrix_tensor[:, 4].long()

        edge_entities = {
            0: 'paper',
            1: 'author',
            2: 'institution',
            3: 'field_of_study',
            4: 'venue'
        }

        # Get embeddings for u_idx and v_idx
        z_u = torch.stack([
            z[edge_entities[j.item()]][i.item()]
            for i, j in zip(u_idx, pv1_idx)
        ])
        z_v = torch.stack([
            z[edge_entities[j.item()]][i.item()]
            for i, j in zip(v_idx, pv2_idx)
        ])

        # Compute link loss for all edges in the batch
        link_loss = self.link_loss(labels, z_u, z_v)  # shape (B,)

        # Mean loss over the batch
        loss = -torch.mean(link_loss)

        # Optionally add regularization
        if self.use_regularization:
            regularization = self.lam * torch.sum(z ** 2)
            loss += regularization

        return loss


In [10]:
class NodeEmbeddingTrainer:
    def __init__(self, device=None):        # Initialize input data, parameters, and setup
        self.device = device or torch.device("cpu")

        # Optimizers
        # self.optimizer = torch.optim.Adam([], lr=self.lr) # KOM TILBAGE

        # Loss function (assumed to be defined elsewhere)
        # self.loss_function = LossFunction(alpha=self.alpha, eps=self.eps, use_regularization=True, lam=self.lam)

    def save_checkpoint(self, path):
        checkpoint = {
            'collected_embeddings': self.collected_embeddings.state_dict(),
            'optimizer': self.optimizer.state_dict(),
        }
        torch.save(checkpoint, path)

    @staticmethod
    def load_checkpoint(path, *args, **kwargs):
        obj = NodeEmbeddingTrainer(*args, **kwargs)
        checkpoint = torch.load(path)
        obj.papernode_embeddings.load_state_dict(checkpoint['papernode_embeddings'])
        obj.venuenode_embeddings.load_state_dict(checkpoint['venuenode_embeddings'])
        obj.optimizer.load_state_dict(checkpoint['optimizer'])
        obj.venue_optimizer.load_state_dict(checkpoint['venue_optimizer'])
        obj.specific_papernode_indices = checkpoint['specific_papernode_indices']
        obj.specific_venuenode_indices = checkpoint['specific_venuenode_indices']
        return obj

In [None]:
### paper_c_paper_train
num_papers = 200     # total number of papers (IDs from 0 to 749999)
num_edges = 1000        # number of citation edges

# Randomly generate citing and cited paper IDs
citing = torch.randint(0, num_papers, (num_edges,))
cited = torch.randint(0, num_papers, (num_edges,))

# Stack to form a 2 x num_edges tensor
paper_c_paper_train = torch.stack([citing, cited])

### data and venue_value

# Generate a tensor of N random integers between 0 and 50
num_values = num_papers  # You can change this number
random_values = torch.randint(0, 51, (num_values, 1))

# Create the dictionary
data = {
    'y_dict': {
        'paper': random_values
    }
}

tensor_values = data['y_dict']['paper']
venue_value = {i: tensor_values[i] for i in range(tensor_values.size(0))}

In [None]:
### data and venue_value

# Generate a tensor of N random integers between 0 and 50
num_values = num_papers  # You can change this number
random_values = torch.randint(0, 51, (num_values, 1))

# Create the dictionary
data = {
    'y_dict': {
        'paper': random_values
    }
}

tensor_values = data['y_dict']['paper']
venue_value = {i: tensor_values[i] for i in range(tensor_values.size(0))}

venues_values = torch.unique(data['y_dict']['paper'])

collected_embeddings = {
    'paper': {},
    'venue': {}
}

embedding_dim = 2
a = -100
b = -a
# Venue embeddings
embed = torch.nn.Embedding(len(venues_values), embedding_dim)
torch.nn.init.uniform_(embed.weight, a, b)

venue_id_to_idx = {venue_id.item(): idx for idx, venue_id in enumerate(venues_values)}

indices = torch.tensor([venue_id_to_idx[venue_id.item()] for venue_id in venues_values], dtype=torch.long)
embeddings = embed(indices)

for venue_id in venues_values:
    collected_embeddings['venue'][venue_id.item()] = embeddings[venue_id_to_idx[venue_id.item()]]

# Paper embeddings
unique_paper_ids = torch.unique(paper_c_paper_train)
embed = torch.nn.Embedding(len(unique_paper_ids), embedding_dim)
torch.nn.init.uniform_(embed.weight, a, b)
paper_id_to_idx = {pid.item(): idx for idx, pid in enumerate(unique_paper_ids)}

indices = torch.tensor([paper_id_to_idx[pid.item()] for pid in paper_c_paper_train.flatten()], dtype=torch.long)
embeddings = embed(indices)

for pid, emb in zip(paper_c_paper_train.flatten(), embeddings):
    collected_embeddings['paper'][pid.item()] = emb

if collected_embeddings:
    torch.save(collected_embeddings, f"/mnt/c/Users/Bruger/Desktop/Bachelor/GraphML_Bachelorprojekt/dataset/ogbn_mag/processed/collected_embeddings_{embedding_dim}_spread_{b}_synt.pt")
    print("embeddings saved")


# Load the collected embeddings dictionary
collected_embeddings = torch.load(f"/mnt/c/Users/Bruger/Desktop/Bachelor/GraphML_Bachelorprojekt/dataset/ogbn_mag/processed/collected_embeddings_{embedding_dim}_spread_{b}_synt.pt")



embeddings saved


  collected_embeddings = torch.load(f"/mnt/c/Users/Bruger/Desktop/Bachelor/GraphML_Bachelorprojekt/dataset/ogbn_mag/processed/collected_embeddings_{embedding_dim}_spread_{b}_synt.pt")


In [None]:
embed_dict = collected_embeddings

In [51]:
import torch
import copy
import sys
import os
import gc
import wandb
from datetime import datetime
import argparse
import numpy as np
from collections import defaultdict

batch_size = 7
num_epochs = 1
lr = 0.1
alpha = 0.1
lam = 0.01
embedding_dim = 2

loss_function = LossFunction()
N_emb = NodeEmbeddingTrainer()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

print("starting")


citation_dict = defaultdict(list)
for src, tgt in zip(paper_c_paper_train[0], paper_c_paper_train[1]):
    citation_dict[src.item()].append(tgt.item())

all_papers = list(citation_dict.keys())

num_iterations = int(len(embed_dict['venue']) + len(embed_dict['paper'])) # we need to be able to look at the complete dataset

# num_iterations = 2


params = []
for subdict in embed_dict.values():
    params.extend(subdict.values())
loss_pr_epoch = []

for i in range(num_epochs):
    print(f"Epoch {i + 1}/{num_epochs}")
    l_prev = list(paper_c_paper_train.unique().numpy())  # Initial list of nodes
    optimizer = torch.optim.Adam(params, lr=lr)
    loss_pr_iteration = []

    # import time
    # start = time.time()
    # dm, unique_list, random_sample = mini_b.data_matrix()
    # print("Batch gen time:", time.time() - start)

    mini_b = mini_batches_fast(paper_c_paper_train, l_prev, batch_size, ('paper', 'cites', 'paper'), data, citation_dict, all_papers)

    for j in range(num_iterations):
        mini_b.set_unique_list(l_prev)  # Update only the node list
        dm, l_next, random_sample = mini_b.data_matrix()
        # print(dm)

    # for j in range(num_iterations):

        # Generate mini-batches
        # mini_b = mini_batches_fast(paper_c_paper_train, l_prev, batch_size, ('paper', 'cites', 'paper'), data)
        # dm, l_next, random_sample = mini_b.data_matrix()

        # Move data to GPU
        dm = dm.to(device)
        optimizer.zero_grad()
        loss = loss_function.compute_loss(embed_dict, dm)
        loss.backward()
        optimizer.step()
        print(f"Loss: {loss.detach().item()}")
        # Update node list for the next iteration
        loss_pr_iteration.append(loss.detach().item())

        l_prev = l_next
        

        if len(l_next) == 0:
            print("No more nodes to process. Exiting.")
            print(loss_pr_iteration)
            loss_pr_epoch.append(np.mean(loss_pr_iteration))
            print(f"loss_epoch: {loss_pr_epoch[i]}")
            break

        # Cleanup
        if (i + 1) % 5 == 0:  # Or do it every iteration if memory is super tight
            import gc
            gc.collect()
            torch.cuda.empty_cache()

Using device: cpu
starting
Epoch 1/1
Loss: 2.642009973526001
Loss: 2.973162889480591
Loss: 2.921341896057129
Loss: 2.777721643447876
Loss: 2.631525754928589
Loss: 2.912360429763794
Loss: 2.7863221168518066
Loss: 2.8194921016693115
Loss: 3.055563449859619
Loss: 2.7700271606445312
Loss: 2.959388494491577
Loss: 2.9187705516815186
Loss: 2.8522348403930664
Loss: 2.801301956176758
Loss: 2.9689011573791504
Loss: 2.7890467643737793
Loss: 2.887988328933716
Loss: 3.0414206981658936
Loss: 2.859528064727783
Loss: 2.8570852279663086
Loss: 2.720592975616455
Loss: 2.649127960205078
Loss: 2.8607873916625977
Loss: 3.0326731204986572
Loss: 3.231698751449585
Loss: 2.7435061931610107
Loss: 2.8153586387634277
Loss: 2.804795026779175
Loss: 4.792372703552246
No more nodes to process. Exiting.
[2.642009973526001, 2.973162889480591, 2.921341896057129, 2.777721643447876, 2.631525754928589, 2.912360429763794, 2.7863221168518066, 2.8194921016693115, 3.055563449859619, 2.7700271606445312, 2.959388494491577, 2.9187

In [None]:
venues_values = torch.unique(data['y_dict']['paper'])

collected_embeddings = {
    'paper': {},
    'venue': {}
}

embedding_dim = 2
a = -100
b = -a
# Venue embeddings
embed = torch.nn.Embedding(len(venues_values), embedding_dim)
torch.nn.init.uniform_(embed.weight, a, b)

venue_id_to_idx = {venue_id.item(): idx for idx, venue_id in enumerate(venues_values)}

indices = torch.tensor([venue_id_to_idx[venue_id.item()] for venue_id in venues_values], dtype=torch.long)
embeddings = embed(indices)

for venue_id in venues_values:
    collected_embeddings['venue'][venue_id.item()] = embeddings[venue_id_to_idx[venue_id.item()]]

# Paper embeddings
unique_paper_ids = torch.unique(paper_c_paper_train)
embed = torch.nn.Embedding(len(unique_paper_ids), embedding_dim)
torch.nn.init.uniform_(embed.weight, a, b)
paper_id_to_idx = {pid.item(): idx for idx, pid in enumerate(unique_paper_ids)}

indices = torch.tensor([paper_id_to_idx[pid.item()] for pid in paper_c_paper_train.flatten()], dtype=torch.long)
embeddings = embed(indices)

for pid, emb in zip(paper_c_paper_train.flatten(), embeddings):
    collected_embeddings['paper'][pid.item()] = emb

if collected_embeddings:
    torch.save(collected_embeddings, f"/mnt/c/Users/Bruger/Desktop/Bachelor/GraphML_Bachelorprojekt/dataset/ogbn_mag/processed/collected_embeddings_{embedding_dim}_spread_{b}_synt.pt")
    print("embeddings saved")


# Load the collected embeddings dictionary
collected_embeddings = torch.load(f"/mnt/c/Users/Bruger/Desktop/Bachelor/GraphML_Bachelorprojekt/dataset/ogbn_mag/processed/collected_embeddings_{embedding_dim}_spread_{b}_synt.pt")

embed_dict = collected_embeddings