In [1]:
"""Notebook that shares functionality with the Colab shared with CGFP. Used to make sure nothing breaks before updating the Huggingface model."""

'Notebook that shares functionality with the Colab shared with CGFP. Used to make sure nothing breaks before updating the Huggingface model.'

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
from pathlib import Path

import torch
import yaml
from transformers import AutoTokenizer

from cgfp.inference.inference import inference, inference_handler
from cgfp.training.models import MultiTaskModel

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"
device

'cuda:0'

In [5]:
SCRIPT_DIR = Path().resolve().parent / "scripts"

In [6]:
with Path.open(SCRIPT_DIR / "config_train.yaml") as file:
    config = yaml.safe_load(file)

In [7]:
CHECKPOINT = config["eval"]["eval_checkpoint"]
CHECKPOINT

'uchicago-dsi/cgfp-roberta'

In [8]:
tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT)



In [9]:
model = MultiTaskModel.from_pretrained(CHECKPOINT)

In [10]:
text = "IW WG BRAN MUFFIN"
# text = "frozen peas and carrots"
result = inference(model, tokenizer, text, device, assertion=False, confidence_score=False)
result

{'Food Product Group': 'Condiments & Snacks',
 'Food Product Category': 'Condiments & Snacks',
 'Primary Food Product Category': 'Condiments & Snacks',
 'Basic Type': 'muffin',
 'Flavor/Cut': 'None',
 'Shape': 'None',
 'Skin': 'None',
 'Seed/Bone': 'None',
 'Processing': 'None',
 'Cooked/Cleaned': 'None',
 'WG/WGR': 'whole grain rich',
 'Dietary Concern': 'None',
 'Additives': 'None',
 'Dietary Accommodation': 'None',
 'Frozen': 'None',
 'Packaging': 'ss',
 'Commodity': 'None'}

In [11]:
result = inference(model, tokenizer, text, device, assertion=False, confidence_score=False, combine_name=True)
result

'muffin, whole grain rich, ss'

In [12]:
DATA_DIR = "/net/projects/cgfp/data/test/"
FILENAME = "Turner  Dairy dataset.xlsx"
INPUT_COLUMN = "Product Type"
INPUT_PATH = DATA_DIR + FILENAME

In [13]:
SHEET_NUMBER = 0
ASSERTION = False # filters results that have mismatched food product groups and categories

In [15]:
# TODO: Add option for output file name
inference_handler(model, tokenizer, input_path=INPUT_PATH, save_dir=DATA_DIR, device=device, sheet_name=SHEET_NUMBER, input_column=INPUT_COLUMN, assertion=ASSERTION)

Classification completed! File saved to /net/projects/cgfp/data/test/Turner  Dairy dataset_classified.xlsx


Unnamed: 0,Product Identifier,Product Type,Food Product Group,Food Product Category,Primary Food Product Category,Basic Type,Sub-Type 1,Sub-Type 2,Sub-Type 3,Flavor/Cut,...,Seed/Bone,Processing,Cooked/Cleaned,WG/WGR,Dietary Concern,Additives,Dietary Accommodation,Frozen,Packaging,Commodity
0,,Whole Milk Gallon,Milk & Dairy,Milk,Milk,milk,,,,,...,,,,,,,,,,
1,,Whole Milk Hpt,Milk & Dairy,Milk,Milk,milk,,,,,...,,,,,,,,,ss,
2,,1% Low-Fat Milk Hpt,Milk & Dairy,Milk,Milk,milk,,,,,...,,,,,1%,,,,,
3,,1% Strawberry Milk Hpt,Milk & Dairy,Milk,Milk,milk,,,,,...,,,,,nonfat,,,,ss,
4,,Fat Free Vanilla Milk Hpt,Milk & Dairy,Milk,Milk,milk,,,,flavored,...,,,,,nonfat,,,,,
5,,2% Reduced Fat Milk Hpt,Milk & Dairy,Milk,Milk,milk,,,,,...,,,,,2%,,,,,
6,,Smiley Cookie Hpt,Condiments & Snacks,Condiments & Snacks,Condiments & Snacks,cookie,,,,,...,,,,,,,,,,
7,,Fat Free Chocolate Milk Hpt,Milk & Dairy,Milk,Milk,milk,chocolate,,,,...,,,,,nonfat,,,,ss,
8,,Skim Milk Hpt,Milk & Dairy,Milk,Milk,milk,,,,,...,,,,,nonfat,,,,ss,
9,,1% Chocolate Milk Gallon,Milk & Dairy,Milk,Milk,milk,chocolate,,,,...,,,,,1%,,,,,
