In [1]:
import os
import wandb
import transformers
import json
import random
import math
import pickle
import pandas as pd
import numpy as np

from typing import Dict, List, Optional, Tuple, Union

from dataclasses import dataclass
from datasets import load_dataset, Dataset, DatasetDict, load_from_disk
from sklearn.metrics import accuracy_score, f1_score, matthews_corrcoef
from sklearn.model_selection import train_test_split

from transformers import Trainer, AutoTokenizer, Mamba2ForCausalLM, TrainingArguments, DataCollatorWithPadding, Mamba2Config, GenerationMixin, DataCollatorForMultipleChoice, DataCollatorForLanguageModeling
from transformers.models.mamba2.modeling_mamba2 import Mamba2PreTrainedModel, Mamba2Model, Mamba2Cache, Mamba2CausalLMOutput
from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings, add_code_sample_docstrings
import evaluate

from peft import PeftModel, LoraConfig, get_peft_model
import peft.tuners.lora.layer as pl

import torch
from torch import nn
import torch.nn.functional as F
from torch.nn.modules.loss import CrossEntropyLoss

import pennylane as qml

from trl import SFTTrainer

from tqdm import tqdm
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence

_CHECKPOINT_FOR_DOC = "mistralai/mamba-codestral-7B-v0.1"
_CONFIG_FOR_DOC = "Mamba2Config"

In [3]:
# model_name="state-spaces/mamba-130m-hf"
model_name="AntonV/mamba2-130m-hf"
output_dir = "mamba2_cp/0_iot_fpft_mamba2_130m"

lora_r = 2
train_batch_size = 1
eval_batch_size = 1
learning_rate = 5e-4
num_train_epochs = 10

In [4]:
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right" #enforce padding side left

In [None]:
df_dataset = pd.read_csv("iot_datasets/iot_resource_allocation_dataset.csv")
def row_to_json_prompt(r):
    payload = {
        "Timestamp":               r["Timestamp"],
        "Device_ID":               r["Device_ID"],
        "Temp":                    r["Sensor_Data"].split(",")[0].split(":")[1].strip(),
        "Humidity":                r["Sensor_Data"].split(",")[1].split(":")[1].strip(),
        "Workload_Type":           r["Workload_Type"],
        "Processing_Tier":         r["Processing_Tier"],
        "CPU_Usage(%)":            r["CPU_Usage(%)"],
        "Memory_Usage(MB)":        r["Memory_Usage(MB)"],
        "Network_Latency(ms)":     r["Network_Latency(ms)"],
        "Network_Jitter(ms)":      r["Jitter(ms)"],
        "Task_Execution_Time(ms)": r["Task_Execution_Time(ms)"],
        "Predicted_Resource_Allocation(%)": r["Predicted_Resource_Allocation(%)"],
    }
    return f"{json.dumps(payload, ensure_ascii=False)}\n###\nPredict the Actual Resource Allocation(%)"

df_dataset["text"]  = df_dataset.apply(row_to_json_prompt, axis=1)
df_dataset["label"] = df_dataset["Actual_Resource_Allocation(%)"].astype(str)          
df = df_dataset[["text", "label"]]    

print(df["text"][0])
print()
print("Actual Resource Allocation(%):")
print(df["label"][0])

{"Timestamp": "2024-03-25 12:00:00", "Device_ID": "D14", "Temp": "32°C", "Humidity": "72%", "Workload_Type": "Data Analytics", "Processing_Tier": "Device", "CPU_Usage(%)": 24, "Memory_Usage(MB)": 2276, "Network_Latency(ms)": 29, "Network_Jitter(ms)": 8, "Task_Execution_Time(ms)": 189, "Predicted_Resource_Allocation(%)": 54}
###
Predict the Actual Resource Allocation(%)

Actual Resource Allocation(%):
51


In [None]:
CACHE_PATH = "iot_datasets/iot_resource_allocation_hf_dataset"

if os.path.exists(CACHE_PATH):
    print("➜  Load the cached dataset …")
    ds_dict = load_from_disk(CACHE_PATH)          # DatasetDict(train, test)
else:
    train_df, test_df = train_test_split(df, test_size=0.2, random_state=42, shuffle=False)

    ds_dict = DatasetDict({
        "train": Dataset.from_pandas(train_df.reset_index(drop=True)),
        "test":  Dataset.from_pandas(test_df.reset_index(drop=True))
    })

    ds_dict.save_to_disk(CACHE_PATH)
    print(f"✓ Save to {CACHE_PATH}")

➜  载入已缓存的数据集 …


In [7]:
iot_train_hf_dataset = Dataset.from_list(ds_dict["train"])
print(iot_train_hf_dataset[0])

{'text': '{"Timestamp": "2024-03-25 12:00:00", "Device_ID": "D14", "Temp": "32°C", "Humidity": "72%", "Workload_Type": "Data Analytics", "Processing_Tier": "Device", "CPU_Usage(%)": 24, "Memory_Usage(MB)": 2276, "Network_Latency(ms)": 29, "Network_Jitter(ms)": 8, "Task_Execution_Time(ms)": 189, "Predicted_Resource_Allocation(%)": 54}\n###\nPredict the Actual Resource Allocation(%)', 'label': '51'}


In [None]:
def length_only(example):
    ids = tokenizer(example["text"], add_special_tokens=True,
                    padding=False, truncation=False).input_ids
    return {"seq_len": len(ids)}

train_with_len = ds_dict["train"].map(length_only)
MAX_LEN = int(np.max(train_with_len["seq_len"]))
print("Raw longest length:", MAX_LEN)

MAX_LEN = min(MAX_LEN, tokenizer.model_max_length)            
print("Effective MAX_LEN:", MAX_LEN)

Raw longest length: 141
Effective MAX_LEN: 141


In [None]:
# 1. Tokenize input_ids / attention_mask / labels
def tok_fn(batch):
    prompt_ids = tokenizer(batch["text"], truncation=True, max_length=MAX_LEN).input_ids
    answer_ids = tokenizer(batch["label"], add_special_tokens=False).input_ids
    input_ids  = [p + answer_ids[i] + [tokenizer.eos_token_id] for i, p in enumerate(prompt_ids)]
    labels     = [[-100]*len(p) + answer_ids[i] + [tokenizer.eos_token_id] for i, p in enumerate(prompt_ids)]
    attn_mask  = [[1]*len(ids) for ids in input_ids]
    return {"input_ids": input_ids, "attention_mask": attn_mask, "labels": labels}

# 2. DatasetDict
tokd = ds_dict.map(
    tok_fn,
    batched=True,
    remove_columns=["text", "label"]   
)

tokd.set_format(
    type="torch",
    columns=["input_ids", "attention_mask", "labels"]
)

In [None]:
def collate_fn(features):
    input_ids_batch = []
    attention_mask_batch = []
    labels_batch = []

    for f in features: 
        ids = f["input_ids"]
        if isinstance(ids, torch.Tensor):
            ids = ids.detach().clone()
        else:
            ids = torch.tensor(ids, dtype=torch.long)
        input_ids_batch.append(ids)

        mask = f["attention_mask"]
        if isinstance(mask, torch.Tensor):
            mask = mask.detach().clone()
        else:
            mask = torch.tensor(mask, dtype=torch.long)
        attention_mask_batch.append(mask)

        lab = f["labels"]
        if isinstance(lab, torch.Tensor):
            lab = lab.detach().clone()
        else:
            lab = torch.tensor(lab, dtype=torch.long)
        labels_batch.append(lab)

    input_ids = pad_sequence(
        input_ids_batch, batch_first=True, padding_value=tokenizer.pad_token_id
    )
    attention_mask = pad_sequence(
        attention_mask_batch, batch_first=True, padding_value=0
    )
    labels = pad_sequence(
        labels_batch, batch_first=True, padding_value=-100
    )

    return {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "labels": labels,
    }

In [11]:
model = Mamba2ForCausalLM.from_pretrained(
    model_name, 
    device_map="cuda",
    )
print(model)
model.resize_token_embeddings(len(tokenizer))

Mamba2ForCausalLM(
  (backbone): Mamba2Model(
    (embeddings): Embedding(50288, 768)
    (layers): ModuleList(
      (0-23): 24 x Mamba2Block(
        (norm): Mamba2RMSNorm()
        (mixer): Mamba2Mixer(
          (act): SiLU()
          (conv1d): Conv1d(1792, 1792, kernel_size=(4,), stride=(1,), padding=(3,), groups=1792)
          (in_proj): Linear(in_features=768, out_features=3352, bias=False)
          (norm): MambaRMSNormGated()
          (out_proj): Linear(in_features=1536, out_features=768, bias=False)
        )
      )
    )
    (norm_f): Mamba2RMSNorm()
  )
  (lm_head): Linear(in_features=768, out_features=50288, bias=False)
)


Embedding(50277, 768)

In [None]:
rmse = evaluate.load("mse")                                     
mae  = evaluate.load("mae")

def post_process(texts):
    outs = []
    for t in texts:
        try:
            outs.append(float(t.strip().split()[-1]))
        except ValueError:
            outs.append(np.nan)
    return np.array(outs)

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    preds = np.argmax(logits, axis=-1)
    decoded = tokenizer.batch_decode(preds, skip_special_tokens=True)
    y_pred  = post_process(decoded)
    y_true  = post_process(tokenizer.batch_decode(labels, skip_special_tokens=True))

    mask = ~np.isnan(y_pred)
    y_pred, y_true = y_pred[mask], y_true[mask]

    rmse_v = rmse.compute(predictions=y_pred, references=y_true, squared=False)["mse"]
    mae_v  = mae.compute(predictions=y_pred, references=y_true)["mae"]
    eff    = (np.abs(y_pred - y_true) <= 2).mean()
    return {"rmse": rmse_v, "mae": mae_v, "efficiency": eff}

In [None]:
# Experiments and trials welcome!
training_args = TrainingArguments(
    output_dir=output_dir,
    overwrite_output_dir=True,
    save_steps=5000,
    logging_steps=100,
    # logging_first_step=True,
    fp16=False,          
    bf16=False,           
    save_safetensors=False,
    resume_from_checkpoint=True,
    dataloader_drop_last = True,
    remove_unused_columns=False,
    label_names=["labels"],

    per_device_train_batch_size=train_batch_size,
    learning_rate=learning_rate,
    num_train_epochs=num_train_epochs,
)

trainer = SFTTrainer(
    model=model,
    processing_class=tokenizer,
    args=training_args,
    train_dataset=tokd["train"],
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
)

In [None]:
model = trainer.model
total_params = sum(p.numel() for p in model.parameters())

trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
non_trainable_params = total_params - trainable_params

print(f"Total params: {total_params}")
print(f"Trainable params: {trainable_params}")
print(f"Non-trainable params: {non_trainable_params}")
print(model)

Total params: 128981184
Trainable params: 128981184
Non-trainable params: 0
Mamba2ForCausalLM(
  (backbone): Mamba2Model(
    (embeddings): Embedding(50277, 768)
    (layers): ModuleList(
      (0-23): 24 x Mamba2Block(
        (norm): Mamba2RMSNorm()
        (mixer): Mamba2Mixer(
          (act): SiLU()
          (conv1d): Conv1d(1792, 1792, kernel_size=(4,), stride=(1,), padding=(3,), groups=1792)
          (in_proj): Linear(in_features=768, out_features=3352, bias=False)
          (norm): MambaRMSNormGated()
          (out_proj): Linear(in_features=1536, out_features=768, bias=False)
        )
      )
    )
    (norm_f): Mamba2RMSNorm()
  )
  (lm_head): Linear(in_features=768, out_features=50277, bias=False)
)


In [15]:
torch.cuda.empty_cache()

In [16]:
trainer.train()
# trainer.train(resume_from_checkpoint=True)



Step,Training Loss
100,2.2567
200,2.0466
300,1.6858
400,1.7081
500,1.595
600,1.5774
700,1.5623
800,1.5354
900,1.4634
1000,1.4749


TrainOutput(global_step=8000, training_loss=1.303566107749939, metrics={'train_runtime': 741.2941, 'train_samples_per_second': 10.792, 'train_steps_per_second': 10.792, 'total_flos': 616883943951360.0, 'train_loss': 1.303566107749939, 'epoch': 10.0})

In [17]:
trainer.save_model(output_dir)
tokenizer.save_pretrained(output_dir)

('mamba2_cp/0_iot_fpft_mamba2_130m/tokenizer_config.json',
 'mamba2_cp/0_iot_fpft_mamba2_130m/special_tokens_map.json',
 'mamba2_cp/0_iot_fpft_mamba2_130m/tokenizer.json')

In [None]:
def run_mamba(model, context):

    text = f"{context}"
    # print(text)
    # input_ids = torch.LongTensor([tokenizer.encode(text)]).cuda()
    input_ids = tokenizer.encode(text, return_tensors="pt").cuda()
    # print(input_ids)

    attention_mask = torch.ones_like(input_ids).cuda()

    # print("max_length", input_ids.shape[1])

    out = model.generate(
        input_ids=input_ids,
        attention_mask=attention_mask,
        max_length=input_ids.shape[1]+10, #max_length_in_dataset
        # max_length=max_length_in_dataset,
        # max_new_tokens=max_length_in_dataset,
        eos_token_id=tokenizer.eos_token_id,
    )

    # print(out)
    decoded = tokenizer.batch_decode(out)[0]
    # print("="*80)
    # print(decoded)

    # out returns the whole sequence plus the original
    cleaned = decoded.replace(text, "")
    # cleaned = decoded[len(text):]
    cleaned = cleaned.replace("<|endoftext|>", "")

    # the model will just keep generating, so only grab the first one
    cleaned = cleaned.split("\n\n")[0].strip()
    lines = cleaned.splitlines()
    if lines:
        cleaned = lines[0].strip()

    # print(answer)
    return cleaned

In [19]:
context="{\"Timestamp\": \"2024-03-25 12:00:00\", \"Device_ID\": \"D14\", \"Temp\": \"32°C\", \"Humidity\": \"72%\", \"Workload_Type\": \"Data Analytics\", \"Processing_Tier\": \"Device\", \"CPU_Usage(%)\": 24, \"Memory_Usage(MB)\": 2276, \"Network_Latency(ms)\": 29, \"Network_Jitter(ms)\": 8, \"Task_Execution_Time(ms)\": 189, \Predicted_Resource_Allocation(%)\": 54}\n###\nPredict the Actual Resource Allocation(%)"
print(run_mamba(model, context=context))

51
