In [1]:
import os
import wandb
import transformers
import json
import random
import math
import pickle
import pandas as pd
import numpy as np
import re

from typing import Dict, List, Optional, Tuple, Union

from dataclasses import dataclass
from datasets import load_dataset, Dataset, DatasetDict, load_from_disk
from sklearn.metrics import accuracy_score, f1_score, matthews_corrcoef
from sklearn.model_selection import train_test_split

from transformers import Trainer, AutoTokenizer, Mamba2ForCausalLM, TrainingArguments, DataCollatorWithPadding, Mamba2Config, GenerationMixin, DataCollatorForMultipleChoice, DataCollatorForLanguageModeling
from transformers.models.mamba2.modeling_mamba2 import Mamba2PreTrainedModel, Mamba2Model, Mamba2Cache, Mamba2CausalLMOutput
from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings, add_code_sample_docstrings
import evaluate

from peft import PeftModel, LoraConfig, get_peft_model
import peft.tuners.lora.layer as pl

import torch
from torch import nn
import torch.nn.functional as F
from torch.nn.modules.loss import CrossEntropyLoss

import pennylane as qml

from trl import SFTTrainer

from tqdm import tqdm
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence

from tqdm.notebook import tqdm
import time
from IPython.display import clear_output

_CHECKPOINT_FOR_DOC = "mistralai/mamba-codestral-7B-v0.1"
_CONFIG_FOR_DOC = "Mamba2Config"

In [2]:
# Take in the model you want to train
model_name = "mamba2_cp/0_iot_fpft_mamba2_130m"

tokenizer_model_name="AntonV/mamba2-130m-hf"

output_file = os.path.join(
    r"output", "0_iot_fpft_mamba2_130m_result.json"
)

In [3]:
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right" #enforce padding side left

In [4]:
CACHE_PATH = "iot_datasets/iot_resource_allocation_hf_dataset"
ds_dict = load_from_disk(CACHE_PATH) 

In [5]:
iot_test_hf_dataset = Dataset.from_list(ds_dict["test"])
print(iot_test_hf_dataset[0])

{'text': '{"Timestamp": "2024-03-25 14:13:20", "Device_ID": "D2", "Temp": "32°C", "Humidity": "47%", "Workload_Type": "Image Processing", "Processing_Tier": "Fog", "CPU_Usage(%)": 30, "Memory_Usage(MB)": 3059, "Network_Latency(ms)": 33, "Network_Jitter(ms)": 8, "Task_Execution_Time(ms)": 60, "Predicted_Resource_Allocation(%)": 55}\n###\nPredict the Actual Resource Allocation(%)', 'label': '59'}


In [None]:
def length_only(example):
    ids = tokenizer(example["text"], add_special_tokens=True,
                    padding=False, truncation=False).input_ids
    return {"seq_len": len(ids)}

train_with_len = ds_dict["test"].map(length_only)
MAX_LEN = int(np.max(train_with_len["seq_len"]))
print("Raw longest length:", MAX_LEN)

MAX_LEN = min(MAX_LEN, tokenizer.model_max_length)           
print("Effective MAX_LEN:", MAX_LEN)

Map:   0%|          | 0/200 [00:00<?, ? examples/s]

Raw longest length: 141
Effective MAX_LEN: 141


In [7]:
model = Mamba2ForCausalLM.from_pretrained(
    model_name, 
    device_map="cuda",
    )
print(model)

Mamba2ForCausalLM(
  (backbone): Mamba2Model(
    (embeddings): Embedding(50277, 768)
    (layers): ModuleList(
      (0-23): 24 x Mamba2Block(
        (norm): Mamba2RMSNorm()
        (mixer): Mamba2Mixer(
          (act): SiLU()
          (conv1d): Conv1d(1792, 1792, kernel_size=(4,), stride=(1,), padding=(3,), groups=1792)
          (in_proj): Linear(in_features=768, out_features=3352, bias=False)
          (norm): MambaRMSNormGated()
          (out_proj): Linear(in_features=1536, out_features=768, bias=False)
        )
      )
    )
    (norm_f): Mamba2RMSNorm()
  )
  (lm_head): Linear(in_features=768, out_features=50277, bias=False)
)


In [None]:
def run_mamba(model, context):

    text = f"{context}"
    # print(text)
    # input_ids = torch.LongTensor([tokenizer.encode(text)]).cuda()
    input_ids = tokenizer.encode(text, return_tensors="pt").cuda()
    # print(input_ids)

    attention_mask = torch.ones_like(input_ids).cuda()

    # print("max_length", input_ids.shape[1])

    out = model.generate(
        input_ids=input_ids,
        attention_mask=attention_mask,
        max_length=input_ids.shape[1]+10, #max_length_in_dataset
        # max_length=max_length_in_dataset,
        # max_new_tokens=max_length_in_dataset,
        eos_token_id=tokenizer.eos_token_id,
    )

    # print(out)
    decoded = tokenizer.batch_decode(out)[0]
    # print("="*80)
    # print(decoded)

    # out returns the whole sequence plus the original
    cleaned = decoded.replace(text, "")
    # cleaned = decoded[len(text):]
    cleaned = cleaned.replace("<|endoftext|>", "")

    # the model will just keep generating, so only grab the first one
    cleaned = cleaned.split("\n\n")[0].strip()
    lines = cleaned.splitlines()
    if lines:
        cleaned = lines[0].strip()

    # print(answer)
    return cleaned

In [9]:
context="{\"Timestamp\": \"2024-03-25 12:00:00\", \"Device_ID\": \"D14\", \"Temp\": \"32°C\", \"Humidity\": \"72%\", \"Workload_Type\": \"Data Analytics\", \"Processing_Tier\": \"Device\", \"CPU_Usage(%)\": 24, \"Memory_Usage(MB)\": 2276, \"Network_Latency(ms)\": 29, \"Network_Jitter(ms)\": 8, \"Task_Execution_Time(ms)\": 189, \Predicted_Resource_Allocation(%)\": 54}\n###\nPredict the Actual Resource Allocation(%)"
print(run_mamba(model, context=context))

51


In [None]:
def write_results(results, output_file):
    with open(output_file, "w", encoding="utf-8") as f:
        json.dump(results, f, indent=2, ensure_ascii=False)

results = []
num_correct = 0  


start_time_total = time.time()


all_data = iot_test_hf_dataset

total_qs = len(all_data)

pbar = tqdm(total=total_qs, desc="Processing Questions: 1/{}".format(total_qs), dynamic_ncols=True)

for i, data in enumerate(all_data):
    start_time = time.time()

    pbar.set_description(f"Processing Questions: {i+1}/{total_qs}")

    text = data["text"]
    y_true = float(data["label"])  

    guess = run_mamba(model, text)  
    elapsed_time = time.time() - start_time

    try:
        pred_value = float(re.findall(r"[-+]?\d*\.\d+|\d+", guess)[-1])
    except (IndexError, ValueError):
        pred_value = np.nan

    is_correct = abs(pred_value - y_true) <= 3
    if is_correct:
        num_correct += 1

    tqdm.write(f"\nQuestion {i+1}/{total_qs}")
    tqdm.write(f"Context: {text}")
    tqdm.write(f"y_true: {y_true}")
    tqdm.write(f"pred: {pred_value}")
    if is_correct:
        tqdm.write("✅ Correct")
    else:
        tqdm.write("❌ Incorrect")
    tqdm.write(f"Time taken: {elapsed_time:.2f}s")
    tqdm.write("=" * 80)

    result = {
        "idx": i,
        "question": text,
        "context": context,
        "answer": y_true,
        "guess": pred_value,
        "is_correct": is_correct,
        "time": elapsed_time,
    }
    results.append(result)

    if len(results) % 20 == 0:
        write_results(results, output_file)

    if (i + 1) % 100 == 0:
        accuracy = num_correct / (i + 1) * 100
        tqdm.write(f"Current progress: {i+1}/{total_qs}, Accuracy: {accuracy:.2f}%")
        tqdm.write("=" * 80)
        pbar.set_postfix({"Accuracy": f"{accuracy:.2f}%", "Time/Question": f"{elapsed_time:.2f}s"})

    pbar.update(1)

pbar.close()

total_runtime = time.time() - start_time_total
hours = total_runtime / 3600


accuracy = num_correct / total_qs * 100
summary = {
    "total_questions": total_qs,
    "correct_answers": num_correct,
    "accuracy": accuracy,
    "run_time": hours,
}
results.append(summary)

write_results(results, output_file)

tqdm.write(f"Number of correct answers: {num_correct}, Total questions: {total_qs}, Accuracy rate: {accuracy:.2f}%, Total runtime: {total_runtime:.2f}s ({hours:.2f} hours)")

Processing Questions: 1/200:   0%|                                                                     | 0/200…


Question 1/200
Context: {"Timestamp": "2024-03-25 14:13:20", "Device_ID": "D2", "Temp": "32°C", "Humidity": "47%", "Workload_Type": "Image Processing", "Processing_Tier": "Fog", "CPU_Usage(%)": 30, "Memory_Usage(MB)": 3059, "Network_Latency(ms)": 33, "Network_Jitter(ms)": 8, "Task_Execution_Time(ms)": 60, "Predicted_Resource_Allocation(%)": 55}
###
Predict the Actual Resource Allocation(%)
y_true: 59.0
pred: 79.0
❌ Incorrect
Time taken: 0.11s

Question 2/200
Context: {"Timestamp": "2024-03-25 14:13:30", "Device_ID": "D19", "Temp": "36°C", "Humidity": "48%", "Workload_Type": "Image Processing", "Processing_Tier": "Fog", "CPU_Usage(%)": 43, "Memory_Usage(MB)": 5126, "Network_Latency(ms)": 43, "Network_Jitter(ms)": 4, "Task_Execution_Time(ms)": 72, "Predicted_Resource_Allocation(%)": 88}
###
Predict the Actual Resource Allocation(%)
y_true: 92.0
pred: 90.0
✅ Correct
Time taken: 0.03s

Question 3/200
Context: {"Timestamp": "2024-03-25 14:13:40", "Device_ID": "D11", "Temp": "22°C", "Humidi