### Reference: [Fine-Tuning Llama 3.1 for Text Classification](https://www.datacamp.com/tutorial/fine-tuning-llama-3-1)

In [None]:
%pip install bitsandbytes
%pip install peft
%pip install trl

In [2]:
from huggingface_hub import login

# Replace with your actual Hugging Face token
login(token="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")

The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.
Token is valid (permission: fineGrained).
Your token has been saved to /root/.cache/huggingface/token
Login successful


In [3]:
import numpy as np
import pandas as pd
import os
import wandb
from tqdm import tqdm
import bitsandbytes as bnb
import torch
import torch.nn as nn
import transformers
from datasets import Dataset
from peft import LoraConfig, PeftConfig
from trl import SFTTrainer
from trl import setup_chat_format
from transformers import (AutoModelForCausalLM, 
                          AutoTokenizer, 
                          BitsAndBytesConfig, 
                          TrainingArguments, 
                          pipeline, 
                          logging)
from sklearn.metrics import (accuracy_score, 
                             classification_report, 
                             confusion_matrix)
from sklearn.model_selection import train_test_split

In [4]:
# Load the Excel file
df = pd.read_excel('/kaggle/input/news-less-clean/news_less_clean.xlsx').iloc[:, :-2]

# Map labels to numerical values
label_mapping = {
    1: 'Vessel Delay',
    2: 'Vessel Accidents',
    3: 'Maritime Piracy or Terrorism risk',
    4: 'Port or Important Route Congestion',
    5: 'Port Criminal Activities',
    6: 'Cargo Damage and Loss',
    7: 'Inland Transportation Risks',
    8: 'Environmental Impact and Pollution',
    9: 'Natural Extreme Events and Extreme Weather',
    10: 'Cargo or Ship Detainment',
    11: 'Unstable Regulatory and Political Environment',
    12: 'Maritime-related but not covered by existing categories',
    13: 'Non-maritime-related'
}

# Apply the mapping to the dataset
df['Category'] = df['LABEL'].map(label_mapping)

# Inspect the data
print(df.head())

               Date                                                URL  \
0  20240815T010000Z  https://borneobulletin.com.bn/explosions-repor...   
1  20240716T194500Z  https://www.hindustantimes.com/india-news/crew...   
2  20240809T100000Z  https://www.yahoo.com/news/multiple-attacks-ta...   
3  20240717T041500Z  https://timesofoman.com/article/147862-oil-tan...   
4  20240812T201500Z  https://menafn.com/1108546043/Multiple-Attacks...   

                                               Title                 Source  \
0  Explosions reported near two ships off Yemen :...  borneobulletin.com.bn   
1  Crew , including 13 Indians , still missing af...     hindustantimes.com   
2  Multiple attacks target merchant ship off Yeme...              yahoo.com   
3  Oil tanker with 13 Indians on board sinks off ...        timesofoman.com   
4    Multiple Attacks Target Merchant Ship Off Yemen             menafn.com   

         Country  LABEL                           Category  
0         Brunei   

In [5]:
# Split the DataFrame
train_size = 0.8
eval_size = 0.1

# Calculate sizes
train_end = int(train_size * len(df))
eval_end = train_end + int(eval_size * len(df))

# Split the data
X_train = df[:train_end]
X_eval = df[train_end:eval_end]
X_test = df[eval_end:]
predicted_df = df[eval_end:]

# Define the prompt generation functions
def generate_prompt(data_point):
    return f"""
            Classify the text into Vessel_Delay, Vessel_Accidents, Maritime_Piracy_or_Terrorism_risk, Port_or_Important_Route_Congestion, Port_Criminal_Activities, Cargo_Damage_and_Loss, Inland_Transportation_Risks, Environmental_Impact_and_Pollution, Natural_Extreme_Events_and_Extreme_Weather, Cargo_or_Ship_Detainment, Unstable_Regulatory_and_Political_Environment, Maritime-related_but_not_covered_by_existing_categories or Non-maritime-related.
text: {data_point["Title"]}
label: {data_point["Category"].replace(' ', '_')}""".strip()

def generate_test_prompt(data_point):
    return f"""
            Classify the text into Vessel_Delay, Vessel_Accidents, Maritime_Piracy_or_Terrorism_risk, Port_or_Important_Route_Congestion, Port_Criminal_Activities, Cargo_Damage_and_Loss, Inland_Transportation_Risks, Environmental_Impact_and_Pollution, Natural_Extreme_Events_and_Extreme_Weather, Cargo_or_Ship_Detainment, Unstable_Regulatory_and_Political_Environment, Maritime-related_but_not_covered_by_existing_categories or Non-maritime-related.
text: {data_point["Title"]}
label: """.strip()

# Generate prompts for training and evaluation data
X_train.loc[:,'text'] = X_train.apply(generate_prompt, axis=1)
X_eval.loc[:,'text'] = X_eval.apply(generate_prompt, axis=1)

# Generate test prompts and extract true labels
y_true = X_test.loc[:,'Category']
X_test = pd.DataFrame(X_test.apply(generate_test_prompt, axis=1), columns=["text"])

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  X_train.loc[:,'text'] = X_train.apply(generate_prompt, axis=1)
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  X_eval.loc[:,'text'] = X_eval.apply(generate_prompt, axis=1)


In [6]:
X_train.Category.value_counts()

Category
Maritime-related but not covered by existing categories    120
Vessel Accidents                                           102
Maritime Piracy or Terrorism risk                           98
Port or Important Route Congestion                          88
Unstable Regulatory and Political Environment               42
Cargo Damage and Loss                                       35
Environmental Impact and Pollution                          25
Vessel Delay                                                14
Cargo or Ship Detainment                                    13
Non-maritime-related                                        11
Natural Extreme Events and Extreme Weather                  10
Inland Transportation Risks                                  7
Port Criminal Activities                                     5
Name: count, dtype: int64

In [7]:
train_data = Dataset.from_pandas(X_train[["text"]])
eval_data = Dataset.from_pandas(X_eval[["text"]])

In [8]:
train_data['text'][3]

'Classify the text into Vessel_Delay, Vessel_Accidents, Maritime_Piracy_or_Terrorism_risk, Port_or_Important_Route_Congestion, Port_Criminal_Activities, Cargo_Damage_and_Loss, Inland_Transportation_Risks, Environmental_Impact_and_Pollution, Natural_Extreme_Events_and_Extreme_Weather, Cargo_or_Ship_Detainment, Unstable_Regulatory_and_Political_Environment, Maritime-related_but_not_covered_by_existing_categories or Non-maritime-related.\ntext: Oil tanker with 13 Indians on board sinks off Oman coast\nlabel: Vessel_Accidents'

In [9]:
base_model_name = "meta-llama/Llama-3.2-1B"

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=False,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype="float16",
)

model = AutoModelForCausalLM.from_pretrained(
    base_model_name,
    device_map="auto",
    torch_dtype="float16",
    quantization_config=bnb_config, 
)

model.config.use_cache = False
model.config.pretraining_tp = 1

tokenizer = AutoTokenizer.from_pretrained(base_model_name)

tokenizer.pad_token_id = tokenizer.eos_token_id

config.json:   0%|          | 0.00/843 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/2.47G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/185 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/50.5k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/301 [00:00<?, ?B/s]

In [10]:
def predict(test, model, tokenizer, categories):
    y_pred = []
    
    for i in tqdm(range(len(test))):
        prompt = test.iloc[i]["text"]
        pipe = pipeline(task="text-generation", 
                        model=model, 
                        tokenizer=tokenizer, 
                        max_new_tokens=10, 
                        temperature=0.1)
        
        result = pipe(prompt)
        answer = result[0]['generated_text'].split("label:")[1].split("\n")[0].replace('_', ' ').strip()
        print(answer)
        # Determine the predicted category
        for category in categories:
            if category.lower() in answer.lower():
                y_pred.append(category)
                break
        else:
            y_pred.append("none")
    
    return y_pred

y_pred = predict(X_test, model, tokenizer, df['Category'].unique())

  1%|▏         | 1/72 [00:01<01:52,  1.58s/it]

Yemen Houthi rebels strike a Norwegian - flagged


  3%|▎         | 2/72 [00:01<01:01,  1.15it/s]

Maritime Piracy or Terrorism Risk


  4%|▍         | 3/72 [00:02<00:44,  1.53it/s]

Maritime Piracy or Terrorism Risk


  6%|▌         | 4/72 [00:02<00:36,  1.85it/s]

Maritime Piracy or Terrorism Risk


  7%|▋         | 5/72 [00:03<00:31,  2.11it/s]

US Navy saves tanker in Middle East from attackers


  8%|▊         | 6/72 [00:03<00:29,  2.27it/s]

Maritime Piracy or Terrorism Risk


 10%|▉         | 7/72 [00:03<00:27,  2.40it/s]

GT united states Canaveral Cargo Terminal records he


 11%|█         | 8/72 [00:04<00:25,  2.50it/s]

Maritime Piracy or Terrorism Risk


 12%|█▎        | 9/72 [00:04<00:24,  2.56it/s]

Yemen Houthi group says 10 fighters killed


 14%|█▍        | 10/72 [00:04<00:24,  2.58it/s]

Vessel Delay


 15%|█▌        | 11/72 [00:05<00:23,  2.61it/s]

Vessel Delay


 17%|█▋        | 12/72 [00:05<00:22,  2.65it/s]

Multiple commercial vessels attacked in Red Sea


 18%|█▊        | 13/72 [00:06<00:22,  2.66it/s]

Former Port of Indiana - Burns Harbor director returns to


 19%|█▉        | 14/72 [00:06<00:21,  2.68it/s]

Somalia : Experts Warn That Somali Pirates May Cooperate


 21%|██        | 15/72 [00:06<00:21,  2.69it/s]

Maritime Piracy or Terrorism Risk


 22%|██▏       | 16/72 [00:07<00:20,  2.69it/s]

Port of Oakland September container volume holds steady


 24%|██▎       | 17/72 [00:07<00:20,  2.65it/s]

Maritime Piracy or Terrorism Risk


 25%|██▌       | 18/72 [00:07<00:20,  2.67it/s]

Haiphong orders 30 Mitsui cranes


 26%|██▋       | 19/72 [00:08<00:19,  2.69it/s]

Maritime Piracy or Terrorism Risk


 28%|██▊       | 20/72 [00:08<00:19,  2.67it/s]

Vessel Delay


 29%|██▉       | 21/72 [00:09<00:18,  2.70it/s]

US warship shot down Houthi drone attacking


 31%|███       | 22/72 [00:09<00:18,  2.71it/s]

U. S. Intercepts Houthi Ball


 32%|███▏      | 23/72 [00:09<00:18,  2.71it/s]

Kenya, EAC stare at costly imports as shipping


 33%|███▎      | 24/72 [00:10<00:17,  2.72it/s]

Houthis attack container ship, Maersk hal


 35%|███▍      | 25/72 [00:10<00:17,  2.72it/s]

Clive Palmer $40 million superyacht freed


 36%|███▌      | 26/72 [00:10<00:16,  2.73it/s]

2 attacks launched from rebel - held Yemen strike


 38%|███▊      | 27/72 [00:11<00:16,  2.73it/s]

Hellenic Shipping News Worldwide


 39%|███▉      | 28/72 [00:11<00:16,  2.72it/s]

8 crew members missing from Chinese fishing boat after


 40%|████      | 29/72 [00:11<00:15,  2.72it/s]

US Navy Helicopters Kill Iran - Backed


 42%|████▏     | 30/72 [00:12<00:15,  2.71it/s]

Vessel Delay


 43%|████▎     | 31/72 [00:12<00:15,  2.72it/s]

Maersk Halts Ship Passage Via Red Sea


 44%|████▍     | 32/72 [00:13<00:14,  2.72it/s]

Houthi rebels in Yemen accidentally strike Norwegian tanker


 46%|████▌     | 33/72 [00:13<00:14,  2.72it/s]

US military sinks Huthi vessels that attacked cargo


 47%|████▋     | 34/72 [00:13<00:13,  2.72it/s]

Vessel Delay


 49%|████▊     | 35/72 [00:14<00:13,  2.73it/s]

US says it sank Houthi vessels that attacked


 50%|█████     | 36/72 [00:14<00:13,  2.73it/s]

Vessel Delay


 51%|█████▏    | 37/72 [00:14<00:12,  2.73it/s]

Maritime Piracy or Terrorism Risk


 53%|█████▎    | 38/72 [00:15<00:12,  2.73it/s]

Maritime Piracy or Terrorism Risk


 54%|█████▍    | 39/72 [00:15<00:12,  2.73it/s]

Maersk pauses Red Sea sailings after H


 56%|█████▌    | 40/72 [00:15<00:11,  2.73it/s]

Maritime Security


 57%|█████▋    | 41/72 [00:16<00:11,  2.73it/s]

Vizhinjam


 58%|█████▊    | 42/72 [00:16<00:11,  2.73it/s]

Red Sea attacks : Ballistic missile from rebel -


 60%|█████▉    | 43/72 [00:17<00:10,  2.73it/s]

Ballistic missile fired from rebel - held Yemen strikes


 61%|██████    | 44/72 [00:17<00:10,  2.64it/s]

Vessel Delay


 62%|██████▎   | 45/72 [00:17<00:10,  2.66it/s]

Maritime Piracy or Terrorism Risk


 64%|██████▍   | 46/72 [00:18<00:09,  2.64it/s]

Port records high cargo volumes despite harsh economic times


 65%|██████▌   | 47/72 [00:18<00:09,  2.59it/s]

Panama Canal drought hits new crisis level amid severe El


 67%|██████▋   | 48/72 [00:19<00:09,  2.63it/s]

Maritime Piracy or Terrorism Risk


 68%|██████▊   | 49/72 [00:19<00:08,  2.65it/s]

Maritime Piracy or Terrorism Risk


 69%|██████▉   | 50/72 [00:19<00:08,  2.66it/s]

Maersk suspends ship passage via Red Sea


 71%|███████   | 51/72 [00:20<00:07,  2.69it/s]

Port Of Rotterdam Offers  Substantial  Port Fee


 72%|███████▏  | 52/72 [00:20<00:07,  2.71it/s]

Maritime Piracy or Terrorism Risk


 74%|███████▎  | 53/72 [00:20<00:07,  2.70it/s]

Union Minister Sonowal reviews progress of proposed International


 75%|███████▌  | 54/72 [00:21<00:06,  2.70it/s]

Maritime Piracy or Terrorism Risk


 76%|███████▋  | 55/72 [00:21<00:06,  2.67it/s]

Gujarat : Mundra Port achieves historic milestone ; handles


 78%|███████▊  | 56/72 [00:21<00:05,  2.68it/s]

Yemen : Maersk halts Red Sea shipping


 79%|███████▉  | 57/72 [00:22<00:05,  2.68it/s]

Maritime Piracy or Terrorism Risk


 81%|████████  | 58/72 [00:22<00:05,  2.70it/s]

Shipping firms suspend Red Sea traffic after Yemen rebel strikes


 82%|████████▏ | 59/72 [00:23<00:04,  2.71it/s]

Red Sea | Maersk pauses Red Sea sail


 83%|████████▎ | 60/72 [00:23<00:04,  2.71it/s]

Clive Palmer $40m superyacht runs


 85%|████████▍ | 61/72 [00:23<00:04,  2.70it/s]

2 Pinoy seamen safe after foiled


 86%|████████▌ | 62/72 [00:24<00:03,  2.71it/s]

US Military Says It Sank Huthi V


 88%|████████▊ | 63/72 [00:24<00:03,  2.72it/s]

Lagos to Host World Largest Container - RORO


 89%|████████▉ | 64/72 [00:24<00:02,  2.72it/s]

2 attacks launched by Yemen Houthi rebels


 90%|█████████ | 65/72 [00:25<00:02,  2.73it/s]

US forces repel Houthi attack on Ma


 92%|█████████▏| 66/72 [00:25<00:02,  2.73it/s]

Maersk pauses Red Sea sailings after H


 93%|█████████▎| 67/72 [00:26<00:01,  2.73it/s]

Yemen Houthi rebels attack container ships in vital


 94%|█████████▍| 68/72 [00:26<00:01,  2.73it/s]

Maritime Piracy or Terrorism Risk


 96%|█████████▌| 69/72 [00:26<00:01,  2.73it/s]

Protesters demanding ceasefire in Israel - Hamas war gather


 97%|█████████▋| 70/72 [00:27<00:00,  2.75it/s]

Vessel Delay


 99%|█████████▊| 71/72 [00:27<00:00,  2.69it/s]

Maritime Piracy or Terrorism Risk | Maritime Piracy


100%|██████████| 72/72 [00:27<00:00,  2.58it/s]

Maritime Piracy or Terrorism Risk





In [11]:
def evaluate(y_true, y_pred, labels):
    mapping = {label: idx for idx, label in enumerate(labels)}
    
    def map_func(x):
        return mapping.get(x, -1)  # Map to -1 if not found, but should not occur with correct data
    
    y_true_mapped = np.vectorize(map_func)(y_true)
    y_pred_mapped = np.vectorize(map_func)(y_pred)
    
    # Calculate accuracy
    accuracy = accuracy_score(y_true=y_true_mapped, y_pred=y_pred_mapped)
    print(f'Accuracy: {accuracy:.3f}')
    
    # Generate accuracy report
    unique_labels = set(y_true_mapped)  # Get unique labels
    
    for label in unique_labels:
        label_indices = [i for i in range(len(y_true_mapped)) if y_true_mapped[i] == label]
        label_y_true = [y_true_mapped[i] for i in label_indices]
        label_y_pred = [y_pred_mapped[i] for i in label_indices]
        label_accuracy = accuracy_score(label_y_true, label_y_pred)
        print(f'Accuracy for label {labels[label]}: {label_accuracy:.3f}')
        
    # Generate classification report
    class_report = classification_report(y_true=y_true_mapped, y_pred=y_pred_mapped, target_names=labels, labels=list(range(len(labels))))
    print('\nClassification Report:')
    print(class_report)
    
    # Generate confusion matrix
    conf_matrix = confusion_matrix(y_true=y_true_mapped, y_pred=y_pred_mapped, labels=list(range(len(labels))))
    print('\nConfusion Matrix:')
    print(conf_matrix)

evaluate(y_true, y_pred, df['Category'].unique())

Accuracy: 0.222
Accuracy for label Vessel Accidents: 0.000
Accuracy for label Maritime Piracy or Terrorism risk: 0.372
Accuracy for label Environmental Impact and Pollution: 0.000
Accuracy for label Non-maritime-related: 0.000
Accuracy for label Unstable Regulatory and Political Environment: 0.000
Accuracy for label Maritime-related but not covered by existing categories: 0.000
Accuracy for label Vessel Delay: 0.000
Accuracy for label Port Criminal Activities: 0.000

Classification Report:
                                                         precision    recall  f1-score   support

                                       Vessel Accidents       0.00      0.00      0.00         2
                      Maritime Piracy or Terrorism risk       0.84      0.37      0.52        43
                     Environmental Impact and Pollution       0.00      0.00      0.00         1
                                   Non-maritime-related       0.00      0.00      0.00         5
          Unstable 

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


In [12]:
import bitsandbytes as bnb

def find_all_linear_names(model):
    cls = bnb.nn.Linear4bit
    lora_module_names = set()
    for name, module in model.named_modules():
        if isinstance(module, cls):
            names = name.split('.')
            lora_module_names.add(names[0] if len(names) == 1 else names[-1])
    if 'lm_head' in lora_module_names:  # needed for 16 bit
        lora_module_names.remove('lm_head')
    return list(lora_module_names)
modules = find_all_linear_names(model)
modules

['v_proj', 'down_proj', 'q_proj', 'k_proj', 'up_proj', 'o_proj', 'gate_proj']

In [13]:
output_dir="meta-llama/Llama-3.2-1B-finetuned"

peft_config = LoraConfig(
    lora_alpha=16,
    lora_dropout=0,
    r=64,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=modules,
)

training_arguments = TrainingArguments(
    output_dir=output_dir,                    # directory to save and repository id
    num_train_epochs=1,                       # number of training epochs
    per_device_train_batch_size=1,            # batch size per device during training
    gradient_accumulation_steps=8,            # number of steps before performing a backward/update pass
    gradient_checkpointing=True,              # use gradient checkpointing to save memory
    optim="paged_adamw_32bit",
    logging_steps=1,                         
    learning_rate=2e-4,                       # learning rate, based on QLoRA paper
    weight_decay=0.001,
    fp16=True,
    bf16=False,
    max_grad_norm=0.3,                        # max gradient norm based on QLoRA paper
    max_steps=-1,
    warmup_ratio=0.03,                        # warmup ratio based on QLoRA paper
    group_by_length=False,
    lr_scheduler_type="cosine",               # use cosine learning rate scheduler
    report_to="wandb",                  # report metrics to w&b
    eval_strategy="steps",              # save checkpoint every epoch
    eval_steps = 0.2
)

trainer = SFTTrainer(
    model=model,
    args=training_arguments,
    train_dataset=train_data,
    eval_dataset=eval_data,
    peft_config=peft_config,
    dataset_text_field="text",
    tokenizer=tokenizer,
    max_seq_length=512,
    packing=False,
    dataset_kwargs={
    "add_special_tokens": False,
    "append_concat_token": False,
    }
)


Deprecated positional argument(s) used in SFTTrainer, please use the SFTConfig to set these arguments instead.


Map:   0%|          | 0/570 [00:00<?, ? examples/s]

Map:   0%|          | 0/71 [00:00<?, ? examples/s]

  self.scaler = torch.cuda.amp.GradScaler(**kwargs)


In [14]:
trainer.train()
#d4f751c4d547a1999f19e94a8904ca831b02050c

[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
[34m[1mwandb[0m: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

  ········································


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011112598277778336, max=1.0…

  return fn(*args, **kwargs)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


Step,Training Loss,Validation Loss
15,0.7557,0.625915
30,0.4954,0.466
45,0.3865,0.418909
60,0.4163,0.414541




TrainOutput(global_step=71, training_loss=0.7883249370145126, metrics={'train_runtime': 250.9129, 'train_samples_per_second': 2.272, 'train_steps_per_second': 0.283, 'total_flos': 480914370293760.0, 'train_loss': 0.7883249370145126, 'epoch': 0.9964912280701754})

In [15]:
wandb.finish()
model.config.use_cache = True

VBox(children=(Label(value='0.028 MB of 0.028 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
eval/loss,█▃▁▁
eval/runtime,▁▁█▅
eval/samples_per_second,██▁▄
eval/steps_per_second,██▁▃
train/epoch,▁▁▁▁▁▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▆▆▆▆▆▆▆▇▇▇▇▇▇▇██
train/global_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇▇▇█
train/grad_norm,▂▂▂▂█▄▃▃▃▂▂▁▂▁▁▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/learning_rate,▆████▇▇▇▇▇▆▆▆▆▅▅▅▅▅▄▄▄▃▃▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁
train/loss,██▇▅▅▃▂▂▂▂▁▁▂▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂

0,1
eval/loss,0.41454
eval/runtime,4.9668
eval/samples_per_second,14.295
eval/steps_per_second,1.812
total_flos,480914370293760.0
train/epoch,0.99649
train/global_step,71.0
train/grad_norm,0.29146
train/learning_rate,0.0
train/loss,0.5797


In [16]:
# Save trained model and tokenizer
trainer.save_model(output_dir)
tokenizer.save_pretrained(output_dir)



('meta-llama/Llama-3.2-1B/tokenizer_config.json',
 'meta-llama/Llama-3.2-1B/special_tokens_map.json',
 'meta-llama/Llama-3.2-1B/tokenizer.json')

In [17]:
y_pred = predict(X_test, model, tokenizer, df['Category'].unique())
evaluate(y_true, y_pred, df['Category'].unique())

  0%|          | 0/72 [00:00<?, ?it/s]`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.
  return fn(*args, **kwargs)
  1%|▏         | 1/72 [00:00<01:05,  1.08it/s]

Maritime Piracy or Terrorism risk


  3%|▎         | 2/72 [00:01<01:02,  1.11it/s]

Maritime Piracy or Terrorism risk


  4%|▍         | 3/72 [00:02<01:01,  1.13it/s]

Maritime Piracy or Terrorism risk


  6%|▌         | 4/72 [00:03<01:00,  1.13it/s]

Maritime Piracy or Terrorism risk


  7%|▋         | 5/72 [00:04<00:58,  1.15it/s]

Maritime Piracy or Terrorism risk


  8%|▊         | 6/72 [00:05<00:57,  1.14it/s]

Maritime Piracy or Terrorism risk


 10%|▉         | 7/72 [00:06<00:56,  1.14it/s]

Cargo Damage and Loss


 11%|█         | 8/72 [00:07<00:56,  1.14it/s]

Maritime Piracy or Terrorism risk


 12%|█▎        | 9/72 [00:07<00:55,  1.14it/s]

Maritime Piracy or Terrorism risk


 14%|█▍        | 10/72 [00:08<00:54,  1.14it/s]

Maritime Piracy or Terrorism risk


 15%|█▌        | 11/72 [00:09<00:53,  1.14it/s]

Maritime Piracy or Terrorism risk


 17%|█▋        | 12/72 [00:10<00:51,  1.15it/s]

Maritime Piracy or Terrorism risk


 18%|█▊        | 13/72 [00:11<00:51,  1.15it/s]

Port or Important Route Congestion


 19%|█▉        | 14/72 [00:12<00:50,  1.15it/s]

Maritime Piracy or Terrorism risk


 21%|██        | 15/72 [00:13<00:49,  1.14it/s]

Maritime Piracy or Terrorism risk


 22%|██▏       | 16/72 [00:14<00:48,  1.15it/s]

Port or Important Route Congestion


 24%|██▎       | 17/72 [00:14<00:47,  1.15it/s]

Maritime Piracy or Terrorism risk


 25%|██▌       | 18/72 [00:15<00:47,  1.14it/s]

Port or Important Route Congestion


 26%|██▋       | 19/72 [00:16<00:46,  1.14it/s]

Maritime Piracy or Terrorism risk


 28%|██▊       | 20/72 [00:17<00:45,  1.14it/s]

Maritime Piracy or Terrorism risk


 29%|██▉       | 21/72 [00:18<00:44,  1.14it/s]

Maritime Piracy or Terrorism risk


 31%|███       | 22/72 [00:19<00:43,  1.14it/s]

Maritime Piracy or Terrorism risk


 32%|███▏      | 23/72 [00:20<00:42,  1.14it/s]

Maritime-related but not covered by existing categories


 33%|███▎      | 24/72 [00:21<00:42,  1.14it/s]

Maritime Piracy or Terrorism risk


 35%|███▍      | 25/72 [00:21<00:41,  1.15it/s]

Vessel Accidents


 36%|███▌      | 26/72 [00:22<00:40,  1.14it/s]

Maritime Piracy or Terrorism risk


 38%|███▊      | 27/72 [00:23<00:39,  1.14it/s]

Maritime-related but not covered by existing categories


 39%|███▉      | 28/72 [00:24<00:38,  1.14it/s]

Maritime Piracy or Terrorism risk


 40%|████      | 29/72 [00:25<00:37,  1.14it/s]

Maritime Piracy or Terrorism risk


 42%|████▏     | 30/72 [00:26<00:36,  1.14it/s]

Maritime Piracy or Terrorism risk


 43%|████▎     | 31/72 [00:27<00:35,  1.14it/s]

Maritime-related but not covered by existing categories


 44%|████▍     | 32/72 [00:28<00:35,  1.14it/s]

Maritime Piracy or Terrorism risk


 46%|████▌     | 33/72 [00:28<00:34,  1.14it/s]

Maritime Piracy or Terrorism risk


 47%|████▋     | 34/72 [00:29<00:33,  1.14it/s]

Maritime Piracy or Terrorism risk


 49%|████▊     | 35/72 [00:30<00:32,  1.14it/s]

Maritime Piracy or Terrorism risk


 50%|█████     | 36/72 [00:31<00:31,  1.14it/s]

Maritime Piracy or Terrorism risk


 51%|█████▏    | 37/72 [00:32<00:30,  1.14it/s]

Maritime Piracy or Terrorism risk


 53%|█████▎    | 38/72 [00:33<00:29,  1.14it/s]

Maritime Piracy or Terrorism risk


 54%|█████▍    | 39/72 [00:34<00:28,  1.14it/s]

Maritime Piracy or Terrorism risk


 56%|█████▌    | 40/72 [00:35<00:28,  1.14it/s]

Environmental Impact and Pollution


 57%|█████▋    | 41/72 [00:35<00:27,  1.14it/s]

Port or Important Route Congestion


 58%|█████▊    | 42/72 [00:36<00:26,  1.14it/s]

Maritime Piracy or Terrorism risk


 60%|█████▉    | 43/72 [00:37<00:25,  1.14it/s]

Maritime Piracy or Terrorism risk


 61%|██████    | 44/72 [00:38<00:24,  1.14it/s]

Maritime Piracy or Terrorism risk


 62%|██████▎   | 45/72 [00:39<00:23,  1.14it/s]

Maritime Piracy or Terrorism risk


 64%|██████▍   | 46/72 [00:40<00:22,  1.15it/s]

Cargo Damage and Loss


 65%|██████▌   | 47/72 [00:41<00:21,  1.15it/s]

Port or Important Route Congestion


 67%|██████▋   | 48/72 [00:42<00:20,  1.15it/s]

Maritime Piracy or Terrorism risk


 68%|██████▊   | 49/72 [00:42<00:20,  1.14it/s]

Maritime Piracy or Terrorism risk


 69%|██████▉   | 50/72 [00:43<00:19,  1.14it/s]

Maritime Piracy or Terrorism risk


 71%|███████   | 51/72 [00:44<00:18,  1.15it/s]

Port or Important Route Congestion


 72%|███████▏  | 52/72 [00:45<00:17,  1.15it/s]

Maritime Piracy or Terrorism risk


 74%|███████▎  | 53/72 [00:46<00:16,  1.15it/s]

Port or Important Route Congestion


 75%|███████▌  | 54/72 [00:47<00:15,  1.15it/s]

Maritime Piracy or Terrorism risk


 76%|███████▋  | 55/72 [00:48<00:14,  1.14it/s]

Port or Important Route Congestion


 78%|███████▊  | 56/72 [00:49<00:14,  1.14it/s]

Maritime Piracy or Terrorism risk


 79%|███████▉  | 57/72 [00:49<00:13,  1.14it/s]

Maritime Piracy or Terrorism risk


 81%|████████  | 58/72 [00:50<00:12,  1.14it/s]

Maritime Piracy or Terrorism risk


 82%|████████▏ | 59/72 [00:51<00:11,  1.14it/s]

Maritime Piracy or Terrorism risk


 83%|████████▎ | 60/72 [00:52<00:10,  1.14it/s]

Vessel Accidents


 85%|████████▍ | 61/72 [00:53<00:09,  1.14it/s]

Vessel Accidents


 86%|████████▌ | 62/72 [00:54<00:08,  1.14it/s]

Maritime Piracy or Terrorism risk


 88%|████████▊ | 63/72 [00:55<00:07,  1.14it/s]

Port or Important Route Congestion


 89%|████████▉ | 64/72 [00:56<00:07,  1.14it/s]

Maritime Piracy or Terrorism risk


 90%|█████████ | 65/72 [00:56<00:06,  1.14it/s]

Maritime Piracy or Terrorism risk


 92%|█████████▏| 66/72 [00:57<00:05,  1.14it/s]

Maritime Piracy or Terrorism risk


 93%|█████████▎| 67/72 [00:58<00:04,  1.14it/s]

Maritime Piracy or Terrorism risk


 94%|█████████▍| 68/72 [00:59<00:03,  1.14it/s]

Maritime Piracy or Terrorism risk


 96%|█████████▌| 69/72 [01:00<00:02,  1.14it/s]

Maritime-related but not covered by existing categories


 97%|█████████▋| 70/72 [01:01<00:01,  1.15it/s]

Cargo Damage and Loss


 99%|█████████▊| 71/72 [01:02<00:00,  1.15it/s]

Maritime-related but not covered by existing categories


100%|██████████| 72/72 [01:03<00:00,  1.14it/s]

Maritime Piracy or Terrorism risk
Accuracy: 0.625
Accuracy for label Vessel Accidents: 0.500
Accuracy for label Maritime Piracy or Terrorism risk: 0.953
Accuracy for label Environmental Impact and Pollution: 1.000
Accuracy for label Non-maritime-related: 0.000
Accuracy for label Unstable Regulatory and Political Environment: 0.000
Accuracy for label Maritime-related but not covered by existing categories: 0.167
Accuracy for label Vessel Delay: 0.000
Accuracy for label Port Criminal Activities: 0.000

Classification Report:
                                                         precision    recall  f1-score   support

                                       Vessel Accidents       0.33      0.50      0.40         2
                      Maritime Piracy or Terrorism risk       0.80      0.95      0.87        43
                     Environmental Impact and Pollution       1.00      1.00      1.00         1
                                   Non-maritime-related       0.00      0.00      


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


In [18]:
predicted_df["Predicted"] = y_pred
predicted_df.to_csv('predictions.csv', index=False)

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  predicted_df["Predicted"] = y_pred
