<a href="https://colab.research.google.com/github/rahulsm27/ML/blob/main/Text_Generation_using_Fnet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## 1. Install libraries and import

In [1]:
!pip install datasets
!pip install torch[transformers]

Collecting datasets
  Downloading datasets-2.15.0-py3-none-any.whl (521 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m521.2/521.2 kB[0m [31m9.6 MB/s[0m eta [36m0:00:00[0m
Collecting pyarrow-hotfix (from datasets)
  Downloading pyarrow_hotfix-0.5-py3-none-any.whl (7.8 kB)
Collecting dill<0.3.8,>=0.3.0 (from datasets)
  Downloading dill-0.3.7-py3-none-any.whl (115 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m115.3/115.3 kB[0m [31m18.8 MB/s[0m eta [36m0:00:00[0m
Collecting multiprocess (from datasets)
  Downloading multiprocess-0.70.15-py310-none-any.whl (134 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.8/134.8 kB[0m [31m23.2 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: pyarrow-hotfix, dill, multiprocess, datasets
Successfully installed datasets-2.15.0 dill-0.3.7 multiprocess-0.70.15 pyarrow-hotfix-0.5


In [2]:
import torch

## 2. Loading Data

In [3]:
from datasets import load_dataset
datasets = load_dataset('wikitext','wikitext-2-raw-v1')

Downloading builder script:   0%|          | 0.00/8.48k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/6.84k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/9.62k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/4.72M [00:00<?, ?B/s]

Generating test split:   0%|          | 0/4358 [00:00<?, ? examples/s]

Generating train split:   0%|          | 0/36718 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/3760 [00:00<?, ? examples/s]

In [4]:
datasets

DatasetDict({
    test: Dataset({
        features: ['text'],
        num_rows: 4358
    })
    train: Dataset({
        features: ['text'],
        num_rows: 36718
    })
    validation: Dataset({
        features: ['text'],
        num_rows: 3760
    })
})

## 3. PREPROCESSING DATA

In [5]:
import re
def preprocess_text(sentence):
  text = sentence['text'].lower() # lowering the sentence and storing in text vaiable
  text = re.sub('[^a-z?!.,]',' ',text) # removing other than characters and punctuations
  text = re.sub('\s\s+',' ',text) # removing double spaces
  sentence['text'] = text
  return sentence

In [6]:
datasets['train'] = datasets['train'].map(preprocess_text)
datasets['test'] = datasets['test'].map(preprocess_text)
datasets['validation'] = datasets['validation'].map(preprocess_text)

Map:   0%|          | 0/36718 [00:00<?, ? examples/s]

Map:   0%|          | 0/4358 [00:00<?, ? examples/s]

Map:   0%|          | 0/3760 [00:00<?, ? examples/s]

In [7]:
datasets['train'] = datasets['train'].filter(lambda x : len(x['text']) > 20)
datasets['test'] = datasets['test'].filter(lambda x : len(x['text']) > 20)
datasets['validation'] = datasets['validation'].filter(lambda x : len(x['text']) > 20)

Filter:   0%|          | 0/36718 [00:00<?, ? examples/s]

Filter:   0%|          | 0/4358 [00:00<?, ? examples/s]

Filter:   0%|          | 0/3760 [00:00<?, ? examples/s]

In [8]:
datasets

DatasetDict({
    test: Dataset({
        features: ['text'],
        num_rows: 2312
    })
    train: Dataset({
        features: ['text'],
        num_rows: 18794
    })
    validation: Dataset({
        features: ['text'],
        num_rows: 1988
    })
})

## 3. TOKENIZATION

In [9]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')

tokenizer_config.json:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

In [10]:
tokenizer.special_tokens_map

{'unk_token': '[UNK]',
 'sep_token': '[SEP]',
 'pad_token': '[PAD]',
 'cls_token': '[CLS]',
 'mask_token': '[MASK]'}

In [11]:
vocab_size = len(tokenizer.vocab)

In [12]:
def tokenize(sentence):
  sentence = tokenizer(sentence['text'],truncation = True)

  return sentence

tokenized_inputs = datasets['train'].map(tokenize)

Map:   0%|          | 0/18794 [00:00<?, ? examples/s]

In [13]:
from transformers import DataCollatorWithPadding
from torch.utils.data import DataLoader

batch = 16

data_collator = DataCollatorWithPadding(tokenizer=tokenizer, padding=True)
dataloader = DataLoader(list(zip(tokenized_inputs['input_ids'], tokenized_inputs['attention_mask'])),batch_size=batch,
                        collate_fn=data_collator)

## 4. MODEL

In [15]:
import torch.nn as nn
import torch.fft as fft

In [None]:
class EncoderFnet(nn.module):

  def __init__(self,vocab_size,max_seq_len, embed_dim, dense_dim, num_heads):
    super(EncoderFnet,self).__init__()
    self.embed_dim = embed_dim
    self.dense_dim = dense_dim
    self.num_heads = num_heads


    self.positional_embedding = nn.Embedding(max_seq_len, embed_dim)
    self.embedding = nn.Embedding(vocab_size,embed_dim)
    self.dense_proj = nn.Sequential (nn.Linear(embed_dim,dense_dim), nn.Relu(), nn.Linear(dense)dim,embed_dim)

    self.layernorm_1 = nn.LayerNorm(embed_dim)
    self.layernorm_2 = nn.LayerNorm(embed_dim)

  def forward(self,inputs):
    # embeddign the inputs
    embedded_inputs = self.embedding(inputs)

    # Adding positional embeddings
    seq_length = inputs.size(1)
    positions = torch.arange(0, seq_length, device=inputs.device).unsqueeze(0)
    positional_embeddings = self.positional_embedding(positions)

    # Summing token embeddings and positional embeddings
    embedded_inputs = embedded_inputs + positional_embeddings

    #castinn the inputs to compex
    inp_complex = torch.view_as_complex(embedded_inputs)

    # getting fourier transform
    fft_result = fft.ftt2(inp_complex)

    #taking real part
    fft_real = fft_result.real.float()

    proj_input = self.layernorm_1 (embedded_inputs + fft_real)
    proj_output = self.dense.proj(proj_input)
    return self.layernrm_2(proj_input +proj_output)



