In [None]:
import torch
import torch.nn as nn

# Example embedding: vocab size 1000, embedding dim 136
embedding = nn.Embedding(1, 136)

# Masking function for gradient hook
def mask_gradients(grad):
    # grad shape: [vocab_size, 136]
    mask = torch.zeros_like(grad)
    mask[:, :8] = 1  # only keep gradients for first 8 dims
    return grad * mask

# Register the hook
embedding.weight.register_hook(mask_gradients)

# Example forward and backward pass

embedded = embedding(input_ids)

# Assume some loss computed on the full 136-dimensional vectors
loss = embedded.pow(2).sum()
loss.backward()

# optimizer.step() would now only update the first 8 dims
# Print the gradients to verify
print(embedding.weight.grad)  # Should show gradients only for the first 8 dimensions


IndexError: index out of range in self

In [None]:
import torch
import sys
import os
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from Packages.mini_batches import mini_batches_code
from Packages.loss_function import LossFunction
from Packages.data_divide import paper_c_paper_train, paper_c_paper_valid
import gc

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

emb_matrix = torch.load("dataset/ogbn_mag/processed/hpc/emb_matrix_8_125_epoch.pt", map_location=device)
# paper_c_paper_valid = torch.load("dataset/ogbn_mag/processed/paper_c_paper_valid.pt", map_location=device)
data, _ = torch.load(r"dataset/ogbn_mag/processed/geometric_data_processed.pt", weights_only=False)

valid_dict = {}

# Get unique node IDs from both train and valid edges
unique_train = set(paper_c_paper_train.flatten().unique().tolist())
unique_valid = set(paper_c_paper_valid.flatten().unique().tolist())

# Keep only validation nodes that do not appear in training edges
valid_exclusive = unique_valid - unique_train

# Initial list of nodes for iterations
l_prev = list(valid_exclusive)
num_iterations = int(len(l_prev)-1)

sample = 1


for i in range(num_iterations):
    print(f"Iteration {i+1}")

    # Generate mini-batches
    mini_b_new = mini_batches_code(paper_c_paper_valid, l_prev, sample, ('paper', 'cites', 'paper'),data)
    dm_new,l_next,remapped_datamatrix_tensor_new,random_sample = mini_b_new.node_mapping()

    dm_new = dm_new.to(device)
    remapped_datamatrix_tensor_new = remapped_datamatrix_tensor_new.to(device)

    new_datamatrix = dm_new[torch.all(dm_new[:, 4:] != 4, dim=1)]
    new_remapped_datamatrix_tensor_new = remapped_datamatrix_tensor_new[torch.all(remapped_datamatrix_tensor_new[:, 4:] != 4, dim=1)]

    loss_function = LossFunction(alpha=10, eps=1e-10, use_regularization=True, lam=0.001)
    for j in range(sample):

        new_embedding = torch.nn.Embedding(sample, 8).to(device)
        valid_dict[random_sample[j]] = new_embedding
        new_embedding.weight = torch.mean(emb_matrix[dm_new[:, 0] == 1], dim=0 and dm_new[:, 1] == random_sample[j]).unsqueeze(0)

        



    new_optimizer = torch.optim.Adam(new_embedding.parameters(), lr=0.001)

    num_epochs = 30
    
        # Training loop
    for epoch in range(num_epochs):
        new_optimizer.zero_grad()

        # Concatenate the embeddings
        temp_embed = torch.cat([emb_matrix, new_embedding.weight], dim=0)
        types = new_datamatrix[:, 3:]
        loss = loss_function.compute_loss(temp_embed, new_remapped_datamatrix_tensor_new[:, :3])  # Compute loss
        
        # Backpropagation and optimization
        loss.backward()
        new_optimizer.step()

        # Print loss every 10 epochs
        if epoch % 10 == 0:
            print(f"Epoch {epoch}: Loss = {loss.item():.4f}")

    # Update node list for the next iteration
    l_prev = l_next

    # valid_dict[random_sample[0]] = new_embedding.weight.detach().cpu().clone()

    # Cleanup
    if (i + 1) % 10 == 0:
        gc.collect()
        torch.cuda.empty_cache()

torch.save(valid_dict, "dataset/ogbn_mag/processed/hpc/valid_dict_8.pt")

print('embed_valid done')
        

In [1]:
import torch 


torch.load("/home/william/Documents/DTU/GraphML_Bachelorprojekt/dataset/ogbn_mag/processed/venue_embeddings.pt")

{0: tensor([-1.2961, -1.0552], requires_grad=True),
 1: tensor([-0.5786, -0.5141], requires_grad=True),
 2: tensor([-0.4999,  1.4018], requires_grad=True),
 3: tensor([0.5730, 0.5671], requires_grad=True),
 4: tensor([ 1.1350, -0.0300], requires_grad=True),
 5: tensor([ 1.7031, -0.8159], requires_grad=True),
 6: tensor([ 0.6272, -0.6556], requires_grad=True),
 7: tensor([0.3111, 0.0467], requires_grad=True),
 8: tensor([ 1.3843, -1.2562], requires_grad=True),
 9: tensor([0.1344, 0.5471], requires_grad=True),
 10: tensor([1.0425, 1.1302], requires_grad=True),
 11: tensor([ 0.0474, -0.1102], requires_grad=True),
 12: tensor([ 1.6026, -0.9757], requires_grad=True),
 13: tensor([-0.4701, -0.8018], requires_grad=True),
 14: tensor([0.2122, 0.1687], requires_grad=True),
 15: tensor([1.0261, 0.6671], requires_grad=True),
 16: tensor([-0.8706, -0.6984], requires_grad=True),
 17: tensor([ 0.5884, -1.9676], requires_grad=True),
 18: tensor([ 0.7311, -0.7635], requires_grad=True),
 19: tensor([0.