In [1]:
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from transformers import T5Tokenizer, T5ForConditionalGeneration

In [2]:
base_model = T5ForConditionalGeneration.from_pretrained('google/c-t5-base')
tokenizer = T5Tokenizer.from_pretrained('google/flan-t5-base')

input_text = "The quick brown fox jumps over the lazy dog."
inputs = tokenizer(input_text, return_tensors='pt')

input_ids = inputs['input_ids']
attention_mask = inputs['attention_mask']

embedding_output = base_model.encoder.embed_tokens(input_ids)

encoder_layer = base_model.encoder.block[0]
attention_Layer = encoder_layer.layer[0]
ff_layer = encoder_layer.layer[1]

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


In [8]:
input_text = "The quick brown fox jumps over the lazy dog."
inputs = tokenizer(input_text, return_tensors='pt')

data_name = {'rte', 'mnli', 'squad'}
lora_layers = {'attention_Layer':[], 'ff_layer': []}
for d in data_name:
    lora_path = f"./lora_trained/lora-t5-{d}"
    lora_base_model = T5ForConditionalGeneration.from_pretrained(lora_path)
    lora_encoder_layer = lora_base_model.encoder.block[0]
    lora_layers['attention_Layer'] += [lora_encoder_layer.layer[0]]
    lora_layers['ff_layer'] += [lora_encoder_layer.layer[1]]
# print(lora_layers)

Some weights of the model checkpoint at ./lora_trained/lora-t5-squad were not used when initializing T5ForConditionalGeneration: ['decoder.block.0.layer.0.SelfAttention.q.base_layer.weight', 'decoder.block.0.layer.0.SelfAttention.q.lora_A.default.weight', 'decoder.block.0.layer.0.SelfAttention.q.lora_B.default.weight', 'decoder.block.0.layer.0.SelfAttention.v.base_layer.weight', 'decoder.block.0.layer.0.SelfAttention.v.lora_A.default.weight', 'decoder.block.0.layer.0.SelfAttention.v.lora_B.default.weight', 'decoder.block.0.layer.1.EncDecAttention.q.base_layer.weight', 'decoder.block.0.layer.1.EncDecAttention.q.lora_A.default.weight', 'decoder.block.0.layer.1.EncDecAttention.q.lora_B.default.weight', 'decoder.block.0.layer.1.EncDecAttention.v.base_layer.weight', 'decoder.block.0.layer.1.EncDecAttention.v.lora_A.default.weight', 'decoder.block.0.layer.1.EncDecAttention.v.lora_B.default.weight', 'decoder.block.1.layer.0.SelfAttention.q.base_layer.weight', 'decoder.block.1.layer.0.SelfAtte

In [14]:
d = 768
R = 8
N = len(data_name)
L = inputs['input_ids'].shape[1]

e = nn.Parameter(torch.randn(N *L * d, N))


with torch.no_grad():
    attention_output = attention_Layer(embedding_output,attention_mask)[0]   # Equation 5 of the paper.
    feed_forward_output =  ff_layer(attention_output)[0]       # Equation 6 of paper


F_theta_x =   feed_forward_output

# Apply LoRA experts
lora_outputs = torch.zeros((N,L,d))
for i in range(N):
    lora_attention_output = lora_layers['attention_Layer'][i](embedding_output, attention_mask)[0]    # equation 7 of paper
    lora_ff_output = lora_layers['ff_layer'][i](lora_attention_output)[0]   # equation 8 of paper
    lora_attention_output = torch.squeeze(lora_attention_output)
    
    layer_norm = nn.LayerNorm(lora_attention_output.shape)

    E_omega_x_normalize = layer_norm(lora_ff_output * lora_attention_output )  # equation 9 of paper.
    lora_outputs[i]= E_omega_x_normalize


E_omega_x_flattened = lora_outputs.flatten()
epsilon = torch.matmul(E_omega_x_flattened, e)    # equation 10 of paper

# Softmax to compute gating  meaning gate value of each lora
#  learnable parameter of temperature 
temperature = 1.0 
gated_value = F.softmax(epsilon / temperature, dim=0) # equation 11 of paper

final_output_E_Omega_x = [ gated_value[x] * lora_outputs[x] for x in range(lora_outputs.shape[0])] # equation 12 of paper


# Combine the LoRA outputs based on gating values
# slightly confuse on this part.
final_output = torch.zeros((L,d))
for i in range(N):
    final_output += final_output_E_Omega_x[i] + F_theta_x # equation 13 of paper


print("Final output shape:", final_output.shape)
print(final_output)

Final output shape: torch.Size([13, 768])
tensor([[-21.4644,  11.3576,   0.7138,  ...,  -4.5001, -16.1802, -23.4917],
        [ -9.5499,  10.4079,   5.3058,  ...,  10.5751,   9.5109, -18.7083],
        [ 12.4027,   6.2433,  35.0161,  ...,  22.5951, -30.5896,  40.4467],
        ...,
        [-29.0437,  34.4034, -21.8787,  ...,  36.0988,  21.4280,  -5.9321],
        [  6.0815,   6.6027,  -0.5446,  ...,   3.2604,  11.2176,  -2.8747],
        [133.7380,  78.3760,  54.8937,  ...,  43.1555, -59.7224,  93.7194]],
       grad_fn=<AddBackward0>)
