In [3]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [4]:
# !git clone https://github.com/waileuklo/torchquantum.git
!git clone https://github.com/mit-han-lab/torchquantum.git
%cd torchquantum
!pip install -c '../drive/MyDrive/Documents/HKUST/BDT/Courses/MSBD 5002/Project/Code/gistfile1.txt' --editable .
%env PYTHONPATH=.
!pip install -c '../drive/MyDrive/Documents/HKUST/BDT/Courses/MSBD 5002/Project/Code/gistfile1.txt' qiskit_ibm_runtime
!pip install qiskit.aer

Cloning into 'torchquantum'...
remote: Enumerating objects: 15103, done.[K
remote: Counting objects: 100% (1787/1787), done.[K
remote: Compressing objects: 100% (448/448), done.[K
remote: Total 15103 (delta 1486), reused 1441 (delta 1338), pack-reused 13316[K
Receiving objects: 100% (15103/15103), 97.86 MiB | 47.65 MiB/s, done.
Resolving deltas: 100% (8587/8587), done.
/content/torchquantum
Obtaining file:///content/torchquantum
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting dill==0.3.4 (from torchquantum==0.1.8)
  Downloading dill-0.3.4-py2.py3-none-any.whl (86 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m86.9/86.9 kB[0m [31m1.3 MB/s[0m eta [36m0:00:00[0m
Collecting nbsphinx (from torchquantum==0.1.8)
  Downloading nbsphinx-0.9.3-py3-none-any.whl (31 kB)
Collecting pathos>=0.2.7 (from torchquantum==0.1.8)
  Downloading pathos-0.3.2-py3-none-any.whl (82 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m82.1/82.1 kB[0m 

In [5]:
import os
import pandas as pd
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler
import torch.optim as optim
import torchquantum as tq
import torchquantum.functional as tqf
from torchquantum.operator import op_name_dict
from torchquantum.layer import Op1QAllLayer, Op2QAllLayer
from torchquantum.plugin import (tq2qiskit_expand_params,
                                  tq2qiskit,
                                  tq2qiskit_measurement,
                                  qiskit_assemble_circs)
from qiskit import IBMQ
import random

use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
n_epochs = 500
n_wires = 13
steps = 2
bsz = 32

path = r'../drive/MyDrive/Documents/HKUST/BDT/Courses/MSBD 5002/Project/Code/'
data = 'd2w'
source_edge_file = 'wb-net.txt'
target_edge_file = 'db-net.txt'
percentage = 10

In [6]:
data_path   = path + f'networks/{data}/'
output_path = path + f'output/{data}/'
vst_gt_path = data_path + f'{data}_gt'
vst_path    = data_path + f'{data}_{percentage}'
es_path     = data_path + source_edge_file
et_path     = data_path + target_edge_file

In [11]:
def load_graph_file(vst_gt_path, vst_path, es_path, et_path, device):
    vst_gt = pd.read_csv(vst_gt_path, sep=' ', names=['Vt', 'Vs'], usecols=[0, 1])
    vst = pd.read_csv(vst_path, sep=' ', names=['Vt', 'Vs'], usecols=[0, 1])
    es = pd.read_csv(es_path, sep=' ', names=['Vs_1', 'Vs_2'], usecols=[0, 1])
    et = pd.read_csv(et_path, sep=' ', names=['Vt_1', 'Vt_2'], usecols=[0, 1])

    shift = len(vst_gt)
    vst_gt['Vt'] = vst_gt['Vt'] + shift
    vst['Vt'] = vst['Vt'] + shift
    et = et + shift

    ################################################# Process vertices ###################################################################################
    # Concat source graph vertices and those target graph vertices not in the anchor target vertices
    v = pd.concat([vst_gt['Vs'], vst_gt['Vt'][vst_gt['Vt'].map(lambda x: x not in vst['Vt'].values)]], ignore_index=True)
    v_idx = pd.DataFrame(columns=['V', 'idx']) # Create a dataframe storing the indices of the vertices
    v_idx['V'], v_idx['idx'] = v.sort_values(), range(len(v))

    vst_gt_train = vst_gt.join(v_idx.set_index('V'), on='Vs') # Store the indices of the source vertices in a new dataframe
    vst_gt_train = vst_gt_train.join(v_idx.set_index('V'), on='Vt', rsuffix='_t') # Likewise for the target vertices
    vst_gt_train['idx_t'] = vst_gt_train['idx_t'].fillna(vst_gt_train['idx']).astype(np.int64) # Use the index of the source vertices
                                                                                               # for anchor vertices
    Vs_col = vst_gt_train['Vs']
    Vt_col = vst_gt_train['Vt']
    idx_s_col = vst_gt_train['idx']
    idx_t_col = vst_gt_train['idx_t']

    idx_dict = {idx_s: idx_t for idx_s, idx_t in zip(idx_s_col, idx_t_col)} # A dictionary storing index correspondence
    idx_dict.update({idx_t: idx_s for idx_t, idx_s in zip(idx_t_col, idx_s_col)})

    v_idx_dict = {(idx_s, idx_t): (v_s, v_t) for v_s, v_t, idx_s, idx_t in zip(Vs_col, Vt_col, idx_s_col, idx_t_col)} # A dictionary storing index-vertex correspondence
    v_idx_dict.update({(idx_t, idx_s): (v_s, v_t) for v_s, v_t, idx_s, idx_t in zip(Vs_col, Vt_col, idx_s_col, idx_t_col)})

    idx_s_test = idx_s_col[Vs_col.map(lambda x: x not in vst['Vs'].values)].to_list() # A list storing indices for non-anchor source vertices
    v_idx_t_dict = {v_t: idx_t for v_t, idx_t in zip(Vt_col, idx_t_col)} # A dictionary storing vertex-index correspondence for target vertices
    idx_t_test = list(v_idx_t_dict.values())

    v_s_train = vst['Vs'].to_list()
    v_s_test = Vs_col[Vs_col.map(lambda x: x not in vst['Vs'].values)].to_list()
    v_dict = {v_s: v_t for v_s, v_t in vst_gt.values}
    v_dict.update({v_t: v_s for v_s, v_t in vst_gt.values})

    #################################################### Process edges #######################################################################################
    et_join_1 = et.join(vst.set_index('Vt'), on='Vt_1') # Create target edge dataframe joined with source vertices,
                                                        # based on the first vertex of the target edge
    et_join_2 = et.join(vst.set_index('Vt'), on='Vt_2') # Create target edge dataframe joined with source vertices,
                                                        # based on the second vertex of the target edge
    et_temp = pd.DataFrame(columns=['Vs_1', 'Vs_2'])    # Create a temporary target edge dataframe with the target vertices
                                                        # mapped to the source vertices if possible, and otherwise mapped to itself
    et_temp['Vs_1'] = et_join_1['Vs'].fillna(et_join_1['Vt_1']).astype(np.int64)
    et_temp['Vs_2'] = et_join_2['Vs'].fillna(et_join_2['Vt_2']).astype(np.int64)
    e = pd.concat([es, et_temp], ignore_index=True).drop_duplicates() # Create an edge dataframe with edges from source edges
                                                                      # and temporary target edges
    e['edge'] = 1
    A = e.pivot(index='Vs_1', columns='Vs_2').fillna(0).astype(pd.SparseDtype("float", 0))
    A = torch.Tensor(np.array(A)).to(device)

    return idx_dict, v_idx_dict, idx_s_test, idx_t_test, v_idx_t_dict, v_s_train, v_s_test, v_dict, e, es, et, A

In [12]:
idx_dict, v_idx_dict,\
idx_s_test, idx_t_test, v_idx_t_dict,\
v_s_train, v_s_test, v_dict,\
e, es, et, A = load_graph_file(vst_gt_path, vst_path, es_path, et_path, device)

In [9]:
def DTQW_biased(A, steps, starting):
    N = A.shape[1] # number of vertices
    D = A.sum(axis=1).reshape([N, 1]) # neighbor counts
    DD = D.expand(-1, N) # neighbor counts copied row-wise

    psi = torch.zeros(A.shape).to(A.device)
    psi[starting] = (A[starting] / D[starting]).sqrt()
    prob_vec_steps = torch.zeros(steps, N).to(A.device)

    for i in range(steps):
        T1 = (2 / DD - 1) * psi.permute([1, 0]) # first term in QW
        T2 = 2 / DD * psi.permute([1, 0]) # transpose psi for calculation of second term
        T2 = T2.sum(axis=1).reshape([N, 1]) * A - T2 # each entry of each row of the resulting psi
                                                     # receives the sum of contributions all terms
                                                     # from T2, except the transpose term, so need
                                                     # to subtract T2
        psi = T1 + T2
        prob_vec = (psi.abs() ** 2).sum(axis=1) # sum row-wise to get probability at a certain vertex
        prob_vec_steps[i] = prob_vec

    return prob_vec_steps.sum(axis=0) # sum column-wise to get probability at a certain vertex over all steps

def load_or_write_states(A, steps):
    N = A.shape[0]
    qw_path = f'{output_path}{data}_{percentage}_qw_{steps}steps.pt'
    if os.path.exists(qw_path):
        states = torch.load(qw_path)
    else:
        if not os.path.exists(output_path):
            os.makedirs(output_path)
        states = torch.zeros(A.shape).to(A.device)
        for i in range(N):
            states[i] = DTQW_biased(A, steps, i)
        torch.save(states, qw_path)
    return states

class CustomTensorDataset(Dataset):
    def __init__(self, tensor):
        self.tensor = tensor

    def __getitem__(self, index):
        data = self.tensor[index]
        return data, index

    def __len__(self):
        return len(self.tensor)

# def initial_state(A, B, i):
#     mask = torch.zeros(A.shape)
#     mask[i] = 1 / B[i].sqrt()
#     return A * mask

# def rw(state, A, B):
#     diagonal_term = 2 * A.matmul(state).diag() / B
#     diagonal_term = diagonal_term.reshape((diagonal_term.numel(), 1))
#     transpose_term = A * state.permute([1, 0])
#     return diagonal_term * A - transpose_term

# def RW(A, B, T):
#     for i in range(A.shape[0]):
#         state = initial_state(A, B, i)
#         for t in range(T):
#             state = rw(state, A, B)
#         yield state

In [10]:
class UnitaryEntanglementLayer(tq.QuantumModule):
    def __init__(self, n_wires):
        super().__init__()
        self.rz1 = Op1QAllLayer(op_name_dict['rz'], n_wires=n_wires, has_params=True, trainable=True)
        self.ry = Op1QAllLayer(op_name_dict['ry'], n_wires=n_wires, has_params=True, trainable=True)
        self.rz2 = Op1QAllLayer(op_name_dict['rz'], n_wires=n_wires, has_params=True, trainable=True)
        self.ent = Op2QAllLayer(op_name_dict['cnot'], n_wires=n_wires, jump=1, circular=True)

    def forward(self, q_device):
        self.rz1(q_device)
        self.ry(q_device)
        self.rz2(q_device)
        self.ent(q_device)

class VQNE(tq.QuantumModule):
    def __init__(self, n_wires):
        super().__init__()
        self.n_wires = n_wires
        self.uelayers = tq.QuantumModuleList([UnitaryEntanglementLayer(n_wires=self.n_wires) for _ in range(8)])

    def forward(self, states):
        bsz = states.shape[0]
        q_device = tq.QuantumDevice(n_wires=self.n_wires, bsz=bsz)
        q_device.set_states(states)
        # print(q_device)

        for uelayer in self.uelayers:
            uelayer(q_device)

        return q_device.get_states_1d()

In [55]:
def forward_one_step(states, model, idx_pos, idx_neg):
    states_emb = model(states)

    states_neg = states_all[idx_neg]
    states_neg = model(states_neg)
    states_neg = states_neg.reshape([states_neg.shape[0] // states_emb.shape[0], states_emb.shape[0], -1])
    inner_prod_neg = (states_emb.conj() * states_neg).real.sum(-1)

    states_pos = states_all[idx_pos]
    states_pos = model(states_pos)
    inner_prod_pos = (states_emb.conj() * states_pos).real.sum(-1)
    inner_prod_pos = inner_prod_pos.expand([states_neg.shape[0], -1])

    zeros = torch.zeros(inner_prod_neg.shape).to(states_emb.device)
    loss_MT = torch.maximum(inner_prod_neg - inner_prod_pos + 0.07, zeros)
    loss_MT = loss_MT.sum(-1).mean()
    loss_EPE = (states_emb - states_pos).norm() ** 2

    # print(f'states_emb:\n {states_emb}')
    # print(f'states_neg:\n {states_neg}')
    # print(f'states_pos:\n {states_pos}')
    # print(f'loss_MT: {loss_MT}')
    # print(f'loss_EPE: {loss_EPE}\n')

    return loss_MT + 0.02 * loss_EPE

def shift_and_run(func, states, model, idx_pos, idx_neg):
    param_list = []
    grad_list = []
    for param in model.parameters():
        param_list.append(param)
    for param in param_list:
        param.copy_(param + np.pi * 0.5)
        loss1 = func(states, model, idx_pos, idx_neg)
        param.copy_(param - np.pi)
        loss2 = func(states, model, idx_pos, idx_neg)
        param.copy_(param + np.pi * 0.5)
        grad = 0.5 * (loss1 - loss2)
        grad_list.append(grad)
    return func(states, model, idx_pos, idx_neg), grad_list

def train(dataloader, model, optimizer):
    idx_all = np.array(range(states_all.shape[0]))

    for states, idx in dataloader:
        idx_pos = [idx_dict[id.item()] for id in idx]
        idx_neg_candidate = np.delete(idx_all, idx.tolist() + idx_pos)
        idx_neg = np.random.choice(idx_neg_candidate, states.shape[0] * 10)

        with torch.no_grad():
            loss, grad_list = shift_and_run(forward_one_step, states, model, idx_pos, idx_neg)

        optimizer.zero_grad()

        for i, param in enumerate(model.parameters()):
            param.grad = grad_list[i].to(dtype=torch.float32, device=param.device).view(param.shape)

        optimizer.step()

        print(f'\rloss: {loss.item()}', end='')

def train_backprop(dataloader, model, optimizer):
    idx_all = np.array(range(states_all.shape[0]))

    for states, idx in dataloader:
        idx_pos = [idx_dict[id.item()] for id in idx]
        idx_neg_candidate = np.delete(idx_all, idx.tolist() + idx_pos)
        idx_neg = np.random.choice(idx_neg_candidate, states.shape[0] * 10)

        loss = forward_one_step(states, model, idx_pos, idx_neg)

        optimizer.zero_grad()
        loss.backward()

        grad_list = []
        for param in model.parameters():
            grad_list.append(param.grad)

        optimizer.step()

        print(f'\rloss: {loss.item()}', end='')

def validate(states_source, dataloader, model):
    # count_10 = 0
    # count_5  = 0
    # count_1  = 0
    # vst_10   = []
    # vst_5    = []
    # vst_1    = []
    with torch.no_grad():
        states_source_emb = model(states_source)
        states_source_emb = states_source_emb.reshape([states_source_emb.shape[0], 1, states_source_emb.shape[1]]) # shape (num of test vertices, num of total vertices)

        # idx_t_test_to_idx = {idx_t: idx for idx, idx_t in enumerate(idx_t_test)} # index the target vertex indices, where idx will be used in states_target_emb
        # idx_to_idx_t_test = {idx: idx_t for idx, idx_t in enumerate(idx_t_test)} # revert the indexing
        # idx_t_test_pos_list = [idx_t_test_to_idx[idx_dict[i]] for i in idx_s_test] # select those idx in idx_t_test_dict that correspond to test source vertices
        # idx_t_test_pos_tensor = torch.Tensor(idx_t_test_pos_list).to(states_all.device)

        sim = []
        for states_target, _ in dataloader:
            states_target_emb = model(states_target)
            sim.append((states_source_emb.conj() * states_target_emb).real.sum(-1)) # shape (num of test vertices, bsz)
        sim = torch.concat(sim, axis=1) # shape (num of test vertices, num of total vertices)
        sim = sim.sort()

        for k in [10, 5, 1]:
            mask_k = (sim.indices[:, :k] == idx_t_test_pos.unsqueeze(1)).any(1) # reshape the idx_t_test_tensor to (num of test vertices, 1) to broadcast for comparison
            count_k = mask_k.sum()
            idx_t_test_pos_selected = idx_t_test_pos[mask_k]
            v_t_test_pos_selected = [idx_v_t_test[i.item()] for i in idx_t_test_pos_selected]
            vst_k = [(v_dict[i], i) for i in v_t_test_pos_selected]
            precision_k = count_k / len(idx_s_test)
            print(f'Precision@{k}: {precision_k}')
            yield precision_k, vst_k


    #     for idx in idx_s_test:
    #         sim = []
    #         state = states_all[[idx]]
    #         state_emb = model(state)
    #         for states, _ in dataloader:
    #             states_emb = model(states)
    #             sim.append((state_emb.conj() * states_emb).real.sum(-1))
    #             del states, states_emb

    #         sim = torch.concat(sim)
    #         sim = sim.sort()

    #         idx_pos = idx_dict[idx]
    #         if idx_pos in sim.indices[:10]:
    #             count_10 += 1
    #             vst_10.append(v_idx_dict[(idx, idx_pos)])
    #             if idx_dict[idx] in sim.indices[:5]:
    #                 count_5 += 1
    #                 vst_5.append(v_idx_dict[(idx, idx_pos)])
    #                 if idx_dict[idx] in sim.indices[:1]:
    #                     count_1 += 1
    #                     vst_1.append(v_idx_dict[(idx, idx_pos)])
    # precision_10 = count_10 / len(idx_s_test)
    # precision_5  = count_5  / len(idx_s_test)
    # precision_1  = count_1  / len(idx_s_test)
    # print(f'Precision@10: {precision_10}')
    # print(f'Precision@5:  {precision_5}')
    # print(f'Precision@1:  {precision_1}')
    # return precision_10, precision_5, precision_1, vst_10, vst_5, vst_1

def analyse(vertex_pair_list):
    vertex_pair_data = dict()
    for vertex_pair in vertex_pair_list:
        vs = vertex_pair[0]
        vt = vertex_pair[1]
        neighbor_s = es[es['Vs_1'] == vs]['Vs_2']
        neighbor_t = et[et['Vt_1'] == vt]['Vt_2']
        degree_s = len(neighbor_s)
        degree_t = len(neighbor_t)

        common_neighbor = []
        common_anchor_neighbor = []
        for neighbor in neighbor_s:
            if v_dict[neighbor] in neighbor_t.values:
                common_neighbor.append((neighbor, v_dict[neighbor]))
                if neighbor in v_s_train:
                    common_anchor_neighbor.append((neighbor, v_dict[neighbor]))
        vertex_pair_data[vertex_pair] = [degree_s, degree_t, common_neighbor, common_anchor_neighbor]
    return vertex_pair_data

In [None]:
seed = 42
G = torch.Generator()
G.manual_seed(seed)

states_all = load_or_write_states(A, steps=steps)
pad = (0, 2 ** n_wires - states_all.shape[1])
states_all = F.pad(states_all, pad, 'constant', 0)
states_all = states_all.to(torch.complex64)
dataset = CustomTensorDataset(tensor=states_all)
sampler_train = RandomSampler(dataset, generator=G)
dataloader_train = DataLoader(dataset, sampler=sampler_train, batch_size=bsz)

states_source = states_all[idx_s_test]
states_target = states_all[idx_t_test]
dataset_val = CustomTensorDataset(tensor=states_target)
sampler_val = SequentialSampler(dataset_val)
dataloader_val = DataLoader(dataset_val, sampler=sampler_val, batch_size=bsz)

v_idx_t_test = {v_t: idx for idx, v_t in enumerate(v_idx_t_dict.keys())} # index the target vertices, where idx are used in states_target_emb
idx_v_t_test = {idx: v_t for idx, v_t in enumerate(v_idx_t_dict.keys())} # revert the indexing
idx_t_test_pos = [v_idx_t_test[v_dict[i]] for i in v_s_test]
idx_t_test_pos = torch.Tensor(idx_t_test_pos).to(states_all.device)

np.random.seed(42)
model = VQNE(n_wires=n_wires).to(device)
optimizer = optim.Adam(model.parameters(), lr=0.1)

for epoch in range(1, n_epochs + 1):
    print(f'Epoch: {epoch}')

    # train(dataloader_train, model, optimizer)
    train_backprop(dataloader_train, model, optimizer)

    best_precision = 0
    best_precision_list = []
    (precision_10, vst_10), (precision_5, vst_5), (precision_1, vst_1) = validate(states_source, dataloader_val, model)
    if precision_1 >= best_precision:
        best_precision = precision_1
        best_precision_list = vst_1
        print('Saving model')
        torch.save(model.state_dict(), f'{output_path}{data}_{percentage}_qw_{steps}steps_model_weights.pt')

Epoch: 1
loss: 0.7556607127189636Precision@10: 0.00035223670420236886
Precision@5: 0.0
Precision@1: 0.0
Saving model
Epoch: 2
loss: 0.7537938952445984Precision@10: 0.00035223670420236886
Precision@5: 0.0
Precision@1: 0.0
Saving model
Epoch: 3
loss: 0.9742156267166138Precision@10: 0.00035223670420236886
Precision@5: 0.0
Precision@1: 0.0
Saving model
Epoch: 4
loss: 0.8598842620849609Precision@10: 0.00035223670420236886
Precision@5: 0.0
Precision@1: 0.0
Saving model
Epoch: 5
loss: 0.9707614183425903Precision@10: 0.00035223670420236886
Precision@5: 0.0
Precision@1: 0.0
Saving model
Epoch: 6
loss: 0.7547695636749268Precision@10: 0.00035223670420236886
Precision@5: 0.0
Precision@1: 0.0
Saving model
Epoch: 7
loss: 0.965024471282959Precision@10: 0.00035223670420236886
Precision@5: 0.0
Precision@1: 0.0
Saving model
Epoch: 8
loss: 0.9653204679489136Precision@10: 0.00035223670420236886
Precision@5: 0.0
Precision@1: 0.0
Saving model
Epoch: 9
loss: 3.3353610038757324

In [None]:
t1 = torch.Tensor([[1, 1, 1, 1], [-1, -2, -1, -1]])
t2 = torch.Tensor([[1, 2, 3, 4], [13, 14, 15, 16], [5, 6, 7, 8], [9, 10, 11, 12], [17, 18, 19, 20], [21, 22, 23, 24]])
print(f'original t1:\n {t1}')
print(f'original t1 shape: {t1.shape}\n')
print(f'original t2:\n {t2}')
print(f'original t2 shape: {t2.shape}\n')

t1 = t1.reshape([t1.shape[0], 1, t1.shape[1]])
print(f't1 after adding middle dimension:\n {t1}')
print(f't1 shape after adding middle dimension: {t1.shape}\n')

# t1 = t1.expand([-1, t2.shape[0] // t1.shape[0], -1])
# print(f't1 after expansion:\n {t1}')
# print(f't1 shape after expansion: {t1.shape}\n')

# t1 = t1.reshape(t2.shape)
# print(f't1 after reshaping to t2 shape:\n {t1}')
# print(f't1 shape after reshaping to t2 shape: {t1.shape}\n')

######################################################################
# t2 = t2.reshape([t2.shape[0] // t1.shape[0], t1.shape[0], -1])
# print(f't2 after reshaping:\n {t2}')
# print(f't2 shape after reshaping: {t2.shape}')
######################################################################

t3 = t1 * t2
print(f't1 * t2:\n {t3}')
print(f't1 * t2 shape: {t3.shape}')

t3 = t3.sum(-1)
print(f't1 * t2 summed along the last dimension:\n {t3}')

t3 = t3.sort()
print(f't1 * t2 sorted:\n {t3}')

original t1:
 tensor([[ 1.,  1.,  1.,  1.],
        [-1., -2., -1., -1.]])
original t1 shape: torch.Size([2, 4])

original t2:
 tensor([[ 1.,  2.,  3.,  4.],
        [13., 14., 15., 16.],
        [ 5.,  6.,  7.,  8.],
        [ 9., 10., 11., 12.],
        [17., 18., 19., 20.],
        [21., 22., 23., 24.]])
original t2 shape: torch.Size([6, 4])

t1 after adding middle dimension:
 tensor([[[ 1.,  1.,  1.,  1.]],

        [[-1., -2., -1., -1.]]])
t1 shape after adding middle dimension: torch.Size([2, 1, 4])

t1 * t2:
 tensor([[[  1.,   2.,   3.,   4.],
         [ 13.,  14.,  15.,  16.],
         [  5.,   6.,   7.,   8.],
         [  9.,  10.,  11.,  12.],
         [ 17.,  18.,  19.,  20.],
         [ 21.,  22.,  23.,  24.]],

        [[ -1.,  -4.,  -3.,  -4.],
         [-13., -28., -15., -16.],
         [ -5., -12.,  -7.,  -8.],
         [ -9., -20., -11., -12.],
         [-17., -36., -19., -20.],
         [-21., -44., -23., -24.]]])
t1 * t2 shape: torch.Size([2, 6, 4])
t1 * t2 summed al

In [None]:
t4 = torch.Tensor([0, 3])
t4 = t4.reshape(t4.shape[0], -1)
print(f't4:\n {t4}')
t5 = t3.indices[:, :2]
print(f't5:\n {t5}')

t6 = (t5 - t4)#.prod(1)#.count_nonzero()
print(f't6:\n {t6}')

t7 = t6.count_nonzero()
print(f't7:\n {t7}')

t6.numel() - t7

t4:
 tensor([[0.],
        [3.]])
t5:
 tensor([[0, 2],
        [5, 4]])
t6:
 tensor([[0., 2.],
        [2., 1.]])
t7:
 3


tensor(1)

In [None]:
t1 = torch.Tensor([[1, 1, 1, 1]])
t2 = torch.Tensor([[5, 6, 7, 8], [1, 2, 3, 4], [9, 10, 11, 12], [13, 14, 15, 16], [17, 18, 19, 20], [21, 22, 23, 24]])
print(t1.shape)
print(t2.shape)
t3 = (t1 * t2).sum(-1).sort()

torch.Size([1, 4])
torch.Size([6, 4])


In [None]:
t3.indices[:1]

tensor([[0, 2, 3, 1, 4, 5]])