In [1]:
from RCSYS_utils import *
from RCSYS_models import *
from utils import *

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
# Change here to process another benchmark
benchmark_path = '../processed_data/benchmark_macro.pt'

In [30]:
data = torch.load(benchmark_path)
df_user = pd.read_csv('../processed_data/user_tagging.csv')
df_food = pd.read_csv('../processed_data/food_tagging.csv')
df_fndds = pd.read_csv('../processed_data/fndds.csv')

In [7]:
# Define the mappings
gender_dict = {1: 'male', 2: 'female'}
race_dict = {
    0: 'Missing',
    1: 'Mexican American',
    2: 'Other Hispanic',
    3: 'White',
    4: 'Black',
    5: 'Other race'
}
education_dict = {
    0: 'Missing',
    1: 'Less than 9th grade',
    2: '9-11th grade',
    3: 'GED or equivalent',
    4: 'Some college or AA degree',
    5: 'College graduate or above',
    7: 'Refused',
    9: "Don't know"
}

# Define the function to generate the prompt text
def create_prompt(row):
    gender = gender_dict[row['gender']]
    age = row['age']
    race = race_dict[row['race']]
    income = row['household_income']
    education = education_dict[row['education']]
    prompt = f"User Node {row['SEQN']}: The user information is as follows: {gender}, age {age}, {race}, household income level (the higher the better): {income}, education status: {education}."
    return prompt

# Define the list of nutrition-related columns
nutrition_columns = [
    'user_low_carb', 'user_low_phosphorus', 'user_low_calorie', 'user_high_calorie',
    'user_high_potassium', 'user_low_sodium', 'user_low_cholesterol',
    'user_low_saturated_fat', 'user_low_protein', 'user_high_protein',
    'user_low_sugar', 'user_high_fiber', 'user_high_iron', 'user_high_folate_acid',
    'user_high_vitamin_b12', 'user_high_calcium', 'user_high_vitamin_d', 'user_high_vitamin_c'
]

# Define the function to generate the health tag prompt
def create_health_tag_prompt(row):
    tags = []
    for col in nutrition_columns:
        if row[col] == 1:
            tag = col.replace('user_', '').replace('_', ' ')
            tags.append(tag)
    return ', '.join(tags)

# Apply the function to create the new column
df_user['prompt_health'] = df_user.apply(create_health_tag_prompt, axis=1)

# Apply the function to create the new column
df_user['prompt'] = df_user.apply(create_prompt, axis=1)

In [9]:
df_user[['SEQN', 'prompt', 'prompt_health']].to_csv('../processed_data/user_prompt.csv', index=False)

In [13]:
food_nutrition_columns = [
    'low_calorie', 'high_calorie', 'low_protein',
    'high_protein', 'low_carb', 'high_carb', 'low_sugar', 'high_sugar',
    'low_fiber', 'high_fiber', 'low_saturated_fat', 'high_saturated_fat',
    'low_cholesterol', 'high_cholesterol', 'low_sodium', 'high_sodium',
    'low_calcium', 'high_calcium', 'low_phosphorus', 'high_phosphorus',
    'low_potassium', 'high_potassium', 'low_iron', 'high_iron',
    'low_folic_acid', 'high_folic_acid', 'low_vitamin_c', 'high_vitamin_c',
    'low_vitamin_d', 'high_vitamin_d', 'low_vitamin_b12',
    'high_vitamin_b12'
]

# Define the function to generate the health tag prompt
def create_food_tag_prompt(row):
    tags = []
    for col in food_nutrition_columns:
        if row[col] == 1:
            tag = col.replace('_', ' ')
            tags.append(tag)
    return ', '.join(tags)

df_food['prompt_health'] = df_food.apply(create_food_tag_prompt, axis=1)

In [19]:
def create_food_prompt(group):
    food_id = group['food_id'].iloc[0]
    food_desc = group['food_desc'].iloc[0]
    food_category = group['WWEIA_desc'].iloc[0]
    ingredients = ', '.join(group['ingredient_desc'])
    prompt = (f"Food Node {food_id}: The food description is: {food_desc}. "
              f"This food belongs to the category: {food_category}. "
              f"The ingredients in this food are: {ingredients}.")
    return pd.Series({'food_prompt_text': prompt})

# Apply the function to each group
df_prompts = df_fndds.groupby('food_id').apply(create_food_prompt).reset_index()

In [24]:
df_food = df_food.merge(df_prompts, on='food_id', how='left')

In [25]:
df_food[['food_id', 'food_prompt_text', 'prompt_health']].to_csv('../processed_data/food_prompt.csv', index=False)

In [31]:
# Create dictionaries for user prompts
user_prompt_dict = df_user.set_index('SEQN')['prompt'].to_dict()
user_prompt_health_dict = df_user.set_index('SEQN')['prompt_health'].to_dict()

# Create dictionaries for food prompts
food_prompt_dict = df_food.set_index('food_id')['food_prompt_text'].to_dict()
food_prompt_health_dict = df_food.set_index('food_id')['prompt_health'].to_dict()

# Initialize prompt features with empty strings
user_prompt = [""] * data['user'].num_nodes
user_prompt_health = [""] * data['user'].num_nodes
food_prompt = [""] * data['food'].num_nodes
food_prompt_health = [""] * data['food'].num_nodes

# Assign prompts to the appropriate nodes
for i in range(data['user'].num_nodes):
    seqn = data['user'].node_id[i].item()
    if seqn in user_prompt_dict:
        user_prompt[i] = user_prompt_dict[seqn]
        user_prompt_health[i] = user_prompt_health_dict[seqn]

for i in range(data['food'].num_nodes):
    food_id = data['food'].node_id[i].item()
    if food_id in food_prompt_dict:
        food_prompt[i] = food_prompt_dict[food_id]
        food_prompt_health[i] = food_prompt_health_dict[food_id]

# Add prompt features to HeteroData
data['user'].prompt = user_prompt
data['user'].prompt_health = user_prompt_health
data['food'].prompt = food_prompt
data['food'].prompt_health = food_prompt_health

In [32]:
data

HeteroData(
  user={
    x=[8170, 38],
    node_id=[8170],
    num_nodes=8170,
    tags=[8170, 14],
    prompt=[8170],
    prompt_health=[8170],
  },
  food={
    x=[6769, 66],
    node_id=[6769],
    num_nodes=6769,
    tags=[6769, 14],
    prompt=[6769],
    prompt_health=[6769],
  },
  (user, eats, food)={
    edge_index=[2, 314224],
    edge_label_index=[2, 122009],
  }
)

In [33]:
torch.save(data, benchmark_path)