In [44]:
#|default_exp 51_finetune-llama-for-category-generation

In [45]:
#| export
import torch, pandas as pd, json, numpy as np
from torch.utils.data import Dataset, DataLoader

from peft import LoraConfig, get_peft_model
from transformers import AutoTokenizer, LlamaForCausalLM, Trainer, TrainingArguments

## Prompt

<|begin_of_text|><|start_header_id|>system<|end_header_id|>

You are a helpful assistant<|eot_id|><|start_header_id|>user<|end_header_id|>

You are given a passage of text and one or more questions that can be answered from the passage. Assume the passage is part of a Wikipedia article.

Your task: Identify at least 5 relevant Wikipedia categories for the article that help connect the passage with the question(s).

Guidelines for categories:
1. Must be broad but still specific (not generic like Thing or Knowledge).
2. Match Wikipedia’s category style:
   - Use plural nouns or topic phrases
   - Reflect standard groupings (topics, time periods, locations, types)
   - No long sentences; keep them as labels or tags
3. Use existing Wikipedia-style categories (choose the most plausible if unsure).
4. Do not include the "Category:" prefix. Output plain names only.

Scoring:
Assign each category a relevance score from {1, 2, 3, 4, 5}
5 = most relevant
1 = irrelevant

Input notes:
If multiple questions exist, they will be separated by " || "

Output format:
Return the result strictly and only in **valid JSON** with this structure:

{
    "Category Name 1":score,
    "Category Name 2":score,
    "Category Name 3":score,
    "Category Name 4":score,
    "Category Name 5":score
}

Rules:
- Do not output anything outside the `{}` block.
- Categories must look like valid Wikipedia labels, not explanations.
- Ensure the output can be parsed by:

```
python
import json
content = json.loads(output)
```

Example 1
Passage:
Definition: Acre. An acre is a measure of land area in Imperial units or U.S. customary units. It is equal to 43 560 square feet, 4840 square yards, or 160 square rods. The precise meaning of this depends on the exact definition adopted for a foot: the international acre is 4 046.856 422 4 m (for the UK, see).

Questions:
convert acres to sq. ft.

Expected output:
{
    'Units of area':5,
    'Imperial units':4,
    'Customary units in the United States':4,
    'Agricultural terminology':3,
    'Land measurement systems':3
}

Example 2
Passage:
When will my dog come into heat? What age will my dog come into her first heat? First heat can vary greatly dog to dog. The youngest is about six months of age though sometimes a female will come into season younger. First heat can start as late as 12 or even 14 months of age or later in rare cases. Again, it can vary dog to dog. How often will my dog come into heat? Again, this varies dog-to-dog average is every six months but it could be more or less often.

Questions:
when do female dogs first go into heat

Expected output:
{
    'Dog breeding':5,
    'Animal reproduction':5,
    'Dogs as pets':4,
    'Mammal physiology':3,
    'Veterinary medicine topics':3    
}

Example 3
Passage:
Cottonmouth and copperhead bites are immediately painful and signs and symptoms such as those listed below, usually begin immediately: 1  body as a whole swelling. 2  respiratory difficulty breathing. 3  skin discoloration of skin.  gastrointestinal nausea, 1  vomiting. heart and blood vessels weak pulse.

Questions:
copperhead snake bite effects

Expected output:
{
    'Snakebites':5,
    'Toxicology':4,
    'Venomous snakes of North America':4,
    'Copperheads (genus)':4,
    'Emergency medicine cases':4,
    'Cottonmouths (Agkistrodon piscivorus)':3,
    'Toxicology incidents in the United States':3
}

Now you are given with this passage:
Passage: "{passage}"

Questions: "{questions}"

Do not output anything outside the `{}` block.
Do not include the "Category:" prefix in the category names. Use plain names only.
Each category must look like a valid Wikipedia-style label, not a free-text explanation.
Remember to provide ONLY the categories in the correct format and no other explanation.
<|eot_id|><|start_header_id|>assistant<|end_header_id|>

In [46]:
#| export
PROMPT_TEMPLATE = """<|begin_of_text|><|start_header_id|>system<|end_header_id|>

You are a helpful assistant<|eot_id|><|start_header_id|>user<|end_header_id|>

You are given a passage of text and one or more questions that can be answered from the passage. Assume the passage is part of a Wikipedia article.

Your task: Identify at least 5 relevant Wikipedia categories for the article that help connect the passage with the question(s).

Guidelines for categories:
1. Must be broad but still specific (not generic like Thing or Knowledge).
2. Match Wikipedia’s category style:
   - Use plural nouns or topic phrases
   - Reflect standard groupings (topics, time periods, locations, types)
   - No long sentences; keep them as labels or tags
3. Use existing Wikipedia-style categories (choose the most plausible if unsure).
4. Do not include the "Category:" prefix. Output plain names only.

Scoring:
Assign each category a relevance score from {{1, 2, 3, 4, 5}}
5 = most relevant
1 = irrelevant

Input notes:
If multiple questions exist, they will be separated by " || "

Output format:
Return the result strictly and only in **valid JSON** with this structure:

{{
    "Category Name 1":score,
    "Category Name 2":score,
    "Category Name 3":score,
    "Category Name 4":score,
    "Category Name 5":score
}}

Rules:
- Do not output anything outside the `{{}}` block.
- Categories must look like valid Wikipedia labels, not explanations.
- Ensure the output can be parsed by:

```
python
import json
content = json.loads(output)
```

Example 1
Passage:
Definition: Acre. An acre is a measure of land area in Imperial units or U.S. customary units. It is equal to 43 560 square feet, 4840 square yards, or 160 square rods. The precise meaning of this depends on the exact definition adopted for a foot: the international acre is 4 046.856 422 4 m (for the UK, see).

Questions:
convert acres to sq. ft.

Expected output:
{{
    'Units of area':5,
    'Imperial units':4,
    'Customary units in the United States':4,
    'Agricultural terminology':3,
    'Land measurement systems':3
}}

Example 2
Passage:
When will my dog come into heat? What age will my dog come into her first heat? First heat can vary greatly dog to dog. The youngest is about six months of age though sometimes a female will come into season younger. First heat can start as late as 12 or even 14 months of age or later in rare cases. Again, it can vary dog to dog. How often will my dog come into heat? Again, this varies dog-to-dog average is every six months but it could be more or less often.

Questions:
when do female dogs first go into heat

Expected output:
{{
    'Dog breeding':5,
    'Animal reproduction':5,
    'Dogs as pets':4,
    'Mammal physiology':3,
    'Veterinary medicine topics':3    
}}

Example 3
Passage:
Cottonmouth and copperhead bites are immediately painful and signs and symptoms such as those listed below, usually begin immediately: 1  body as a whole swelling. 2  respiratory difficulty breathing. 3  skin discoloration of skin.  gastrointestinal nausea, 1  vomiting. heart and blood vessels weak pulse.

Questions:
copperhead snake bite effects

Expected output:
{{
    'Snakebites':5,
    'Toxicology':4,
    'Venomous snakes of North America':4,
    'Copperheads (genus)':4,
    'Emergency medicine cases':4,
    'Cottonmouths (Agkistrodon piscivorus)':3,
    'Toxicology incidents in the United States':3
}}

Now you are given with this passage:
Passage: 
"{passage}"

Questions: 
"{questions}"

Do not output anything outside the `{{}}` block.
Do not include the "Category:" prefix in the category names. Use plain names only.
Each category must look like a valid Wikipedia-style label, not a free-text explanation.
Remember to provide ONLY the categories in the correct format and no other explanation.
<|eot_id|><|start_header_id|>assistant<|end_header_id|>"""

In [3]:
print(PROMPT_TEMPLATE.format(passage="Suchith Prabhu", questions="Sup? || Genius!"))

<|begin_of_text|><|start_header_id|>system<|end_header_id|>

You are a helpful assistant<|eot_id|><|start_header_id|>user<|end_header_id|>

You are given a passage of text and one or more questions that can be answered from the passage. Assume the passage is part of a Wikipedia article.

Your task: Identify at least 5 relevant Wikipedia categories for the article that help connect the passage with the question(s).

Guidelines for categories:
1. Must be broad but still specific (not generic like Thing or Knowledge).
2. Match Wikipedia’s category style:
   - Use plural nouns or topic phrases
   - Reflect standard groupings (topics, time periods, locations, types)
   - No long sentences; keep them as labels or tags
3. Use existing Wikipedia-style categories (choose the most plausible if unsure).
4. Do not include the "Category:" prefix. Output plain names only.

Scoring:
Assign each category a relevance score from {1, 2, 3, 4, 5}
5 = most relevant
1 = irrelevant

Input notes:
If multiple 

In [4]:
#| export
def make_prompt(passage, questions):
    return PROMPT_TEMPLATE.format(passage=passage, questions=questions)
    

## Load data

In [5]:
input_file = "/Users/suchith720/Downloads/OneDrive_2_9-13-2025/label_train_exact.raw.txt"
category_file = "/Users/suchith720/Downloads/OneDrive_2_9-13-2025/all_categories.json"

In [6]:
df = pd.read_csv(input_file, sep='\t', header=None, names=['identifier', 'passage', 'questions'])

In [7]:
df.head()

Unnamed: 0,identifier,passage,questions
0,0,The presence of communication amid scientific ...,)what was the immediate impact of the success ...
1,16,The approach is based on a theory of justice t...,_________ justice is designed to repair the ha...
2,49,"Colorâurine can be a variety of colors, most...",what color is amber urine
3,60,Inborn errors of bile acid synthesis can produ...,is autoimmune hepatitis a bile acid synthesis ...
4,81,What organs are on your left side of body. Cau...,cause of pain above ribs


In [48]:
with open(category_file) as file:
    categories = json.load(file)
categories = [json.dumps(categories[idx]) if idx in categories else None  for idx in map(str, df['identifier'])]

In [49]:
df['categories'] = categories

In [64]:
idx = np.where(~df['categories'].isna())[0]

In [65]:
df.iloc[idx]

Unnamed: 0,identifier,passage,questions,categories
0,0,The presence of communication amid scientific ...,)what was the immediate impact of the success ...,"{""Manhattan Project"": 5, ""Nuclear weapons deve..."
1,16,The approach is based on a theory of justice t...,_________ justice is designed to repair the ha...,"{""Restorative justice"": 5, ""Theories of law"": ..."
2,49,"Colorâurine can be a variety of colors, most...",what color is amber urine,"{""Urine"": 5, ""Human physiology"": 4, ""Bodily fl..."
3,60,Inborn errors of bile acid synthesis can produ...,is autoimmune hepatitis a bile acid synthesis ...,"{""Genetic disorders"": 5, ""Liver diseases"": 5, ..."
4,81,What organs are on your left side of body. Cau...,cause of pain above ribs,"{""Human anatomy"": 5, ""Gastrointestinal disorde..."
...,...,...,...,...
523593,8841257,"These nephridia, which are called protonephrid...",what are nephridia?,"{""Excretory system"": 5, ""Invertebrate anatomy""..."
523594,8841362,What is Anterolisthesis? Anterolisthesis is de...,anterolisthesis definition,"{""Spinal disorders"": 5, ""Orthopedic conditions..."
523595,8841547,FSH and LH levels in perimenopause are often f...,what are fsh levels during perimenopause,"{""Menopause"": 5, ""Hormones"": 4, ""Gynaecology"":..."
523596,8841643,Yowie is one of several names given to a homin...,what is a yowie,"{""Cryptids"": 5, ""Mythical creatures of Austral..."


In [8]:
#| export
def load_data(input_file, category_file):
    df = pd.read_csv(input_file, sep='\t', header=None, names=['identifier', 'passage', 'questions'])
    
    with open(category_file) as file:
        categories = json.load(file)
    categories = [json.dumps(categories[idx]) if idx in categories else None  for idx in map(str, df['identifier'])]
    df['categories'] = categories
    
    idx = np.where(~df['categories'].isna())[0]
    return df.iloc[idx]
    

In [9]:
input_file = "/Users/suchith720/Downloads/OneDrive_2_9-13-2025/label_train_exact.raw.txt"
category_file = "/Users/suchith720/Downloads/OneDrive_2_9-13-2025/all_categories.json"

In [10]:
df = load_data(input_file, category_file)

In [75]:
df.to_csv("/Users/suchith720/Downloads/OneDrive_2_9-13-2025/train.csv", index=False)

In [12]:
df.shape

(522691, 4)

In [11]:
print(make_prompt(df['passage'][0], df['questions'][0]))

<|begin_of_text|><|start_header_id|>system<|end_header_id|>

You are a helpful assistant<|eot_id|><|start_header_id|>user<|end_header_id|>

You are given a passage of text and one or more questions that can be answered from the passage. Assume the passage is part of a Wikipedia article.

Your task: Identify at least 5 relevant Wikipedia categories for the article that help connect the passage with the question(s).

Guidelines for categories:
1. Must be broad but still specific (not generic like Thing or Knowledge).
2. Match Wikipedia’s category style:
   - Use plural nouns or topic phrases
   - Reflect standard groupings (topics, time periods, locations, types)
   - No long sentences; keep them as labels or tags
3. Use existing Wikipedia-style categories (choose the most plausible if unsure).
4. Do not include the "Category:" prefix. Output plain names only.

Scoring:
Assign each category a relevance score from {1, 2, 3, 4, 5}
5 = most relevant
1 = irrelevant

Input notes:
If multiple 

## Tokenizer

In [20]:
mname = "meta-llama/Meta-Llama-3-8B-Instruct"

tokz = AutoTokenizer.from_pretrained(mname)
tokz.pad_token = tokz.eos_token

tokenizer_config.json:   0%|          | 0.00/51.0k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/73.0 [00:00<?, ?B/s]

In [28]:
prompt = make_prompt(df["passage"][100], df["questions"][100])
labels = df['categories'][100]

In [33]:
output = tokz(
    prompt, 
    text_target=labels,
    max_length=512,
    truncation=True,
    padding="max_length",
    return_tensors="pt",
)

In [43]:
model = LlamaForCausalLM.from_pretrained(mname, torch_dtype=torch.float16, device_map="auto")

## Training

In [47]:
class CategoryDataset(Dataset):
    def __init__(self, tokenizer, data, max_length=512):
        self.tokz, self.data, self.max_length = tokenizer, data, max_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        passage = self.data[idx]["passage"]
        questions = self.data[idx]["questions"]
        labels = self.data[idx]["categories"]

        prompt = make_prompt(passage, questions)
        enc = self.tokz(
            prompt, 
            text_target=labels,
            max_length=self.max_length,
            truncation=True,
            padding="max_length",
            return_tensors="pt"
        )

        input_ids = enc["input_ids"].squeeze()
        attention_mask = enc["attention_mask"].squeeze()
        labels = enc["labels"].squeeze()

        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels": labels,
        }
        

In [None]:
model = LlamaForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.float16,
    device_map="auto"
)


lora_config = LoraConfig(
    r=8,
    lora_alpha=16,
    target_modules=["q_proj", "v_proj"],  # common for LLaMA
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM"
)

model = get_peft_model(model, lora_config)

train_data = [
    {"text": "The stock market saw heavy trading in the tech sector.", 
     "categories": "finance, stock market, technology"},
    {"text": "Cristiano Ronaldo scored a hat-trick in the football match.", 
     "categories": "sports, football"}
]

train_dataset = CategoryDataset(tokenizer, train_data, max_length=256)

In [None]:
training_args = TrainingArguments(
    output_dir="./llama-category-finetune",
    per_device_train_batch_size=2,
    gradient_accumulation_steps=4,
    num_train_epochs=3,
    learning_rate=2e-4,
    logging_dir="./logs",
    logging_steps=10,
    save_strategy="epoch",
    fp16=True,
    report_to="none"
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    tokenizer=tokenizer
)

trainer.train()

trainer.save_model("./llama-category-finetune")
tokenizer.save_pretrained("./llama-category-finetune")

## Inference

In [None]:
import torch
from transformers import LlamaTokenizer, LlamaForCausalLM

# -------------------
# 1. Load fine-tuned model
# -------------------
model_path = "./llama-category-finetune"   # folder where you saved the model
tokenizer = LlamaTokenizer.from_pretrained(model_path)
model = LlamaForCausalLM.from_pretrained(
    model_path,
    torch_dtype=torch.float16,
    device_map="auto"
)

# -------------------
# 2. Function for inference
# -------------------
def generate_categories(passage, max_new_tokens=64):
    prompt = f"Passage: {passage}\nCategories:"

    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

    with torch.no_grad():
        output_ids = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            temperature=0.7,
            top_p=0.9,
            do_sample=True
        )

    # Decode and remove the prompt part
    output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
    categories = output_text.split("Categories:")[-1].strip()

    return categories

# -------------------
# 3. Example usage
# -------------------
test_passage = "Apple launched its latest iPhone model with advanced AI features."
print("Predicted categories:", generate_categories(test_passage))

test_passage2 = "The Indian cricket team won the World Cup after a thrilling final."
print("Predicted categories:", generate_categories(test_passage2))
