In [1]:
!pip install google-cloud-storage



In [2]:
!pip install axolotl[ring-flash-attn]

Collecting axolotl[ring-flash-attn]
  Downloading axolotl-0.8.1.tar.gz (275 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/275.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m275.1/275.1 kB[0m [31m31.5 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting bitsandbytes==0.45.4 (from axolotl[ring-flash-attn])
  Downloading bitsandbytes-0.45.4-py3-none-manylinux_2_24_x86_64.whl.metadata (5.0 kB)
Collecting triton>=3.0.0 (from axolotl[ring-flash-attn])
  Downloading triton-3.3.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (1.5 kB)
Collecting liger-kernel==0.5.6 (from axolotl[ring-flash-attn])
  Downloading liger_kernel-0.5.6-py3-none-any.whl.metadata (23 kB)
Collecting packaging==23.2 (from axolotl[ring-flash-attn])
  U

In [3]:
!pip install transformers huggingface_hub




In [4]:
!pip install  transformers datasets peft accelerate bitsandbytes wandb deepspeed

Collecting deepspeed
  Downloading deepspeed-0.16.5.tar.gz (1.5 MB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/1.5 MB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.5/1.5 MB[0m [31m86.0 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting hjson (from deepspeed)
  Downloading hjson-3.1.0-py3-none-any.whl.metadata (2.6 kB)
Collecting ninja (from deepspeed)
  Using cached ninja-1.11.1.4-py3-none-manylinux_2_12_x86_64.manylinux2010_x86_64.whl.metadata (5.0 kB)
Downloading hjson-3.1.0-py3-none-any.whl (54 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m54.0/54.0 kB[0m [31m1.5 MB/s[0m eta [36m0:00:00[0m
[?25hUsing cached ninja-1.11.1.4-py3-none-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (422 kB)
Building wheels for collected packages: deepspeed
  Building wheel for deepspeed (setup.py) ... [?25l[?25hdone
  Created wheel for deepspeed

# Load Data

In [5]:
import json
from google.cloud import storage
import os

def save_json(data, filename):
    # Get the directory from the filename
    directory = os.path.dirname(filename)

    # Check if the directory exists, if not, create it
    if not os.path.exists(directory):
        os.makedirs(directory)

    # Save the data to the file
    with open(filename, 'w') as json_file:
        json.dump(data, json_file, indent=4)

def list_files_in_bucket(bucket_name, prefix=""):
    client = storage.Client()
    bucket = client.get_bucket(bucket_name)
    blobs = bucket.list_blobs(prefix=prefix)

    # Print the list of file names in the bucket
    print("Files in the bucket:")
    for blob in blobs:
        print(blob.name)

def load_json_from_gcs(bucket_name, file_name):
    from google.cloud import storage
    import json

    client = storage.Client()
    bucket = client.get_bucket(bucket_name)
    blob = bucket.blob(file_name)

    if not file_name.endswith('.jsonl'):  # Ensure it's a JSONL file
        raise ValueError(f"The specified file '{file_name}' is not a JSONL file.")

    concatenated_data = []  # To accumulate JSON objects
    try:
        # Download and decode the file content
        content = blob.download_as_string().decode('utf-8')
        # Split content by lines and load each line as a separate JSON object
        for line in content.splitlines():
            if line.strip():  # Only parse non-empty lines
                json_obj = json.loads(line)
                concatenated_data.append(json_obj)  # Add JSON object to the list
    except json.JSONDecodeError as e:
        print(f"Error decoding JSON in {file_name}: {e}")

    # Return the JSON as a string for output
    return json.dumps(concatenated_data, indent=4)  # Prettify the JSON output

def load_csv_from_gcs(bucket_name, file_name):
    from google.cloud import storage
    import pandas as pd
    from io import StringIO  # Corrected import for StringIO

    client = storage.Client()
    bucket = client.get_bucket(bucket_name)
    blob = bucket.blob(file_name)

    if not file_name.endswith('.csv'):  # Ensure it's a CSV file
        raise ValueError(f"The specified file '{file_name}' is not a CSV file.")

    try:
        # Download CSV content and load it into a pandas DataFrame
        content = blob.download_as_string().decode('utf-8')
        data = pd.read_csv(StringIO(content))  # Use StringIO to parse the CSV content
    except Exception as e:
        print(f"Error loading CSV file '{file_name}': {e}")
        return None

    return data


def save_csv_to_gcs(bucket_name, file_name, dataframe):
    from google.cloud import storage
    import pandas as pd

    client = storage.Client()
    bucket = client.get_bucket(bucket_name)
    blob = bucket.blob(file_name)

    if not file_name.endswith('.csv'):
        raise ValueError(f"The specified file '{file_name}' is not a CSV file.")

    try:
        # Convert the DataFrame to CSV and upload it to GCS
        csv_content = dataframe.to_csv(index=False)  # Convert DataFrame to CSV string
        blob.upload_from_string(csv_content, content_type='text/csv')
        print(f"File '{file_name}' successfully saved to bucket '{bucket_name}'.")
    except Exception as e:
        print(f"Error saving CSV file '{file_name}': {e}")

In [6]:
def stratified_sample(df, col1, col2, frac=0.5, random_state=None):
    total_samples = int(len(df) * frac)
    grouped = df.groupby([col1, col2])
    n_groups = len(grouped)
    samples_per_group = total_samples // n_groups

    sampled_df = (
        grouped
        .apply(lambda x: x.sample(n=min(samples_per_group, len(x)), random_state=random_state))
        .reset_index(drop=True)
    )
    return sampled_df

In [7]:
train_data = load_csv_from_gcs("mddi-reach-conversation", "mistral_training_data/mistral_train.csv")
augmented_full_df = load_csv_from_gcs("mddi-reach-conversation", "mistral_training_data/augmented_mistral_train.csv")
test_data = load_csv_from_gcs("mddi-reach-conversation", "mistral_training_data/mistral_test.csv").rename({'key_point': 'stance', 'person_id': 'user'}, axis=1)
vali_data = load_csv_from_gcs("mddi-reach-conversation", "mistral_training_data/mistral_val.csv")
vali_data['prompt'] = vali_data[['prompt', 'label']].apply(lambda x: x[0] + str(x[1]) + "</s>", axis=1)

train_data.tail()

  vali_data['prompt'] = vali_data[['prompt', 'label']].apply(lambda x: x[0] + str(x[1]) + "</s>", axis=1)


Unnamed: 0,group,user,stance,topic_group,human_label,agreement,group_user,label,pid,chat_group_id,contributor,topic,content_concat,prompt
2435,6,dd88a7a7-de66-4503-bfee-dfc0e74f578f,Contributor shared varying perspectives on whe...,rideout,no opinion,unanimous,6dd88a7a7-de66-4503-bfee-dfc0e74f578f,0,dd88a7a7-de66-4503-bfee-dfc0e74f578f,6,16,"📢 *Topic* 📢\n\nIn Parliament today (3 July), S...",Contributor255: https://www.channelnewsasia.co...,<s>[INST]Determine whether Contributor16 holds...
2436,6,dd88a7a7-de66-4503-bfee-dfc0e74f578f,Contributor had differing opinions on whether ...,rideout,no opinion,unanimous,6dd88a7a7-de66-4503-bfee-dfc0e74f578f,0,dd88a7a7-de66-4503-bfee-dfc0e74f578f,6,16,"📢 *Topic* 📢\n\nIn Parliament today (3 July), S...",Contributor255: https://www.channelnewsasia.co...,<s>[INST]Determine whether Contributor16 holds...
2437,6,dd88a7a7-de66-4503-bfee-dfc0e74f578f,Contributor expressed different views on the t...,rideout,no opinion,unanimous,6dd88a7a7-de66-4503-bfee-dfc0e74f578f,0,dd88a7a7-de66-4503-bfee-dfc0e74f578f,6,16,"📢 *Topic* 📢\n\nIn Parliament today (3 July), S...",Contributor255: https://www.channelnewsasia.co...,<s>[INST]Determine whether Contributor16 holds...
2438,6,dd88a7a7-de66-4503-bfee-dfc0e74f578f,Contributor shared differing opinions on wheth...,rideout,no opinion,unanimous,6dd88a7a7-de66-4503-bfee-dfc0e74f578f,0,dd88a7a7-de66-4503-bfee-dfc0e74f578f,6,16,"📢 *Topic* 📢\n\nIn Parliament today (3 July), S...",Contributor255: https://www.channelnewsasia.co...,<s>[INST]Determine whether Contributor16 holds...
2439,6,dd88a7a7-de66-4503-bfee-dfc0e74f578f,Contributor shared differing views on the inde...,rideout,no opinion,unanimous,6dd88a7a7-de66-4503-bfee-dfc0e74f578f,0,dd88a7a7-de66-4503-bfee-dfc0e74f578f,6,16,"📢 *Topic* 📢\n\nIn Parliament today (3 July), S...",Contributor255: https://www.channelnewsasia.co...,<s>[INST]Determine whether Contributor16 holds...


In [8]:
augmented_full_df_sampled = stratified_sample(augmented_full_df, 'stance', 'human_label', frac=0.3, random_state=42)
augmented_full_df_sampled.shape

  .apply(lambda x: x.sample(n=min(samples_per_group, len(x)), random_state=random_state))


(9082, 7)

In [9]:
stratified_test = test_data.copy().sample(frac=0.1)[[ 'stance', 'prompt', 'label']].reset_index(drop=True)
stratified_test["prompt"] = stratified_test[["prompt", "label"]].apply(lambda x: x[0] + str(x[1]) + "</s>", axis=1)
print(stratified_test.prompt.loc[0][-20:])
print(stratified_test.shape)
list(stratified_test.prompt.values)[0][-5:]

selves. [/INST]1</s>
(135, 3)


  stratified_test["prompt"] = stratified_test[["prompt", "label"]].apply(lambda x: x[0] + str(x[1]) + "</s>", axis=1)


'1</s>'

In [10]:
import pandas as pd
df = pd.concat([
    train_data[['chat_group_id', 'label', 'stance', 'prompt', 'human_label', 'topic_group', 'agreement', 'user', 'group_user', 'topic']],
    #augmented_full_df[['chat_group_id',  'stance', 'content', 'class', 'context', 'human_label']].rename({'class': 'label'}, axis=1)
    augmented_full_df_sampled[[  'stance', 'prompt', 'label', 'human_label']],
    stratified_test
    ]
).reset_index(drop=True)

#df = train_data[['chat_group_id', 'label', 'stance', 'prompt', 'human_label', 'topic_group', 'agreement', 'user', 'group_user', 'topic']].loc[:100]
df.head(2)

Unnamed: 0,chat_group_id,label,stance,prompt,human_label,topic_group,agreement,user,group_user,topic
0,1.0,0,Contributor shared different perspectives on p...,<s>[INST]Determine whether Contributor429 hold...,no opinion,national_day,unanimous,04195570-f8b3-4eab-866f-32808d77d8e1,104195570-f8b3-4eab-866f-32808d77d8e1,"*📢 Topic 📢*\nIn his National Day Message, Prim..."
1,1.0,0,Contributor expressed differing levels of trus...,<s>[INST]Determine whether Contributor429 hold...,disagree,national_day,majority,04195570-f8b3-4eab-866f-32808d77d8e1,104195570-f8b3-4eab-866f-32808d77d8e1,"*📢 Topic 📢*\nIn his National Day Message, Prim..."


In [11]:
import json

def mistral_to_chat(entry):
    entry = entry.strip()

    # Check if entry starts with <s>[INST] and ends with </s>
    if not entry.startswith("<s>[INST]") or not entry.endswith("</s>"):
        print("Skipped (bad format):", entry)
        return None

    messages = []
    while True:
        # Look for the next [INST] ... [/INST] ... </s> block
        inst_start = entry.find("[INST]")
        inst_end = entry.find("[/INST]")
        s_end = entry.find("</s>", inst_end)

        if inst_start == -1 or inst_end == -1 or s_end == -1:
            break

        instruction = entry[inst_start + len("[INST]"):inst_end].strip()
        output = entry[inst_end + len("[/INST]"):s_end].strip()

        if instruction:
            messages.append({"role": "user", "content": instruction})
        if output:
            messages.append({"role": "assistant", "content": output})

        # Move to the next segment
        entry = entry[s_end + len("</s>"):].strip()

    if not messages:
        return None

    return {"messages": messages}


def convert_list_to_jsonl(entries, output_path):
    with open(output_path, "w", encoding="utf-8") as outfile:
        for entry in entries:
            chat_format = mistral_to_chat(entry)
            if chat_format:
                json.dump(chat_format, outfile, ensure_ascii=False)
                outfile.write("\n")
train_prompts = list(df.prompt.values)

convert_list_to_jsonl(train_prompts, "train_data.jsonl")


In [12]:
print(len(df))
df.prompt.str.len().describe()

11657


Unnamed: 0,prompt
count,11657.0
mean,7604.036802
std,3066.212165
min,518.0
25%,8081.0
50%,8515.0
75%,8787.0
max,19208.0


# Fine-tuning

In [13]:
%%writefile /content/config.yml
# File: /content/config.yml

base_model: unsloth/Meta-Llama-3.1-8B-Instruct

datasets:
  - path: /content/train_data.jsonl
    ds_type: json
    type: chat_template
    chat_template: tokenizer_default
    field_messages: messages
    message_property_mappings:
      role: role
      content: content
    roles:
      user: ["user"]
      assistant: ["assistant"]
    drop_system_message: true
    roles_to_train: ["assistant"]
    system_prompt: "You are a helpful AI assistant for classifying stance."
    data_files:
      - /content/train_data.jsonl

# test_datasets:
#   - path: /content/val_data.jsonl
#     ds_type: json
#     type:
#       type: chat_template
#       chat_template: tokenizer_default
#     data_files:
#       - /content/val_data.jsonl
#     split: train

dataset_processes: 1

# Output
output_dir: /content/llama-output

# LoRA config (if you still want to use LoRA)
adapter: lora
lora_r: 8
lora_alpha: 16
lora_dropout: 0.1
lora_target_modules:
  - q_proj
  - k_proj
  #- v_proj
  #- o_proj
  #- gate_proj
  #- up_proj
  #- down_proj
lora_modules_to_save:
  - embed_tokens
  - lm_head

# Format
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer
sequence_len: 8192
pad_to_sequence_len: true
load_in_8bit: true
#load_in_4bit: true

flash_attention: true
sequence_parallel_degree: 2

# Training
num_epochs: 2
micro_batch_size: 1
gradient_accumulation_steps: 4

# Optimization
learning_rate: 2e-5
lr_scheduler_type: cosine
#weight_decay: 0.001

#Validation
# evaluation_strategy: steps
# #early_stopping_patience: 3
# do_causal_lm_eval: false
# eval_causal_lm_metrics:
#   - sacrebleu
#   - ter
#   - perplexity
# eval_sample_packing: false
# eval_max_new_tokens: 9
# eval_batch_size: 1
# eval_steps: 1000
# logging_steps: 1000

save_steps: 200

# Precision
bf16: true

# Trainer
trainer: AxolotlTrainer

# DeepSpeed
#deepspeed: /content/ds_config_zero3.json

# Wandb
wandb_mode: online
wandb_project: reach-fine-tuning
wandb_name: llama-v0
wandb_run_id: llama-v0
wandb_log_model: end


Writing /content/config.yml


In [14]:
%%writefile /content/ds_config_zero3.json
{
  "train_batch_size": 8,
  "train_micro_batch_size_per_gpu": 1,
  "gradient_accumulation_steps": 4,
  "zero_optimization": {
    "stage": 3,
    "offload_optimizer": {
      "device": "cpu"
    },
    "offload_param": {
      "device": "cpu"
    },
    "overlap_comm": true,
    "contiguous_gradients": true
  },
  "bf16": {
    "enabled": true
  },
  "steps_per_print": 100,
  "wall_clock_breakdown": false
}



Writing /content/ds_config_zero3.json


In [32]:
import torch
# Create a tensor and move it to GPU
x = torch.randn(1, 1, device='cuda')

torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()

if torch.cuda.is_available():
    print(f"CUDA is available! Using GPU: {torch.cuda.get_device_name(0)}")
else:
    print("CUDA is not available. Ensure a GPU is installed and configured.")

# Check the available and total memory on the GPU
allocated_memory = torch.cuda.memory_allocated()  # Memory currently in use
reserved_memory = torch.cuda.memory_reserved()  # Memory reserved by the allocator
free_memory = torch.cuda.memory_reserved() - torch.cuda.memory_allocated()  # Free memory

print(f"Allocated Memory: {allocated_memory / 1024 ** 2:.2f} MB")
print(f"Reserved Memory: {reserved_memory / 1024 ** 2:.2f} MB")
print(f"Free Memory: {free_memory / 1024 ** 2:.2f} MB")

CUDA is available! Using GPU: NVIDIA A100-SXM4-40GB
Allocated Memory: 0.00 MB
Reserved Memory: 2.00 MB
Free Memory: 2.00 MB


In [33]:
import os

# Set the environment variable
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'

# Now, import PyTorch
import torch

# You can now use PyTorch as usual
print(torch.cuda.is_available())  # Check if CUDA is available

True


In [15]:
from huggingface_hub import login

login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [16]:
!wandb login

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
[34m[1mwandb[0m: Paste an API key from your profile and hit enter, or press ctrl+c to quit: 
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


In [None]:
!nvidia-smi nvlink --status

GPU 0: NVIDIA A100-SXM4-40GB (UUID: GPU-a82dd269-9d2e-6d13-818f-001427bc2793)
	 Link 0: 25 GB/s
	 Link 1: 25 GB/s
	 Link 2: 25 GB/s
	 Link 3: 25 GB/s
	 Link 4: 25 GB/s
	 Link 5: 25 GB/s
	 Link 6: 25 GB/s
	 Link 7: 25 GB/s
	 Link 8: 25 GB/s
	 Link 9: 25 GB/s
	 Link 10: 25 GB/s
	 Link 11: 25 GB/s
GPU 1: NVIDIA A100-SXM4-40GB (UUID: GPU-d14257b2-2da5-0825-6052-9757a3ba708b)
	 Link 0: 25 GB/s
	 Link 1: 25 GB/s
	 Link 2: 25 GB/s
	 Link 3: 25 GB/s
	 Link 4: 25 GB/s
	 Link 5: 25 GB/s
	 Link 6: 25 GB/s
	 Link 7: 25 GB/s
	 Link 8: 25 GB/s
	 Link 9: 25 GB/s
	 Link 10: 25 GB/s
	 Link 11: 25 GB/s


In [None]:
!export NCCL_P2P_LEVEL=NVL

In [None]:
!rm -rf /content/mistral-output*
!rm -rf /content/wandb*
!rm -rf /content/last_run_prepared*

In [None]:
import os
print(os.path.exists('/content/train_data.jsonl'))  # Should return True
print(os.path.exists('/content/val_data.jsonl'))    # Should return True
print(os.path.exists('/content/ds_config_zero3.json'))
print(os.path.exists('/content/config.yml'))

True
True
True
True


In [17]:
!accelerate launch -m axolotl.cli.train /content/config.yml

	`--num_processes` was set to a value of `4`
		More than one GPU was found, enabling multi-GPU training.
		If this was unintended please pass in `--num_processes=1`.
	`--num_machines` was set to a value of `1`
	`--mixed_precision` was set to a value of `'no'`
	`--dynamo_backend` was set to a value of `'no'`
2025-04-16 03:25:49.847729: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-04-16 03:25:49.847726: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-04-16 03:25:49.847729: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You ma

In [None]:
"""
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    load_in_4bit=True,
    device_map="auto",
    torch_dtype=torch.bfloat16,
)

model.to("cuda")

model = prepare_model_for_kbit_training(model)

peft_config = LoraConfig(
    r=8,
    lora_alpha=32,
    target_modules=["q_proj", "v_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
)

model = get_peft_model(model, peft_config)

training_args = TrainingArguments(
    output_dir="/content/outputs/mistral-qlora",
    per_device_train_batch_size=1,
    gradient_accumulation_steps=8,
    num_train_epochs=3,
    evaluation_strategy="steps",  # eval every few steps
    eval_steps=20,                # eval every 20 steps
    logging_steps=5,              # log training loss every 5 steps
    save_strategy="steps",
    save_steps=50,                # save model every 50 steps
    learning_rate=2e-4,
    lr_scheduler_type="cosine",   # Cosine decay for LR
    warmup_ratio=0.1,             # 10% warm-up
    bf16=True,
    report_to="wandb",            # Track with WandB
    run_name="mistral-finetune",
    logging_dir="./logs",         # For tracking logs
)

# Step 5: Data Collator
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

# Step 6: Early Stopping Callback
early_stopping = EarlyStoppingCallback(
    early_stopping_patience=2  # stop if eval loss doesn't improve for 2 evals
)

# Step 7: Compute Perplexity
def compute_metrics(eval_preds):
    loss = eval_preds["loss"] if isinstance(eval_preds, dict) else eval_preds.loss
    perplexity = math.exp(loss)
    return {"eval_loss": loss, "perplexity": perplexity}

# Step 8: Initialize the Trainer
trainer = SFTTrainer(
    model=model,
    args=training_args,
    train_dataset=dataset["train"],
    eval_dataset=dataset["validation"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    callbacks=[early_stopping],  # Add early stopping
)

# Step 9: Train the Model
trainer.train()
"""

In [None]:
# Save Model artifact to gcs
!gsutil -m rm -r gs://mddi-reach-conversation/llama-output/**
!gsutil -m cp -r /content/llama-output/ gs://mddi-reach-conversation/

Removing gs://mddi-reach-conversation/mistral-output/README.md#1744695848343972...
Removing gs://mddi-reach-conversation/mistral-output/adapter_config.json#1744695848322978...
Removing gs://mddi-reach-conversation/mistral-output/adapter_model.safetensors#1744695871889325...
Removing gs://mddi-reach-conversation/mistral-output/checkpoint-2600/README.md#1744695848539387...
Removing gs://mddi-reach-conversation/mistral-output/checkpoint-2600/adapter_config.json#1744695849306100...
Removing gs://mddi-reach-conversation/mistral-output/checkpoint-2600/adapter_model.safetensors#1744695872109259...
Removing gs://mddi-reach-conversation/mistral-output/checkpoint-2600/optimizer.pt#1744695893551287...
Removing gs://mddi-reach-conversation/mistral-output/checkpoint-2600/rng_state_0.pth#1744695848674136...
Removing gs://mddi-reach-conversation/mistral-output/checkpoint-2600/rng_state_1.pth#1744695848422612...
Removing gs://mddi-reach-conversation/mistral-output/checkpoint-2600/scheduler.pt#17446958

## Testing

In [None]:
# Load Model
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer
model_name = "meta-llama/Llama-3.1-8B-Instruct"
# Load base Mistral model
base_model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype="auto",
    device_map="auto",
)
tokenizer = AutoTokenizer.from_pretrained(model_name)


[2025-04-15 15:29:50,833] [INFO] [real_accelerator.py:239:get_accelerator] Setting ds_accelerator to cuda (auto detect)


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [None]:

# Load fine-tuned PEFT adapter
model = PeftModel.from_pretrained(base_model, "/content/mistral-output")

# Merge LoRA weights into base model
merged_model = model.merge_and_unload()

In [None]:
for name, param in model.named_parameters():
    print(f"{name}: mean={param.data.mean().item():.4f}, std={param.data.std().item():.4f}")

base_model.model.model.embed_tokens.weight: mean=-0.0000, std=0.0027
base_model.model.model.layers.0.self_attn.q_proj.weight: mean=0.0000, std=0.0042
base_model.model.model.layers.0.self_attn.k_proj.weight: mean=-0.0000, std=0.0044
base_model.model.model.layers.0.self_attn.v_proj.weight: mean=-0.0000, std=0.0020
base_model.model.model.layers.0.self_attn.o_proj.weight: mean=0.0000, std=0.0020
base_model.model.model.layers.0.mlp.gate_proj.weight: mean=0.0000, std=0.0032
base_model.model.model.layers.0.mlp.up_proj.weight: mean=-0.0000, std=0.0030
base_model.model.model.layers.0.mlp.down_proj.weight: mean=-0.0000, std=0.0029
base_model.model.model.layers.0.input_layernorm.weight: mean=0.0845, std=0.2461
base_model.model.model.layers.0.post_attention_layernorm.weight: mean=0.4258, std=0.0640
base_model.model.model.layers.1.self_attn.q_proj.weight: mean=-0.0000, std=0.0040
base_model.model.model.layers.1.self_attn.k_proj.weight: mean=-0.0000, std=0.0049
base_model.model.model.layers.1.self_a

In [None]:
def mistral_to_llama(entry, system_prompt="You are a helpful AI assistant for classifying stance."):
    entry = entry.strip()

    # Remove leading/trailing <s> or </s> if duplicated or misformatted
    if entry.startswith("<s>"):
        entry = entry[len("<s>"):].strip()
    if entry.endswith("</s>"):
        entry = entry[:-len("</s>")].strip()

    llama_chat = f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n{system_prompt}<|eot_id|>"

    while True:
        inst_start = entry.find("[INST]")
        inst_end = entry.find("[/INST]")

        if inst_start == -1 or inst_end == -1:
            break

        # Extract instruction
        instruction = entry[inst_start + len("[INST]"):inst_end].strip()

        # Find the next closing </s> *after* the [/INST]
        s_end = entry.find("</s>", inst_end)
        if s_end == -1:
            output = entry[inst_end + len("[/INST]"):].strip()
            entry = ""  # no more segments
        else:
            output = entry[inst_end + len("[/INST]"):s_end].strip()
            entry = entry[s_end + len("</s>"):].strip()

        if instruction:
            llama_chat += f"<|start_header_id|>user<|end_header_id|>\n{instruction}<|eot_id|>"
        if output:
            llama_chat += f"<|start_header_id|>assistant<|end_header_id|>\n{output}<|eot_id|>"

    return llama_chat if "<|start_header_id|>user<|end_header_id|>" in llama_chat else None

test_data = load_csv_from_gcs("mddi-reach-conversation", "mistral_training_data/mistral_test.csv").rename({'key_point': 'stance', 'person_id': 'user'}, axis=1)

test_data['llama_prompt'] = test_data['prompt'].apply(mistral_to_llama)

In [None]:
# Define an input prompt
input_text = list(test_data.prompt.values)[4]

# Tokenize the input text
inputs = tokenizer(input_text.replace("(1 for agree, 0 for not agree)", "Strictly only output 1 for agree, 0 for not agree").replace("<s><s>", "<s>"), return_tensors="pt", truncation=True).to("cuda")

# Generate text
output = merged_model.generate(inputs["input_ids"], attention_mask=inputs['attention_mask'], pad_token_id=tokenizer.pad_token_id)

# Decode the generated text
generated_text = tokenizer.decode(output[0])

# Print the generated text
print(generated_text.split("[/INST]")[-1], list(test_data.label.values)[4])

 0

Contributor234's comments suggest that they have concerns about the 1


In [None]:
generated_text

'<s><s>[INST]Determine whether Contributor234 holds the same view as this statement: \'Contributors agreed with the need for societal attitudes to evolve alongside structural reforms to reduce academic pressure and broaden definitions of success.’? Based on the following conversation summary, respond with ‘1’ if they share the view, or ‘0’ otherwise. Even if it is implicit, consider it a match. Do not include any additional text. The following are messages by contributor 234 and other contributors: \n Contributor283: Singapore as a society has yet to realise that academic qualifications are no barometer of ability, intelligence, or successI have seen a lot of "scholars" do incredibly stupid things that are completely out of touch with reality\nContributor771: There are people who memorized word to word for exams... But no problem solving brain... Heard that that\'s what foreign students do... Lecturers thought they copied from books or friends.. tested them.. they wrote all word for wo

In [None]:
import warnings
from transformers import logging

# Suppress the specific warning
warnings.filterwarnings("ignore", message="Setting `pad_token_id` to `eos_token_id`")

# Alternatively, set logging to error level to suppress other warnings
logging.set_verbosity_error()


In [None]:
print(len(test_data))

1349


In [None]:
# Format Test Data
from tqdm import tqdm
all_predictions = []
for test_prompt in tqdm(list(test_data.prompt.values)):
    # Tokenize the input text
    inputs = tokenizer(test_prompt.replace("(1 for agree, 0 for not agree)", "You must strictly only output 1 for agree, 0 for not agree").replace("<s><s>", "<s>"), return_tensors="pt", truncation=True).to("cuda")

    # Generate text
    output = merged_model.generate(inputs["input_ids"], attention_mask=inputs['attention_mask'], pad_token_id=tokenizer.pad_token_id)

    # Decode the generated text
    generated_text = tokenizer.decode(output[0] )
    generated_text_cleaned = generated_text.split("[/INST]")[-1]
    if "1" in generated_text_cleaned:
        prediction = 1
    elif "0" in generated_text_cleaned:
        prediction = 0
    else:
        prediction = -1
    all_predictions.append(prediction)
test_data["pred"] = all_predictions


100%|██████████| 1349/1349 [18:53<00:00,  1.19it/s]


In [None]:
print(f"Percentage of incorrectly parsed output: {len(test_data[test_data['pred'] == -1])/len(test_data)}")


Percentage of incorrectly parsed output: 0.005189028910303929


In [None]:
from sklearn.metrics import (
    f1_score,
    precision_score,
    recall_score,
    roc_auc_score,
    roc_curve,
    precision_recall_curve,
    average_precision_score,
    auc,
    accuracy_score
)

# Metrics calculation
def compute_metrics(true_labels, prediction):
    metrics = {}
    metrics["accuracy"] = accuracy_score(true_labels, prediction)
    metrics["f1"] = f1_score(true_labels, prediction)
    metrics["precision"] = precision_score(true_labels, prediction)
    metrics["recall"] = recall_score(true_labels, prediction)
    metrics["f1_weighted"] = f1_score(true_labels, prediction, average = "weighted")
    metrics["recall_weighted"] = recall_score(true_labels,prediction, average = "weighted")
    metrics["precision_weighted"] = precision_score(true_labels, prediction, average = "weighted")
    metrics["f1_marco"] = f1_score(true_labels, prediction, average = "macro")
    metrics["precision_marco"] = precision_score(true_labels, prediction, average = "macro")
    metrics["recall_marco"] = recall_score(true_labels, prediction, average = "macro")

    return metrics

In [None]:

pred_to_evaluate = test_data[test_data["pred"]!=-1]


compute_metrics(list(pred_to_evaluate["label"].values), list(pred_to_evaluate["pred"].values))


{'accuracy': 0.5312965722801788,
 'f1': 0.5741367637102234,
 'precision': 0.5280199252801993,
 'recall': 0.629080118694362,
 'f1_weighted': 0.5267179602714996,
 'recall_weighted': 0.5312965722801788,
 'precision_weighted': 0.5320807790760923,
 'f1_marco': 0.5265050015734215,
 'precision_marco': 0.5320990164434392,
 'recall_marco': 0.5308574246166421}

In [None]:
pred_to_evaluate

Unnamed: 0,discussion_id,stance,contributor_name,key_point_id,contributor_id,expressed_view,user,topic,topic_statement,label,content_concat,prompt,pred
0,223,Contributors expressed concern over the securi...,Loh Zheng Han | 97210312,2659,1381,True,1407,The parliamentary session on February 4th will...,Parliament Sitting (4 Feb 2025),1,Contributor726: You are not reading US politic...,<s>[INST]Determine whether Contributor1407 hol...,1
1,223,Contributors suggested that the Government sho...,Ching Jia Alex Chen | 97313019,2660,1552,False,1578,The parliamentary session on February 4th will...,Parliament Sitting (4 Feb 2025),0,Contributor726: You are not reading US politic...,<s>[INST]Determine whether Contributor1578 hol...,1
2,223,Contributors expressed concern about the impac...,Yew Heng Pah,2702,501,True,527,The parliamentary session on February 4th will...,Parliament Sitting (4 Feb 2025),1,Contributor527: Regarding the trump tariffs.. ...,<s>[INST]Determine whether Contributor527 hold...,1
3,227,Contributors agreed with the need for societal...,Liew Poh Eng,2906,18,False,44,The discussion revolves around Minister Chan's...,Speech at MOE x NIE x IPS Lecture,0,Contributor543: I think today's topic ties in ...,<s>[INST]Determine whether Contributor44 holds...,0
4,227,Contributors agreed with the need for societal...,Adam Haziq Mohd Arshad,2906,208,True,234,The discussion revolves around Minister Chan's...,Speech at MOE x NIE x IPS Lecture,1,Contributor283: Singapore as a society has yet...,<s>[INST]Determine whether Contributor234 hold...,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...
1344,242,Contributors emphasized the need for a strong ...,Matthew Chua,3550,178,True,204,The discussion revolves around Singapore's fis...,Singapore's fiscal planning effectiveness,1,Contributor527: I think this will be the start...,<s>[INST]Determine whether Contributor204 hold...,1
1345,242,Contributors emphasized the need for a strong ...,Jinhui Lee,3550,1193,True,1219,The discussion revolves around Singapore's fis...,Singapore's fiscal planning effectiveness,1,Contributor1219: When the government continuou...,<s>[INST]Determine whether Contributor1219 hol...,1
1346,242,Contributors discussed the importance of havin...,Matthew Chua,3551,178,False,204,The discussion revolves around Singapore's fis...,Singapore's fiscal planning effectiveness,0,Contributor527: I think this will be the start...,<s>[INST]Determine whether Contributor204 hold...,1
1347,242,Contributors discussed the importance of havin...,Adam Haziq Mohd Arshad,3551,208,True,234,The discussion revolves around Singapore's fis...,Singapore's fiscal planning effectiveness,1,Contributor2132: like breaking up opposition s...,<s>[INST]Determine whether Contributor234 hold...,1


In [None]:
# Format Test Data
from tqdm import tqdm
all_base_predictions = []
for train_prompt in tqdm(list(train_data.prompt.values)):
    # Tokenize the input text
    inputs = tokenizer(train_prompt.replace("(1 for agree, 0 for not agree)", "Strictly only output 1 for agree, 0 for not agree"), return_tensors="pt").to("cuda")

    # Generate text
    output = merged_model.generate(inputs['input_ids'], pad_token_id=tokenizer.pad_token_id)

    # Decode the generated text
    generated_text = tokenizer.decode(output[0])
    generated_text_cleaned = generated_text.split("[/INST]")[-1]
    if "1" in generated_text_cleaned:
        prediction = 1
    elif "0" in generated_text_cleaned:
        prediction = 0
    else:
        prediction = -1
    all_base_predictions.append(prediction)
train_data["pred"] = all_base_predictions

In [None]:
pred_to_evaluate = train_data[train_data["pred"]!=-1]
print(len(train_data), len(pred_to_evaluate))
compute_metrics(pred_to_evaluate["label"], pred_to_evaluate["pred"])

In [None]:
# Format Test Data

all_base_predictions = []
for test_prompt in tqdm(list(test_data.prompt.values)):
    # Tokenize the input text
    inputs = tokenizer(test_prompt, return_tensors="pt").to("cuda")

    # Generate text
    output = base_model.generate(inputs['input_ids'])

    # Decode the generated text
    generated_text = tokenizer.decode(output[0])
    generated_text_cleaned = generated_text.split("[/INST]")[-1]
    if "1" in generated_text_cleaned:
        prediction = 1
    elif "0" in generated_text_cleaned:
        prediction = 0
    else:
        prediction = -1
    all_base_predictions.append(prediction)
test_data["base_pred"] = all_base_predictions

100%|██████████| 1349/1349 [22:18<00:00,  1.01it/s]


In [None]:
pred_to_evaluate = test_data[test_data["base_pred"]!=-1]

compute_metrics(pred_to_evaluate["label"], pred_to_evaluate["base_pred"])

{'accuracy': 0.5373271889400921,
 'f1': 0.5611888111888111,
 'precision': 0.5177419354838709,
 'recall': 0.6125954198473282,
 'f1_weighted': 0.5350945225308552,
 'recall_weighted': 0.5373271889400921,
 'precision_weighted': 0.5413705961052475,
 'f1_marco': 0.535955029376082,
 'precision_marco': 0.5405913978494623,
 'recall_marco': 0.539809296376427}

In [None]:
# Format Test Data
from tqdm import tqdm
all_base_predictions = []
for train_prompt in tqdm(list(train_data.prompt.values)):
    # Tokenize the input text
    inputs = tokenizer(train_prompt.replace("(1 for agree, 0 for not agree)", "Strictly only output 1 for agree, 0 for not agree"), return_tensors="pt").to("cuda")

    # Generate text
    output = base_model.generate(inputs['input_ids'], pad_token_id=tokenizer.pad_token_id)

    # Decode the generated text
    generated_text = tokenizer.decode(output[0])
    generated_text_cleaned = generated_text.split("[/INST]")[-1]
    if "1" in generated_text_cleaned:
        prediction = 1
    elif "0" in generated_text_cleaned:
        prediction = 0
    else:
        prediction = -1
    all_base_predictions.append(prediction)
train_data["base_pred"] = all_base_predictions

In [None]:

pred_to_evaluate = train_data[train_data["base_pred"]!=-1]
print(len(train_data), len(pred_to_evaluate))
compute_metrics(pred_to_evaluate["label"], pred_to_evaluate["base_pred"])