Connect GPU

In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

Import Libraries

In [2]:
import os
import torch
from PIL import Image
from transformers import AutoProcessor, LlavaForConditionalGeneration
from transformers import BitsAndBytesConfig
from peft import PeftModel

Paths [Dataset, Model, etc.]

In [None]:
model_path = "/storage/home/sriramk/BTP_sriramk/pixtral/models--mistral-community--pixtral-12b/snapshots/c2756cbbb9422eba9f6c5c439a214b0392dfc998"
# lora_model_path = "/storage/home/sriramk/BTP_sriramk/trained/pixtral-lora-finetuned"    #Used when have finetuned adaptor at a path
images_path = "/storage/home/sriramk/BTP_sriramk/1D_networks/rates"
dataset_path = "/storage/home/sriramk/BTP_sriramk/dataset.jsonl"     #dataset path
out_dir = "/storage/home/sriramk/BTP_sriramk/trained"   #Path to store the finetuned adaptor

# Load Model and Processor

Option A - Without Fine-Tuned Adaptor / First Iteration

In [None]:
processor = AutoProcessor.from_pretrained(model_path)

model = LlavaForConditionalGeneration.from_pretrained(
    model_path,
    torch_dtype=torch.bfloat16,
    device_map={"": 0} 
)

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


`torch_dtype` is deprecated! Use `dtype` instead!


Loading checkpoint shards:   0%|          | 0/6 [00:00<?, ?it/s]

Option B - With Fine-Tuned Adaptor

In [None]:
# base_model = LlavaForConditionalGeneration.from_pretrained(     #loading base model path
#     model_path,
#     torch_dtype=torch.bfloat16,
#     device_map={"": 0} 
# )
# processor = AutoProcessor.from_pretrained(lora_model_path)          #processor loaded using lora path

# key_mapping = {             
#     "base_model.model.model.multi_modal_projector": "model.multi_modal_projector"           #setting value of multi_model_projector of base model to finetuned one
# }

# model = PeftModel.from_pretrained(              #connecting the adaptor to the model
#     base_model,
#     lora_model_path,
#     key_mapping=key_mapping
# )

Equate Pad Token = EOS Token

In [5]:
processor.tokenizer.pad_token = processor.tokenizer.eos_token

Model Configuration

In [11]:
print(model)

LlavaForConditionalGeneration(
  (model): LlavaModel(
    (vision_tower): PixtralVisionModel(
      (patch_conv): Conv2d(3, 1024, kernel_size=(16, 16), stride=(16, 16), bias=False)
      (ln_pre): PixtralRMSNorm((1024,), eps=1e-05)
      (transformer): PixtralTransformer(
        (layers): ModuleList(
          (0-23): 24 x PixtralAttentionLayer(
            (attention_norm): PixtralRMSNorm((1024,), eps=1e-05)
            (feed_forward): PixtralMLP(
              (gate_proj): Linear(in_features=1024, out_features=4096, bias=False)
              (up_proj): Linear(in_features=1024, out_features=4096, bias=False)
              (down_proj): Linear(in_features=4096, out_features=1024, bias=False)
              (act_fn): SiLUActivation()
            )
            (attention): PixtralAttention(
              (k_proj): Linear(in_features=1024, out_features=1024, bias=False)
              (v_proj): Linear(in_features=1024, out_features=1024, bias=False)
              (q_proj): Linear(in_feature

Validating model chat template on the example image

In [6]:
image_path = "/storage/home/sriramk/BTP_sriramk/images/network_16.png"

messages = [
    {"role": "user", "content": [
        {"type": "text", "text": "Describe the image."},
        {"type": "image", "image": Image.open(image_path).convert("RGB")}  # optional
    ]},
    {"role": "assistant", "content": [
        {"type": "text", "text": "The image shows a biological network."}
    ]}
]


formatted_text = processor.apply_chat_template(messages, add_generation_prompt = False)
print(formatted_text)

<s>[INST]Describe the image.[IMG][/INST]The image shows a biological network.</s>


Load the dataset

In [7]:
from datasets import load_dataset
ds = load_dataset("json", data_files=dataset_path)

train_dataset = ds["train"]
eval_dataset = None

In [9]:
print(train_dataset[0])

{'messages': [{'role': 'user', 'content': [{'type': 'text', 'text': 'Analyze the biological network image and extract nodes, edges, reactions, ODEs, and Jacobian. Output strictly in JSON format.', 'image': None}, {'type': 'image', 'text': None, 'image': '/storage/home/sriramk/BTP_sriramk/1D_networks/rates/network_1.png'}]}, {'role': 'assistant', 'content': [{'type': 'text', 'text': '{"nodes": [{"id": "A"}], "edges": [{"source": "φ", "target": "A", "type": "inhibition", "label": "k1"}, {"source": "A", "target": "A", "type": "activation", "label": "k2"}, {"source": "φ", "target": "A", "type": "activation", "label": "k3"}], "valid": false, "reason": "inhibitory interaction detected between ϕ and another node in either direction. Any inhibitory regulation involving ϕ (ϕ → X or X → ϕ) is not permitted.", "message": "Invalid network; ODEs, Jacobian, and stability analysis omitted."}', 'image': None}]}]}


# Training Section

Data Collator

In [None]:
class MyDataCollator:
    def __init__(self, processor):
        self.processor = processor

    def __call__(self, examples):
        texts = []
        images = []
        assistant_responses = []
        for example in examples:
            messages = example["messages"]

            user_contents = messages[0]["content"]
            image_path = None
            for item in user_contents:
                if item["type"] == "image" and item["image"] is not None:
                    image_path = item["image"]
                    break

            # load image
            image = Image.open(image_path).convert("RGB")

            assistant_response = ""
            for c in messages[1]["content"]:
                if c["type"] == "text":
                    assistant_response = c["text"]
                    break

            text = self.processor.apply_chat_template(messages, add_generation_prompt=False)

            texts.append(text.strip())
            images.append([image])
            assistant_responses.append(assistant_response)

        batch = self.processor(text=texts, images=images, return_tensors="pt", padding=True)

        labels = batch["input_ids"].clone()

        for i, (input_ids, assistant_response) in enumerate(zip(batch["input_ids"], assistant_responses)):
            assistant_tokens = self.processor.tokenizer(
                assistant_response,
                return_tensors="pt",
                add_special_tokens=False
            )["input_ids"].squeeze(0)

            start_idx = self.find_subsequence(input_ids, assistant_tokens)
            if start_idx is not None:
                labels[i, :start_idx] = -100
                labels[i, start_idx + len(assistant_tokens):] = -100

        batch["labels"] = labels


        # Debugging Code

        print("\n================ DEBUG ================")

        for i in range(len(examples)):
            print(f"\n--- SAMPLE {i} ---")

            print("\n[Formatted Chat Template]\n")
            print(texts[i])

            print("\n[Assistant Response]\n")
            print(assistant_responses[i])

            print("\n[Tokenized Input IDs Decoded]\n")
            print(self.processor.tokenizer.decode(batch["input_ids"][i]))

            assistant_tokens = self.processor.tokenizer(
                assistant_responses[i],
                return_tensors="pt",
                add_special_tokens=False
            )["input_ids"].squeeze(0)

            print("\n[Assistant Tokens Decoded]\n")
            print(self.processor.tokenizer.decode(assistant_tokens))

            start_idx = self.find_subsequence(batch["input_ids"][i], assistant_tokens)
            print("\n[Start Index]\n", start_idx)

            print("\n[Decoded Labels After Masking]\n")
            decoded_labels = batch["labels"][i][batch["labels"][i] != -100]
            print(self.processor.tokenizer.decode(decoded_labels))

        print("\n[Mask Ratio]")
        print((batch["labels"] == -100).sum().item(), "/", batch["labels"].numel())
        print("========================================")




        return batch
    
    def find_subsequence(self, sequence, subsequence):
        seq_len = len(sequence)
        sub_len = len(subsequence)

        for i in range(seq_len - sub_len + 1):
            if torch.equal(sequence[i:i + sub_len], subsequence):
                return i
        return None

In [None]:
data_collator = MyDataCollator(processor)

Testing on sample dataset

In [10]:
sample_batch = [train_dataset[i] for i in range(5)]

processed_batch = data_collator(sample_batch)

print("Processed batch keys:", processed_batch.keys())

print("\nTokenized input IDs (before padding):")
print(processed_batch["input_ids"])

print("\nLabels before masking: ")
print(processed_batch["labels"])

print("\nDecoded input texts:")
for input_id in processed_batch["input_ids"]:
    print(processor.tokenizer.decode(input_id, skip_special_tokens=False))



--- SAMPLE 0 ---

[Formatted Chat Template]

<s>[INST]Analyze the biological network image and extract nodes, edges, reactions, ODEs, and Jacobian. Output strictly in JSON format.[IMG][/INST]{"nodes": [{"id": "A"}], "edges": [{"source": "φ", "target": "A", "type": "inhibition", "label": "k1"}, {"source": "A", "target": "A", "type": "activation", "label": "k2"}, {"source": "φ", "target": "A", "type": "activation", "label": "k3"}], "valid": false, "reason": "inhibitory interaction detected between ϕ and another node in either direction. Any inhibitory regulation involving ϕ (ϕ → X or X → ϕ) is not permitted.", "message": "Invalid network; ODEs, Jacobian, and stability analysis omitted."}</s>

[Assistant Response]

{"nodes": [{"id": "A"}], "edges": [{"source": "φ", "target": "A", "type": "inhibition", "label": "k1"}, {"source": "A", "target": "A", "type": "activation", "label": "k2"}, {"source": "φ", "target": "A", "type": "activation", "label": "k3"}], "valid": false, "reason": "inhibi

LORA Configuration

In [None]:
from peft import LoraConfig

lora_config = LoraConfig(
    r=16,     # 8 or 16
    lora_alpha=32,
    use_rslora=True,
    target_modules=[
        "q_proj", "k_proj", "v_proj", "o_proj",  # attention
        "gate_proj", "up_proj", "down_proj"      # MLP
    ],
    modules_to_save=["multi_modal_projector"],
    lora_dropout=0.2,   #0.1
    bias="none",
    task_type="CAUSAL_LM"
)

Applying LORA configuration to model

In [13]:
from peft import get_peft_model

model = get_peft_model(model, lora_config)

In [None]:
model.print_trainable_parameters()          # Checking the trainable Parameters

trainable params: 97,527,808 || all params: 12,780,267,520 || trainable%: 0.7631


Training Configuration

In [None]:
from transformers import TrainingArguments, Trainer

epochs = 10            # Number of iterations
lr = 2e-5              # Learning Rate
schedule="cosine"      # constant is default

# Checking Logs
run_name = f"trelis-chess-{lr}_lr-{epochs}_epochs-{schedule}_schedule-completions-only"             

# Training Args
training_args = TrainingArguments(
    #max_steps = 1,
    num_train_epochs=epochs,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=4,
    # warmup_steps=20,
    learning_rate=lr,
    # weight_decay=0.01,                //commented later
    logging_steps=5,    #originally was 0.1
    output_dir=out_dir,
    lr_scheduler_type=schedule,
    bf16=True,
    remove_unused_columns=False,
    report_to=[],
    run_name=run_name,
    logging_dir=f"./logs/{run_name}",
    gradient_checkpointing=True,
    gradient_checkpointing_kwargs={'use_reentrant':True}
)

The model is already on multiple devices. Skipping the move to device specified in `args`.


Applying Training Configuration to model

In [None]:
trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset
)

Training

In [None]:
trainer.train()



--- SAMPLE 0 ---

[Formatted Chat Template]

<s>[INST]Analyze the biological network image and extract nodes, edges, reactions, ODEs, and Jacobian. Output strictly in JSON format.[IMG][/INST]{"nodes": [{"id": "A"}], "edges": [{"source": "A", "target": "φ", "type": "activation", "label": "k1"}, {"source": "A", "target": "A", "type": "inhibition", "label": "k2"}, {"source": "A", "target": "φ", "type": "inhibition", "label": "k3"}, {"source": "φ", "target": "A", "type": "activation", "label": "k4"}], "valid": false, "reason": "inhibitory interaction detected between ϕ and another node in either direction. Any inhibitory regulation involving ϕ (ϕ → X or X → ϕ) is not permitted.", "message": "Invalid network; ODEs, Jacobian, and stability analysis omitted."}</s>

[Assistant Response]

{"nodes": [{"id": "A"}], "edges": [{"source": "A", "target": "φ", "type": "activation", "label": "k1"}, {"source": "A", "target": "A", "type": "inhibition", "label": "k2"}, {"source": "A", "target": "φ", "ty

`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.


Saving Trained Adaptor and Modified Processor

In [None]:
model.save_pretrained("/storage/home/sriramk/BTP_sriramk/trained/pixtral-lora-finetuned-2")
processor.save_pretrained("/storage/home/sriramk/BTP_sriramk/trained/pixtral-lora-finetuned-2")


['/storage/home/sriramk/BTP_sriramk/trained/pixtral-lora-finetuned-1/processor_config.json']