In [52]:
import bptf
from bptf import BPTF
import numpy as np
import pandas as pd
import sparse
import os
import shutil
from tqdm import tqdm
import pickle
from concurrent.futures import ProcessPoolExecutor, as_completed
import scipy.stats as st
import matplotlib.pyplot as plt
import torch
import tensorly
import cupy
import multiprocessing
from joblib import Parallel, delayed
from tqdm.contrib.concurrent import process_map

os.getcwd()

'c:\\Users\\luiyu\\OneDrive - The University of Chicago\\UChicago_MS_Stat\\aaron_schein\\bptf_new'

# Utility functions

In [53]:
def dataframe_to_sparse_tensor(data, country_indices, date_indices, cameo_col='CAMEO_Code', events_col='Num_Events'):
    """
    Converts a pandas DataFrame into an sptensor.
    
    Parameters:
        data (pd.DataFrame): The DataFrame containing the data.
        country_indices (pd.DataFrame): DataFrame with country indices.
        date_indices (pd.DataFrame): DataFrame with date indices.
        cameo_col (str): Name of the CAMEO code column in 'data'. Defaults to 'CAMEO_Code'.
        events_col (str): Name of the column containing event counts in 'data'. Defaults to 'Num_Events'.
    
    Returns:
        sptensor: The resulting sparse tensor.
    """
    
    # Define the shape of the tensor (V, V, A, T)
    V = len(country_indices)
    A = 20  # Assuming the CAMEO code has 20 distinct values
    T = len(date_indices)
    
    shape = (V, V, A, T)
    
    # Initialize empty subs and vals
    subs = ([], [], [], [])
    vals = []

    # Iterate through the DataFrame to populate the tensor
    for i in range(len(data)):

        source_country = country_indices.loc[country_indices['country'] == data['Source_Country_Code'].iloc[i]]
        source_country_index = int(source_country.iloc[0, 1])

        target_country = country_indices.loc[country_indices['country'] == data['Target_Country_Code'].iloc[i]]
        target_country_index = int(target_country.iloc[0, 1])

        action_index = int(data[cameo_col].iloc[i] - 1)  # Adjust CAMEO code to 0-based index

        date = date_indices.loc[date_indices['date'] == data['formatteddate'].iloc[i]]
        date_index = int(date.iloc[0, 1])

        # Append indices and values
        subs[0].append(source_country_index)
        subs[1].append(target_country_index)
        subs[2].append(action_index)
        subs[3].append(date_index)
        vals = np.append(vals, data[events_col].iloc[i])

    # Convert subs to a tuple of numpy arrays and vals to a numpy array
    subs = tuple(np.array(s, dtype=int) for s in subs)
    vals = np.array(vals)

    # Create the sparse COO tensor
    Y = sparse.COO(coords=subs, data=vals, shape=shape)
    
    return Y

def add_sparse_tensors(tensor1, tensor2):
    assert tensor1.shape == tensor2.shape, "Tensors must have the same shape to be added."
    
    # Directly add the two tensors
    result_tensor = tensor1 + tensor2
    
    return result_tensor


def sum_sparse_tensor_list(tensor_list):
    if not tensor_list:
        raise ValueError("The list of tensors is empty.")
    
    # Use sum with a starting tensor of 0 in the shape of the tensors
    result_tensor = sum(tensor_list[1:], start=tensor_list[0])
    
    return result_tensor

# Get and process data

In [54]:
# read files by batch
folder = "\\dyadic_data_2000_2020_inclusive\\"
data_filepath = os.getcwd() + folder
files = os.listdir(data_filepath)
for filepath in tqdm(files, desc = 'Reading data'):
    complete_filepath = data_filepath + filepath
    file = pd.read_csv(complete_filepath)
    if filepath == files[0]:
        data = file
    else:
        data = pd.concat([data, file])
    del file
data = data.sort_values(by='Num_Events', ascending=False)

# filter by date to get a smaller dataset to play around with
# Also my laptop can only fit a 10/11 year dataset
years = [str(2000 + x) for x in range(0, 13)]
data = data[data['formatteddate'].str.startswith(tuple(years))]
# Take only ICEWS data
data = data[data['Database'] == 'ICEWS']

# collapse by month
data['formatteddate'] = data['formatteddate'].str[:7]
data = data.groupby(['Source_Country_Code', 'Target_Country_Code', 'CAMEO_Code', 'formatteddate', 'Database'])
data = data.sum().reset_index()

# get unique list of countries
countries = pd.concat([data['Source_Country_Code'], data['Target_Country_Code']])
countries = pd.unique(countries)
country_indices = pd.DataFrame({
    'country' : countries,
    'index' : range(len(countries))
})
del countries

# there's no need to get a list of actions since it's just 1 to 20

# get unique list of dates
# Make sure to change the date format and frequency of the date range if you change the time unit
date_indices = pd.date_range(
    start=pd.to_datetime(data['formatteddate'], format='%Y-%m').min().strftime('%Y-%m'), 
    end=pd.to_datetime(data['formatteddate'], format='%Y-%m').max().strftime('%Y-%m'),
    freq='MS')
date_indices = date_indices.strftime('%Y-%m').to_list()
date_indices = pd.DataFrame({
    'date' : date_indices,
    'index' : range(len(date_indices))
})
date_indices['date'] = date_indices['date'].str[:7]

Reading data: 100%|██████████| 85/85 [00:55<00:00,  1.52it/s]


# Convert dataframe to sparse tensor

In [55]:
num_of_batches = 10000
data = np.array_split(data, num_of_batches)

  return bound(*args, **kwds)


In [56]:
data[0].shape

(150, 6)

In [57]:
# Y = []
# progressbar = tqdm(range(len(data)))
# for batch in progressbar:
#     progressbar.set_description(f'Now convert batch {batch} out of total {len(data)} batches.')
#     Y.append(dataframe_to_sparse_tensor(data[batch], country_indices, date_indices))
# Y = sum(Y)
# del data

In [58]:
num_cores = multiprocessing.cpu_count()

if __name__ == '__main__':
    data = Parallel(n_jobs = num_cores)(
        delayed(dataframe_to_sparse_tensor)(batch, country_indices, date_indices)
        for batch in tqdm(data, desc='Converting to sparse tensors')
    )
# Y = sparse.COO(coords=np.empty((len(data[0].shape), 0), dtype=int), data=np.array([]), shape=data[0].shape)
# for batch in tqdm(data):
#     Y += batch

Converting to sparse tensors: 100%|██████████| 10000/10000 [02:13<00:00, 74.64it/s]


In [59]:
n = 10
while len(data) > 10:
    data = [data[i:i + n] for i in range(0, len(data), n)]
    data = [sum(batch) for batch in data]
data = sum(data)

In [61]:
n_components = 100
bptf_ICEWS = BPTF(data_shape=data.shape, n_components=n_components)
bptf_ICEWS.fit(data, max_iter = 100, verbose = True)
for j in range(len(data.shape)):
    assert bptf_ICEWS.G_DK_M[j].shape == (data.shape[j], n_components)

ITERATION 0:	Time: 0.000000	Objective: -19949045278.59	Change: nan	


  1%|          | 1/100 [00:23<38:57, 23.61s/it]

ITERATION 1:	Time: 23.611221	Objective: -10026732.82	Change: 9.99497e-01	


  2%|▏         | 2/100 [00:46<38:05, 23.33s/it]

ITERATION 2:	Time: 23.124402	Objective: -9625551.53	Change: 4.00112e-02	


  3%|▎         | 3/100 [01:11<38:43, 23.95s/it]

ITERATION 3:	Time: 24.694375	Objective: -7682125.80	Change: 2.01903e-01	


  4%|▍         | 4/100 [01:35<38:36, 24.13s/it]

ITERATION 4:	Time: 24.411948	Objective: -3710340.70	Change: 5.17016e-01	


  5%|▌         | 5/100 [01:58<37:15, 23.54s/it]

ITERATION 5:	Time: 22.476031	Objective: -79640.94	Change: 9.78535e-01	


  6%|▌         | 6/100 [02:20<36:15, 23.14s/it]

ITERATION 6:	Time: 22.381518	Objective: 2025645.55	Change: 2.64347e+01	


  7%|▋         | 7/100 [02:43<35:56, 23.19s/it]

ITERATION 7:	Time: 23.280911	Objective: 3088040.00	Change: 5.24472e-01	


  8%|▊         | 8/100 [03:06<35:14, 22.99s/it]

ITERATION 8:	Time: 22.556463	Objective: 3651142.77	Change: 1.82350e-01	


  9%|▉         | 9/100 [03:28<34:33, 22.78s/it]

ITERATION 9:	Time: 22.329627	Objective: 3980061.69	Change: 9.00866e-02	


 10%|█         | 10/100 [03:51<34:17, 22.86s/it]

ITERATION 10:	Time: 23.018692	Objective: 4198885.00	Change: 5.49799e-02	


 11%|█         | 11/100 [04:13<33:32, 22.61s/it]

ITERATION 11:	Time: 22.060891	Objective: 4354922.37	Change: 3.71616e-02	


 12%|█▏        | 12/100 [04:35<32:52, 22.41s/it]

ITERATION 12:	Time: 21.958150	Objective: 4462658.04	Change: 2.47388e-02	


 13%|█▎        | 13/100 [04:58<32:21, 22.32s/it]

ITERATION 13:	Time: 22.105462	Objective: 4543219.00	Change: 1.80522e-02	


 14%|█▍        | 14/100 [05:20<31:51, 22.23s/it]

ITERATION 14:	Time: 22.017010	Objective: 4612636.16	Change: 1.52793e-02	


 15%|█▌        | 15/100 [05:42<31:43, 22.39s/it]

ITERATION 15:	Time: 22.761750	Objective: 4677690.29	Change: 1.41035e-02	


 16%|█▌        | 16/100 [06:05<31:27, 22.47s/it]

ITERATION 16:	Time: 22.647037	Objective: 4733758.04	Change: 1.19862e-02	


 17%|█▋        | 17/100 [06:26<30:41, 22.19s/it]

ITERATION 17:	Time: 21.545974	Objective: 4772430.45	Change: 8.16949e-03	


 18%|█▊        | 18/100 [06:48<30:10, 22.07s/it]

ITERATION 18:	Time: 21.799945	Objective: 4802663.34	Change: 6.33491e-03	


 19%|█▉        | 19/100 [07:11<29:58, 22.21s/it]

ITERATION 19:	Time: 22.525180	Objective: 4829161.09	Change: 5.51730e-03	


 20%|██        | 20/100 [07:33<29:26, 22.08s/it]

ITERATION 20:	Time: 21.783268	Objective: 4854702.26	Change: 5.28895e-03	


 21%|██        | 21/100 [07:54<28:57, 22.00s/it]

ITERATION 21:	Time: 21.799693	Objective: 4878320.29	Change: 4.86498e-03	


 22%|██▏       | 22/100 [08:16<28:29, 21.92s/it]

ITERATION 22:	Time: 21.746418	Objective: 4896880.83	Change: 3.80470e-03	


 23%|██▎       | 23/100 [08:38<28:17, 22.05s/it]

ITERATION 23:	Time: 22.331833	Objective: 4914010.64	Change: 3.49811e-03	


 24%|██▍       | 24/100 [09:01<28:12, 22.27s/it]

ITERATION 24:	Time: 22.777851	Objective: 4929872.83	Change: 3.22795e-03	


 25%|██▌       | 25/100 [09:25<28:22, 22.70s/it]

ITERATION 25:	Time: 23.708015	Objective: 4942554.64	Change: 2.57244e-03	


 26%|██▌       | 26/100 [09:49<28:31, 23.13s/it]

ITERATION 26:	Time: 24.138059	Objective: 4953434.91	Change: 2.20134e-03	


 27%|██▋       | 27/100 [10:13<28:23, 23.34s/it]

ITERATION 27:	Time: 23.828304	Objective: 4963214.19	Change: 1.97424e-03	


 28%|██▊       | 28/100 [10:36<27:49, 23.18s/it]

ITERATION 28:	Time: 22.819065	Objective: 4972179.45	Change: 1.80634e-03	


 29%|██▉       | 29/100 [10:59<27:27, 23.20s/it]

ITERATION 29:	Time: 23.248264	Objective: 4980552.27	Change: 1.68393e-03	


 30%|███       | 30/100 [11:23<27:14, 23.36s/it]

ITERATION 30:	Time: 23.710367	Objective: 4988442.90	Change: 1.58429e-03	


 31%|███       | 31/100 [11:47<27:01, 23.49s/it]

ITERATION 31:	Time: 23.810648	Objective: 4995840.52	Change: 1.48295e-03	


 32%|███▏      | 32/100 [12:10<26:30, 23.39s/it]

ITERATION 32:	Time: 23.134916	Objective: 5002765.86	Change: 1.38622e-03	


 33%|███▎      | 33/100 [12:34<26:17, 23.54s/it]

ITERATION 33:	Time: 23.896374	Objective: 5009307.27	Change: 1.30756e-03	


 34%|███▍      | 34/100 [13:00<26:45, 24.33s/it]

ITERATION 34:	Time: 26.163125	Objective: 5015486.80	Change: 1.23361e-03	


 35%|███▌      | 35/100 [13:24<26:25, 24.39s/it]

ITERATION 35:	Time: 24.528555	Objective: 5021345.38	Change: 1.16810e-03	


 36%|███▌      | 36/100 [13:47<25:24, 23.81s/it]

ITERATION 36:	Time: 22.470080	Objective: 5027099.59	Change: 1.14595e-03	


 37%|███▋      | 37/100 [14:09<24:34, 23.40s/it]

ITERATION 37:	Time: 22.437329	Objective: 5032904.28	Change: 1.15468e-03	


 38%|███▊      | 38/100 [14:32<23:52, 23.11s/it]

ITERATION 38:	Time: 22.435177	Objective: 5038141.00	Change: 1.04050e-03	


 39%|███▉      | 39/100 [14:54<23:12, 22.83s/it]

ITERATION 39:	Time: 22.179299	Objective: 5042456.08	Change: 8.56481e-04	


 40%|████      | 40/100 [15:16<22:42, 22.71s/it]

ITERATION 40:	Time: 22.416006	Objective: 5046318.51	Change: 7.65982e-04	


 41%|████      | 41/100 [15:42<23:21, 23.75s/it]

ITERATION 41:	Time: 26.174906	Objective: 5050069.86	Change: 7.43385e-04	


 42%|████▏     | 42/100 [16:06<22:54, 23.70s/it]

ITERATION 42:	Time: 23.585922	Objective: 5053744.30	Change: 7.27602e-04	


 43%|████▎     | 43/100 [16:30<22:40, 23.87s/it]

ITERATION 43:	Time: 24.264279	Objective: 5057176.38	Change: 6.79114e-04	


 44%|████▍     | 44/100 [16:55<22:39, 24.27s/it]

ITERATION 44:	Time: 25.215348	Objective: 5060573.63	Change: 6.71768e-04	


 45%|████▌     | 45/100 [17:20<22:17, 24.31s/it]

ITERATION 45:	Time: 24.408253	Objective: 5063938.18	Change: 6.64856e-04	


 46%|████▌     | 46/100 [17:45<22:01, 24.48s/it]

ITERATION 46:	Time: 24.855718	Objective: 5066629.22	Change: 5.31412e-04	


 47%|████▋     | 47/100 [18:08<21:14, 24.06s/it]

ITERATION 47:	Time: 23.070311	Objective: 5068962.30	Change: 4.60480e-04	


 48%|████▊     | 48/100 [18:32<21:00, 24.23s/it]

ITERATION 48:	Time: 24.643289	Objective: 5071271.04	Change: 4.55466e-04	


 49%|████▉     | 49/100 [18:59<21:07, 24.85s/it]

ITERATION 49:	Time: 26.300213	Objective: 5073631.75	Change: 4.65506e-04	


 50%|█████     | 50/100 [19:22<20:23, 24.47s/it]

ITERATION 50:	Time: 23.577871	Objective: 5076061.49	Change: 4.78895e-04	


 51%|█████     | 51/100 [19:46<19:40, 24.09s/it]

ITERATION 51:	Time: 23.187829	Objective: 5078470.60	Change: 4.74602e-04	


 52%|█████▏    | 52/100 [20:09<19:09, 23.95s/it]

ITERATION 52:	Time: 23.623335	Objective: 5080707.97	Change: 4.40561e-04	


 53%|█████▎    | 53/100 [20:32<18:25, 23.53s/it]

ITERATION 53:	Time: 22.541441	Objective: 5082726.16	Change: 3.97226e-04	


 54%|█████▍    | 54/100 [20:55<18:00, 23.48s/it]

ITERATION 54:	Time: 23.364901	Objective: 5084552.56	Change: 3.59334e-04	


 55%|█████▌    | 55/100 [21:18<17:34, 23.44s/it]

ITERATION 55:	Time: 23.362467	Objective: 5086374.42	Change: 3.58314e-04	


 56%|█████▌    | 56/100 [21:43<17:27, 23.81s/it]

ITERATION 56:	Time: 24.667356	Objective: 5088211.58	Change: 3.61193e-04	


 57%|█████▋    | 57/100 [22:07<17:09, 23.93s/it]

ITERATION 57:	Time: 24.205241	Objective: 5089749.70	Change: 3.02289e-04	


 58%|█████▊    | 58/100 [22:31<16:42, 23.87s/it]

ITERATION 58:	Time: 23.721080	Objective: 5091218.03	Change: 2.88489e-04	


 59%|█████▉    | 59/100 [22:54<16:08, 23.62s/it]

ITERATION 59:	Time: 23.052278	Objective: 5092601.91	Change: 2.71817e-04	


 60%|██████    | 60/100 [23:17<15:39, 23.48s/it]

ITERATION 60:	Time: 23.149564	Objective: 5093881.10	Change: 2.51185e-04	


 61%|██████    | 61/100 [23:43<15:40, 24.11s/it]

ITERATION 61:	Time: 25.576716	Objective: 5095134.58	Change: 2.46076e-04	


 62%|██████▏   | 62/100 [24:06<15:08, 23.91s/it]

ITERATION 62:	Time: 23.430210	Objective: 5096546.90	Change: 2.77189e-04	


 63%|██████▎   | 63/100 [24:29<14:35, 23.67s/it]

ITERATION 63:	Time: 23.126589	Objective: 5097746.96	Change: 2.35466e-04	


 64%|██████▍   | 64/100 [24:52<13:59, 23.33s/it]

ITERATION 64:	Time: 22.512650	Objective: 5098863.06	Change: 2.18939e-04	


 65%|██████▌   | 65/100 [25:14<13:28, 23.10s/it]

ITERATION 65:	Time: 22.561863	Objective: 5100002.06	Change: 2.23384e-04	


 66%|██████▌   | 66/100 [25:38<13:07, 23.16s/it]

ITERATION 66:	Time: 23.306450	Objective: 5101300.97	Change: 2.54688e-04	


 67%|██████▋   | 67/100 [26:01<12:42, 23.12s/it]

ITERATION 67:	Time: 23.015520	Objective: 5102978.73	Change: 3.28888e-04	


 68%|██████▊   | 68/100 [26:23<12:13, 22.93s/it]

ITERATION 68:	Time: 22.491771	Objective: 5104904.86	Change: 3.77453e-04	


 69%|██████▉   | 69/100 [26:46<11:48, 22.86s/it]

ITERATION 69:	Time: 22.684984	Objective: 5106418.95	Change: 2.96594e-04	


 70%|███████   | 70/100 [27:08<11:22, 22.75s/it]

ITERATION 70:	Time: 22.509748	Objective: 5107571.68	Change: 2.25741e-04	


 71%|███████   | 71/100 [27:31<10:59, 22.73s/it]

ITERATION 71:	Time: 22.674995	Objective: 5108719.34	Change: 2.24698e-04	


 72%|███████▏  | 72/100 [27:54<10:39, 22.84s/it]

ITERATION 72:	Time: 23.081287	Objective: 5109847.16	Change: 2.20765e-04	


 73%|███████▎  | 73/100 [28:17<10:15, 22.80s/it]

ITERATION 73:	Time: 22.713844	Objective: 5110818.65	Change: 1.90120e-04	


 74%|███████▍  | 74/100 [28:40<09:54, 22.88s/it]

ITERATION 74:	Time: 23.049861	Objective: 5111706.85	Change: 1.73787e-04	


 75%|███████▌  | 75/100 [29:03<09:32, 22.89s/it]

ITERATION 75:	Time: 22.912882	Objective: 5112500.24	Change: 1.55211e-04	


 76%|███████▌  | 76/100 [29:26<09:12, 23.04s/it]

ITERATION 76:	Time: 23.381680	Objective: 5113203.22	Change: 1.37503e-04	


 77%|███████▋  | 77/100 [29:49<08:46, 22.90s/it]

ITERATION 77:	Time: 22.567465	Objective: 5113826.29	Change: 1.21853e-04	


 78%|███████▊  | 78/100 [30:12<08:28, 23.11s/it]

ITERATION 78:	Time: 23.616904	Objective: 5114406.15	Change: 1.13392e-04	


 79%|███████▉  | 79/100 [30:35<08:02, 22.99s/it]

ITERATION 79:	Time: 22.717536	Objective: 5114967.81	Change: 1.09820e-04	


 80%|████████  | 80/100 [30:58<07:38, 22.93s/it]

ITERATION 80:	Time: 22.780162	Objective: 5115501.53	Change: 1.04343e-04	


 81%|████████  | 81/100 [31:20<07:13, 22.79s/it]

ITERATION 81:	Time: 22.478335	Objective: 5116049.16	Change: 1.07053e-04	


 82%|████████▏ | 82/100 [31:43<06:50, 22.79s/it]

ITERATION 82:	Time: 22.762568	Objective: 5116652.63	Change: 1.17957e-04	


 83%|████████▎ | 83/100 [32:06<06:26, 22.76s/it]

ITERATION 83:	Time: 22.691499	Objective: 5117380.32	Change: 1.42220e-04	


 84%|████████▍ | 84/100 [32:29<06:03, 22.72s/it]

ITERATION 84:	Time: 22.621312	Objective: 5118368.70	Change: 1.93142e-04	


 85%|████████▌ | 85/100 [32:51<05:39, 22.61s/it]

ITERATION 85:	Time: 22.349968	Objective: 5119947.71	Change: 3.08499e-04	


 86%|████████▌ | 86/100 [33:13<05:16, 22.60s/it]

ITERATION 86:	Time: 22.582038	Objective: 5122840.17	Change: 5.64939e-04	


 87%|████████▋ | 87/100 [33:36<04:54, 22.68s/it]

ITERATION 87:	Time: 22.853764	Objective: 5128209.31	Change: 1.04808e-03	


 88%|████████▊ | 88/100 [34:01<04:40, 23.37s/it]

ITERATION 88:	Time: 24.994695	Objective: 5135271.31	Change: 1.37709e-03	


 89%|████████▉ | 89/100 [34:24<04:15, 23.27s/it]

ITERATION 89:	Time: 23.015908	Objective: 5139548.73	Change: 8.32948e-04	


 90%|█████████ | 90/100 [34:47<03:50, 23.09s/it]

ITERATION 90:	Time: 22.689899	Objective: 5141672.01	Change: 4.13126e-04	


 91%|█████████ | 91/100 [35:09<03:25, 22.84s/it]

ITERATION 91:	Time: 22.253678	Objective: 5143126.93	Change: 2.82965e-04	


 92%|█████████▏| 92/100 [35:32<03:02, 22.76s/it]

ITERATION 92:	Time: 22.574051	Objective: 5144233.05	Change: 2.15068e-04	


 93%|█████████▎| 93/100 [35:54<02:37, 22.55s/it]

ITERATION 93:	Time: 22.043955	Objective: 5145071.40	Change: 1.62970e-04	


 94%|█████████▍| 94/100 [36:16<02:15, 22.54s/it]

ITERATION 94:	Time: 22.516852	Objective: 5145728.63	Change: 1.27739e-04	


 95%|█████████▌| 95/100 [36:39<01:52, 22.55s/it]

ITERATION 95:	Time: 22.577482	Objective: 5146261.29	Change: 1.03515e-04	


 95%|█████████▌| 95/100 [37:01<01:56, 23.39s/it]

ITERATION 96:	Time: 22.362030	Objective: 5146714.30	Change: 8.80275e-05	





In [62]:
cp_init = tensorly.cp_tensor.CPTensor(
        tensorly.decomposition._cp.initialize_cp(
            data, non_negative = True, init = 'random', rank = n_components
            )
        )
tensor_mu, _ = tensorly.decomposition.non_negative_parafac(
    data, rank=n_components, init=cp_init, return_errors=True
    )

In [67]:
assert data.todense().shape == bptf_ICEWS.reconstruct(drop_diag=True, style='geometric').shape

In [78]:
bptf_error = torch.norm(
    torch.from_numpy(data.todense()) - torch.from_numpy(bptf_ICEWS.reconstruct(drop_diag=True, style='post_pred', n_samples=100)), 'fro'
    ) / torch.norm(
        torch.from_numpy(data.todense()), 'fro'
        )
nonneg_cp = torch.norm(torch.from_numpy(data.todense()) - torch.from_numpy(tensorly.cp_to_tensor(tensor_mu)), 'fro') / torch.norm(torch.from_numpy(data.todense()), 'fro')

sampling: 100%|██████████| 100/100 [02:27<00:00,  1.47s/it]


In [79]:
print(f'BPTF error = {bptf_error}, Non-negative CP error = {nonneg_cp}')

BPTF error = 0.9594820676916082, Non-negative CP error = 0.3618835681327173
