In [4]:

from transformers import (
    MBartForConditionalGeneration, MBartTokenizer, 
    Seq2SeqTrainingArguments, Seq2SeqTrainer,
    MBartTokenizerFast, MBartModel, MBartConfig,
    
  )
import random
from transformers.models.mbart.modeling_mbart import MBartAttention
from pathlib import Path
from typing import Union, Dict
import pathlib
from typing import List, Dict, Tuple, Optional
import os
import torch
from torch.utils.data import random_split

import pandas as pd
from transformers import AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainer
import numpy as np

import gc
import torch
import copy
from transformers import AutoTokenizer

import math
import torch
from torch import nn
from torch.nn import functional as F
from torch import Tensor
import tabulate
from accelerate import Accelerator
accelerator = Accelerator()
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [5]:
'''
Create MBart model 
'''
max_input_length = 128
max_target_length = 128
per_device_train_batch_size=16
per_device_eval_batch_size=16
learning_rate= 3e-4
num_train_epochs=2
lr_scheduler_type = 'cosine'
warmup_ratio = 0.03
lora_head = False
lora_attention = True
lora_encoder = True
lora_decoder = False 
lora_rank = 4
lora_alpha = 8
lora_dropout = 0.0
source_lang = "en"
target_lang = "vi"
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model_name = "facebook/mbart-large-50-many-to-many-mmt"
tokenizer_en2vi = AutoTokenizer.from_pretrained(model_name)
model_en2vi = MBartForConditionalGeneration.from_pretrained(model_name)
for param in model_en2vi.parameters():
    param.requires_grad = False



In [6]:
'''Test out normal MBart model first '''
model_en2vi

MBartForConditionalGeneration(
  (model): MBartModel(
    (shared): MBartScaledWordEmbedding(250054, 1024, padding_idx=1)
    (encoder): MBartEncoder(
      (embed_tokens): MBartScaledWordEmbedding(250054, 1024, padding_idx=1)
      (embed_positions): MBartLearnedPositionalEmbedding(1026, 1024)
      (layers): ModuleList(
        (0-11): 12 x MBartEncoderLayer(
          (self_attn): MBartAttention(
            (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (activation_fn): ReLU()
          (fc1): Linear(in_features=1024, out_features=4096, bias=True)
          (fc2): Linear(in_features=4096, out_features=1024, bias=True)
        

In [7]:
'''Define LoRA layer'''
class LoRALayer(nn.Module):
    def __init__(self, in_dim, out_dim, rank, alpha, lora_dropout=0.0):
        super().__init__()
        std_dev = 1 / torch.sqrt(torch.tensor(rank).float())
        self.A = nn.Parameter(torch.randn(in_dim, rank) * std_dev)
        self.B = nn.Parameter(torch.zeros(rank, out_dim))
        self.alpha = alpha
        self.lora_dropout = nn.Dropout(lora_dropout) # Add dropout layer

    def forward(self, x):
        lora_output = self.alpha * (x @ self.A @ self.B)
        lora_output = self.lora_dropout(lora_output) # Apply dropout to the LoRA output
        return lora_output

class LinearWithLoRAMerged(nn.Module):
    def __init__(self, linear, rank, alpha, lora_dropout=0.0):
        super().__init__()
        self.linear = linear
        self.lora = LoRALayer(
            linear.in_features, linear.out_features, rank, alpha, lora_dropout
        )
    def forward(self, x):
        lora = self.lora.A @ self.lora.B
        combined_weight = self.linear.weight + self.lora.alpha*lora.T
        return F.linear(x, combined_weight, self.linear.bias)

In [8]:
def freeze_linear_layers(model):
    for child in model.children():
        if isinstance(child, nn.Linear):
            for param in child.parameters():
                param.requires_grad = False
        else:
            # Recursively freeze linear layers in children modules
            freeze_linear_layers(child)


def replace_linear_with_lora_recursive(model, rank, alpha, dropout):
    if lora_encoder:
        for layer in model.model.encoder.layers:
            if lora_attention: 
                layer.self_attn.q_proj = LinearWithLoRAMerged(layer.self_attn.q_proj,rank,alpha,dropout)
                layer.self_attn.k_proj = LinearWithLoRAMerged(layer.self_attn.k_proj,rank,alpha,dropout)
                layer.self_attn.v_proj = LinearWithLoRAMerged(layer.self_attn.v_proj,rank,alpha,dropout)
                layer.self_attn.out_proj = LinearWithLoRAMerged(layer.self_attn.out_proj,rank,alpha,dropout)
            layer.fc1 = LinearWithLoRAMerged(layer.fc1,rank,alpha,dropout)
            layer.fc2 = LinearWithLoRAMerged(layer.fc2,rank,alpha,dropout)
    if lora_decoder:
        for layer in model.model.decoder.layers:
            if lora_attention:
                layer.self_attn.q_proj = LinearWithLoRAMerged(layer.self_attn.q_proj,rank,alpha,dropout)
                layer.self_attn.k_proj = LinearWithLoRAMerged(layer.self_attn.k_proj,rank,alpha,dropout)
                layer.self_attn.v_proj = LinearWithLoRAMerged(layer.self_attn.v_proj,rank,alpha,dropout)
                layer.self_attn.out_proj = LinearWithLoRAMerged(layer.self_attn.out_proj,rank,alpha,dropout)
                layer.encoder_attn.q_proj = LinearWithLoRAMerged(layer.encoder_attn.q_proj,rank,alpha,dropout)
                layer.encoder_attn.k_proj = LinearWithLoRAMerged(layer.encoder_attn.k_proj,rank,alpha,dropout)
                layer.encoder_attn.v_proj = LinearWithLoRAMerged(layer.encoder_attn.v_proj,rank,alpha,dropout)
                layer.encoder_attn.out_proj = LinearWithLoRAMerged(layer.encoder_attn.out_proj,rank,alpha,dropout)
            layer.fc1 = LinearWithLoRAMerged(layer.fc1,rank,alpha,dropout)
            layer.fc2 = LinearWithLoRAMerged(layer.fc2,rank,alpha,dropout)
    if lora_head:
        model.lm_head = LinearWithLoRAMerged(model.lm_head,rank,alpha,dropout)
        
model_lora = copy.deepcopy(model_en2vi)
replace_linear_with_lora_recursive(model_lora,lora_rank, lora_alpha, lora_dropout)
model_lora.to(device)
freeze_linear_layers(model_lora)
# Check if linear layers are frozen
for name, param in model_lora.named_parameters():
    print(f"{name}: {param.requires_grad}")

model.shared.weight: False
model.encoder.embed_positions.weight: False
model.encoder.layers.0.self_attn.k_proj.linear.weight: False
model.encoder.layers.0.self_attn.k_proj.linear.bias: False
model.encoder.layers.0.self_attn.k_proj.lora.A: True
model.encoder.layers.0.self_attn.k_proj.lora.B: True
model.encoder.layers.0.self_attn.v_proj.linear.weight: False
model.encoder.layers.0.self_attn.v_proj.linear.bias: False
model.encoder.layers.0.self_attn.v_proj.lora.A: True
model.encoder.layers.0.self_attn.v_proj.lora.B: True
model.encoder.layers.0.self_attn.q_proj.linear.weight: False
model.encoder.layers.0.self_attn.q_proj.linear.bias: False
model.encoder.layers.0.self_attn.q_proj.lora.A: True
model.encoder.layers.0.self_attn.q_proj.lora.B: True
model.encoder.layers.0.self_attn.out_proj.linear.weight: False
model.encoder.layers.0.self_attn.out_proj.linear.bias: False
model.encoder.layers.0.self_attn.out_proj.lora.A: True
model.encoder.layers.0.self_attn.out_proj.lora.B: True
model.encoder.lay

In [9]:

 
def format_with_underscore(n):
    """Mini helper function to format a number with underscore as thousand separator"""
    return f"{n:_}"
def parameter_count_table(model, output_file_path=None, output_print=True, add_dtypes=False, show_nograd_paras=False):
    
    table = [["Module", "Parameters"]]
    if add_dtypes:
        table = [["Module", "Parameters", "dtype"]]
    total_params = 0
    max_len = 0
    for name, parameter in model.named_parameters():
        if (not parameter.requires_grad) and (not show_nograd_paras): continue
        params = parameter.numel()
        formatted_params = format_with_underscore(params)
        max_len = max(max_len, len(formatted_params))
        if add_dtypes:
            table.append([str(name), formatted_params, parameter.dtype])
        else:
            table.append([str(name), formatted_params])
        total_params += params

    table.append(tabulate.SEPARATING_LINE)

    formatted_total = format_with_underscore(total_params)
    max_len = max(max_len, len(formatted_total))
    if add_dtypes:
        table.append(["TOTAL", formatted_total])
    else:
        table.append(["TOTAL", formatted_total, ''])

    # Right align the numbers in the table
    for row in table[1:]:
        if row is not tabulate.SEPARATING_LINE:
            row[1] = row[1].rjust(max_len)

    tabulated_table = tabulate.tabulate(table, headers="firstrow")
    if output_file_path is not None:
        with open(output_file_path, 'w') as f:
            f.write(tabulated_table)
    if output_print:
        print(tabulated_table)
        print("")
parameter_count_table(model_lora)

Module                                               Parameters
-------------------------------------------------  ------------
model.encoder.layers.0.self_attn.k_proj.lora.A            4_096
model.encoder.layers.0.self_attn.k_proj.lora.B            4_096
model.encoder.layers.0.self_attn.v_proj.lora.A            4_096
model.encoder.layers.0.self_attn.v_proj.lora.B            4_096
model.encoder.layers.0.self_attn.q_proj.lora.A            4_096
model.encoder.layers.0.self_attn.q_proj.lora.B            4_096
model.encoder.layers.0.self_attn.out_proj.lora.A          4_096
model.encoder.layers.0.self_attn.out_proj.lora.B          4_096
model.encoder.layers.0.fc1.lora.A                         4_096
model.encoder.layers.0.fc1.lora.B                        16_384
model.encoder.layers.0.fc2.lora.A                        16_384
model.encoder.layers.0.fc2.lora.B                         4_096
model.encoder.layers.1.self_attn.k_proj.lora.A            4_096
model.encoder.layers.1.self_attn.k_proj.

In [7]:
import torch
from transformers import MBartForConditionalGeneration

# Load the mBART model
mbart_model = MBartForConditionalGeneration.from_pretrained('facebook/mbart-large-50-many-to-many-mmt')

# Function to count total parameters, trainable parameters, and frozen parameters
def count_parameters(model):
    total_params = sum(p.numel() for p in model.model.parameters())  # Total parameters
    trainable_params = sum(p.numel() for p in model.model.parameters() if p.requires_grad)  # Trainable parameters
    frozen_params = total_params - trainable_params  # Frozen parameters

    print(f"Total Parameters: {total_params:,}")
    print(f"Trainable Parameters: {trainable_params:,}")
    print(f"Frozen Parameters: {frozen_params:,}")

# Apply the function to count parameters in the mBART model
count_parameters(mbart_model)

Total Parameters: 610,879,488
Trainable Parameters: 610,879,488
Frozen Parameters: 0


In [6]:
import torch
from transformers import MBartForConditionalGeneration

# Load the mBART model
mbart_model = MBartForConditionalGeneration.from_pretrained('facebook/mbart-large-50-many-to-many-mmt')
# Function to count total parameters, trainable parameters, and frozen parameters in the encoder
def count_encoder_parameters(model):


    total_params = sum(p.numel() for p in model.model.encoder.parameters())  # Total encoder parameters
    trainable_params = sum(p.numel() for p in model.model.encoder.parameters() if p.requires_grad)  # Trainable encoder parameters
    frozen_params = total_params - trainable_params  # Frozen encoder parameters

    print(f"Encoder - Total Parameters: {total_params:,}")
    print(f"Encoder - Trainable Parameters: {trainable_params:,}")
    print(f"Encoder - Frozen Parameters: {frozen_params:,}")

# Apply the function to count parameters in the encoder of the mBART model
count_encoder_parameters(mbart_model)

Encoder - Total Parameters: 408,264,704
Encoder - Trainable Parameters: 408,264,704
Encoder - Frozen Parameters: 0
