In [9]:
# !pip install flax

In [10]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "2" # use GPU 0 and 1

In [11]:
from transformers import FlaxAutoModelForSeq2SeqLM
from transformers import AutoTokenizer

MODEL_NAME_OR_PATH = "flax-community/t5-recipe-generation"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME_OR_PATH, use_fast=True)
model = FlaxAutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME_OR_PATH)


In [32]:
import pandas as pd
import numpy as np
from collections import defaultdict
# from transformers import T5Tokenizer
# from transformers import T5Config, T5ForConditionalGeneration
# import torch
# from torch.utils.data import Dataset, DataLoader
# from transformers import AdamW, get_scheduler
# from tqdm.auto import tqdm
# import glob, json

In [34]:
import json
from collections import defaultdict

mapping = defaultdict(dict)

with open("/data/prateek/github/see-food/images_10k/layer1_small_updated.json", "r") as f:
    data1 = json.load(f)
    
with open("/data/prateek/github/see-food/images_10k/det_ingrs_small_updated.json", "r") as f:
    data2 = json.load(f)
    
for element in data1:
    id = element["id"]
    recipe = element["instructions"]
    title = element["title"]
    recipe_final = []
    for x in recipe:
        recipe_final.append(x["text"])
    mapping[id]["recipe"] = "Recipe: " + " ".join(recipe_final)
    mapping[id]["title"] = "Title: "+ title


for element in data2:
    id = element["id"]
    ingrd = element["ingredients"]
    ingrd_final = []
    for x in ingrd:
        ingrd_final.append(x["text"])
    mapping[id]["ingredients"] = "Ingredients: " + ", ".join(ingrd_final)
    
    

import pandas as pd
recipe_df = pd.DataFrame.from_dict(mapping, orient='index', columns=['recipe', 'ingredients', 'title']).reset_index()
recipe_df = recipe_df.rename(columns={'index': 'id'})
recipe_df = recipe_df.drop('id', axis=1)

recipe_df['recipe'] = recipe_df['title'] + '\n' + recipe_df['recipe']

# assume df is your dataframe
n = len(recipe_df) # get the number of rows
split_idx = int(n * 0.8) # define the split index
recipe_df_train = recipe_df[:split_idx] # get the first part
recipe_df_test = recipe_df[split_idx:] # get the second part

print(len(recipe_df), len(recipe_df_train), len(recipe_df_test))

3582 2865 717


In [14]:
prefix = "items: "
# generation_kwargs = {
#     "max_length": 512,
#     "min_length": 64,
#     "no_repeat_ngram_size": 3,
#     "early_stopping": True,
#     "num_beams": 5,
#     "length_penalty": 1.5,
# }
generation_kwargs = {
    "max_length": 512,
    "min_length": 64,
    "no_repeat_ngram_size": 3,
    "do_sample": True,
    "top_k": 60,
    "top_p": 0.95
}


special_tokens = tokenizer.all_special_tokens
tokens_map = {
    "<sep>": "--",
    "<section>": "\n"
}
def skip_special_tokens(text, special_tokens):
    for token in special_tokens:
        text = text.replace(token, "")

    return text

def target_postprocessing(texts, special_tokens):
    if not isinstance(texts, list):
        texts = [texts]
    
    new_texts = []
    for text in texts:
        text = skip_special_tokens(text, special_tokens)

        for k, v in tokens_map.items():
            text = text.replace(k, v)

        new_texts.append(text)

    return new_texts

def generation_function(texts):
    _inputs = texts if isinstance(texts, list) else [texts]
    inputs = [prefix + inp for inp in _inputs]
    inputs = tokenizer(
        inputs, 
        max_length=256, 
        padding="max_length", 
        truncation=True, 
        return_tensors="jax"
    )

    input_ids = inputs.input_ids
    attention_mask = inputs.attention_mask

    output_ids = model.generate(
        input_ids=input_ids, 
        attention_mask=attention_mask,
        **generation_kwargs
    )
    generated = output_ids.sequences
    generated_recipe = target_postprocessing(
        tokenizer.batch_decode(generated, skip_special_tokens=False),
        special_tokens
    )
    return generated_recipe

In [15]:
items = [
    "macaroni, butter, salt, bacon, milk, flour, pepper, cream corn",
    "provolone cheese, bacon, bread, ginger",
    "frozen chopped broccoli, cooked rice, shredded cheddar cheese, shredded cheddar cheese, eggs, butter, milk, onion, garlic powder, basil, oregano, salt and pepper"
]

In [49]:
ingredients_path = "/data/prateek/github/see-food/TEST_DATASET/PRED-vit100k/ingredients/"

from glob import glob
all_files = glob(ingredients_path + "*txt")


recipe_id_list = []
ingredients_list = []
for fl in all_files:
    with open(fl, 'r') as f:
        data_inst = f.readlines()
    recipe_id = fl.split("/")[-1].split(".")[0]
    data_inst = [i[:-1] for i in data_inst]
    data_inst = ", ".join(data_inst)
    recipe_id_list.append(recipe_id)
    ingredients_list.append(data_inst)

In [50]:
len(recipe_id_list), len(ingredients_list)

(517, 517)

In [51]:
generated = generation_function(ingredients_list)

In [52]:
from tqdm import tqdm

recipe_details = defaultdict(dict)
for e, text in enumerate(generated):
    sections = text.split("\n")
    title = sections[0].strip().replace("title:", "")
    ingredients = sections[1].strip().replace("ingredients:", "").replace("--", "")
    recipe = sections[2].strip().replace("directions:", "").replace("--", "")
    
    k = recipe_id_list[e]
    recipe_details[k]["title"] = title
    recipe_details[k]["ingredients"] = ingredients
    recipe_details[k]["recipe"] = recipe

In [53]:
# import json

# # assume my_dict is your dictionary
# with open('my_dict.json', 'w+') as f:
#     json.dump(recipe_details, f)


In [68]:
for k, v in recipe_details.items():
    file_name = k + ".txt"
    title = v["title"].strip()
    recipe = v["recipe"].strip()
    
    with open("/data/prateek/github/see-food/TEST_DATASET/PRED-vit100k/instructions/" + file_name, "w") as f:
        valss = recipe.split(".")
        for val in valss:
            # print(val.strip())
            f.write(val.strip() +"\n")
    # valss = recipe.split(".")
    # print(valss)
    # break
    
    with open("/data/prateek/github/see-food/TEST_DATASET/PRED-vit100k/instructions/" + file_name, 'r') as f:
        lines = f.readlines()

    if lines[-1] == '\n':
        lines = lines[:-1]

    with open("/data/prateek/github/see-food/TEST_DATASET/PRED-vit100k/instructions/" + file_name, 'w') as f:
        f.writelines(lines)

517

In [54]:
recipe_details

defaultdict(dict,
            {'02115914de': {'title': ' fish salad',
              'ingredients': ' 1 lb. firm fish fillets 1/2 tsp. pepper 1/2 c. mayonnaise 1/2 c. oil 1/2 tsp. salt 2 tbsp. vinegar 2 tbsp. juice 1 small onion, chopped',
              'recipe': ' cut fish into serving size pieces. combine ingredients and pour over fish. toss until well coated. chill.'},
             '00afbdbb79': {'title': ' teriyaki sauce',
              'ingredients': ' 2 tablespoons sugar 2 tablespoons honey 3 tablespoons soy_sauce 1/2 teaspoon salt 1 tablespoon vinegar 1 teaspoon ginger 3 tablespoons oil 1 garlic clove 1 clove 2 tablespoons butter',
              'recipe': ' in a saucepan over low flame, mix sugar, honey, soy sauce, salt, vinegar, ginger, oil, garlic, clove and butter. stir until all ingredients are thoroughly blended. serve with sliced pork.'},
             '004628ff8d': {'title': ' honey praline',
              'ingredients': ' 3/4 cup sugar 3 tablespoons honey 1 tablespoon oil 