In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import LoraConfig, get_peft_model
from torch.utils.data import DataLoader
from datasets import load_dataset

In [2]:
from huggingface_hub import login
login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv‚Ä¶

In [3]:
from transformers import BitsAndBytesConfig
from peft import prepare_model_for_kbit_training
# --- 1. Model Setup ---
model_name = "meta-llama/Llama-3.2-1B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token


# Load model in 4-bit
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map={"": "cuda"}
)

# Enable Gradient Checkpointing (Saves massive VRAM by recomputing parts of the graph)
model.gradient_checkpointing_enable()


# Prepare model for k-bit training (stabilizes norms/layers)
model = prepare_model_for_kbit_training(model)

In [4]:
# --- 2. Apply Standard LoRA to Attention Layers ---
# We keep this to adapt the attention mechanism alongside your MoE MLPs
attn_lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM"
)
model = get_peft_model(model, attn_lora_config)

In [5]:
# --- 3. Define MixLoRA Architecture ---

class TopKRouter(nn.Module):
    def __init__(self, hidden_dim, num_experts, k=2):
        super().__init__()
        self.linear = nn.Linear(hidden_dim, num_experts)
        self.k = k

    def forward(self, x):
        # x: [batch, seq, dim]
        logits = self.linear(x)
        router_probs = F.softmax(logits, dim=-1)

        # Get top-k experts
        # weights: [batch, seq, k], indices: [batch, seq, k]
        weights, indices = router_probs.topk(self.k, dim=-1)
        return weights, indices, router_probs

class LoRAExpert(nn.Module):
    """
    FIX: A true Bottleneck Adapter to save VRAM.
    Structure: Linear(dim->r) -> SiLU -> Linear(r->dim)
    """
    def __init__(self, hidden_dim, r=16, dropout=0.05):
        super().__init__()
        self.lora_A = nn.Linear(hidden_dim, r, bias=False)
        self.act = nn.SiLU()
        self.lora_B = nn.Linear(r, hidden_dim, bias=False)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # Result is [batch, seq, dim]
        return self.lora_B(self.act(self.lora_A(self.dropout(x))))

# --- Redefine this class with the casting fix ---
class MixLoRAFFN_V2(nn.Module):
    def __init__(self, base_ffn, hidden_dim, num_experts=8, k=2, r=16):
        super().__init__()
        self.base_ffn = base_ffn
        self.router = TopKRouter(hidden_dim, num_experts, k)

        self.experts = nn.ModuleList([
            LoRAExpert(hidden_dim, r=r) for _ in range(num_experts)
        ])

        self.latest_router_probs = None

    def forward(self, x):
        # FIX: Force input to match the router's dtype (bfloat16)
        # This fixes the mismatch when Gradient Checkpointing passes float32

        x = x.to(self.router.linear.weight.dtype)

        # 1. Compute Base Output (Frozen)
        with torch.no_grad():
            base_out = self.base_ffn(x)

        # 2. Routing
        weights, indices, router_probs = self.router(x)
        self.latest_router_probs = router_probs

        # 3. Compute Expert Mixture
        expert_out = torch.zeros_like(base_out)

        for i, expert in enumerate(self.experts):
            is_selected = (indices == i).any(dim=-1, keepdim=True)

            if is_selected.any():
                expert_weight = (weights * (indices == i).float()).sum(dim=-1, keepdim=True)
                current_expert_out = expert(x)
                expert_out += is_selected * expert_weight * current_expert_out

        return base_out + expert_out

In [6]:
 # --- 4. INJECT MixLoRA into the Model (FIXED) ---
print("Injecting MixLoRA layers...")

HIDDEN_DIM = model.config.hidden_size
print(f"Detected Hidden Dimension: {HIDDEN_DIM}")

NUM_EXPERTS = 8
TOP_K = 2

# Helper to find layers
def get_model_layers(model):
    try: return model.base_model.model.model.layers
    except AttributeError: return model.model.layers

layers = get_model_layers(model)

for layer in layers:
    original_mlp = layer.mlp

    # Create wrapper
    mix_lora_layer = MixLoRAFFN_V2(
        original_mlp,   # base_ffn
        HIDDEN_DIM,     # hidden_dim
        NUM_EXPERTS,    # num_experts
        TOP_K,          # k
        16 
    )

    # FIX: Move the new layer to GPU ('cuda') AND set to bfloat16
    mix_lora_layer.to(device="cuda", dtype=torch.bfloat16)

    # Ensure gradients are on
    mix_lora_layer.train()
    for param in mix_lora_layer.experts.parameters():
        param.requires_grad = True
    for param in mix_lora_layer.router.parameters():
        param.requires_grad = True

    layer.mlp = mix_lora_layer
    # After: layer.mlp = mix_lora_layer

    for p in layer.mlp.router.parameters():
        p.requires_grad = True
    
    for p in layer.mlp.experts.parameters():
        p.requires_grad = True


print("Injection complete. Experts moved to GPU.")


Injecting MixLoRA layers...
Detected Hidden Dimension: 2048
Injection complete. Experts moved to GPU.


In [7]:
 # --- 5. Data Processing ---
psy_dataset = load_dataset("Amod/mental_health_counseling_conversations")

def tokenize_fn(example):
    text = example["Context"] + tokenizer.eos_token + example["Response"]
    tokens = tokenizer(
        text,
        truncation=True,
        padding="max_length",
        max_length=512
    )
    tokens["labels"] = [
        -100 if t == tokenizer.pad_token_id else t
        for t in tokens["input_ids"]
    ]
    return tokens

tokenized_dataset = psy_dataset.map(tokenize_fn, remove_columns=["Context", "Response"])
tokenized_dataset.set_format("torch")

train_dataloader = DataLoader(
    tokenized_dataset["train"],
    batch_size=4, # Reduced batch size for safety
    shuffle=True
)

In [8]:
# --- 6. Training Loop (Fixed Layer Access) ---

# Re-define optimizer to be safe
trainable_params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.AdamW(trainable_params, lr=2e-4)

# Load Balance Loss Function
def load_balance_loss(router_probs):
    # router_probs: [batch, seq, num_experts]
    mean_prob = router_probs.mean(dim=(0, 1)) # Avg over batch and seq
    return (mean_prob * mean_prob).sum() * num_experts

num_experts = NUM_EXPERTS
aux_loss_weight = 0.01

model.train()
print("Starting training...")
print("model.training =", model.training)


# Helper to find layers safely (Works for Peft + QLoRA)
def get_layers_for_loss(model):
    try:
        # Standard QLoRA/Peft path
        return model.base_model.model.model.layers
    except AttributeError:
        # Fallback
        return model.model.layers

for step, batch in enumerate(train_dataloader):
    batch = {k: v.to(model.device) for k, v in batch.items()}

    # Forward pass
    outputs = model(
        input_ids=batch["input_ids"],
        attention_mask=batch["attention_mask"],
        labels=batch["labels"],
    )

    main_loss = outputs.loss

    # FIX: Access layers safely for Aux Loss
    aux_loss = torch.tensor(0.0, device=main_loss.device)

    for layer in layers:
        if hasattr(layer, "mlp") and isinstance(layer.mlp, MixLoRAFFN_V2):
            router_probs = layer.mlp.latest_router_probs
    
            if router_probs is not None:
                # 1Ô∏è‚É£ Load-balance loss
                aux_loss += load_balance_loss(router_probs)
    
                # 2Ô∏è‚É£ üî• Entropy regularization (ADD HERE)
                entropy = -(
                    router_probs * torch.log(router_probs + 1e-8)
                ).sum(dim=-1).mean()
    
                aux_loss += 0.01 * entropy

    aux_loss = torch.clamp(aux_loss, max=10.0)
    total_loss = main_loss + (aux_loss_weight * aux_loss)

    total_loss.backward()
    torch.nn.utils.clip_grad_norm_(trainable_params, 1.0)

    
    optimizer.step()
    optimizer.zero_grad()

    if step % 10 == 0:
       print(f"Step {step}: Loss {main_loss.item():.4f} | Aux Loss {aux_loss.item():.4f}")


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.


Starting training...
model.training = True
Step 0: Loss 2.6937 | Aux Loss 10.0000
Step 10: Loss 2.7693 | Aux Loss 10.0000
Step 20: Loss 2.8881 | Aux Loss 10.0000
Step 30: Loss 2.8520 | Aux Loss 10.0000
Step 40: Loss 3.1615 | Aux Loss 10.0000
Step 50: Loss 2.9552 | Aux Loss 10.0000
Step 60: Loss 3.0282 | Aux Loss 10.0000
Step 70: Loss 3.0262 | Aux Loss 10.0000
Step 80: Loss 2.7985 | Aux Loss 10.0000
Step 90: Loss 2.9236 | Aux Loss 10.0000
Step 100: Loss 2.9167 | Aux Loss 10.0000
Step 110: Loss 3.4290 | Aux Loss 10.0000
Step 120: Loss 3.0029 | Aux Loss 10.0000
Step 130: Loss 3.2111 | Aux Loss 10.0000
Step 140: Loss 3.4685 | Aux Loss 10.0000
Step 150: Loss 3.6568 | Aux Loss 10.0000
Step 160: Loss 4.7006 | Aux Loss 10.0000
Step 170: Loss 6.1372 | Aux Loss 10.0000
Step 180: Loss 6.3940 | Aux Loss 10.0000
Step 190: Loss 5.8888 | Aux Loss 10.0000
Step 200: Loss 5.5446 | Aux Loss 10.0000
Step 210: Loss 5.8338 | Aux Loss 10.0000
Step 220: Loss 5.8722 | Aux Loss 10.0000
Step 230: Loss 7.3577 | A

In [12]:
import torch

model.eval()

prompt = "I'm not feeling good. What should I do?"

inputs = tokenizer(
    prompt,
    return_tensors="pt"
)

device = next(model.parameters()).device
inputs = {k: v.to(device) for k, v in inputs.items()}

with torch.no_grad():
    outputs = model.generate(
        **inputs,
        max_new_tokens=150,
        do_sample=True,
        temperature=0.7,
        top_p=0.9,
        repetition_penalty=1.1
    )

text = tokenizer.decode(outputs[0])
print(text)


Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


<|begin_of_text|>I'm not feeling good. What should I do? My friend is my parents, but he was the past with the same. This‚Äôs a few years and have a therapist and I am going to get me for my dad. I don't know it like the couple girl. I can't see the current man.   Do I feel that I did this about him. I think I have been an 12 disorder and are not never know to a past with her. But I had so afraid of the lot, but they do this. We're very sex? I‚Äôm been just almost normal and she's been much in anxiety.
 I don‚Äôt be a long life with a boyfriend and does he has been terrible. He doesn't never never have no way back and we have a history with our
