In [1]:
from model import FeatureExtractor, MLP
import numpy as np
from data_manager import setup_training_loader, create_sparse_structure_from_images
from model import create_feature_pairs, modified_sigmoid, create_coo_sparse_matrix
import torch.optim as optim
import torch
from tqdm import tqdm
from scipy.sparse import diags
from scipy.sparse.linalg import eigsh
from torch.optim.lr_scheduler import StepLR
from sklearn.metrics import f1_score, confusion_matrix
from utils import correct_pred_sign
import logging
import os
from datetime import datetime
import random

In [2]:
TARGET_CROP = 1  # The crop ID we're training to detect
UNCHANGED_CROPS = [1, 5, 23, 176]  # List of unchanged crops

train_loader = setup_training_loader(
    path_to_train_data='./training_data/train_patches.npy',
    unchanged_crops=UNCHANGED_CROPS,
    target_crops=[TARGET_CROP],
    train_batch_size=4,
    crop_band_index=18,
    device='cuda',
    ignore_crops=None,
    min_ratio=0.1,
    max_ratio=0.9
)

Filtered 1074 patches to 511 good patches (47.58%)
Dataset loaded with 511 patches
Total pixels: 25639936
Positive pixels (+1): 8439620
Negative pixels (-1): 17200316


In [3]:
checkpoint_path = './checkpoints/v2/crop1_vs_all.pth'
device = 'cuda' if torch.cuda.is_available() else 'cpu'

features_extractor = torch.load(checkpoint_path, weights_only=False)
features_extractor.to(device)

FeatureExtractor(
  (cnn): ShallowCNN(
    (block_in): Sequential(
      (0): Conv2d(18, 36, kernel_size=(9, 9), stride=(1, 1), padding=(4, 4))
      (1): ReLU()
    )
    (blocks_internal): ModuleList(
      (0-3): 4 x Sequential(
        (0): Conv2d(36, 36, kernel_size=(9, 9), stride=(1, 1), padding=(4, 4))
        (1): ReLU()
        (2): Conv2d(36, 36, kernel_size=(9, 9), stride=(1, 1), padding=(4, 4))
        (3): ReLU()
      )
    )
    (block_out): Sequential(
      (0): Conv2d(36, 18, kernel_size=(9, 9), stride=(1, 1), padding=(4, 4))
      (1): ReLU()
    )
  )
)

In [4]:
num_features=18
num_layers=5
mlp = MLP(num_features=num_features, num_layers=num_layers, device=device)

In [5]:
img_height = img_width = 32
window_size = 30
device = 'cuda'
sparse_image_obj = create_sparse_structure_from_images(img_height, img_width, window_size, device)
order = sparse_image_obj['order']
edges = sparse_image_obj['edges']
edges = edges
edge_i, edge_j = edges[:, 0], edges[:, 1]

loss_edges = create_sparse_structure_from_images(img_height, img_width, 3, device)['edges']
loss_edges_i, loss_edges_j = loss_edges[:, 0], loss_edges[:, 1]

In [6]:
# d_star = 1.0

# features_extractor.eval()
# mlp.train()
# epochs = 3

# optimizer = optim.Adam(features_extractor.parameters(), lr=0.001)

# for epoch in range(epochs):

    
#     train_running_loss = []
#     for batch_idx, (bands, label) in enumerate(tqdm(train_loader, desc="Training")):
    
    
#         with torch.no_grad():
#             features = features_extractor(bands)
    
#         self_loops = mlp(features)
    
    
#         batch_loss_list = [] 
#         for b in range(bands.shape[0]):
    
#             for i in range(2):
#                 for j in range(2):
#                     # Extract the 112x112 quadrant
#                     start_h = i * 112
#                     start_w = j * 112
                    
#                     quadrant_features = features[b][start_h:start_h+112, start_w:start_w+112, :]
#                     quadrant_label = label[b][start_h:start_h+112, start_w:start_w+112]
#                     quadrant_self_loops = self_loops[b][start_h:start_h+112, start_w:start_w+112, :]
                    
#                     # Reshape and reorder
#                     quadrant_features = quadrant_features.reshape(-1, quadrant_features.shape[-1])[order, :]
#                     quadrant_label = quadrant_label.reshape(-1)[order]
#                     quadrant_self_loops = quadrant_self_loops.reshape(-1)[order]
                    
#                     features_i, features_j = quadrant_features[edge_i], quadrant_features[edge_j]
#                     distances = ((features_i - features_j) ** 2).sum(dim=1)
#                     weights = modified_sigmoid(distances, d_star, scale=1)
                
#                     num_nodes = 112 ** 2
#                     sparse_coo = torch.sparse_coo_tensor(
#                                             indices=edges.t(),
#                                             values=weights,  # Select weights for batch b
#                                             size=(num_nodes, num_nodes)
#                                             )
#                     sparse_adjacency = sparse_coo + sparse_coo.t()
                    
#                     ones = torch.ones(num_nodes, device=weights.device)
#                     D = torch.diag(sparse_adjacency @ ones)
#                     quadrant_self_loops = torch.diag(quadrant_self_loops)
                    
#                     L = D - sparse_adjacency + quadrant_self_loops
                    
#                     identity = torch.eye(L.shape[0], device=L.device)
#                     shifted_L = (2*num_nodes + quadrant_self_loops.max().detach()) * identity - L
#                     eigenvector = torch.lobpcg(shifted_L, largest=True, k=1, tol=1e-7)[1]
#                     eigenvector = eigenvector.squeeze()
    
#                     loss_edge_weights = quadrant_label[loss_edges_i] * quadrant_label[loss_edges_j]
#                     eigv_i, eigv_j = eigenvector[loss_edges_i], eigenvector[loss_edges_j]
#                     single_loss = (loss_edge_weights * (eigv_i - eigv_j) ** 2).mean()                
    
#                     batch_loss_list.append(single_loss)
                    
#                     # L = D - sparse_adjacency + quadrant_self_loops
                    
#                     # eigenvector = torch.lobpcg(L, largest=False, k=1, tol=1e-6)[1]
                    
#         loss = torch.stack(batch_loss_list).mean()            
#         loss.backward()
#         optimizer.step()
#         optimizer.zero_grad()
    
#         train_running_loss.append(loss.item())
        
#         # batch_size = bands.size(0)
#         # patch_size = 56
#         # start_h = torch.randint(0, bands.size(1) - patch_size + 1, (batch_size,)).to(bands.device)
#         # start_w = torch.randint(0, bands.size(2) - patch_size + 1, (batch_size,)).to(bands.device)
                
#         # patches = torch.zeros(batch_size, patch_size, patch_size, bands.size(3), device=bands.device)
#         # patch_labels = torch.zeros(batch_size, patch_size, patch_size, device=bands.device)
                
#         # for i in range(batch_size):
#         #     patches[i] = bands[i, start_h[i]:start_h[i]+patch_size, start_w[i]:start_w[i]+patch_size, :]
#         #     patch_labels[i] = label[i, start_h[i]:start_h[i]+patch_size, start_w[i]:start_w[i]+patch_size]
                
#         # bands = patches
#         # label = patch_labels
    
    
        
#         # features = features.reshape(features.shape[0], -1, features.shape[-1])[:, order, :]
#         # label = label.reshape(label.shape[0], -1)[:, order]
    
    
#         # features_i, features_j = features[:, edge_i, :], features[:, edge_j, :]
#         # distances = ((features_i - features_j) ** 2).sum(dim=-1)
#         # weights = modified_sigmoid(distances, d_star, scale=1)
    
#         # num_nodes = patch_size ** 2
    
#         # # Create a list of sparse tensors for each batch
#         # sparse_adj_matrices = []
#         # for b in range(weights.shape[0]):
#         #     sparse_adj_matrix = torch.sparse_coo_tensor(
#         #         indices=edges.t(),
#         #         values=weights[b],  # Select weights for batch b
#         #         size=(num_nodes, num_nodes)
#         #     )
#         #     sparse_adj_matrices.append(sparse_adj_matrix)
        
#         # break
#     print(np.mean(train_running_loss))

In [6]:
import torch

def power_iteration(A, num_iterations=100, tol=1e-6):
    """
    Compute the dominant eigenvector of matrix A using power iteration.
    
    Args:
        A (torch.Tensor): Square matrix of shape (n, n).
        num_iterations (int): Maximum number of iterations.
        tol (float): Tolerance for convergence.
    
    Returns:
        torch.Tensor: Dominant eigenvector.
    """
    n = A.shape[0]
    # Initialize a random vector
    v = torch.randn(n, device=A.device, dtype=A.dtype, requires_grad=False)
    v = v / torch.norm(v)  # Normalize the initial vector
    
    for _ in range(num_iterations):
        # Matrix-vector multiplication
        v_new = torch.matmul(A, v)
        # Compute the Rayleigh quotient (approximation of eigenvalue)
        eigenvalue = torch.dot(v_new, v)
        # Normalize the new vector
        v_new_norm = torch.norm(v_new)
        v_new = v_new / v_new_norm
        # Check for convergence
        if torch.norm(v_new - v) < tol:
            break
        v = v_new
    
    return v_new

# Example usage
A = torch.tensor([[2.0, 1.0], [1.0, 3.0]], requires_grad=True)
eigenvector = power_iteration(A)
print("Dominant eigenvector:", eigenvector)

Dominant eigenvector: tensor([-0.5257, -0.8507], grad_fn=<DivBackward0>)


In [6]:
d_star = 1.0

features_extractor.eval()
mlp.train()
epochs = 1

optimizer = optim.Adam(features_extractor.parameters(), lr=0.001)


for epoch in range(epochs):

    
    train_running_loss = []
    for batch_idx, (bands, label) in enumerate(tqdm(train_loader, desc="Training")):
    
    
        with torch.no_grad():
            features = features_extractor(bands)
    
        self_loops = mlp(features)
    
    
        batch_loss_list = [] 
        for b in range(bands.shape[0]):
    
            # Select a random 64x64 patch
            patch_size = 32
            max_start_h = bands.size(1) - patch_size
            max_start_w = bands.size(2) - patch_size
            start_h = random.randint(0, max_start_h) if max_start_h > 0 else 0
            start_w = random.randint(0, max_start_w) if max_start_w > 0 else 0
            
            patch_features = features[b][start_h:start_h+patch_size, start_w:start_w+patch_size, :]
            patch_label = label[b][start_h:start_h+patch_size, start_w:start_w+patch_size]
            patch_self_loops = self_loops[b][start_h:start_h+patch_size, start_w:start_w+patch_size, :]
            
            # Reshape and reorder
            patch_features = patch_features.reshape(-1, patch_features.shape[-1])[order, :]
            patch_label = patch_label.reshape(-1)[order]
            patch_self_loops = patch_self_loops.reshape(-1)[order]
            
            features_i, features_j = patch_features[edge_i], patch_features[edge_j]
            distances = ((features_i - features_j) ** 2).sum(dim=1)
            weights = modified_sigmoid(distances, d_star, scale=1)
        
            num_nodes = patch_size ** 2
            sparse_coo = torch.sparse_coo_tensor(
                                    indices=edges.t(),
                                    values=weights,  # Select weights for batch b
                                    size=(num_nodes, num_nodes)
                                    )
            sparse_adjacency = sparse_coo + sparse_coo.t()
            
            ones = torch.ones(num_nodes, device=weights.device)
            D = torch.diag(sparse_adjacency @ ones)
            patch_self_loops = torch.diag(patch_self_loops)
            
            L = D - sparse_adjacency + patch_self_loops
            
            identity = torch.eye(L.shape[0], device=L.device)
            shifted_L = (2*num_nodes + patch_self_loops.max()) * identity - L
            eigenvector = torch.lobpcg(shifted_L, largest=True, k=1, tol=1e-7)[1]
            eigenvector = eigenvector.squeeze()

            loss_edge_weights = patch_label[loss_edges_i] * patch_label[loss_edges_j]
            eigv_i, eigv_j = eigenvector[loss_edges_i], eigenvector[loss_edges_j]
            single_loss = (torch.abs(loss_edge_weights) * (eigv_i - torch.sign(loss_edge_weights) * eigv_j) ** 2).mean()                

            batch_loss_list.append(single_loss)
            
            # L = D - sparse_adjacency + patch_self_loops
            
            # eigenvector = torch.lobpcg(L, largest=False, k=1, tol=1e-6)[1]
            
        loss = torch.stack(batch_loss_list).mean()  
        try:
            loss.backward()
        except:
            print("Failed to do backprob for this batch")
            continue
        # print("succeed in doing backprob for this batch")    
        optimizer.step()
        optimizer.zero_grad()
    
        train_running_loss.append(loss.item())
        
    print(np.mean(train_running_loss))

Training:   3%|█▍                                            | 4/128 [00:22<09:58,  4.83s/it]

Failed to do backprob for this batch


Training:   5%|██▏                                           | 6/128 [00:40<14:45,  7.26s/it]

Failed to do backprob for this batch


Training:   9%|███▊                                         | 11/128 [01:13<11:51,  6.08s/it]

Failed to do backprob for this batch


Training:  15%|██████▋                                      | 19/128 [02:10<13:22,  7.36s/it]

Failed to do backprob for this batch


Training:  17%|███████▋                                     | 22/128 [02:34<13:45,  7.79s/it]

Failed to do backprob for this batch


Training:  19%|████████▍                                    | 24/128 [02:50<13:26,  7.76s/it]

Failed to do backprob for this batch


Training:  22%|█████████▊                                   | 28/128 [03:27<14:30,  8.71s/it]

Failed to do backprob for this batch


Training:  23%|██████████▏                                  | 29/128 [03:32<12:35,  7.63s/it]

Failed to do backprob for this batch


Training:  23%|██████████▌                                  | 30/128 [03:42<13:40,  8.37s/it]

Failed to do backprob for this batch


Training:  24%|██████████▉                                  | 31/128 [03:47<11:55,  7.38s/it]

Failed to do backprob for this batch


Training:  31%|██████████████                               | 40/128 [04:39<09:08,  6.23s/it]

Failed to do backprob for this batch


Training:  37%|████████████████▌                            | 47/128 [05:31<09:21,  6.93s/it]

Failed to do backprob for this batch


Training:  41%|██████████████████▋                          | 53/128 [06:12<09:08,  7.32s/it]

Failed to do backprob for this batch


Training:  47%|█████████████████████                        | 60/128 [06:57<06:53,  6.07s/it]

Failed to do backprob for this batch


Training:  49%|██████████████████████▏                      | 63/128 [07:16<06:25,  5.92s/it]

Failed to do backprob for this batch


Training:  50%|██████████████████████▌                      | 64/128 [07:19<05:18,  4.98s/it]

Failed to do backprob for this batch


Training:  51%|██████████████████████▊                      | 65/128 [07:21<04:31,  4.32s/it]

Failed to do backprob for this batch


Training:  54%|████████████████████████▎                    | 69/128 [07:50<06:31,  6.63s/it]

Failed to do backprob for this batch


Training:  55%|████████████████████████▌                    | 70/128 [08:00<07:21,  7.62s/it]

Failed to do backprob for this batch


Training:  55%|████████████████████████▉                    | 71/128 [08:05<06:36,  6.95s/it]

Failed to do backprob for this batch


Training:  59%|██████████████████████████▎                  | 75/128 [08:32<06:25,  7.27s/it]

Failed to do backprob for this batch


Training:  61%|███████████████████████████▍                 | 78/128 [08:55<06:19,  7.60s/it]

Failed to do backprob for this batch


Training:  70%|███████████████████████████████▎             | 89/128 [10:00<04:06,  6.31s/it]

Failed to do backprob for this batch


Training:  71%|███████████████████████████████▉             | 91/128 [10:06<02:46,  4.49s/it]

Failed to do backprob for this batch


Training:  72%|████████████████████████████████▎            | 92/128 [10:12<02:56,  4.90s/it]

Failed to do backprob for this batch


Training:  73%|████████████████████████████████▋            | 93/128 [10:17<02:55,  5.02s/it]

Failed to do backprob for this batch


Training:  73%|█████████████████████████████████            | 94/128 [10:22<02:52,  5.08s/it]

Failed to do backprob for this batch


Training:  75%|█████████████████████████████████▊           | 96/128 [10:37<03:30,  6.56s/it]

Failed to do backprob for this batch


Training:  77%|██████████████████████████████████▍          | 98/128 [10:45<02:37,  5.23s/it]

Failed to do backprob for this batch


Training:  87%|██████████████████████████████████████▏     | 111/128 [12:27<02:07,  7.49s/it]

Failed to do backprob for this batch


Training:  88%|██████████████████████████████████████▌     | 112/128 [12:35<01:59,  7.46s/it]

Failed to do backprob for this batch


Training:  93%|████████████████████████████████████████▉   | 119/128 [13:20<01:00,  6.76s/it]

Failed to do backprob for this batch


Training:  94%|█████████████████████████████████████████▎  | 120/128 [13:25<00:51,  6.40s/it]

Failed to do backprob for this batch


Training:  98%|███████████████████████████████████████████▎| 126/128 [13:59<00:11,  5.52s/it]

Failed to do backprob for this batch


Training: 100%|████████████████████████████████████████████| 128/128 [14:09<00:00,  6.63s/it]


9.978925226413537e-05


Training:   2%|█                                             | 3/128 [00:21<12:50,  6.16s/it]

Failed to do backprob for this batch


Training:   8%|███▌                                         | 10/128 [01:04<11:07,  5.65s/it]

Failed to do backprob for this batch


Training:   9%|████▏                                        | 12/128 [01:19<12:18,  6.36s/it]

Failed to do backprob for this batch


Training:  13%|█████▉                                       | 17/128 [02:01<14:30,  7.84s/it]

Failed to do backprob for this batch


Training:  14%|██████▎                                      | 18/128 [02:07<12:55,  7.05s/it]

Failed to do backprob for this batch


Training:  15%|██████▋                                      | 19/128 [02:12<11:44,  6.46s/it]

Failed to do backprob for this batch


Training:  16%|███████▍                                     | 21/128 [02:22<10:32,  5.91s/it]

Failed to do backprob for this batch


Training:  17%|███████▋                                     | 22/128 [02:30<11:24,  6.46s/it]

Failed to do backprob for this batch


Training:  20%|████████▊                                    | 25/128 [02:51<11:55,  6.95s/it]

Failed to do backprob for this batch


Training:  21%|█████████▍                                   | 27/128 [02:59<09:19,  5.54s/it]

Failed to do backprob for this batch


Training:  27%|████████████▎                                | 35/128 [03:54<10:42,  6.90s/it]

Failed to do backprob for this batch


Training:  28%|████████████▋                                | 36/128 [04:02<10:51,  7.09s/it]

Failed to do backprob for this batch


Training:  29%|█████████████                                | 37/128 [04:08<10:08,  6.69s/it]

Failed to do backprob for this batch


Training:  33%|██████████████▊                              | 42/128 [04:39<09:07,  6.37s/it]

Failed to do backprob for this batch


Training:  34%|███████████████▍                             | 44/128 [04:52<08:52,  6.34s/it]

Failed to do backprob for this batch


Training:  38%|████████████████▉                            | 48/128 [05:22<09:43,  7.30s/it]

Failed to do backprob for this batch


Training:  40%|█████████████████▉                           | 51/128 [05:42<09:30,  7.41s/it]

Failed to do backprob for this batch


Training:  42%|██████████████████▉                          | 54/128 [06:01<08:43,  7.08s/it]

Failed to do backprob for this batch


Training:  50%|██████████████████████▌                      | 64/128 [07:05<06:37,  6.21s/it]

Failed to do backprob for this batch


Training:  54%|████████████████████████▎                    | 69/128 [07:31<06:12,  6.31s/it]

Failed to do backprob for this batch


Training:  62%|████████████████████████████▏                | 80/128 [08:40<04:15,  5.31s/it]

Failed to do backprob for this batch


Training:  63%|████████████████████████████▍                | 81/128 [08:50<05:26,  6.94s/it]

Failed to do backprob for this batch


Training:  65%|█████████████████████████████▏               | 83/128 [08:58<04:08,  5.52s/it]

Failed to do backprob for this batch


Training:  67%|██████████████████████████████▏              | 86/128 [09:16<04:04,  5.81s/it]

Failed to do backprob for this batch


Training:  77%|██████████████████████████████████▊          | 99/128 [10:47<03:00,  6.24s/it]

Failed to do backprob for this batch


Training:  79%|██████████████████████████████████▋         | 101/128 [10:57<02:38,  5.86s/it]

Failed to do backprob for this batch


Training:  88%|██████████████████████████████████████▊     | 113/128 [12:17<01:38,  6.57s/it]

Failed to do backprob for this batch


Training:  89%|███████████████████████████████████████▏    | 114/128 [12:25<01:35,  6.85s/it]

Failed to do backprob for this batch


Training: 100%|████████████████████████████████████████████| 128/128 [14:00<00:00,  6.56s/it]


5.9967102984046506e-05


Training:  14%|██████▎                                      | 18/128 [01:58<12:13,  6.67s/it]

Failed to do backprob for this batch


Training:  16%|███████                                      | 20/128 [02:16<13:56,  7.74s/it]

Failed to do backprob for this batch


Training:  18%|████████                                     | 23/128 [02:35<11:21,  6.49s/it]

Failed to do backprob for this batch


Training:  20%|████████▊                                    | 25/128 [02:47<10:28,  6.10s/it]

Failed to do backprob for this batch


Training:  21%|█████████▍                                   | 27/128 [02:56<08:02,  4.78s/it]

Failed to do backprob for this batch


Training:  23%|██████████▏                                  | 29/128 [03:06<08:42,  5.28s/it]

Failed to do backprob for this batch


Training:  27%|████████████▎                                | 35/128 [03:43<09:49,  6.34s/it]

Failed to do backprob for this batch


Training:  28%|████████████▋                                | 36/128 [03:50<10:17,  6.71s/it]

Failed to do backprob for this batch


Training:  33%|██████████████▊                              | 42/128 [04:27<09:40,  6.75s/it]

Failed to do backprob for this batch


Training:  38%|█████████████████▏                           | 49/128 [05:17<09:51,  7.49s/it]

Failed to do backprob for this batch


Training:  39%|█████████████████▌                           | 50/128 [05:24<09:41,  7.45s/it]

Failed to do backprob for this batch


Training:  45%|████████████████████                         | 57/128 [06:09<07:13,  6.11s/it]

Failed to do backprob for this batch


Training:  45%|████████████████████▍                        | 58/128 [06:17<07:41,  6.59s/it]

Failed to do backprob for this batch


Training:  52%|███████████████████████▌                     | 67/128 [07:12<07:25,  7.31s/it]

Failed to do backprob for this batch


Training:  53%|███████████████████████▉                     | 68/128 [07:20<07:28,  7.47s/it]

Failed to do backprob for this batch


Training:  55%|████████████████████████▌                    | 70/128 [07:33<06:47,  7.03s/it]

Failed to do backprob for this batch


Training:  57%|█████████████████████████▋                   | 73/128 [07:56<07:01,  7.67s/it]

Failed to do backprob for this batch


Training:  59%|██████████████████████████▎                  | 75/128 [08:12<06:54,  7.82s/it]

Failed to do backprob for this batch


Training:  62%|████████████████████████████▏                | 80/128 [08:53<06:46,  8.47s/it]

Failed to do backprob for this batch


Training:  66%|█████████████████████████████▌               | 84/128 [09:18<05:10,  7.05s/it]

Failed to do backprob for this batch


Training:  66%|█████████████████████████████▉               | 85/128 [09:29<05:49,  8.12s/it]

Failed to do backprob for this batch


Training:  68%|██████████████████████████████▌              | 87/128 [09:46<05:42,  8.35s/it]

Failed to do backprob for this batch


Training:  71%|███████████████████████████████▉             | 91/128 [10:12<04:20,  7.05s/it]

Failed to do backprob for this batch


Training:  72%|████████████████████████████████▎            | 92/128 [10:17<03:54,  6.51s/it]

Failed to do backprob for this batch


Training:  73%|█████████████████████████████████            | 94/128 [10:37<04:40,  8.26s/it]

Failed to do backprob for this batch


Training:  80%|███████████████████████████████████         | 102/128 [11:35<02:38,  6.08s/it]

Failed to do backprob for this batch


Training:  81%|███████████████████████████████████▊        | 104/128 [11:50<02:48,  7.02s/it]

Failed to do backprob for this batch


Training:  82%|████████████████████████████████████        | 105/128 [11:58<02:43,  7.12s/it]

Failed to do backprob for this batch


Training:  87%|██████████████████████████████████████▏     | 111/128 [12:30<01:48,  6.37s/it]

Failed to do backprob for this batch


Training:  93%|████████████████████████████████████████▉   | 119/128 [13:17<00:49,  5.49s/it]

Failed to do backprob for this batch


Training:  95%|█████████████████████████████████████████▉  | 122/128 [13:38<00:35,  5.90s/it]

Failed to do backprob for this batch


Training:  97%|██████████████████████████████████████████▋ | 124/128 [13:48<00:22,  5.58s/it]

Failed to do backprob for this batch


Training:  99%|███████████████████████████████████████████▋| 127/128 [14:09<00:06,  6.29s/it]

Failed to do backprob for this batch


Training: 100%|████████████████████████████████████████████| 128/128 [14:12<00:00,  6.66s/it]

Failed to do backprob for this batch
7.724995168821787e-05





In [9]:
shifted_L.shape

torch.Size([4096, 4096])

In [6]:
import numpy as np
from scipy.sparse import coo_matrix

# Define the data
data = np.array([-1, 3, -2])
row = np.array([0, 1, 2])
col = np.array([1, 2, 0])

# Create the sparse COO matrix
sparse_adjacency = coo_matrix((data, (row, col)), shape=(3, 3))

# Print the matrix (in dense form for visualization)
print(np.abs(sparse_adjacency).toarray())

[[0 1 0]
 [0 0 3]
 [2 0 0]]


In [2]:
sparse_adjacency

<3x3 sparse matrix of type '<class 'numpy.int64'>'
	with 3 stored elements in COOrdinate format>