In [1]:
import sys
sys.path.append('..')
import json
import random

import numpy as np
import pandas as pd

SEED = 93
np.random.seed(SEED)
random.seed(SEED)

In [None]:
import torch

if torch.has_mps:
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

device

In [2]:
import meal


MEAL_RAW = "../data/meal-1694543314.json"
meals = []

with open(MEAL_RAW, "r") as fp:
    for l in fp:
        raw_dict = json.loads(l)
        raw_dict["recipes"] = [meal.RecipeMealInput(**k) for k in raw_dict["recipes"]]
        raw_dict["appropriate"] = meal.MealValidOutput(**raw_dict["appropriate"])
        meals.append(raw_dict)

print(meal.format_recipes(meals[0]["recipes"]))

Name: Nectarine Bars
Type: Desserts, Cookies, Bar Cookie Recipes
Ingredients: 2 cups all-purpose flour, 1 ½ cups white sugar, divided, ½ cup butter, 3  nectarines, sliced, divided, 1 tablespoon cornstarch, 2  eggs


In [3]:
from datasets import Dataset
from sklearn.model_selection import train_test_split
from datasets import DatasetDict

TRAIN_SIZE = 0.8
EVAL_SIZE = 0.5
TEST_SIZE = 0.5
train_raw, evaltest_raw = train_test_split(meals, train_size=TRAIN_SIZE, random_state=SEED)
eval_raw, test_raw = train_test_split(evaltest_raw, train_size=EVAL_SIZE, random_state=SEED)

def process_recipes(meals):
    for m in meals:
        yield {
            "text": meal.format_recipes(m["recipes"]),
            "label": 1 if m["appropriate"].valid else 0
        }
        
train = Dataset.from_list(list(process_recipes(train_raw)))
evaluation = Dataset.from_list(list(process_recipes(eval_raw)))
test = Dataset.from_list(list(process_recipes(test_raw)))

combo_ds = DatasetDict({
    "train": train,
    "eval": evaluation,
    "test": test
})
combo_ds

DatasetDict({
    train: Dataset({
        features: ['text', 'label'],
        num_rows: 6072
    })
    eval: Dataset({
        features: ['text', 'label'],
        num_rows: 759
    })
    test: Dataset({
        features: ['text', 'label'],
        num_rows: 759
    })
})

In [4]:
label_series = pd.Series(train["label"])
label_series.value_counts()

0    5177
1     895
Name: count, dtype: int64

In [5]:
labels_count = label_series.value_counts()
labels_weight = labels_count.sum() / labels_count 
labels_weight = labels_weight / labels_weight.sum()
labels_weight

0    0.147398
1    0.852602
Name: count, dtype: float64

In [6]:
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from peft import get_peft_model, LoraConfig, TaskType

checkpoint = "distilbert-base-uncased"
id2label = {0: "Negative", 1: "Positive"}
label2id = {v: k for k, v in id2label.items()}
peft_config = LoraConfig(
    task_type=TaskType.SEQ_CLS, inference_mode=False, r=16, lora_alpha=16, lora_dropout=0.1, bias="none",
    target_modules=["q_lin", "v_lin"]
)

tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = AutoModelForSequenceClassification.from_pretrained(
    checkpoint, num_labels=len(id2label), 
    id2label=id2label, label2id=label2id
)

model = get_peft_model(model, peft_config).to(device)
combo_ds = combo_ds.map(lambda d: tokenizer(d["text"], truncation=True))
model.print_trainable_parameters()

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['pre_classifier.weight', 'pre_classifier.bias', 'classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Map:   0%|          | 0/6072 [00:00<?, ? examples/s]

Map:   0%|          | 0/759 [00:00<?, ? examples/s]

Map:   0%|          | 0/759 [00:00<?, ? examples/s]

trainable params: 1,479,172 || all params: 67,842,052 || trainable%: 2.180317305260755


In [7]:
from transformers import Trainer
import torch
from torch import nn


class BalancingBinaryTrainer(Trainer):
    
    def __init__(self, class_weights=(1, 1), **kwargs):
        Trainer.__init__(self, **kwargs)
        self.class_weights = list(class_weights)
    
    def compute_loss(self, model, inputs, return_outputs=False):
        labels = inputs.pop("labels")
        # forward pass
        outputs = model(**inputs)
        logits = outputs.get("logits")
        # compute custom loss (suppose one has 2 labels with different weights)
        loss_fct = nn.CrossEntropyLoss(weight=torch.tensor(self.class_weights, 
                                                           device=model.device))
        loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1))
        return (loss, outputs) if return_outputs else loss

In [8]:
from transformers import DataCollatorWithPadding, TrainingArguments
from common import compute_accuracy

BATCH_TRAIN = 32
BATCH_EVAL = 64
GRADIENT_STEP = 1
LEARNING_RATE = 1e-4
EPOCHS = 4
LAMBDA = 0.01

data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

training_args = TrainingArguments(
    output_dir="meal-distil-bert-lora",
    learning_rate=LEARNING_RATE,
    per_device_train_batch_size=BATCH_TRAIN,
    per_device_eval_batch_size=BATCH_EVAL,
    gradient_accumulation_steps=GRADIENT_STEP,
    gradient_checkpointing=False,
    num_train_epochs=EPOCHS,
    weight_decay=LAMBDA,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    seed=SEED
)

trainer = BalancingBinaryTrainer(
    model=model,
    args=training_args,
    train_dataset=combo_ds["train"],
    eval_dataset=combo_ds["eval"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_accuracy,
    class_weights=labels_weight.loc[[0, 1]].tolist()
)

try:
    results = trainer.train(resume_from_checkpoint = True)
except ValueError:
    results = trainer.train(resume_from_checkpoint = False)

You're using a DistilBertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Epoch,Training Loss,Validation Loss


KeyboardInterrupt: 