In [None]:
import time
import torch
import numpy as np
import random
from torch.utils.data import DataLoader

from data_prepare import data_partition,NeighborFinder
from model import PTGCN
from modules import TimeEncode,MergeLayer,time_encoding
import pandas as pd

In [2]:
class Config(object):
    """config."""
    data = 'Moivelens'
    batch_size = 64
    n_degree = [20,50]  #'Number of neighbors to sample'
    n_head = 4  #'Number of heads used in attention layer'
    n_epoch = 50 #'Number of epochs'
    n_layer = 2 #'Number of network layers'
    lr = 0.0001  #'Learning rate'
    patience = 25  #'Patience for early stopping'
    drop_out = 0.1  #'Dropout probability'
    gpu = 0,  #'Idx for the gpu to use'
    node_dim = 160  #'Dimensions of the node embedding'
    time_dim = 160  #'Dimensions of the time embedding'
    embed_dim = 160 #'Dimensions of the hidden embedding'
    is_GPU = True
    temperature = 0.07

In [None]:
def evaluate(model, ratings, items, dl, adj_user_edge, adj_item_edge, adj_user_time, adj_item_time, device):

    torch.cuda.empty_cache()
    NDCG5 = 0.0
    NDCG10 = 0.0
    recall5 = 0.0
    recall10 =0.0
    num_sample = 0
    
    with torch.no_grad():
        model = model.eval()
        
        for ix,batch in enumerate(dl):
            #if ix%100==0:
               # print('batch:',ix)
            count = len(batch)
            num_sample = num_sample + count
            b_user_edge = find_latest_1D(np.array(ratings.iloc[batch]['user_id']), adj_user_edge, adj_user_time, ratings.iloc[batch]['timestamp'].tolist())
            b_user_edge = torch.from_numpy(b_user_edge).to(device)
            b_users = torch.from_numpy(np.array(ratings.iloc[batch]['user_id'])).to(device) 
            
            b_item_edge = find_latest_1D(np.array(ratings.iloc[batch]['item_id']), adj_item_edge, adj_item_time, ratings.iloc[batch]['timestamp'].tolist())
            b_item_edge = torch.from_numpy(b_item_edge).to(device)
            b_items = torch.from_numpy(np.array(ratings.iloc[batch]['item_id'])).to(device)
            timestamps = torch.from_numpy(np.array(ratings.iloc[batch]['timestamp'])).to(device)
            
            negative_samples = sampler(items, adj_user, ratings.iloc[batch]['user_id'].tolist() ,100)  
            neg_edge = find_latest(negative_samples, adj_item_edge, adj_item_time, ratings.iloc[batch]['timestamp'].tolist())
            negative_samples = torch.from_numpy(np.array(negative_samples)).to(device)
            item_set = torch.cat([b_items.view(-1,1),negative_samples], dim=1) #batch, 101
            timestamps_set = timestamps.unsqueeze(1).repeat(1,101)
            neg_edge = torch.from_numpy(neg_edge).to(device)
            edge_set = torch.cat([b_item_edge.view(-1,1),neg_edge], dim=1) #batch, 101
            
            user_embeddings = model(b_users, b_user_edge,timestamps, config.n_layer, nodetype='user')
            itemset_embeddings = model(item_set.flatten(), edge_set.flatten(), timestamps_set.flatten(), config.n_layer, nodetype='item')
            itemset_embeddings = itemset_embeddings.view(count, 101, -1)
            
            logits = torch.bmm(user_embeddings.unsqueeze(1), itemset_embeddings.permute(0,2,1)).squeeze(1) # [count,101]
            logits = -logits.cpu().numpy()
            rank = logits.argsort().argsort()[:,0]
            
            recall5 += np.array(rank<5).astype(float).sum()
            recall10 += np.array(rank<10).astype(float).sum()
            NDCG5 += (1 / np.log2(rank + 2))[rank<5].sum()
            NDCG10 += (1 / np.log2(rank + 2))[rank<10].sum()
            
        recall5 = recall5/num_sample
        recall10 = recall10/num_sample
        NDCG5 = NDCG5/num_sample
        NDCG10 = NDCG10/num_sample
            
        print("===> recall_5: {:.10f}, recall_10: {:.10f}, NDCG_5: {:.10f}, NDCG_10: {:.10f}, time:{}".format(recall5, recall10, NDCG5, NDCG10, time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time()))))

    return recall5, recall10, NDCG5, NDCG10

In [None]:
def sampler(items, adj_user, b_users, size):
    negs = []
    for user in b_users:      
        houxuan = list(set(items)-set(adj_user[user]))
        src_index = random.sample(list(range(len(houxuan))), size)
        negs.append(np.array(houxuan)[src_index])
    negs = np.array(negs)
    return negs

In [None]:
def find_latest(nodes, adj, adj_time, timestamps):
    #negative_samples, [b,size]
    edge = np.zeros_like(nodes)
    for ix in range(nodes.shape[0]):
        for iy in range(nodes.shape[1]):
            node = nodes[ix, iy]
            edge_idx = np.searchsorted(adj_time[node], timestamps[ix])-1
            edge[ix, iy] = np.array(adj[node])[edge_idx]
    return edge

def find_latest_1D(nodes, adj, adj_time, timestamps):
    #negative_samples, [b,size]
    edge = np.zeros_like(nodes)
    for ix in range(nodes.shape[0]):
        node = nodes[ix]
        edge_idx = np.searchsorted(adj_time[node], timestamps[ix])-1
        edge[ix] = np.array(adj[node])[edge_idx]
    return edge

In [None]:
# i want to know the different between high metric and low metric
ratings = []
with open('./data/movielens//ml-1m/ratings.dat', 'r') as f:
    for l in f:
        user_id, item_id, rating, timestamp = [int(_) for _ in l.split('::')]
        ratings.append({
                'user_id': user_id,
                'item_id': item_id,
                'rating': rating,
                'timestamp': timestamp,
                })
print(len(ratings))
print(ratings[0])

In [None]:
import pandas as pd
ratings = pd.DataFrame(ratings)
print(ratings)

In [None]:
users = ratings['user_id'].unique()
items = ratings['item_id'].unique()  
print(len(users),len(items))
print(users)
print(min(items))

In [None]:
ratings['timestamp'] = ratings['timestamp'] - min(ratings['timestamp'])
print(min(ratings['timestamp']))

In [None]:
item_count = ratings['item_id'].value_counts()
item_count.name = 'item_count'
ratings = ratings.join(item_count, on='item_id')

user_count = ratings['user_id'].value_counts()
user_count.name = 'user_count'
ratings = ratings.join(user_count, on='user_id')


In [None]:
print(ratings)


In [None]:
ratings = ratings[(ratings['user_count'] >= 5) & (ratings['item_count'] >= 5)]


In [None]:
print(ratings)
len(ratings['user_id'].unique()) == len(users) and len(ratings['item_id'].unique()) == len(items)

In [None]:
users = ratings['user_id'].unique()
items = ratings['item_id'].unique()
len(users), len(items)

In [None]:
del ratings['user_count']
del ratings['item_count']
print(ratings)

In [None]:
item_count = ratings['item_id'].value_counts()
item_count.name = 'item_count'
ratings = ratings.join(item_count, on='item_id')

user_count = ratings['user_id'].value_counts()
user_count.name = 'user_count'
ratings = ratings.join(user_count, on='user_id')

In [None]:
print(ratings)

In [None]:
ratings = ratings[(ratings['user_count'] >= 5) & (ratings['item_count'] >= 5)]
ratings

In [None]:
len(ratings['user_id'].unique()) == len(users) and len(ratings['item_id'].unique()) == len(items)

In [None]:
del ratings['user_count']
del ratings['item_count']

In [None]:
users = ratings['user_id'].unique()
items = ratings['item_id'].unique()
len(users),len(items)

In [None]:
item_count = ratings['item_id'].value_counts()
item_count.name = 'item_count'
ratings = ratings.join(item_count, on='item_id')

user_count = ratings['user_id'].value_counts()
user_count.name = 'user_count'
ratings = ratings.join(user_count, on='user_id')
ratings = ratings[(ratings['user_count'] >=5) & (ratings['item_count'] >= 5)]
ratings

In [None]:
users = ratings['user_id'].unique()
items = ratings['item_id'].unique()
len(users), len(items)

In [None]:
user_ids_invmap = {id_: i for i, id_ in enumerate(users)}
item_ids_invmap = {id_: i for i, id_ in enumerate(items)}
ratings['user_id'].replace(user_ids_invmap, inplace=True)
ratings['item_id'].replace(item_ids_invmap, inplace=True)

In [None]:
ratings

In [None]:
print('user_count:'+str(len(users))+','+'item_count:'+str(len(items)))
print('avr of user:'+str(ratings['user_id'].value_counts().mean())+'avr of item:'+str(ratings['item_id'].value_counts().mean()))
print(len(ratings))

In [None]:
users = ratings['user_id'].unique()
items = ratings['item_id'].unique()

In [None]:
ratings = ratings.sort_values(by='timestamp',ascending=True)
print(ratings)

In [None]:
ratings = ratings.reset_index(drop=True)
print(ratings)

In [None]:
full_data = []

In [None]:
adj_user = {cur_user:ratings[ratings.user_id == cur_user].index.tolist() for cur_user in users} 
adj_item = {cur_item:ratings[ratings.item_id == cur_item].index.tolist() for cur_item in items}
print(adj_user)
print(adj_item)
    

In [None]:
print(ratings.shape)

In [None]:
print(adj_user[6039].index(0))

In [None]:
for i in range(ratings.shape[0]):  #edge ID
    cur_user = ratings['user_id'].iloc[i]
    cur_item = ratings['item_id'].iloc[i]
    #确保训练集和测试集中的序列至少含有3个邻居
    #这里没看懂，总之就是每个用户或者项目的前三个交互被省略了
    if adj_user[cur_user].index(i)>=3 and adj_item[cur_item].index(i)>=3:
        full_data.append(i)

In [None]:
offset1 = int(len(full_data) * 0.8)
offset2 = int(len(full_data) * 0.9)
print(len(full_data),offset1,offset2)

In [None]:
import random
random.shuffle(full_data)
train_data, valid_data, test_data = full_data[0:offset1], full_data[offset1:offset2], full_data[offset2:len(full_data)]
print(len(train_data),len(valid_data),len(test_data))

In [None]:
del ratings['rating']
del ratings['user_count']
del ratings['item_count']

In [None]:
print(ratings.columns)

In [None]:
print(ratings)

In [None]:
print(valid_data)

In [None]:
import time
import torch
import numpy as np
import random
from torch.utils.data import DataLoader

from data_prepare import data_partition,NeighborFinder
from model import PTGCN
from modules import TimeEncode,MergeLayer,time_encoding
from data_prepare import data_partition,NeighborFinder
ratings, train_data, valid_data, test_data = data_partition('data/movielens/ml-1m')
print(ratings.shape,len(train_data),len(valid_data),len(test_data))
users = ratings['user_id'].unique()
items = ratings['item_id'].unique() 
items_in_data = ratings.iloc[train_data+valid_data+test_data]['item_id'].unique()
print(items_in_data,len(items_in_data))
adj_user = {user: ratings[ratings.user_id == user]['item_id'].tolist() for user in users}
adj_user_edge = {user:ratings[ratings.user_id == user].index.tolist() for user in users}
adj_user_time = {user:ratings[ratings.user_id == user]['timestamp'].tolist() for user in users} 

adj_item_edge = {item:ratings[ratings.item_id == item].index.tolist() for item in items}
adj_item_time = {item:ratings[ratings.item_id == item]['timestamp'].tolist() for item in items} 
num_users = len(users)
num_items = len(items)
print(num_users,num_items)
neighor_finder = NeighborFinder(ratings)
time_encoder = time_encoding(160)
MLPLayer = MergeLayer(160, 160, 160, 1)
a_users = np.array(ratings['user_id'])
a_items = np.array(ratings['item_id'])
edge_idx = np.arange(0, len(a_users))
print(a_users.shape,a_items.shape,edge_idx.shape)

In [None]:
n_neighbors = 50
source_idx = a_users
adj_user = np.zeros((len(edge_idx), 50), dtype=np.int32)
adj_user.shape
user_mask = np.ones((len(edge_idx), n_neighbors), dtype=np.bool)
user_time = np.zeros((len(edge_idx), n_neighbors), dtype=np.int32)  # time matirx，节点与其他max_nodes的时间差
adj_user_edge = np.zeros((len(source_idx), n_neighbors), dtype=np.int32)
user_mask.shape, user_time.shape, adj_user_edge.shape

In [None]:
# print(neighor_finder.user_edgeidx[source_idx[0]])
# print(edge_idx[0])
i = 999610
  

In [None]:
idx = np.searchsorted(neighor_finder.user_edgeidx[source_idx[i]], edge_idx[i]) + 1  
print(idx)
his_len = len(neighor_finder.user_edgeidx[source_idx[i]][:idx])
print(his_len)
used_len = his_len if his_len <= n_neighbors else n_neighbors
print(used_len)
n_ratings = np.array(ratings)
test_list = neighor_finder.user_edgeidx[source_idx[i]]
test_list.shape, idx - used_len, idx, test_list[0:1]
adj_user[i, n_neighbors - used_len:] = n_ratings[:,1][neighor_finder.user_edgeidx[source_idx[i]][idx - used_len:idx]]
user_time[i, n_neighbors - used_len:] = n_ratings[:,2][neighor_finder.user_edgeidx[source_idx[i]][idx - used_len:idx]]
user_mask[i, n_neighbors - used_len:] = 0
adj_user_edge[i, n_neighbors - used_len:] = neighor_finder.user_edgeidx[source_idx[i]][idx - used_len:idx]
print(adj_user[i])


In [None]:
print(a_users[-1])

In [None]:
print(neighor_finder.user_edgeidx[4957])


In [None]:
test_user_ratings = ratings[ratings['user_id']==4957]
print(test_user_ratings)

In [None]:
test_user_adj_50 = test_user_ratings['item_id'].to_list()
test_user_adj_50[-50:]

In [None]:
class Config(object):
    """config."""
    data = 'Moivelens'
    batch_size = 64
    n_degree = [20,50]  #'Number of neighbors to sample'
    n_head = 4  #'Number of heads used in attention layer'
    n_epoch = 50 #'Number of epochs'
    n_layer = 2 #'Number of network layers'
    lr = 0.0001  #'Learning rate'
    patience = 25  #'Patience for early stopping'
    drop_out = 0.1  #'Dropout probability'
    gpu = 0,  #'Idx for the gpu to use'
    node_dim = 160  #'Dimensions of the node embedding'
    time_dim = 160  #'Dimensions of the time embedding'
    embed_dim = 160 #'Dimensions of the hidden embedding'
    is_GPU = True
    temperature = 0.07

In [None]:
import time
import torch
import numpy as np
import random
from torch.utils.data import DataLoader

from data_prepare import data_partition,NeighborFinder
from model import PTGCN
from modules import TimeEncode,MergeLayer,time_encoding
from data_prepare import data_partition,NeighborFinder

config = Config()
checkpoint_dir='/models'  
min_NDCG10 = 1000.0
max_itrs = 0

device_string = 'cuda' if torch.cuda.is_available() else 'cpu'
device = torch.device(device_string)

print("loading the dataset...")

ratings, train_data, valid_data, test_data = data_partition('data/movielens/ml-1m')
print(ratings.shape,len(train_data),len(valid_data),len(test_data))
users = ratings['user_id'].unique()
items = ratings['item_id'].unique() 
items_in_data = ratings.iloc[train_data+valid_data+test_data]['item_id'].unique()
print(items_in_data,len(items_in_data))
adj_user = {user: ratings[ratings.user_id == user]['item_id'].tolist() for user in users}
adj_user_edge = {user:ratings[ratings.user_id == user].index.tolist() for user in users}
adj_user_time = {user:ratings[ratings.user_id == user]['timestamp'].tolist() for user in users} 

adj_item_edge = {item:ratings[ratings.item_id == item].index.tolist() for item in items}
adj_item_time = {item:ratings[ratings.item_id == item]['timestamp'].tolist() for item in items} 
num_users = len(users)
num_items = len(items)
print(num_users,num_items)
neighor_finder = NeighborFinder(ratings)
time_encoder = time_encoding(160)
MLPLayer = MergeLayer(160, 160, 160, 1)
a_users = np.array(ratings['user_id'])
a_items = np.array(ratings['item_id'])
edge_idx = np.arange(0, len(a_users))
print(a_users.shape,a_items.shape,edge_idx.shape)
user_neig50 = neighor_finder.get_user_neighbor_ind(a_users, edge_idx, max(config.n_degree), device)
item_neig50 = neighor_finder.get_item_neighbor_ind(a_items, edge_idx, max(config.n_degree), device)

criterion = torch.nn.CrossEntropyLoss(reduction='sum')

In [None]:
def sampler(items, adj_user, b_users, size):
    negs = []
    for user in b_users:      
        houxuan = list(set(items)-set(adj_user[user]))
        src_index = random.sample(list(range(len(houxuan))), size)
        negs.append(np.array(houxuan)[src_index])
    negs = np.array(negs)
    return negs

def find_latest(nodes, adj, adj_time, timestamps):
    #negative_samples, [b,size]
    edge = np.zeros_like(nodes)
    for ix in range(nodes.shape[0]):
        for iy in range(nodes.shape[1]):
            node = nodes[ix, iy]
            edge_idx = np.searchsorted(adj_time[node], timestamps[ix])-1
            edge[ix, iy] = np.array(adj[node])[edge_idx]
    return edge

def find_latest_1D(nodes, adj, adj_time, timestamps):
    #negative_samples, [b,size]
    edge = np.zeros_like(nodes)
    for ix in range(nodes.shape[0]):
        node = nodes[ix]
        edge_idx = np.searchsorted(adj_time[node], timestamps[ix])-1
        edge[ix] = np.array(adj[node])[edge_idx]
    return edge

In [None]:
print(num_users,num_items,config.n_head,config.drop_out)
model = PTGCN(user_neig50, item_neig50, num_users, num_items,
                 time_encoder, config.n_layer,  config.n_degree, config.node_dim, config.time_dim,
                 config.embed_dim, device, config.n_head, config.drop_out
                 ).to(device)

In [None]:
optim = torch.optim.Adam(model.parameters(),lr=config.lr)

num_params = 0
for param in model.parameters():
    num_params += param.numel()
print(num_params)

In [None]:
dl = DataLoader(train_data, config.batch_size, shuffle=True, pin_memory=True)
itrs = 0
sum_loss=0

In [None]:
for epoch in range(config.n_epoch):
    time1 = 0.0
    x=0.0
    for id,batch in enumerate(dl):
        #print('epoch:',epoch,' batch:',id)
        x=x+1
        # print(batch.shape)
        optim.zero_grad()
        count = len(batch)
        print(count)
        
        b_user_edge = find_latest_1D(np.array(ratings.iloc[batch]['user_id']), adj_user_edge, adj_user_time, ratings.iloc[batch]['timestamp'].tolist())
        negative_samples = sampler(items_in_data, adj_user, ratings.iloc[batch]['user_id'].tolist() ,1) 
        print(b_user_edge.shape)
        user_embeddings = model(b_users, b_user_edge, timestamps, config.n_layer, nodetype='user')
        break
        # optim.zero_grad()
    print(x)
    break

In [None]:
print(len(train_data),len(valid_data),len(test_data))

In [None]:
print(min(train_data),max(train_data),min(valid_data),max(valid_data),min(test_data),max(test_data))

In [None]:


# adj_user[i, n_neighbors - used_len:] = n_ratings[:,1][neighor_finder.user_edgeidx[source_idx[i]][idx - used_len:idx]]       

In [None]:
ratings = []
with open('./data/movielens//ml-1m/ratings.dat', 'r') as f:
    for l in f:
        user_id, item_id, rating, timestamp = [int(_) for _ in l.split('::')]
        ratings.append({
                'user_id': user_id,
                'item_id': item_id,
                'rating': rating,
                'timestamp': timestamp,
                })
ratings = pd.DataFrame(ratings)
users = ratings['user_id'].unique() # 所有用户的id
items = ratings['item_id'].unique() # 所有item的id 
ratings['timestamp'] = ratings['timestamp'] - min(ratings['timestamp']) # 每条边的时间戳相当于初始时间戳过了多少时间
for i in range(1000):
    item_count = ratings['item_id'].value_counts()
    item_count.name = 'item_count'
    ratings = ratings.join(item_count, on='item_id')

    user_count = ratings['user_id'].value_counts()
    user_count.name = 'user_count'
    ratings = ratings.join(user_count, on='user_id')
    ratings = ratings[(ratings['user_count'] >= 5) & (ratings['item_count'] >= 5)]

    if len(ratings['user_id'].unique()) == len(users) and len(ratings['item_id'].unique()) == len(items):
        break
    users = ratings['user_id'].unique()
    items = ratings['item_id'].unique()
    del ratings['user_count']
    del ratings['item_count']

del ratings['user_count']
del ratings['item_count']

users = ratings['user_id'].unique()
items = ratings['item_id'].unique()

item_count = ratings['item_id'].value_counts()
item_count.name = 'item_count'
ratings = ratings.join(item_count, on='item_id')

user_count = ratings['user_id'].value_counts()
user_count.name = 'user_count'
ratings = ratings.join(user_count, on='user_id')
ratings = ratings[(ratings['user_count'] >=5) & (ratings['item_count'] >= 5)]

users = ratings['user_id'].unique()
items = ratings['item_id'].unique()

# 下面这几行代码相当于对用户和项目重新分配了id
# 为了方便理解，可以理解成tx_id2node_id
user_ids_invmap = {id_: i for i, id_ in enumerate(users)}
item_ids_invmap = {id_: i for i, id_ in enumerate(items)}
ratings['user_id'].replace(user_ids_invmap, inplace=True)
ratings['item_id'].replace(item_ids_invmap, inplace=True)

print('user_count:'+str(len(users))+','+'item_count:'+str(len(items)))
print('avr of user:'+str(ratings['user_id'].value_counts().mean())+'avr of item:'+str(ratings['item_id'].value_counts().mean()))
print(len(ratings))

users = ratings['user_id'].unique()
items = ratings['item_id'].unique()

# 这一步主要是根据时间戳进行交互的排序
ratings = ratings.sort_values(by='timestamp',ascending=True)  

ratings = ratings.reset_index(drop=True)


# 这里是记录了每个用户和每个项目各自参与了哪些交互，注意，list中保存的是交互的索引，而不是用户或者项目的id
adj_user = {cur_user:ratings[ratings.user_id == cur_user].index.tolist() for cur_user in users} 
adj_item = {cur_item:ratings[ratings.item_id == cur_item].index.tolist() for cur_item in items}
     

In [None]:
full_data = []
for i in range(ratings.shape[0]):  #edge ID
    cur_user = ratings['user_id'].iloc[i]
    cur_item = ratings['item_id'].iloc[i]
    #确保训练集和测试集中的序列至少含有3个邻居
    #这里没看懂，总之就是每个用户或者项目的前三个交互被省略了
    
    if adj_user[cur_user].index(i)>=3 and adj_item[cur_item].index(i)>=3:
        full_data.append(i)
print(len(full_data))
print(len(ratings))

In [None]:
print(min(full_data))
top3index = [adj_user[cur_user][:3] for cur_user in users]
print(top3index[:10])

In [3]:
import time
import torch
import numpy as np
import random
from torch.utils.data import DataLoader

from data_prepare import data_partition,NeighborFinder
from model import PTGCN
from modules import TimeEncode,MergeLayer,time_encoding
from data_prepare import data_partition,NeighborFinder

config = Config()
checkpoint_dir='/models'  
min_NDCG10 = 1000.0
max_itrs = 0

device_string = 'cuda' if torch.cuda.is_available() else 'cpu'
device = torch.device(device_string)

print("loading the dataset...")

ratings, train_data, valid_data, test_data = data_partition('data/movielens/ml-1m')


loading the dataset...
user_count:6040,item_count:3416
avr of user:165.49850993377484avr of item:292.6261709601874
999611
Index(['user_id', 'item_id', 'timestamp'], dtype='object')


In [6]:
train_time = ratings['timestamp'].iloc[np.array(train_data)].tolist()
valid_time = ratings['timestamp'].iloc[np.array(valid_data)].tolist()
test_time = ratings['timestamp'].iloc[np.array(test_data)].tolist()

In [7]:
print(min(train_time),max(train_time))
print(min(valid_time),max(valid_time))
print(min(test_time),max(test_time))

7750 89750658
7839 89734000
10018 89750350


In [8]:
testintrain = np.searchsorted(np.array(train_time),np.array(test_time))

In [9]:
print(testintrain)

[477190 758415 107072 ... 719833 546383 106318]


In [11]:
print(test_data)

[776191, 927667, 77052, 849557, 12750, 915901, 511968, 622467, 198066, 366836, 634144, 156933, 642339, 599901, 437444, 91308, 73307, 565444, 766233, 911144, 704903, 889824, 279489, 633656, 939101, 491724, 738298, 618296, 634204, 929679, 338679, 108030, 399011, 130718, 448018, 561950, 319209, 666688, 11878, 15383, 426718, 841077, 852344, 258083, 912163, 276666, 388354, 914472, 411457, 712036, 498661, 481116, 291668, 365713, 821799, 318040, 978266, 348584, 533525, 748143, 239310, 853886, 523870, 349534, 977570, 983254, 623680, 371328, 685702, 960561, 252363, 639876, 413175, 40808, 658051, 513900, 55715, 403715, 988037, 312242, 952692, 857129, 253404, 937062, 930817, 833236, 168192, 263592, 725023, 5732, 270966, 136953, 698449, 400222, 32302, 27775, 470680, 245196, 161306, 797932, 247761, 101909, 650557, 398199, 651927, 27983, 488976, 293594, 487767, 415594, 987706, 627214, 60397, 162143, 849799, 622297, 141284, 691698, 739842, 517410, 311403, 710033, 661886, 270336, 706145, 71014, 878099

In [1]:
import torch

In [16]:
embeddings = torch.randn(5,5)


In [19]:
a = torch.tensor([[1,3,2,0,4],[4,2,3,1,0]]).long()
print(a)
a_emb = embeddings[a]
# origin_shape = a.shape
b = torch.split(a,2, dim=1)
for i in b:
    print(i.shape)
items = []
item_idx = []
for i in b:
    i_fl = i.flatten()
    item_emb = embeddings[i_fl]
    items.append(item_emb.view(2, -1, 5))
    item_idx.append(i_fl.view(2,-1))
    print(i)
    print(i_fl)
c = torch.cat(item_idx,dim=1)
embd_all = torch.cat(items, dim = 1)
# c = c.view(*origin_shape)
print(c)
# print(a, a.shape)
# print(c, c.shape)
print(a_emb)
print(embd_all)
print(torch.all(a_emb == embd_all))

tensor([[1, 3, 2, 0, 4],
        [4, 2, 3, 1, 0]])
torch.Size([2, 2])
torch.Size([2, 2])
torch.Size([2, 1])
tensor([[1, 3],
        [4, 2]])
tensor([1, 3, 4, 2])
tensor([[2, 0],
        [3, 1]])
tensor([2, 0, 3, 1])
tensor([[4],
        [0]])
tensor([4, 0])
tensor([[1, 3, 2, 0, 4],
        [4, 2, 3, 1, 0]])
tensor([[[-0.7365, -0.3737,  0.0200,  0.6635,  0.0228],
         [-0.7230,  1.2125,  0.3219,  1.1276,  0.8007],
         [ 0.0548,  0.1133,  1.3466, -0.4650, -0.4110],
         [ 1.1402,  1.0056, -0.7389, -0.0954,  1.2557],
         [ 0.3674,  1.0609, -0.4308,  0.1871, -1.2656]],

        [[ 0.3674,  1.0609, -0.4308,  0.1871, -1.2656],
         [ 0.0548,  0.1133,  1.3466, -0.4650, -0.4110],
         [-0.7230,  1.2125,  0.3219,  1.1276,  0.8007],
         [-0.7365, -0.3737,  0.0200,  0.6635,  0.0228],
         [ 1.1402,  1.0056, -0.7389, -0.0954,  1.2557]]])
tensor([[[-0.7365, -0.3737,  0.0200,  0.6635,  0.0228],
         [-0.7230,  1.2125,  0.3219,  1.1276,  0.8007],
         [ 0.05