In [20]:
# 导入pandas模块
import pandas as pd
import datetime as dt

# 定义列名列表
col_names = ["user_id", "POI_id", "POI_catid", "poi_cat_name", "latitude", "longitude", "timezone", "UTC_time"]

# 读取csv文件，并添加列名
df = pd.read_csv("./dataset/dataset_tsmc2014/dataset_TSMC2014_NYC.txt", names=col_names,sep='\t')

# 统计每个用户的签到次数
user_counts = df.groupby("user_id").size()

# 筛选出签到次数大于等于10次的用户
valid_users = user_counts[user_counts >= 5].index

# 从数据集中只保留有效用户的数据
df = df[df["user_id"].isin(valid_users)]

# 统计每个poi的被签到次数
poi_counts = df.groupby("POI_id").size()

# 筛选出被签到次数大于等于10次的poi
valid_pois = poi_counts[poi_counts >= 6].index

# 从数据集中只保留有效poi的数据
df = df[df["POI_id"].isin(valid_pois)]
def get_local_fraction(utc_time, timezone_offset):
  # 将utc时间戳转换为datetime对象
  utc_datetime = dt.datetime.strptime(utc_time, "%a %b %d %H:%M:%S %z %Y")
  # 根据时区偏移计算本地时间
  local_datetime = utc_datetime + dt.timedelta(minutes=timezone_offset)
  # 获取本地时间的小时和分钟
  local_hour = local_datetime.hour
  local_minute = local_datetime.minute
  # 计算本地时间百分比，保留两位小数
  local_fraction = round((local_hour + local_minute / 60) / 24, 2)
  # 返回本地时间百分比
  return local_fraction
df["norm_in_day_time"] = df.apply(lambda row: get_local_fraction(row["UTC_time"], row["timezone"]), axis=1)
# 把UTC time列转换成时间戳
df["timestamp"] = pd.to_datetime(df["UTC_time"], format="%a %b %d %H:%M:%S %z %Y").astype(int) / 10**9

# 按照user分组
groups = df.groupby("user_id")

# 创建一个空的列表，用于存储结果
result = []

# 遍历每个分组
for name, group in groups:
    # 对每个分组内的时间戳进行排序
    group = group.sort_values("timestamp")
    # 创建一个变量，用于记录小组的序号
    traj_id = 1
    # 创建一个变量，用于记录上一个时间戳
    prev_timestamp = None
    # 遍历每个分组内的行
    for index, row in group.iterrows():
        # 获取当前时间戳
        curr_timestamp = row["timestamp"]
        # 如果是第一行，或者时间相差大于24小时
        if prev_timestamp is None or curr_timestamp - prev_timestamp > 24 * 60 * 60:
            # 给这个行开辟一个新的列，列名是traj_id，列的值是user+小组的序号
            row["trajectory_id"] = str(name) + "_" + str(traj_id)
            # 小组的序号加一
            traj_id += 1
        # 否则，给这个行开辟一个新的列，列名是traj_id，列的值和上一行相同
        else:
            row["trajectory_id"] = result[-1]["trajectory_id"]
        # 把这个行添加到结果列表中
        result.append(row)
        # 更新上一个时间戳
        prev_timestamp = curr_timestamp

# 把结果列表转换成pandas数据框
result_df = pd.DataFrame(result)

# 打印结果数据框
print(result_df)



        user_id                    POI_id                 POI_catid  \
3660          1  4d4ac10da0ef54814b6ffff6  4bf58dd8d48988d157941735   
5603          1  4db44994cda1c57c82583709  4bf58dd8d48988d1f1931735   
5783          1  4a541923f964a52008b31fe3  4bf58dd8d48988d14e941735   
6696          1  40f1d480f964a5205b0a1fe3  4bf58dd8d48988d143941735   
7666          1  3fd66200f964a52094e41ee3  4bf58dd8d48988d1cc941735   
...         ...                       ...                       ...   
220748     1083  49f4dca6f964a520626b1fe3  4bf58dd8d48988d1c1941735   
220756     1083  40b68100f964a52085001fe3  4bf58dd8d48988d116941735   
224179     1083  4eda64ced5fb8f213a5d740e  4bf58dd8d48988d176941735   
224647     1083  51140198e4b0874a568cde81  4bf58dd8d48988d162941735   
225359     1083  4a53d9a7f964a520c7b21fe3  4bf58dd8d48988d124941735   

                 poi_cat_name   latitude  longitude  timezone  \
3660      American Restaurant  40.784018 -73.974524      -240   
5603    General E

In [21]:
len(result_df['user_id'].unique()),len(result_df['POI_id'].unique())

(1083, 8446)

In [22]:
result_df.head(10)

Unnamed: 0,user_id,POI_id,POI_catid,poi_cat_name,latitude,longitude,timezone,UTC_time,norm_in_day_time,timestamp,trajectory_id
3660,1,4d4ac10da0ef54814b6ffff6,4bf58dd8d48988d157941735,American Restaurant,40.784018,-73.974524,-240,Sat Apr 07 17:42:24 +0000 2012,0.57,1333821000.0,1_1
5603,1,4db44994cda1c57c82583709,4bf58dd8d48988d1f1931735,General Entertainment,40.739398,-73.99321,-240,Sun Apr 08 18:20:29 +0000 2012,0.6,1333909000.0,1_2
5783,1,4a541923f964a52008b31fe3,4bf58dd8d48988d14e941735,American Restaurant,40.785677,-73.976498,-240,Sun Apr 08 20:02:10 +0000 2012,0.67,1333915000.0,1_2
6696,1,40f1d480f964a5205b0a1fe3,4bf58dd8d48988d143941735,Breakfast Spot,40.719929,-74.008532,-240,Mon Apr 09 16:20:52 +0000 2012,0.51,1333988000.0,1_2
7666,1,3fd66200f964a52094e41ee3,4bf58dd8d48988d1cc941735,Steakhouse,40.734276,-73.993525,-240,Tue Apr 10 00:24:31 +0000 2012,0.85,1334017000.0,1_2
11804,1,49d2b43ef964a520cb5b1fe3,4bf58dd8d48988d1e0931735,Coffee Shop,40.720087,-74.003961,-240,Thu Apr 12 17:19:21 +0000 2012,0.55,1334251000.0,1_3
14856,1,46ea2358f964a520cf4a1fe3,4bf58dd8d48988d11d941735,Bar,40.760667,-73.994948,-240,Sat Apr 14 01:11:20 +0000 2012,0.88,1334366000.0,1_4
15082,1,4d081fb700e6b1f7d4060cd7,4bf58dd8d48988d113941735,Korean Restaurant,40.764104,-73.986725,-240,Sat Apr 14 03:07:56 +0000 2012,0.96,1334373000.0,1_4
15249,1,40fb0f00f964a520d90a1fe3,4bf58dd8d48988d11b941735,Bar,40.760645,-73.986065,-240,Sat Apr 14 04:45:13 +0000 2012,0.03,1334379000.0,1_4
15947,1,428d2880f964a520b5231fe3,4bf58dd8d48988d1fa931735,Hotel,40.756731,-73.97407,-240,Sat Apr 14 17:45:23 +0000 2012,0.57,1334426000.0,1_4


In [23]:
result_df['POI_catid_code']=result_df['POI_catid']

In [13]:
result_df.head(10)

Unnamed: 0,user_id,POI_id,POI_catid,poi_cat_name,latitude,longitude,timezone,UTC_time,timestamp,trajectory_id,POI_catid_code
2454,1,4abc1f51f964a520798620e3,4bf58dd8d48988d1ce941735,Seafood Restaurant,40.781558,-73.975792,-240,Wed Apr 04 23:31:31 +0000 2012,1333582000.0,1_1,4bf58dd8d48988d1ce941735
3660,1,4d4ac10da0ef54814b6ffff6,4bf58dd8d48988d157941735,American Restaurant,40.784018,-73.974524,-240,Sat Apr 07 17:42:24 +0000 2012,1333821000.0,1_2,4bf58dd8d48988d157941735
5603,1,4db44994cda1c57c82583709,4bf58dd8d48988d1f1931735,General Entertainment,40.739398,-73.99321,-240,Sun Apr 08 18:20:29 +0000 2012,1333909000.0,1_3,4bf58dd8d48988d1f1931735
5783,1,4a541923f964a52008b31fe3,4bf58dd8d48988d14e941735,American Restaurant,40.785677,-73.976498,-240,Sun Apr 08 20:02:10 +0000 2012,1333915000.0,1_3,4bf58dd8d48988d14e941735
6696,1,40f1d480f964a5205b0a1fe3,4bf58dd8d48988d143941735,Breakfast Spot,40.719929,-74.008532,-240,Mon Apr 09 16:20:52 +0000 2012,1333988000.0,1_3,4bf58dd8d48988d143941735
7666,1,3fd66200f964a52094e41ee3,4bf58dd8d48988d1cc941735,Steakhouse,40.734276,-73.993525,-240,Tue Apr 10 00:24:31 +0000 2012,1334017000.0,1_3,4bf58dd8d48988d1cc941735
8312,1,4f3283f0e4b057434d8fdc81,4bf58dd8d48988d1c1941735,Mexican Restaurant,40.717888,-74.005668,-240,Tue Apr 10 16:21:48 +0000 2012,1334075000.0,1_3,4bf58dd8d48988d1c1941735
11804,1,49d2b43ef964a520cb5b1fe3,4bf58dd8d48988d1e0931735,Coffee Shop,40.720087,-74.003961,-240,Thu Apr 12 17:19:21 +0000 2012,1334251000.0,1_4,4bf58dd8d48988d1e0931735
13737,1,4f3283f0e4b057434d8fdc81,4bf58dd8d48988d1c1941735,Mexican Restaurant,40.717888,-74.005668,-240,Fri Apr 13 15:41:41 +0000 2012,1334332000.0,1_4,4bf58dd8d48988d1c1941735
14856,1,46ea2358f964a520cf4a1fe3,4bf58dd8d48988d11d941735,Bar,40.760667,-73.994948,-240,Sat Apr 14 01:11:20 +0000 2012,1334366000.0,1_4,4bf58dd8d48988d11d941735


In [24]:
# 假设你想要将Venue ID (Foursquare)列的列名改成venue_id
result_df = result_df.rename(columns={"poi_cat_name": "POI_catname"})


In [6]:
result_df.head(10)

Unnamed: 0,user_id,POI_id,POI_catid,POI_catname,latitude,longitude,timezone,UTC_time,timestamp,trajectory_id,POI_catid_code
2454,1,4abc1f51f964a520798620e3,4bf58dd8d48988d1ce941735,Seafood Restaurant,40.781558,-73.975792,-240,Wed Apr 04 23:31:31 +0000 2012,1333582000.0,1_1,4bf58dd8d48988d1ce941735
3660,1,4d4ac10da0ef54814b6ffff6,4bf58dd8d48988d157941735,American Restaurant,40.784018,-73.974524,-240,Sat Apr 07 17:42:24 +0000 2012,1333821000.0,1_2,4bf58dd8d48988d157941735
5603,1,4db44994cda1c57c82583709,4bf58dd8d48988d1f1931735,General Entertainment,40.739398,-73.99321,-240,Sun Apr 08 18:20:29 +0000 2012,1333909000.0,1_3,4bf58dd8d48988d1f1931735
5783,1,4a541923f964a52008b31fe3,4bf58dd8d48988d14e941735,American Restaurant,40.785677,-73.976498,-240,Sun Apr 08 20:02:10 +0000 2012,1333915000.0,1_3,4bf58dd8d48988d14e941735
6696,1,40f1d480f964a5205b0a1fe3,4bf58dd8d48988d143941735,Breakfast Spot,40.719929,-74.008532,-240,Mon Apr 09 16:20:52 +0000 2012,1333988000.0,1_3,4bf58dd8d48988d143941735
7666,1,3fd66200f964a52094e41ee3,4bf58dd8d48988d1cc941735,Steakhouse,40.734276,-73.993525,-240,Tue Apr 10 00:24:31 +0000 2012,1334017000.0,1_3,4bf58dd8d48988d1cc941735
7906,1,42586c80f964a520db201fe3,4bf58dd8d48988d121941735,Bar,40.775986,-73.979528,-240,Tue Apr 10 03:36:56 +0000 2012,1334029000.0,1_3,4bf58dd8d48988d121941735
8312,1,4f3283f0e4b057434d8fdc81,4bf58dd8d48988d1c1941735,Mexican Restaurant,40.717888,-74.005668,-240,Tue Apr 10 16:21:48 +0000 2012,1334075000.0,1_3,4bf58dd8d48988d1c1941735
11804,1,49d2b43ef964a520cb5b1fe3,4bf58dd8d48988d1e0931735,Coffee Shop,40.720087,-74.003961,-240,Thu Apr 12 17:19:21 +0000 2012,1334251000.0,1_4,4bf58dd8d48988d1e0931735
13737,1,4f3283f0e4b057434d8fdc81,4bf58dd8d48988d1c1941735,Mexican Restaurant,40.717888,-74.005668,-240,Fri Apr 13 15:41:41 +0000 2012,1334332000.0,1_4,4bf58dd8d48988d1c1941735


In [25]:
# 获取traj_id列的唯一值
traj_ids = result_df["trajectory_id"].unique()

# 随机选择80%的traj_id作为训练集的标识
traj_ids = pd.Series(traj_ids)
train_ids = traj_ids.sample(frac=0.8, random_state=0)

# 根据traj_id筛选训练集和测试集
train_df = result_df[result_df["trajectory_id"].isin(train_ids)]
test_df = result_df[~result_df["trajectory_id"].isin(train_ids)]


In [11]:
result_df.head(10)

Unnamed: 0,user_id,POI_id,POI_catid,POI_catname,latitude,longitude,timezone,UTC_time,timestamp,trajectory_id,POI_catid_code
2454,1,4abc1f51f964a520798620e3,4bf58dd8d48988d1ce941735,Seafood Restaurant,40.781558,-73.975792,-240,Wed Apr 04 23:31:31 +0000 2012,1333582000.0,1_1,4bf58dd8d48988d1ce941735
3660,1,4d4ac10da0ef54814b6ffff6,4bf58dd8d48988d157941735,American Restaurant,40.784018,-73.974524,-240,Sat Apr 07 17:42:24 +0000 2012,1333821000.0,1_2,4bf58dd8d48988d157941735
5603,1,4db44994cda1c57c82583709,4bf58dd8d48988d1f1931735,General Entertainment,40.739398,-73.99321,-240,Sun Apr 08 18:20:29 +0000 2012,1333909000.0,1_3,4bf58dd8d48988d1f1931735
5783,1,4a541923f964a52008b31fe3,4bf58dd8d48988d14e941735,American Restaurant,40.785677,-73.976498,-240,Sun Apr 08 20:02:10 +0000 2012,1333915000.0,1_3,4bf58dd8d48988d14e941735
6696,1,40f1d480f964a5205b0a1fe3,4bf58dd8d48988d143941735,Breakfast Spot,40.719929,-74.008532,-240,Mon Apr 09 16:20:52 +0000 2012,1333988000.0,1_3,4bf58dd8d48988d143941735
7666,1,3fd66200f964a52094e41ee3,4bf58dd8d48988d1cc941735,Steakhouse,40.734276,-73.993525,-240,Tue Apr 10 00:24:31 +0000 2012,1334017000.0,1_3,4bf58dd8d48988d1cc941735
7906,1,42586c80f964a520db201fe3,4bf58dd8d48988d121941735,Bar,40.775986,-73.979528,-240,Tue Apr 10 03:36:56 +0000 2012,1334029000.0,1_3,4bf58dd8d48988d121941735
8312,1,4f3283f0e4b057434d8fdc81,4bf58dd8d48988d1c1941735,Mexican Restaurant,40.717888,-74.005668,-240,Tue Apr 10 16:21:48 +0000 2012,1334075000.0,1_3,4bf58dd8d48988d1c1941735
11804,1,49d2b43ef964a520cb5b1fe3,4bf58dd8d48988d1e0931735,Coffee Shop,40.720087,-74.003961,-240,Thu Apr 12 17:19:21 +0000 2012,1334251000.0,1_4,4bf58dd8d48988d1e0931735
13737,1,4f3283f0e4b057434d8fdc81,4bf58dd8d48988d1c1941735,Mexican Restaurant,40.717888,-74.005668,-240,Fri Apr 13 15:41:41 +0000 2012,1334332000.0,1_4,4bf58dd8d48988d1c1941735


In [26]:
train_df.to_csv('./dataset/dataset_tsmc2014/NYC_train.csv',sep=',',index=False)
test_df.to_csv('./dataset/dataset_tsmc2014/NYC_test.csv',sep=',',index=False)

In [12]:
train_df.head(10)

Unnamed: 0,user_id,POI_id,POI_catid,POI_catname,latitude,longitude,timezone,UTC_time,timestamp,trajectory_id,POI_catid_code
2454,1,4abc1f51f964a520798620e3,4bf58dd8d48988d1ce941735,Seafood Restaurant,40.781558,-73.975792,-240,Wed Apr 04 23:31:31 +0000 2012,1333582000.0,1_1,4bf58dd8d48988d1ce941735
3660,1,4d4ac10da0ef54814b6ffff6,4bf58dd8d48988d157941735,American Restaurant,40.784018,-73.974524,-240,Sat Apr 07 17:42:24 +0000 2012,1333821000.0,1_2,4bf58dd8d48988d157941735
5603,1,4db44994cda1c57c82583709,4bf58dd8d48988d1f1931735,General Entertainment,40.739398,-73.99321,-240,Sun Apr 08 18:20:29 +0000 2012,1333909000.0,1_3,4bf58dd8d48988d1f1931735
5783,1,4a541923f964a52008b31fe3,4bf58dd8d48988d14e941735,American Restaurant,40.785677,-73.976498,-240,Sun Apr 08 20:02:10 +0000 2012,1333915000.0,1_3,4bf58dd8d48988d14e941735
6696,1,40f1d480f964a5205b0a1fe3,4bf58dd8d48988d143941735,Breakfast Spot,40.719929,-74.008532,-240,Mon Apr 09 16:20:52 +0000 2012,1333988000.0,1_3,4bf58dd8d48988d143941735
7666,1,3fd66200f964a52094e41ee3,4bf58dd8d48988d1cc941735,Steakhouse,40.734276,-73.993525,-240,Tue Apr 10 00:24:31 +0000 2012,1334017000.0,1_3,4bf58dd8d48988d1cc941735
7906,1,42586c80f964a520db201fe3,4bf58dd8d48988d121941735,Bar,40.775986,-73.979528,-240,Tue Apr 10 03:36:56 +0000 2012,1334029000.0,1_3,4bf58dd8d48988d121941735
8312,1,4f3283f0e4b057434d8fdc81,4bf58dd8d48988d1c1941735,Mexican Restaurant,40.717888,-74.005668,-240,Tue Apr 10 16:21:48 +0000 2012,1334075000.0,1_3,4bf58dd8d48988d1c1941735
11804,1,49d2b43ef964a520cb5b1fe3,4bf58dd8d48988d1e0931735,Coffee Shop,40.720087,-74.003961,-240,Thu Apr 12 17:19:21 +0000 2012,1334251000.0,1_4,4bf58dd8d48988d1e0931735
13737,1,4f3283f0e4b057434d8fdc81,4bf58dd8d48988d1c1941735,Mexican Restaurant,40.717888,-74.005668,-240,Fri Apr 13 15:41:41 +0000 2012,1334332000.0,1_4,4bf58dd8d48988d1c1941735


In [28]:
test_df.head(10)

Unnamed: 0,user_id,POI_id,POI_catid,POI_catname,latitude,longitude,timezone,UTC_time,norm_in_day_time,timestamp,trajectory_id,POI_catid_code
76832,1,4fb62dcb4fc6cfd3cc2c1acc,4bf58dd8d48988d1f1931735,General Entertainment,40.659952,-73.968866,-240,Sat May 19 17:54:31 +0000 2012,0.58,1337450000.0,1_11,4bf58dd8d48988d1f1931735
77205,1,46f52f99f964a520ef4a1fe3,4bf58dd8d48988d14e941735,American Restaurant,40.665308,-73.989401,-240,Sat May 19 20:38:26 +0000 2012,0.69,1337460000.0,1_11,4bf58dd8d48988d14e941735
83327,1,3fd66200f964a52048e81ee3,4bf58dd8d48988d1e0931735,Coffee Shop,40.785889,-73.976859,-240,Sat May 26 01:12:48 +0000 2012,0.88,1337995000.0,1_12,4bf58dd8d48988d1e0931735
83836,1,3fd66200f964a520c5f11ee3,4bf58dd8d48988d14e941735,American Restaurant,40.719607,-73.986764,-240,Sat May 26 22:44:50 +0000 2012,0.78,1338072000.0,1_12,4bf58dd8d48988d14e941735
89016,1,4d27b39755a8b60c0c4bc6c0,4bf58dd8d48988d155941735,Gastropub,40.728007,-73.999143,-240,Tue May 29 21:33:52 +0000 2012,0.73,1338327000.0,1_14,4bf58dd8d48988d155941735
89276,1,4de3e4effa7651589f21983d,4bf58dd8d48988d11e941735,Bar,40.721488,-73.995029,-240,Tue May 29 23:15:10 +0000 2012,0.8,1338333000.0,1_14,4bf58dd8d48988d11e941735
118467,1,49d2b43ef964a520cb5b1fe3,4bf58dd8d48988d1e0931735,Coffee Shop,40.720087,-74.003961,-240,Mon Jul 02 18:16:36 +0000 2012,0.59,1341253000.0,1_20,4bf58dd8d48988d1e0931735
124184,1,3fd66200f964a520caea1ee3,4bf58dd8d48988d1d4941735,Bar,40.730084,-73.989256,-240,Sat Jul 07 00:16:43 +0000 2012,0.84,1341620000.0,1_22,4bf58dd8d48988d1d4941735
124340,1,3fd66200f964a520c7f11ee3,4bf58dd8d48988d1ca941735,Pizza Place,40.731682,-73.996181,-240,Sat Jul 07 02:12:49 +0000 2012,0.92,1341627000.0,1_22,4bf58dd8d48988d1ca941735
125126,1,4530db98f964a520623b1fe3,4bf58dd8d48988d190941735,History Museum,40.792624,-73.95219,-240,Sat Jul 07 19:48:24 +0000 2012,0.66,1341691000.0,1_22,4bf58dd8d48988d190941735


In [27]:
result_df.to_csv('./dataset/dataset_tsmc2014/NYC.csv',sep=',',index=False)

In [18]:
len(result_df['user_id'].unique())

1083

In [1]:
import random

In [22]:
def sample_from(A, B):
    # 创建一个空集合，用于存储B中不在A中的元素
    candidates = set()
    # 遍历B中的每个元素
    for x in B:
        # 如果x不在A中
        if x not in A:
            # 将x添加到候选集合中
            candidates.add(x)
    # 如果候选集合为空，说明B中的所有元素都在A中，无法采样，返回None
    if len(candidates) == 0:
        return None
    # 否则，从候选集合中随机选择一个元素并返回
    else:
        return random.choice(list(candidates))

# 测试一下函数
A=[1, 2, 3, 4, 5]
B=[3, 4, 5, 6, 7, 8, 9]
result = sample_from(A, B)
print(result)

7


In [24]:
a=[1,2,3,5]
a[int(len(a)/2):]

[3, 5]

In [31]:
import torch
import torch.nn.utils.rnn as rnn

# 假设你有一个[B,L,D]的矩阵，例如：
x = torch.tensor([
    [[1, 2, 3], [4, 5, 6], [7, 8, 9], [0, 0, 0], [0, 0, 0]], # 序列长度为3
    [[10, 11, 12], [13, 14, 15], [16, 17, 18], [19, 20, 21], [0, 0, 0]], # 序列长度为4
    [[22, 23, 24], [25, 26, 27], [28, 29, 30], [31, 32, 33], [34, 35, 36]] # 序列长度为5
])

# 假设你有一个列表，保存每个序列的长度，例如：
lengths = [3, 4, 5]

# 使用pack_padded_sequence()函数，将x打包成一个PackedSequence对象，指定填充值为0
y = rnn.pack_padded_sequence(x, lengths=lengths, batch_first=True,enforce_sorted=False)

# 打印y的内容
print(y)
label = torch.tensor([
    [1, -1,3],
    [-1, -1,3],
    [-1, -1,3],
    [-1, -1,3],
    [-1, -1,3]
])

# 获取y中的data属性，即二维的矩阵，保存所有序列的数据
y_data = y.data
print(y_data)

# 使用inner()函数，计算y_data和label的内积，得到一个一维的向量，保存每个序列与标签的分数
scores = torch.inner(y_data, label)

# 打印结果
print(scores)

PackedSequence(data=tensor([[22, 23, 24],
        [10, 11, 12],
        [ 1,  2,  3],
        [25, 26, 27],
        [13, 14, 15],
        [ 4,  5,  6],
        [28, 29, 30],
        [16, 17, 18],
        [ 7,  8,  9],
        [31, 32, 33],
        [19, 20, 21],
        [34, 35, 36]]), batch_sizes=tensor([3, 3, 3, 2, 1]), sorted_indices=tensor([2, 1, 0]), unsorted_indices=tensor([2, 1, 0]))
tensor([[22, 23, 24],
        [10, 11, 12],
        [ 1,  2,  3],
        [25, 26, 27],
        [13, 14, 15],
        [ 4,  5,  6],
        [28, 29, 30],
        [16, 17, 18],
        [ 7,  8,  9],
        [31, 32, 33],
        [19, 20, 21],
        [34, 35, 36]])
tensor([[ 71,  27,  27,  27,  27],
        [ 35,  15,  15,  15,  15],
        [  8,   6,   6,   6,   6],
        [ 80,  30,  30,  30,  30],
        [ 44,  18,  18,  18,  18],
        [ 17,   9,   9,   9,   9],
        [ 89,  33,  33,  33,  33],
        [ 53,  21,  21,  21,  21],
        [ 26,  12,  12,  12,  12],
        [ 98,  36,  36,  36

In [37]:
a = torch.tensor([
    [1, 2],
    [4, 5]
])
b = torch.tensor([
    [1,1],
    [2,2]
])
c = torch.einsum('ij,ij->i', a, b)

# 打印结果
print(c)
torch.sub(a,b)

tensor([ 3, 18])


tensor([[0, 1],
        [2, 3]])

In [43]:
import torch
import random
import torch.nn as nn
import torch.nn.functional as F

user_count = 10
item_count = 20
hidden_dim = 5
max_len = 3

# randomly initialize the user and item embeddings
user_emb_w = torch.rand((user_count + 1, hidden_dim))
item_emb_w = torch.rand((item_count + 1, hidden_dim))

# randomly generate a batch of triplets (u, i, j)
batch = []
for _ in range(max_len):
    u = random.randint(1, user_count) # user id
    i = random.randint(1, item_count) # positive item id
    j = random.randint(1, item_count) # negative item id
    batch.append(torch.tensor([u, i, j]))

# pad the batch with 0 to make it have the same length
batch = torch.nn.utils.rnn.pad_sequence(batch, batch_first=True, padding_value=0)

# print the batch
print(batch)


class BPR(nn.Module):
    def __init__(self, user_count, item_count, hidden_dim, max_len):
        super(BPR, self).__init__()
        self.user_emb = nn.Embedding(user_count + 1, hidden_dim) # user embedding
        self.item_emb = nn.Embedding(item_count + 1, hidden_dim) # item embedding
        self.max_len = max_len # maximum number of triplets in a batch

    def forward(self, batch):
        # batch is a list of triplets (u, i, j)
        # pad the batch with 0 to make it have the same length
        batch = torch.nn.utils.rnn.pad_sequence(batch, batch_first=True, padding_value=0)
        # get the user, positive item and negative item ids
        u = batch[:, 0]
        i = batch[:, 1]
        j = batch[:, 2]
        # get the corresponding embeddings
        u_emb = self.user_emb(u)
        i_emb = self.item_emb(i)
        j_emb = self.item_emb(j)
        # compute the preference difference
        diff = torch.sum(u_emb * (i_emb - j_emb), dim=1)
        # compute the bpr loss
        loss = -torch.mean(F.logsigmoid(diff))
        return loss
model=BPR(10,20,5,3)
model(batch)

tensor([[ 4, 16,  5],
        [ 9,  9,  5],
        [10,  2, 14]])


KeyboardInterrupt: 

In [48]:
# 创建一个[L*5, D]形状的随机tensor
L = 3
D = 2
x = torch.randn(L*4, D)
print(x)
# 使用torch.reshape或torch.view函数将其变成[5, L, D]形状
y = torch.reshape(x, (4, L, D))
# 或者
y = x.view(4, L, D)

# 打印y的形状，应该是[5, L, D]
print(y)

tensor([[-1.3588e-01,  1.3258e+00],
        [-4.2665e-01,  5.9311e-01],
        [-6.5640e-01,  4.9492e-01],
        [-8.3200e-04,  4.1285e-01],
        [ 1.7160e+00, -4.6392e-01],
        [ 9.2991e-01, -7.8517e-02],
        [-1.3035e-01,  6.2406e-01],
        [ 1.4991e-01, -1.6562e+00],
        [ 7.0649e-01, -5.0739e-01],
        [-2.3706e-01, -1.7558e-01],
        [-5.4446e-01,  2.8186e-01],
        [-1.5732e+00, -1.6817e-02]])
tensor([[[-1.3588e-01,  1.3258e+00],
         [-4.2665e-01,  5.9311e-01],
         [-6.5640e-01,  4.9492e-01]],

        [[-8.3200e-04,  4.1285e-01],
         [ 1.7160e+00, -4.6392e-01],
         [ 9.2991e-01, -7.8517e-02]],

        [[-1.3035e-01,  6.2406e-01],
         [ 1.4991e-01, -1.6562e+00],
         [ 7.0649e-01, -5.0739e-01]],

        [[-2.3706e-01, -1.7558e-01],
         [-5.4446e-01,  2.8186e-01],
         [-1.5732e+00, -1.6817e-02]]])


In [55]:
flatten_preference=torch.randn(L, D)
ne_embeddings=torch.randn(4,L,D)
dot_ne=torch.einsum('ij,kij->ki', flatten_preference, ne_embeddings)
dot_pos = torch.randn(L)
sub = torch.sub(dot_pos, dot_ne)
nll = torch.mean(-torch.log(torch.sigmoid(sub)))
print(nll)

tensor(1.2389)


In [4]:
import pandas as pd

In [11]:
df=pd.read_csv('dataset/NYC/NYC_train.csv')
    # 把timestamp列转换成pandas的datetime类型
df['timestamp'] = pd.to_datetime(df['local_time'])

    # 创建一个新的列，叫做hour，用来存储timestamp中的小时部分
df['hour'] = df['timestamp'].dt.hour
# 把hour列转换成一个整数除以4的商
df['hour'] = df['hour'] // 2
    # 使用pandas.groupby方法来对数据框按照poi_id和hour进行分组，并计算每个分组的行数
df = df.groupby(['POI_id', 'hour']).size().reset_index(name='count')

    # 使用pandas.pivot_table方法来把数据框转换成一个透视表
df = df.pivot_table(index='POI_id', columns='hour', values='count', fill_value=0)


# 使用pandas.DataFrame.to_dict方法来把df转化成字典
df = df.to_dict(orient='index')

In [12]:
df

{'3fd66200f964a52000e71ee3': {0: 1,
  1: 0,
  2: 0,
  3: 0,
  4: 0,
  5: 0,
  6: 0,
  7: 0,
  8: 0,
  9: 3,
  10: 2,
  11: 1},
 '3fd66200f964a52001e81ee3': {0: 4,
  1: 1,
  2: 0,
  3: 0,
  4: 0,
  5: 0,
  6: 0,
  7: 0,
  8: 0,
  9: 0,
  10: 0,
  11: 1},
 '3fd66200f964a52003e71ee3': {0: 0,
  1: 0,
  2: 0,
  3: 0,
  4: 1,
  5: 1,
  6: 1,
  7: 0,
  8: 2,
  9: 0,
  10: 0,
  11: 1},
 '3fd66200f964a52004e41ee3': {0: 3,
  1: 0,
  2: 0,
  3: 0,
  4: 0,
  5: 0,
  6: 0,
  7: 0,
  8: 0,
  9: 2,
  10: 3,
  11: 3},
 '3fd66200f964a52004e61ee3': {0: 0,
  1: 0,
  2: 0,
  3: 0,
  4: 0,
  5: 0,
  6: 3,
  7: 0,
  8: 0,
  9: 2,
  10: 2,
  11: 0},
 '3fd66200f964a52005e81ee3': {0: 0,
  1: 0,
  2: 0,
  3: 0,
  4: 0,
  5: 0,
  6: 0,
  7: 1,
  8: 0,
  9: 0,
  10: 0,
  11: 2},
 '3fd66200f964a52005eb1ee3': {0: 0,
  1: 0,
  2: 0,
  3: 0,
  4: 0,
  5: 0,
  6: 1,
  7: 0,
  8: 0,
  9: 1,
  10: 4,
  11: 2},
 '3fd66200f964a52008e91ee3': {0: 0,
  1: 0,
  2: 0,
  3: 0,
  4: 0,
  5: 2,
  6: 0,
  7: 1,
  8: 1,
  9: 1,
  1

In [42]:
import random
a=[1,2,3,4,5]
w=[9,9,5,1,1]
random.choices(a,k=1,weights=w)

[3]

In [43]:
(-1)%12

11