## 1. Install libraries and import

In [1]:
!pip install datasets
!pip install torch[transformers]



In [2]:
import torch

In [3]:

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

cuda


## 2. Loading Data

In [4]:
from datasets import load_dataset
datasets = load_dataset('wikitext','wikitext-2-raw-v1')

Downloading builder script:   0%|          | 0.00/2.03k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/1.25k [00:00<?, ?B/s]

Downloading and preparing dataset wikitext/wikitext-2-raw-v1 (download: 4.50 MiB, generated: 12.90 MiB, post-processed: Unknown size, total: 17.40 MiB) to /root/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126...


Downloading data:   0%|          | 0.00/4.72M [00:00<?, ?B/s]

Generating test split:   0%|          | 0/4358 [00:00<?, ? examples/s]

Generating train split:   0%|          | 0/36718 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/3760 [00:00<?, ? examples/s]

Dataset wikitext downloaded and prepared to /root/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126. Subsequent calls will reuse this data.


  0%|          | 0/3 [00:00<?, ?it/s]

In [5]:
datasets

DatasetDict({
    test: Dataset({
        features: ['text'],
        num_rows: 4358
    })
    train: Dataset({
        features: ['text'],
        num_rows: 36718
    })
    validation: Dataset({
        features: ['text'],
        num_rows: 3760
    })
})

## 3. PREPROCESSING DATA

In [6]:
import re
def preprocess_text(sentence):
  text = sentence['text'].lower() # lowering the sentence and storing in text vaiable
  text = re.sub('[^a-z?!.,]',' ',text) # removing other than characters and punctuations
  text = re.sub('\s\s+',' ',text) # removing double spaces
  sentence['text'] = text
  return sentence

In [7]:
datasets['train'] = datasets['train'].map(preprocess_text)
datasets['test'] = datasets['test'].map(preprocess_text)
datasets['validation'] = datasets['validation'].map(preprocess_text)

  0%|          | 0/36718 [00:00<?, ?ex/s]

  0%|          | 0/4358 [00:00<?, ?ex/s]

  0%|          | 0/3760 [00:00<?, ?ex/s]

In [8]:
datasets['train'] = datasets['train'].filter(lambda x : len(x['text']) > 20)
datasets['test'] = datasets['test'].filter(lambda x : len(x['text']) > 20)
datasets['validation'] = datasets['validation'].filter(lambda x : len(x['text']) > 20)

  0%|          | 0/37 [00:00<?, ?ba/s]

  0%|          | 0/5 [00:00<?, ?ba/s]

  0%|          | 0/4 [00:00<?, ?ba/s]

In [9]:
datasets

DatasetDict({
    test: Dataset({
        features: ['text'],
        num_rows: 2312
    })
    train: Dataset({
        features: ['text'],
        num_rows: 18794
    })
    validation: Dataset({
        features: ['text'],
        num_rows: 1988
    })
})

## 3. TOKENIZATION

In [117]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained('vocab-transformers/distilbert-word2vec_256k-MLM_best')

Downloading tokenizer_config.json:   0%|          | 0.00/412 [00:00<?, ?B/s]

Downloading tokenizer.json:   0%|          | 0.00/6.20M [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

In [11]:
tokenizer.special_tokens_map

{'unk_token': '[UNK]',
 'sep_token': '[SEP]',
 'pad_token': '[PAD]',
 'cls_token': '[CLS]',
 'mask_token': '[MASK]'}

In [12]:
vocab_size = len(tokenizer.vocab)

In [13]:
def tokenize(sentence):
  sentence = tokenizer(sentence['text'],truncation = True)

  return sentence

#tokenized_inputs = datasets['train'].map(tokenize)
tokenized_inputs = datasets['test'].map(tokenize)

  0%|          | 0/2312 [00:00<?, ?ex/s]

In [14]:
tokenized_inputs = tokenized_inputs.remove_columns(['text','token_type_ids'])

In [15]:
from transformers import DataCollatorWithPadding
from torch.utils.data import DataLoader

batch = 16

data_collator = DataCollatorWithPadding(tokenizer=tokenizer, padding=True,return_tensors ="pt")
dataloader = DataLoader(tokenized_inputs,batch_size=batch,collate_fn=data_collator)



In [16]:
tokenized_inputs

Dataset({
    features: ['input_ids', 'attention_mask'],
    num_rows: 2312
})

## 4. MODEL

In [17]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.fft as fft
import numpy as np
import pandas as pd

In [18]:
class PositionalEncoding(torch.nn.Module):
    """
    Pytorch module that creates a positional encoding matrix. This matrix will later be added to the
    transformer's input embeddings to provide a sense of position of the sequence elements.
    """

    def __init__(self, d_model, max_sequence_length):
        super().__init__()
        self.d_model = d_model
        self.max_sequence_length = max_sequence_length
        self.positional_encoding = self.create_positional_encoding().to(device)

    def create_positional_encoding(self):
        """
        Creates a positional encoding matrix of size (max_sequence_length, d_model).
        """

        # Initialize positional encoding matrix
        positional_encoding = np.zeros((self.max_sequence_length, self.d_model))

        # Calculate positional encoding for each position and each dimension
        for pos in range(self.max_sequence_length):
            for i in range(0, self.d_model, 2):
                # Apply sin to even indices in the array; indices in Python start at 0 so i is even.
                positional_encoding[pos, i] = np.sin(pos / (10000 ** ((2 * i) / self.d_model)))

                if i + 1 < self.d_model:
                    # Apply cos to odd indices in the array; we add 1 to i because indices in Python start at 0.
                    positional_encoding[pos, i + 1] = np.cos(pos / (10000 ** ((2 * i) / self.d_model)))

        # Convert numpy array to PyTorch tensor and return it
        return torch.from_numpy(positional_encoding).float()

    def forward(self, x):
        """
        Adds the positional encoding to the input embeddings at the corresponding positions.
        """
        # Add positional encodings to input embeddings. The ":" indexing ensures we only add positional encodings up
        # to the length of the sequence in the batch. x.size(0) is the batch size, so this is a way to make sure
        # we're not adding extra positional encodings.
        expanded_tensor = torch.unsqueeze(self.positional_encoding, 0).expand(x.size(0), -1, -1).to(device)

        return x.to(device) + expanded_tensor[:,:x.size(1), :]





In [19]:
class PositionalEmbedding(nn.Module):
  def __init__(self, sequence_length, vocab_size, embed_dim):
    super(PositionalEmbedding, self).__init__()
    self.token_embeddings = nn.Embedding(vocab_size, embed_dim)
    self.position_embeddings = PositionalEncoding(embed_dim,sequence_length)

  def forward(self, inputs):
    embedded_tokens = self.token_embeddings(inputs).to(device)
    embedded_positions = self.position_embeddings(embedded_tokens).to(device)
    return embedded_positions.to(device)

In [20]:
class FNetEncoder(nn.Module):

  def __init__(self,embed_dim, dense_dim):
    super(FNetEncoder,self).__init__()
    self.embed_dim = embed_dim
    self.dense_dim = dense_dim
    self.dense_proj = nn.Sequential(nn.Linear(self.embed_dim,self.dense_dim), nn.ReLU(), nn.Linear(self.dense_dim,self.embed_dim))

    self.layernorm_1 = nn.LayerNorm(self.embed_dim)
    self.layernorm_2 = nn.LayerNorm(self.embed_dim)

  def forward(self,inputs):

    fft_result = fft.fft2(inputs)

    #taking real part
    fft_real = fft_result.real.float()

    proj_input = self.layernorm_1 (inputs + fft_real)
    proj_output = self.dense_proj(proj_input)
    return self.layernorm_2(proj_input +proj_output)





In [70]:
class FNetDecoder(nn.Module):

  def __init__(self,embed_dim,dense_dim,num_heads):
    super(FNetDecoder,self).__init__()
    self.embed_dim = embed_dim
    self.dense_dim = dense_dim
    self.num_heads = num_heads

    self.attention_1 = nn.MultiheadAttention(embed_dim,num_heads,batch_first=True)
    self.attention_2 = nn.MultiheadAttention(embed_dim,num_heads,batch_first=True)

    self.dense_proj = nn.Sequential(nn.Linear(embed_dim, dense_dim),nn.ReLU(),nn.Linear(dense_dim, embed_dim))

    self.layernorm_1 = nn.LayerNorm(embed_dim)
    self.layernorm_2 = nn.LayerNorm(embed_dim)
    self.layernorm_3 = nn.LayerNorm(embed_dim)

  def forward(self, inputs, encoder_outputs, mask=None):
    causal_mask = nn.Transformer.generate_square_subsequent_mask(inputs.size(1)).to(device)
 #   print(causal_mask.size())

    attention_output_1, _ = self.attention_1(inputs, inputs, inputs, attn_mask=causal_mask)
    out_1 = self.layernorm_1(inputs + attention_output_1)
#    print (out_1.size(),encoder_outputs.size())
    if mask != None:
      attention_output_2, _ = self.attention_2(out_1, encoder_outputs, encoder_outputs, key_padding_mask =torch.transpose(mask, 0, 1).to(device))
    else:
      attention_output_2, _ = self.attention_2(out_1, encoder_outputs, encoder_outputs)
    out_2 = self.layernorm_2(out_1 + attention_output_2)

    proj_output = self.dense_proj(out_2)
    return self.layernorm_3(out_2 + proj_output)



In [71]:
class FNetModel(nn.Module):
    def __init__(self, max_length, vocab_size, embed_dim, latent_dim, num_heads):
        super(FNetModel, self).__init__()

        self.encoder_inputs = PositionalEmbedding(max_length,vocab_size, embed_dim)
        self.encoder1 = FNetEncoder(embed_dim, latent_dim)
        self.encoder2 = FNetEncoder(embed_dim, latent_dim)
        self.encoder3 = FNetEncoder(embed_dim, latent_dim)
        self.encoder4 = FNetEncoder(embed_dim, latent_dim)


        self.decoder_inputs = PositionalEmbedding(max_length,vocab_size, embed_dim)
        self.decoder1 = FNetDecoder(embed_dim, latent_dim, num_heads)
        self.decoder2 = FNetDecoder(embed_dim, latent_dim, num_heads)
        self.decoder3 = FNetDecoder(embed_dim, latent_dim, num_heads)
        self.decoder4 = FNetDecoder(embed_dim, latent_dim, num_heads)


        self.dropout = nn.Dropout(0.5)
        self.dense = nn.Linear(embed_dim, vocab_size)
        
    def encoder(self,encoder_inputs):
        x_encoder = self.encoder_inputs(encoder_inputs)
        x_encoder = self.encoder1(x_encoder)
        x_encoder = self.encoder2(x_encoder)
        x_encoder = self.encoder3(x_encoder)
        x_encoder = self.encoder4(x_encoder)
        return x_encoder
    
    def decoder(self,decoder_inputs,encoder_output,att_mask):
        x_decoder = self.decoder_inputs(decoder_inputs)
        x_decoder = self.decoder1(x_decoder, encoder_output,att_mask) ## HERE for inference
        x_decoder = self.decoder2(x_decoder, encoder_output,att_mask) ## HERE for inference
        x_decoder = self.decoder3(x_decoder, encoder_output,att_mask) ## HERE for inference
        x_decoder = self.decoder4(x_decoder, encoder_output,att_mask) ## HERE for inference
        decoder_outputs = self.dense(x_decoder)   
    
        return decoder_outputs

    def forward(self, encoder_inputs, decoder_inputs,att_mask = None):
        encoder_output = self.encoder(encoder_inputs)
        decoder_output = self.decoder(decoder_inputs,encoder_output,att_mask=None)
        return decoder_output


In [72]:
# Assuming your constants are defined like this:
MAX_LENGTH = 512
VOCAB_SIZE = len(tokenizer.vocab)
EMBED_DIM = 256
LATENT_DIM = 100
NUM_HEADS = 4

# Create an instance of the model
fnet_model = FNetModel(MAX_LENGTH, VOCAB_SIZE, EMBED_DIM, LATENT_DIM, NUM_HEADS).to(device)



In [115]:
# # Define your optimizer and loss function
optimizer = torch.optim.Adam(fnet_model.parameters())
criterion = nn.CrossEntropyLoss(ignore_index=0)

epochs = 10
for epoch in range(epochs):
    train_loss = 0
    for batch in dataloader:
        encoder_inputs_tensor = batch['input_ids'][:,:-1].to(device)
        decoder_inputs_tensor = batch['input_ids'][:,1:].to(device)
      #  print(encoder_inputs_tensor)
       # print(decoder_inputs_tensor)
        att_mask = batch['attention_mask'][:,:-1].to(device).to(dtype=bool)
        optimizer.zero_grad()
        outputs = fnet_model(encoder_inputs_tensor, decoder_inputs_tensor,att_mask)
#         print(outputs.size())
#         print(outputs.view(-1, VOCAB_SIZE).size())
#         print(decoder_inputs_tensor.size())
#         print(decoder_inputs_tensor.view(-1).size())
#       #  break
        decoder_inputs_tensor.masked_fill(batch['attention_mask'][:,1:].ne(1).to(device), -100).to(device)

        loss = criterion(outputs.view(-1, VOCAB_SIZE), decoder_inputs_tensor.view(-1))
    #    loss = criterion(input = outputs, target= decoder_inputs_tensor)
        train_loss = train_loss + loss.item()
        loss.backward()
        optimizer.step()
    print (f"train_loss : {train_loss}")

train_loss : 0.6688331579875921
train_loss : 0.05544484220445156
train_loss : 0.01615841467582868
train_loss : 15.153724075993523
train_loss : 18.581757632637164
train_loss : 0.7130553867173148
train_loss : 0.5788254780454736
train_loss : 0.22383555448323023
train_loss : 0.05019339832870173
train_loss : 0.3218885977730679


In [116]:
import torch
import torch.nn.functional as F

#VOCAB = vectorizer.get_vocabulary()
MAX_LENGTH =100 # your MAX_LENGTH value

def decode_sentence(input_sentence, fnet_model):
    fnet_model.eval()

    with torch.no_grad():
        tokenized_input_sentence = torch.tensor(tokenizer(preprocess_text(input_sentence)['text'])['input_ids']).to(device)# 
        tokenzied_target_sentence = torch.tensor([101]).to(device) # '[CLS]' token  
        current_text = preprocess_text(input_sentence)['text']
        for i in range(MAX_LENGTH):
            predictions = fnet_model(tokenized_input_sentence[:-1].unsqueeze(0),tokenzied_target_sentence.unsqueeze(0))
            predicted_index = torch.argmax(predictions[0, -1, :]).item()
            predicted_token = tokenizer.decode(predicted_index)
            if predicted_token == "[SEP]":  # Assuming [end] is the end token
              break
            current_text += " "+ predicted_token
            tokenized_target_sentence = torch.cat([tokenzied_target_sentence, torch.tensor([predicted_index]).to(device)], 0).to(device)
            tokenized_input_sentence = torch.tensor(tokenizer(current_text)['input_ids']).to(device)
        return current_text
decode_sentence({'text': 'How are you ?'}, fnet_model)

'how are you ? next ##wg ##wg ##wg ##wg ##wg ##wg ##wg ##wg ##wg ##wg ##wg ##wg ##wg ##wg ##wg ##wg ##wg ##wg ##wg ##wg ##isan ##wg ##isan ##isan ##isan ball ##wg ##wg ##isan ball ##isan ##isan ball ##isan ##isan ##isan ##isan ##isan ##isan ##isan ##isan ##isan ball ball ball ball ##isan ##isan ##isan ##isan ##isan ##isan ##isan ##isan ##isan ##isan ##isan had ##wg ##isan ##isan ##isan ##isan ##wg ##isan ball ball ##isan ##isan ##isan ##isan ##isan ##isan ##isan ##isan ##isan ##isan ##isan ##isan ##isan ##isan ##isan ##isan ##isan ##isan ##isan ##isan ##isan ball ##isan ##isan ##isan ##isan ##isan ##wg ##isan ##isan ##isan ##isan'

In [164]:
tokenizer('')

{'input_ids': [101, 102], 'token_type_ids': [0, 0], 'attention_mask': [1, 1]}

In [110]:
for batchs in dataloader:
    print(batchs)
    break

{'input_ids': tensor([[ 101, 2728, 8945,  ...,    0,    0,    0],
        [ 101, 1999, 1010,  ...,    0,    0,    0],
        [ 101, 1999, 8945,  ...,    0,    0,    0],
        ...,
        [ 101, 2010, 2269,  ...,    0,    0,    0],
        [ 101, 1999, 1996,  ...,    0,    0,    0],
        [ 101, 1999, 1010,  ...,    0,    0,    0]]), 'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        ...,
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0]])}


In [29]:
tokenizer.encode('[CLS]')

[101, 101, 102]

In [None]:
[-16.4226, -15.4476,  17.9977, -15.4673, -15.5945, -15.4213, -16.0374,
        -14.9229, -14.6924, -15.9143]

In [96]:
tokenizer.decode(tokenized_inputs[0]['input_ids'])

'[CLS] robert boulter is an english film, television and theatre actor. he had a guest starring role on the television series the bill in. this was followed by a starring role in the play herons written by simon stephens, which was performed in at the royal court theatre. he had a guest role in the television series judge john deed in. in boulter landed a role as craig in the episode teddy s story of the television series the long firm he starred alongside actors mark strong and derek jacobi. he was cast in the theatre productions of the philip ridley play mercury fur, which was performed at the drum theatre in plymouth and the menier chocolate factory in london. he was directed by john tiffany and starred alongside ben whishaw, shane zaza, harry kent, fraser ayres, sophie stanton and dominic hall. [SEP]'

In [109]:
tokenizer('[hi this]',padding=True,batch=12)

TypeError: PreTrainedTokenizerFast._batch_encode_plus() got an unexpected keyword argument 'batch'

In [None]:
import torch
import torch.nn.functional as F

#VOCAB = vectorizer.get_vocabulary()
MAX_LENGTH =10 # your MAX_LENGTH value

def decode_sentence(input_sentence, fnet_model):
    fnet_model.eval()

    with torch.no_grad():
        # Mapping the input sentence to tokens and adding start and end tokens
        tokenized_input_sentence = torch.tensor(tokenizer(preprocess_text(input_sentence)['text'])['input_ids']).to(device)# )
        inital_state = fent_model.encoder(tokenized_input_sentence)[:,-1,:]
        
        inital_target = torch.tensor([101]).to(device) # start token [CLS]

        current_text = '' 

        for i in range(MAX_LENGTH):
          # Get the predictions
            if i == 0:
                predictions = fnet_model.decoder(inital_state,inital_target)
            predicted_index = torch.argmax(predictions[0, -1, :]).item()
        
               
            predicted_token = tokenizer.decode(predicted_index)

            if predicted_token == "[sep]":  # [sep] is the end token
              break
            current_text += " " + predicted_token
            initial_target = torch.tensor(tokenizer(current_text)['input_ids'][:-1]).to(device)

        return current_text

#         # Calculating the token with maximum probability and getting the corresponding word
#         sampled_token_index = torch.argmax(predictions[0, 0, :]).item()
#  #       sampled_token = tokenizer.vocab[sampled_token_index]
#         # If sampled token is the end token then stop generating and return the sentence
#       #  if sampled_token == "[end]":
#        #     break
#         decoded_sentence += str(sampled_token_index) + " "
#         tokenized_target_sentence = torch.cat(
#             [tokenized_target_sentence, torch.tensor([sampled_token_index])], 0
#         )

#     return decoded_sentence

# Assuming you have a PyTorch model named fnet_model
# You need to pass fnet_model as an argument to the function
#decode fnet_model = YourPyTorchModel()
decode_sentence({'text':'Where have  you all been '}, fnet_model)

In [19]:



class FNetTextGenerator(nn.Module):
  def __init__(self, embed_dim, latent_dim, vocab_size, max_seq_len, num_heads):
    super(FNetTextGenerator, self).__init__()
    self.positional = PositionalEmbedding(max_seq_len, vocab_size, embed_dim)
    self.encoder = FNetEncoder(embed_dim,latent_dim)
    self.decoder = FNetDecoder(embed_dim, latent_dim, num_heads)

  def forward(self, inputs, target=None):
    positional_inputs = self.positional(inputs)
    encoder_output = self.encoder(positional_inputs)
#    print(encoder_output.size())
    if target is not None:
      decoder_output = self.decoder(target, encoder_output)
      return decoder_output

    # If no target is provided, generate autoregressively
    batch_size, seq_len = inputs.size()
  #  print(inputs.size())
    generated_sequence = torch.zeros(batch_size, seq_len, dtype=torch.long, device=inputs.device)

        # Initial input for autoregressive decoding
    input_token = inputs[:, 0].unsqueeze(1).float()
 #   print(input_token)

# ...

    for t in range(1, seq_len):
      decoder_output = self.decoder(positional_inputs, encoder_output)

    # Use torch.argmax directly to get the indices of the maximum values
      predicted_token = torch.argmax(F.softmax(decoder_output, dim=-1), dim=-1)

          # Check the shape of predicted_token before updating the generated sequence
      print(f"Shape of predicted_token before update: {predicted_token.shape}")


    # Ensure that predicted_token has the correct shape
      predicted_token = predicted_token.unsqueeze(1)

    # Update the generated sequence
      generated_sequence[:, t] = predicted_token.squeeze()

    # Update input_token for the next iteration
      input_token = predicted_token



    return generated_sequence

# Example usage
embed_dim = 256
latent_dim = 512
vocab_size = 10000
max_seq_len = 50
num_heads = 8



model = FNetTextGenerator(embed_dim, latent_dim, vocab_size, max_seq_len, num_heads).to(device)

# Dummy input
inputs = torch.randint(0, vocab_size, (16, max_seq_len)).long().to(device)


# Forward pass for text generation
generated_sequence = model(inputs)
print("Generated Sequence:", generated_sequence)

Shape of predicted_token before update: torch.Size([16, 50])


RuntimeError: ignored

In [None]:
inputs = torch.randint(0, vocab_size, (16, max_seq_len)).to(device)

In [None]:
inputs.dtype

In [None]:
class FNetTextGenerator(nn.Module):
    def __init__(self, embed_dim, latent_dim, vocab_size, max_seq_len, num_heads, num_layers):
        super(FNetTextGenerator, self).__init__()

        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.positional_embedding = nn.Embedding(max_seq_len, embed_dim)
        self.transformer_decoder_layer = nn.TransformerDecoderLayer(embed_dim, num_heads)
        self.transformer_decoder = nn.TransformerDecoder(self.transformer_decoder_layer, num_layers=num_layers)

        self.transformer_encoder_layer = FNetEncoder(embed_dim, latent_dim)
      #  self.transformer_encoder = torch.nn.TransformerEncoder(self.transformer_encoder_layer, num_layers = num_layers)

        self.fc = nn.Linear(embed_dim, vocab_size)

    def forward(self, inputs, target=None):
        seq_length = inputs.size(1)

        embedded_tokens = self.embedding(inputs)
        positions = torch.arange(0, seq_length, device=inputs.device).unsqueeze(0)
        embedded_positions = self.positional_embedding(positions)
        encoder_input = embedded_tokens + embedded_positions

        encoder_output = self.transformer_encoder_layer(encoder_input)

        if target is not None:
            target = target.permute(1, 0, 2)  # Adjust the shape for transformer decoder
            decoder_output = self.transformer_decoder(target, encoder_output)
            return decoder_output.permute(1, 0, 2)  # Adjust the shape back to batch-first

        # If no target is provided, generate autoregressively
        batch_size, seq_len = inputs.size()
        generated_sequence = torch.zeros(batch_size, seq_len, dtype=torch.long, device=inputs.device)

        # Initial input for autoregressive decoding
        input_token = inputs[:, 0].unsqueeze(1)

        for t in range(1, seq_len):
            embedded_input_token = self.embedding(input_token) + self.positional_embedding(torch.tensor([[t]]).to(inputs.device))
            decoder_output = self.transformer_decoder(embedded_input_token, encoder_output)
            logits = self.fc(decoder_output[-1])
            _, predicted_token = torch.max(F.softmax(logits, dim=-1), dim=-1)
            generated_sequence[:, t] = predicted_token.squeeze()

            input_token = predicted_token.unsqueeze(1)

        return generated_sequence



In [None]:
# Example usage
embed_dim = 256
latent_dim = 512
vocab_size = 10000
max_seq_len = 50
num_heads = 8
num_layers = 6

model = FNetTextGenerator(embed_dim, latent_dim, vocab_size, max_seq_len, num_heads, num_layers).to(device)

# Dummy input
inputs = torch.randint(0, vocab_size, (2, max_seq_len)).to(device)

# Forward pass for text generation
generated_sequence = model(inputs)
print("Generated Sequence:", generated_sequence)

TypeError: ignored

TypeError: ignored

In [126]:
type()

torch.Tensor