In [7]:
import sys
sys.path.append('../')
import torch
from transformers import CamembertTokenizer # note that 'transformers' is HuggingFace package
from transformer.layers import Projection, WordDecoder # here is our 'transformer' (without 's')
from transformer.utils import create_tranformer_model

In [8]:
# Define the dimensions and vocabulary size
# Define the parameters for the transformer model
vocab_size_src = 10000
vocab_size_tgt = 10000
d_model = 512
num_layers = 6
h = 8
d_ff = 2048
dropout = 0.1
seq_len = 100
batch_size=1

# Create the transformer model
transformer = create_tranformer_model(
    vocab_size_src,
    vocab_size_tgt,
    d_model,
    num_layers,
    h,
    d_ff,
    dropout,
    seq_len,
)

# Now you can use the transformer model for encoding, decoding, and projecting
# For example, to encode a batch of source sequences:
src = torch.randint(0, vocab_size_src, (batch_size, seq_len))  # [batch_size, seq_len]
encoder_output = transformer.encode(src)

# And to decode a batch of target sequences:
tgt = torch.randint(0, vocab_size_tgt, (batch_size, seq_len))  # [batch_size, seq_len]
decoder_output = transformer.decode(tgt, encoder_output)

# And to project the decoder output:
projected_output = transformer.project(decoder_output)

In [9]:
print(projected_output.shape)

torch.Size([1, 100, 10000])


## Word decode

In [10]:
# Create an instance of the CamembertTokenizer
tokenizer = CamembertTokenizer.from_pretrained('camembert-base')

# Create an instance of the WordDecoder class
word_decoder = WordDecoder(tokenizer)

# Pass the tensor through the word decoder
decoded_word = word_decoder(projected_output)

print(decoded_word)  # Should print a word

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


[['fille', 'IP', 'IP', 'IP', 'validation', 'IP', 'IP', 'IP', 'IP', 'IP', 'IP', 'IP', 'IP', 'IP', 'IP', 'IP', 'IP', 'IP', 'IP', 'IP', 'IP', 'IP', 'IP', 'IP', 'IP', 'IP', 'IP', 'IP', 'IP', 'IP', 'IP', 'IP', 'IP', 'IP', 'IP', 'fille', 'IP', 'IP', 'IP', 'IP', 'IP', 'IP', 'fille', 'IP', 'IP', 'fille', 'IP', 'IP', 'IP', 'IP', 'IP', 'fille', 'IP', 'IP', 'fille', 'IP', 'IP', 'IP', 'fille', 'IP', 'IP', 'IP', 'IP', 'IP', 'IP', 'IP', 'IP', 'IP', 'IP', 'IP', 'IP', 'IP', 'IP', 'IP', 'IP', 'IP', 'IP', 'IP', 'IP', 'IP', 'IP', 'IP', 'IP', 'IP', 'IP', 'IP', 'fille', 'IP', 'IP', 'IP', 'IP', 'IP', 'IP', 'IP', 'IP', 'IP', 'IP', 'IP', 'IP', 'IP']]
