In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!pip install datasets

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
import torch, torch.nn as nn
from transformers import TrainingArguments

BASE = "Qwen/Qwen2.5-0.5B-Instruct"

tok  = AutoTokenizer.from_pretrained(BASE, trust_remote_code=True)
lm   = AutoModelForCausalLM.from_pretrained(
        BASE, torch_dtype=torch.bfloat16, trust_remote_code=True,
        device_map="auto")


lm = prepare_model_for_kbit_training(lm)          # 4bit + grad_ckpt

lora_cfg = LoraConfig(
    r=4, lora_alpha=16,
    target_modules = [
    f"layers.{i}.self_attn.{proj}"
    for i in range(20, 24)
    for proj in ("q_proj", "v_proj")
],
    lora_dropout=0.05, task_type="CAUSAL_LM")
lm = get_peft_model(lm, lora_cfg)

In [None]:
LABELS = ["Onsite", "Remote", "Hybrid"]
NUM    = len(LABELS)

class MultiTaskQwen(nn.Module):
    def __init__(self, base_lm, tokenizer, labels):
        super().__init__()
        self.lm  = base_lm
        self.tok = tokenizer
        self.labels = labels

        hs = base_lm.config.hidden_size
        self.class_head = nn.Linear(hs, len(labels))
        self.id2vec     = nn.Embedding(len(labels), hs)

        self.label_tok_ids = [tokenizer(l, add_special_tokens=False)["input_ids"]
                              for l in labels]

    def forward(self,
        input_ids,
        attention_mask,
        labels      = None,
        cls_labels     = None,
        task_type_id  = None,
        lambda_cls=1.0,
        lambda_gen=1.0):

        outputs = self.lm.model(
            input_ids, attention_mask,
            output_hidden_states=True, use_cache=False
        )
        h_last = outputs.hidden_states[-1]     # (B,L,H)

        cls_loss = torch.tensor(0., device=h_last.device)
        if task_type_id is not None and (task_type_id == 1).any():
            eos_idx = (input_ids != self.tok.pad_token_id).sum(-1) - 1  # (B,)
            seq_repr = h_last[torch.arange(h_last.size(0)), eos_idx]

            cls_logits = self.class_head(seq_repr)
            loss_fct   = nn.CrossEntropyLoss(ignore_index=-100)
            cls_loss   = loss_fct(cls_logits, cls_labels)

            # Adding class embedding
            pred_ids = cls_logits.argmax(-1)
            delta    = self.id2vec(pred_ids)
            h_last[torch.arange(h_last.size(0)), eos_idx] += delta

        # ---------- LM Head ----------
        lm_logits = self.lm.lm_head(h_last)

        lm_loss = torch.tensor(0., device=h_last.device)

        if labels is not None:
            shift_log  = lm_logits[..., :-1, :].contiguous()
            shift_lab  = labels[..., 1:].contiguous()
            loss_fct   = nn.CrossEntropyLoss(ignore_index=-100)
            lm_loss    = loss_fct(shift_log.view(-1, shift_log.size(-1)),
                                  shift_lab.view(-1))

        total = lambda_cls * cls_loss + lambda_gen * lm_loss
        return {"loss": total, "logits": lm_logits}
model = MultiTaskQwen(lm, tok, LABELS)

In [None]:
MAX_LEN   = 1024
TAIL_LEN  =  128
HEAD_LEN  = MAX_LEN - TAIL_LEN
STRIDE    = 896
LABEL2ID  = {l: i for i, l in enumerate(LABELS)}

def trim_keep_tail(ids):
    return ids if len(ids) <= MAX_LEN else ids[:HEAD_LEN] + ids[-TAIL_LEN:]

def pad_to(ids, pad_id):
    pad = MAX_LEN - len(ids)
    return ids + [pad_id] * pad, [1] * len(ids) + [0] * pad

def build_cls_sample(text_ids, tgt_ids, label_id):
    ids = trim_keep_tail(text_ids + tgt_ids)

    ids, attn = pad_to(ids, tok.pad_token_id)

    cut_prompt = min(len(text_ids), HEAD_LEN)
    lm_labels  = [-100] * cut_prompt + ids[cut_prompt:]

    lm_labels += [-100] * (MAX_LEN - len(lm_labels))
    PAD_ID = tok.pad_token_id
    lm_labels = [ -100 if t == PAD_ID else t for t in lm_labels ]

    if len(ids) != len(attn) or len(attn)!= len(lm_labels) or len(lm_labels) != MAX_LEN:
        print(f"Warning: input_ids: {len(ids)}, attention_mask: {len(attn)}, labels: {len(lm_labels)}")

    return {
        "input_ids":     ids,
        "attention_mask": attn,
        "labels":        lm_labels,
        "cls_labels":    label_id,
        "task_type_id":  1,
    }

def build_gen_sample(prompt_ids, chunk_ids, is_first_chunk):
    ids, attn = pad_to(chunk_ids, tok.pad_token_id)

    if is_first_chunk:
        cut = min(len(prompt_ids), len(ids))
    else:
        cut = 0

    lm_labels = [-100] * cut + ids[cut:]
    lm_labels += [-100] * (MAX_LEN - len(lm_labels))
    PAD_ID = tok.pad_token_id
    lm_labels = [ -100 if t == PAD_ID else t for t in lm_labels ]

    if len(ids) != len(attn) or len(attn)!= len(lm_labels) or len(lm_labels) != MAX_LEN:
        print(f"Warning: input_ids: {len(ids)}, attention_mask: {len(attn)}, labels: {len(lm_labels)}")

    return {
        "input_ids":     ids,
        "attention_mask":attn,
        "labels":        lm_labels,
        "cls_labels":    -100,
        "task_type_id":  0,
    }

def encode_single(ex):
    task = ex["task_type"].lower()

    if task == "work_arrangement":
        text_ids  = tok(ex["prompt"] + tok.eos_token)["input_ids"]
        label_str = ex["complete"].capitalize()
        tgt_ids   = tok(" " + label_str, add_special_tokens=False)["input_ids"] \
                    + [tok.eos_token_id]
        return build_cls_sample(text_ids, tgt_ids, LABEL2ID[label_str])

    prompt_ids = tok(ex["prompt"], add_special_tokens=False)["input_ids"]
    full_ids   = tok(ex["prompt"] + ex["complete"] + tok.eos_token)["input_ids"]

    if len(full_ids) <= MAX_LEN:
        return build_gen_sample(prompt_ids, full_ids, is_first_chunk=True)

    outs = []
    first = True
    for st in range(0, len(full_ids), STRIDE):
        chunk = full_ids[st : st + MAX_LEN]
        if len(chunk) < 128:
            break
        outs.append(build_gen_sample(prompt_ids, chunk, is_first_chunk=first))
        first = False
    return outs   # list[dict]

In [None]:
def encode_batch(batch):
    """
    batch: dict(field→list)
    return: dict(field→list)
    """

    cols = {
        "input_ids": [], "attention_mask": [], "labels": [],
        "cls_labels": [], "task_type_id": []
    }

    for p, c, t in zip(batch["prompt"], batch["complete"], batch["task_type"]):
        ex = {"prompt": p, "complete": c, "task_type": t}
        res = encode_single(ex)
        res_list = res if isinstance(res, list) else [res]

        for r in res_list:
            for k in cols:
                cols[k].append(r[k])

    return cols

In [None]:
from datasets import load_dataset

raw_ds = load_dataset(
    "json",
    data_files="./drive/MyDrive/datasetfiles/combined_prompts_completev2.json"
)

encoded_train = raw_ds["train"].map(
    encode_batch,
    batched=True,
    batch_size= 512,
    remove_columns=raw_ds["train"].column_names
)

In [None]:
class MultiTaskCollator:
    def __init__(self, tok, label_pad=-100):
        self.tok = tok
        self.label_pad = label_pad

    def __call__(self, feats):
        keys = {"input_ids", "attention_mask", "labels",
                "cls_labels", "task_type_id"}
        for f in feats:
            for k in keys:
                if k not in f:
                    f[k] = [] if k in {"input_ids","attention_mask","labels"} else -100

        max_len = max(len(f["input_ids"]) for f in feats)

        seq_keys = {"input_ids", "attention_mask", "labels"}
        batch = {}
        for k in keys:
            if k in seq_keys:
                pad_id = self.label_pad if k=="labels" else self.tok.pad_token_id
                batch[k] = torch.tensor(
                    [f[k] + [pad_id]*(max_len-len(f[k])) for f in feats],
                    dtype=torch.long)
            else:
                batch[k] = torch.tensor([f[k] for f in feats], dtype=torch.long)

        return batch
collate = MultiTaskCollator(tok=tok, label_pad=-100)

In [None]:

for n, p in model.named_parameters():
    p.requires_grad_(False)

for p in model.class_head.parameters(): p.requires_grad_(True)
for p in model.id2vec.parameters():    p.requires_grad_(True)

for n,p in model.named_parameters():
    if "lora_" in n: p.requires_grad_(True)

args1 = TrainingArguments(
    "./stage1_cls",
    per_device_train_batch_size=4, num_train_epochs=2,
    learning_rate=5e-7, logging_steps=20, fp16=True, report_to="none",
    save_safetensors  = False, save_strategy     = "no"
)

In [None]:
cls_ds = encoded_train.filter(lambda ex: ex["task_type_id"] == 1, load_from_cache_file=False)
# print(cls_ds)

In [None]:
from transformers import Trainer

class MultiTaskTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): # Add num_items_in_batch

        out = model(**inputs,
                    lambda_cls=1.0,
                    lambda_gen=0.0)
        loss = out["loss"]
        return (loss, out) if return_outputs else loss

In [None]:
trainer1 = MultiTaskTrainer(
    model           = model,
    args            = args1,
    train_dataset   = cls_ds,
    data_collator   = collate,
)
trainer1.train()
trainer1.save_model("stage1_cls_ckpt")

In [None]:
class JointTrainer(MultiTaskTrainer):
    def compute_loss(
        self,
        model,
        inputs,
        return_outputs: bool = False,
        **kwargs
    ):
        outputs = model(
            **inputs,
            lambda_cls=1.0,
            lambda_gen=1.0
        )
        loss = outputs["loss"]
        return (loss, outputs) if return_outputs else loss
full_ds = encoded_train
args2 = TrainingArguments(
    "./stage2_joint",
    per_device_train_batch_size=4, gradient_accumulation_steps=8,
    num_train_epochs=4, learning_rate=5e-5,
    logging_steps=100, fp16=True, report_to="none",
    save_safetensors  = False, save_strategy     = "no"
)

In [None]:
trainer2 = JointTrainer(
    model           = model,
    args            = args2,
    train_dataset   = full_ds,
    data_collator   = collate,
)
trainer2.train()
trainer2.save_model("./drive/MyDrive/model/stage2_joint_ckpt")