<a href="https://colab.research.google.com/github/rahulsm27/ML/blob/main/Text_Generation_using_Fnet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## 1. Install libraries and import

In [1]:
!pip install datasets
!pip install torch[transformers]

Collecting datasets
  Downloading datasets-2.15.0-py3-none-any.whl (521 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m521.2/521.2 kB[0m [31m7.8 MB/s[0m eta [36m0:00:00[0m
Collecting pyarrow-hotfix (from datasets)
  Downloading pyarrow_hotfix-0.5-py3-none-any.whl (7.8 kB)
Collecting dill<0.3.8,>=0.3.0 (from datasets)
  Downloading dill-0.3.7-py3-none-any.whl (115 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m115.3/115.3 kB[0m [31m8.4 MB/s[0m eta [36m0:00:00[0m
Collecting multiprocess (from datasets)
  Downloading multiprocess-0.70.15-py310-none-any.whl (134 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.8/134.8 kB[0m [31m13.5 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: pyarrow-hotfix, dill, multiprocess, datasets
Successfully installed datasets-2.15.0 dill-0.3.7 multiprocess-0.70.15 pyarrow-hotfix-0.5


In [2]:
import torch

In [3]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

cpu


## 2. Loading Data

In [4]:
from datasets import load_dataset
datasets = load_dataset('wikitext','wikitext-2-raw-v1')

Downloading builder script:   0%|          | 0.00/8.48k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/6.84k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/9.62k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/4.72M [00:00<?, ?B/s]

Generating test split:   0%|          | 0/4358 [00:00<?, ? examples/s]

Generating train split:   0%|          | 0/36718 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/3760 [00:00<?, ? examples/s]

In [5]:
datasets

DatasetDict({
    test: Dataset({
        features: ['text'],
        num_rows: 4358
    })
    train: Dataset({
        features: ['text'],
        num_rows: 36718
    })
    validation: Dataset({
        features: ['text'],
        num_rows: 3760
    })
})

## 3. PREPROCESSING DATA

In [6]:
import re
def preprocess_text(sentence):
  text = sentence['text'].lower() # lowering the sentence and storing in text vaiable
  text = re.sub('[^a-z?!.,]',' ',text) # removing other than characters and punctuations
  text = re.sub('\s\s+',' ',text) # removing double spaces
  sentence['text'] = text
  return sentence

In [7]:
datasets['train'] = datasets['train'].map(preprocess_text)
datasets['test'] = datasets['test'].map(preprocess_text)
datasets['validation'] = datasets['validation'].map(preprocess_text)

Map:   0%|          | 0/36718 [00:00<?, ? examples/s]

Map:   0%|          | 0/4358 [00:00<?, ? examples/s]

Map:   0%|          | 0/3760 [00:00<?, ? examples/s]

In [8]:
datasets['train'] = datasets['train'].filter(lambda x : len(x['text']) > 20)
datasets['test'] = datasets['test'].filter(lambda x : len(x['text']) > 20)
datasets['validation'] = datasets['validation'].filter(lambda x : len(x['text']) > 20)

Filter:   0%|          | 0/36718 [00:00<?, ? examples/s]

Filter:   0%|          | 0/4358 [00:00<?, ? examples/s]

Filter:   0%|          | 0/3760 [00:00<?, ? examples/s]

In [9]:
datasets

DatasetDict({
    test: Dataset({
        features: ['text'],
        num_rows: 2312
    })
    train: Dataset({
        features: ['text'],
        num_rows: 18794
    })
    validation: Dataset({
        features: ['text'],
        num_rows: 1988
    })
})

## 3. TOKENIZATION

In [10]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')

tokenizer_config.json:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

In [11]:
tokenizer.special_tokens_map

{'unk_token': '[UNK]',
 'sep_token': '[SEP]',
 'pad_token': '[PAD]',
 'cls_token': '[CLS]',
 'mask_token': '[MASK]'}

In [12]:
vocab_size = len(tokenizer.vocab)

In [13]:
def tokenize(sentence):
  sentence = tokenizer(sentence['text'],truncation = True)

  return sentence

tokenized_inputs = datasets['train'].map(tokenize)

Map:   0%|          | 0/18794 [00:00<?, ? examples/s]

In [60]:
tokenized_inputs = tokenized_inputs.remove_columns(['text','token_type_ids'])

In [61]:
from transformers import DataCollatorWithPadding
from torch.utils.data import DataLoader

batch = 16

data_collator = DataCollatorWithPadding(tokenizer=tokenizer, padding=True,return_tensors ="pt")
dataloader = DataLoader(tokenized_inputs,batch_size=batch,collate_fn=data_collator)

In [34]:
type(dataloader)


torch.utils.data.dataloader.DataLoader

## 4. MODEL

In [15]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.fft as fft

In [79]:
class PositionalEmbedding(nn.Module):
  def __init__(self, sequence_length, vocab_size, embed_dim):
    super(PositionalEmbedding, self).__init__()
    self.token_embeddings = nn.Embedding(vocab_size, embed_dim)
    self.position_embeddings = nn.Embedding(sequence_length, embed_dim)

  def forward(self, inputs):
    length = inputs.size(-1)
    positions = torch.arange(0, length, device=inputs.device).unsqueeze(0)
    embedded_tokens = self.token_embeddings(inputs)
    embedded_positions = self.position_embeddings(positions)
    return embedded_tokens + embedded_positions


In [80]:
class FNetEncoder(nn.Module):

  def __init__(self,embed_dim, dense_dim):
    super(FNetEncoder,self).__init__()
    self.embed_dim = embed_dim
    self.dense_dim = dense_dim
    self.dense_proj = nn.Sequential(nn.Linear(self.embed_dim,self.dense_dim), nn.ReLU(), nn.Linear(self.dense_dim,self.embed_dim))

    self.layernorm_1 = nn.LayerNorm(self.embed_dim)
    self.layernorm_2 = nn.LayerNorm(self.embed_dim)

  def forward(self,inputs):

    fft_result = fft.fft2(inputs)

    #taking real part
    fft_real = fft_result.real.float()

    proj_input = self.layernorm_1 (inputs + fft_real)
    proj_output = self.dense_proj(proj_input)
    return self.layernorm_2(proj_input +proj_output)





In [81]:
class FNetDecoder(nn.Module):

  def __init__(self,embed_dim,dense_dim,num_heads):
    super(FNetDecoder,self).__init__()
    self.embed_dim = embed_dim
    self.dense_dim = dense_dim
    self.num_heads = num_heads

    self.attention_1 = nn.MultiheadAttention(embed_dim,num_heads)
    self.attention_2 = nn.MultiheadAttention(embed_dim,num_heads)

    self.dense_proj = nn.Sequential(nn.Linear(embed_dim, latent_dim),nn.ReLU(),nn.Linear(latent_dim, embed_dim))

    self.layernorm_1 = nn.LayerNorm(embed_dim)
    self.layernorm_2 = nn.LayerNorm(embed_dim)
    self.layernorm_3 = nn.LayerNorm(embed_dim)

  def forward(self, inputs, encoder_outputs, mask=None):
    causal_mask = nn.Transformer.generate_square_subsequent_mask(inputs.size(0)).to(device)

    if mask is not None:
        padding_mask = mask.unsqueeze(1)
        padding_mask = torch.min(padding_mask, causal_mask)

    attention_output_1, _ = self.attention_1(inputs, inputs, inputs, attn_mask=causal_mask)
    out_1 = self.layernorm_1(inputs + attention_output_1)

    attention_output_2, _ = self.attention_2(out_1, encoder_outputs, encoder_outputs, attn_mask=causal_mask)
    out_2 = self.layernorm_2(out_1 + attention_output_2)

    proj_output = self.dense_proj(out_2)
    return self.layernorm_3(out_2 + proj_output)



In [86]:
class FNetModel(nn.Module):
    def __init__(self, max_length, vocab_size, embed_dim, latent_dim, num_heads):
        super(FNetModel, self).__init__()

        self.encoder_inputs = PositionalEmbedding(max_length,vocab_size, embed_dim)
        self.encoder = FNetEncoder(embed_dim, latent_dim)

        self.decoder_inputs = PositionalEmbedding(max_length,vocab_size, embed_dim)
        self.decoder_state_inputs = nn.Linear(embed_dim, latent_dim)
        self.decoder = FNetDecoder(embed_dim, latent_dim, num_heads)

        self.dropout = nn.Dropout(0.5)
        self.dense = nn.Linear(latent_dim, vocab_size)

    def forward(self, encoder_inputs, decoder_inputs):
        x_encoder = self.encoder_inputs(encoder_inputs)
        x_encoder = self.encoder(x_encoder)

        x_decoder = self.decoder_inputs(decoder_inputs)
        x_decoder_state = self.decoder_state_inputs(x_encoder[:, -1, :])
        x_decoder = self.decoder(x_decoder, x_decoder_state)
        x_decoder = self.dropout(x_decoder)

        decoder_outputs = self.dense(x_decoder)
        return F.softmax(decoder_outputs, dim=-1)

In [87]:
# Assuming your constants are defined like this:
MAX_LENGTH = 512
VOCAB_SIZE = len(tokenizer.vocab)
EMBED_DIM = 256
LATENT_DIM = 100
NUM_HEADS = 4

# Create an instance of the model
fnet_model = FNetModel(MAX_LENGTH, VOCAB_SIZE, EMBED_DIM, LATENT_DIM, NUM_HEADS)

# # Define your optimizer and loss function
# optimizer = torch.optim.Adam(fnet_model.parameters())
# criterion = nn.CrossEntropyLoss()

# # Convert your data to PyTorch tensors
# # Example:
# # encoder_inputs_tensor = torch.tensor(encoder_inputs_data)
# # decoder_inputs_tensor = torch.tensor(decoder_inputs_data)
# # target_tensor = torch.tensor(target_data)

# # Training loop
# epochs = 10
# for epoch in range(epochs):
#     optimizer.zero_grad()
#     outputs = fnet_model(encoder_inputs_tensor, decoder_inputs_tensor)
#     loss = criterion(outputs.view(-1, VOCAB_SIZE), target_tensor.view(-1))
#     loss.backward()
#     optimizer.step()


NameError: ignored

In [88]:
import torch
import torch.nn.functional as F

#VOCAB = vectorizer.get_vocabulary()
MAX_LENGTH =10 # your MAX_LENGTH value

def decode_sentence(input_sentence, fnet_model):
    # Mapping the input sentence to tokens and adding start and end tokens
    tokenized_input_sentence = tokenizer("[start] " + preprocess_text(input_sentence)['text'] + " [end]")
    # )
     # Initializing the initial sentence consisting of only the start token.
    tokenized_target_sentence = torch.tensor(tokenizer.vocab["start"])
    decoded_sentence = ""

    for i in range(MAX_LENGTH):
        # Get the predictions
        with torch.no_grad():
            predictions = fnet_model(
                tokenized_input_sentence,
                tokenized_target_sentence,
            )
        # Calculating the token with maximum probability and getting the corresponding word
        sampled_token_index = torch.argmax(predictions[0, i, :]).item()
        sampled_token = tokenizer.vocab[sampled_token_index]
        # If sampled token is the end token then stop generating and return the sentence
        if sampled_token == "[end]":
            break
        decoded_sentence += sampled_token + " "
        tokenized_target_sentence = torch.cat(
            [tokenized_target_sentence, torch.tensor([sampled_token_index])], 0
        )

    return decoded_sentence

# Assuming you have a PyTorch model named fnet_model
# You need to pass fnet_model as an argument to the function
#decode fnet_model = YourPyTorchModel()
decode_sentence({'text':'Where have you been all this time?'}, fnet_model)

AttributeError: ignored

In [71]:
tokenizer.vocab["start"]

2707

In [19]:



class FNetTextGenerator(nn.Module):
  def __init__(self, embed_dim, latent_dim, vocab_size, max_seq_len, num_heads):
    super(FNetTextGenerator, self).__init__()
    self.positional = PositionalEmbedding(max_seq_len, vocab_size, embed_dim)
    self.encoder = FNetEncoder(embed_dim,latent_dim)
    self.decoder = FNetDecoder(embed_dim, latent_dim, num_heads)

  def forward(self, inputs, target=None):
    positional_inputs = self.positional(inputs)
    encoder_output = self.encoder(positional_inputs)
#    print(encoder_output.size())
    if target is not None:
      decoder_output = self.decoder(target, encoder_output)
      return decoder_output

    # If no target is provided, generate autoregressively
    batch_size, seq_len = inputs.size()
  #  print(inputs.size())
    generated_sequence = torch.zeros(batch_size, seq_len, dtype=torch.long, device=inputs.device)

        # Initial input for autoregressive decoding
    input_token = inputs[:, 0].unsqueeze(1).float()
 #   print(input_token)

# ...

    for t in range(1, seq_len):
      decoder_output = self.decoder(positional_inputs, encoder_output)

    # Use torch.argmax directly to get the indices of the maximum values
      predicted_token = torch.argmax(F.softmax(decoder_output, dim=-1), dim=-1)

          # Check the shape of predicted_token before updating the generated sequence
      print(f"Shape of predicted_token before update: {predicted_token.shape}")


    # Ensure that predicted_token has the correct shape
      predicted_token = predicted_token.unsqueeze(1)

    # Update the generated sequence
      generated_sequence[:, t] = predicted_token.squeeze()

    # Update input_token for the next iteration
      input_token = predicted_token



    return generated_sequence

# Example usage
embed_dim = 256
latent_dim = 512
vocab_size = 10000
max_seq_len = 50
num_heads = 8



model = FNetTextGenerator(embed_dim, latent_dim, vocab_size, max_seq_len, num_heads).to(device)

# Dummy input
inputs = torch.randint(0, vocab_size, (16, max_seq_len)).long().to(device)


# Forward pass for text generation
generated_sequence = model(inputs)
print("Generated Sequence:", generated_sequence)

Shape of predicted_token before update: torch.Size([16, 50])


RuntimeError: ignored

In [None]:
inputs = torch.randint(0, vocab_size, (16, max_seq_len)).to(device)

In [None]:
inputs.dtype

In [None]:
class FNetTextGenerator(nn.Module):
    def __init__(self, embed_dim, latent_dim, vocab_size, max_seq_len, num_heads, num_layers):
        super(FNetTextGenerator, self).__init__()

        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.positional_embedding = nn.Embedding(max_seq_len, embed_dim)
        self.transformer_decoder_layer = nn.TransformerDecoderLayer(embed_dim, num_heads)
        self.transformer_decoder = nn.TransformerDecoder(self.transformer_decoder_layer, num_layers=num_layers)

        self.transformer_encoder_layer = FNetEncoder(embed_dim, latent_dim)
      #  self.transformer_encoder = torch.nn.TransformerEncoder(self.transformer_encoder_layer, num_layers = num_layers)

        self.fc = nn.Linear(embed_dim, vocab_size)

    def forward(self, inputs, target=None):
        seq_length = inputs.size(1)

        embedded_tokens = self.embedding(inputs)
        positions = torch.arange(0, seq_length, device=inputs.device).unsqueeze(0)
        embedded_positions = self.positional_embedding(positions)
        encoder_input = embedded_tokens + embedded_positions

        encoder_output = self.transformer_encoder_layer(encoder_input)

        if target is not None:
            target = target.permute(1, 0, 2)  # Adjust the shape for transformer decoder
            decoder_output = self.transformer_decoder(target, encoder_output)
            return decoder_output.permute(1, 0, 2)  # Adjust the shape back to batch-first

        # If no target is provided, generate autoregressively
        batch_size, seq_len = inputs.size()
        generated_sequence = torch.zeros(batch_size, seq_len, dtype=torch.long, device=inputs.device)

        # Initial input for autoregressive decoding
        input_token = inputs[:, 0].unsqueeze(1)

        for t in range(1, seq_len):
            embedded_input_token = self.embedding(input_token) + self.positional_embedding(torch.tensor([[t]]).to(inputs.device))
            decoder_output = self.transformer_decoder(embedded_input_token, encoder_output)
            logits = self.fc(decoder_output[-1])
            _, predicted_token = torch.max(F.softmax(logits, dim=-1), dim=-1)
            generated_sequence[:, t] = predicted_token.squeeze()

            input_token = predicted_token.unsqueeze(1)

        return generated_sequence



In [None]:
# Example usage
embed_dim = 256
latent_dim = 512
vocab_size = 10000
max_seq_len = 50
num_heads = 8
num_layers = 6

model = FNetTextGenerator(embed_dim, latent_dim, vocab_size, max_seq_len, num_heads, num_layers).to(device)

# Dummy input
inputs = torch.randint(0, vocab_size, (2, max_seq_len)).to(device)

# Forward pass for text generation
generated_sequence = model(inputs)
print("Generated Sequence:", generated_sequence)

TypeError: ignored

TypeError: ignored