# FNet

## 0. Paper

### Info
* Title: FNet: Mixing Tokens with Fourier Transformers
* Author: James Lee-Thorp
* Task: Natural Language Process
* Link: https://arxiv.org/abs/2105.03824


### Features
* Dataset: NSMC, [link](https://github.com/e9t/nsmc)
* Tokenizer: SKT KoBERT, [link](https://github.com/SKTBrain/KoBERT)

### Reference
* https://github.com/rishikksh20/FNet-pytorch
* https://github.com/codertimo/BERT-pytorch


## 1. Setting

In [1]:
!pip install -q sentencepiece

[?25l[K     |▎                               | 10kB 26.9MB/s eta 0:00:01[K     |▌                               | 20kB 19.6MB/s eta 0:00:01[K     |▉                               | 30kB 15.6MB/s eta 0:00:01[K     |█                               | 40kB 14.3MB/s eta 0:00:01[K     |█▍                              | 51kB 7.5MB/s eta 0:00:01[K     |█▋                              | 61kB 8.8MB/s eta 0:00:01[K     |██                              | 71kB 8.4MB/s eta 0:00:01[K     |██▏                             | 81kB 9.2MB/s eta 0:00:01[K     |██▌                             | 92kB 8.7MB/s eta 0:00:01[K     |██▊                             | 102kB 7.3MB/s eta 0:00:01[K     |███                             | 112kB 7.3MB/s eta 0:00:01[K     |███▎                            | 122kB 7.3MB/s eta 0:00:01[K     |███▌                            | 133kB 7.3MB/s eta 0:00:01[K     |███▉                            | 143kB 7.3MB/s eta 0:00:01[K     |████                   

In [2]:
import os
from glob import glob
from tqdm.auto import tqdm

import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torchsummary import torchsummary

import sentencepiece as spm

In [3]:
class CONFIG:
    model_dim = 256
    hidden_dim = 256
    maxlen = 100
    n_layer = 6
    batch_size = 128
    epoch_size = 10
    dataset = 'nsmc'
    base_dir = '/content/drive/Shared drives/Yoon/Project/Doing/Deep Learning Paper Implementation'

## 2. Data

In [4]:
class Tokenizer(object):
    def __init__(self, maxlen):
        self.maxlen = maxlen
        self._tokenizer = spm.SentencePieceProcessor()
        self._tokenizer.Load('kobert_news_wiki_ko_cased-ae5711deb3.spiece')
        self.pad_id = 1
        self.cls_id = 2
        self.sep_id = 3
        self.mask_id = 4

    def __call__(self, text):
        tokens = self._tokenizer.encode(text)
        tokens = [self.cls_id] + tokens
        if len(tokens) >= self.maxlen:
            tokens = tokens[:self.maxlen]
        else:
            tokens += [self.pad_id] * (self.maxlen - len(tokens))
        return tokens


class Dataset(torch.utils.data.Dataset):
    def __init__(self, mode, maxlen):
        self.data = pd.read_table(f'/content/data/ratings_{mode}.txt')
        self.data.dropna(inplace=True)
        self.tokenizer = Tokenizer(maxlen)
        self.data['token'] = self.data['document'].apply(self.tokenizer)
        self.vocab_size = self.tokenizer._tokenizer.vocab_size()

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return np.array(self.data.iloc[idx]['token']), self.data.iloc[idx]['label']

In [5]:
!wget https://kobert.blob.core.windows.net/models/kobert/tokenizer/kobert_news_wiki_ko_cased-ae5711deb3.spiece

--2021-05-24 12:58:32--  https://kobert.blob.core.windows.net/models/kobert/tokenizer/kobert_news_wiki_ko_cased-ae5711deb3.spiece
Resolving kobert.blob.core.windows.net (kobert.blob.core.windows.net)... 52.239.190.132
Connecting to kobert.blob.core.windows.net (kobert.blob.core.windows.net)|52.239.190.132|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 371427 (363K) [application/octet-stream]
Saving to: ‘kobert_news_wiki_ko_cased-ae5711deb3.spiece’


2021-05-24 12:58:33 (740 KB/s) - ‘kobert_news_wiki_ko_cased-ae5711deb3.spiece’ saved [371427/371427]



In [6]:
data_path = os.path.join(CONFIG.base_dir, 'data', CONFIG.dataset)
!unzip -q "{data_path}" -d 'data'

In [7]:
train_data = Dataset('train', maxlen=CONFIG.maxlen)

indices = np.random.permutation(len(train_data))
train_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[30000:])
valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[:30000])
train_loader = torch.utils.data.DataLoader(train_data, sampler=train_sampler, batch_size=CONFIG.batch_size)
valid_loader = torch.utils.data.DataLoader(train_data, sampler=valid_sampler, batch_size=CONFIG.batch_size)

test_data = Dataset('test', maxlen=CONFIG.maxlen)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=CONFIG.batch_size, shuffle=False)

In [8]:
train_data.data.head()

Unnamed: 0,id,document,label,token
0,9976970,아 더빙.. 진짜 짜증나네요 목소리,0,"[2, 3093, 1698, 6456, 54, 54, 4368, 4396, 7316..."
1,3819312,흠...포스터보고 초딩영화줄....오버연기조차 가볍지 않구나,1,"[2, 517, 7989, 55, 7728, 6686, 6366, 4501, 595..."
2,10265843,너무재밓었다그래서보는것을추천한다,0,"[2, 1458, 7191, 0, 6888, 5540, 6553, 6369, 539..."
3,9045019,교도소 이야기구먼 ..솔직히 재미는 없다..평점 조정,0,"[2, 1103, 5859, 6607, 3714, 5495, 6184, 517, 5..."
4,6483659,사이몬페그의 익살스런 연기가 돋보였던 영화!스파이더맨에서 늙어보이기만 했던 커스틴 ...,1,"[2, 2618, 6220, 7712, 5538, 7095, 3757, 6519, ..."


In [9]:
x, y = next(iter(train_loader))
x.size(), y.size()

(torch.Size([128, 100]), torch.Size([128]))

## 3. Model

In [10]:
class Embedding(nn.Module):
    def __init__(self, vocab_size, model_dim, maxlen, drop_rate=0.1):
        super(Embedding, self).__init__()
        self.token_embedding = nn.Embedding(vocab_size, model_dim, padding_idx=1)
        self.pos_embedding = nn.Parameter(torch.rand(maxlen, model_dim))
        self.dropout = nn.Dropout(drop_rate)

    def forward(self, x):
        token = self.token_embedding(x)
        pos = self.pos_embedding[:x.size(1), :].unsqueeze(0)
        embed = token + pos
        embed = self.dropout(embed)
        return embed

class Fourier(nn.Module):
    def forward(self, x):
        return torch.fft.fft(torch.fft.fft(x, dim=-1), dim=-2).real

class FeedForward(nn.Module):
    def __init__(self, model_dim, hidden_dim, drop_rate):
        super(FeedForward, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(model_dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(drop_rate),
            nn.Linear(hidden_dim, model_dim),
            nn.Dropout(drop_rate)
        )
    
    def forward(self, x):
        return x

class FNetBlock(nn.Module):
    def __init__(self, model_dim, hidden_dim, drop_rate=0.1):
        super(FNetBlock, self).__init__()
        self.fourier = Fourier()
        self.norm1 = nn.LayerNorm(model_dim)
        self.ff = FeedForward(model_dim, hidden_dim, drop_rate)
        self.norm2 = nn.LayerNorm(model_dim)

    def forward(self, x):
        x = x + self.fourier(x)
        x = self.norm1(x)
        x = x + self.ff(x)
        x = self.norm2(x)
        return x

class FNet(nn.Module):
    def __init__(self, model_dim, hidden_dim, vocab_size, maxlen, n_layer, n_class, drop_rate=0.1):
        super(FNet, self).__init__()
        self.embedding = Embedding(vocab_size, model_dim, maxlen, drop_rate)
        self.blocks = nn.Sequential(*[FNetBlock(model_dim, hidden_dim, drop_rate) for _ in range(n_layer)])
        self.classifier = nn.Linear(model_dim, n_class)
    
    def forward(self, x):
        x = self.embedding(x)
        x = self.blocks(x)
        x = x[:, 0, :]
        x = self.classifier(x)
        return x

## 4. Experiment

In [11]:
class AverageMeter(object):
    def __init__(self, name):
        self.name = name
        self.reset()

    def reset(self):
        self.sum = 0
        self.count = 0
        self.avg = 0

    def update(self, val, n=1):
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = f'{self.name:10s} {self.avg:.3f}'
        return fmtstr


class ProgressMeter(object):
    def __init__(self, meters):
        self.meters = [AverageMeter(m) for m in meters]
    
    def reset(self):
        for m in self.meters:
            m.reset()
    
    def update(self, values, n=1):
        for m, v in zip(self.meters, values):
            m.update(v, n)
            self.__setattr__(m.name, m.avg)

    def log(self):
        msg = [str(meter) for meter in self.meters]
        msg = ' | '.join(msg)
        return msg


def accuracy(logits, targets):
    _, preds = logits.max(1)
    acc = (preds == targets).float().mean()
    return acc

In [12]:
class Trainer(object):
    def __init__(self, model, criterion, optimizer, scheduler, device):
        self.model = model.to(device)
        self.criterion = criterion.to(device)
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.device = device
        self.best_epoch, self.best_score = 0, 0
        

    def train(self, train_loader, epoch):
        progress = ProgressMeter(["train_loss", "train_acc"])
        self.model.train()

        pbar = tqdm(train_loader)
        pbar.set_description(f'TRAIN {epoch:03d}')
        for idx, (inputs, targets) in enumerate(pbar):
            inputs, targets = inputs.to(self.device), targets.to(self.device)
            outputs = self.model(inputs)
            loss = self.criterion(outputs, targets)
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            loss = loss.item()
            acc = accuracy(outputs, targets).item()
            progress.update([loss, acc], n=inputs.size(0))
            pbar.set_postfix(log=progress.log())

        if self.scheduler:
            self.scheduler.step()

    
    def validate(self, valid_loader, epoch):
        progress = ProgressMeter(["valid_loss", "valid_acc"])
        self.model.eval()

        pbar = tqdm(valid_loader)
        pbar.set_description(f'VALID {epoch:03d}')
        with torch.no_grad():
            for idx, (inputs, targets) in enumerate(pbar):
                inputs, targets = inputs.to(self.device), targets.to(self.device)
                outputs = self.model(inputs)
                loss = self.criterion(outputs, targets)
                loss = loss.item()
                acc = accuracy(outputs, targets).item()
                progress.update([loss, acc], n=inputs.size(0))
                pbar.set_postfix(log=progress.log())

            if progress.valid_acc > self.best_score:
                self.best_epoch = epoch
                self.best_score = progress.valid_acc
                ckpt = {
                    'best_epoch': self.best_epoch,
                    'best_score': self.best_score,
                    'model_state_dict': self.model.state_dict()
                }
                torch.save(ckpt, 'ckpt.pt')

    
    def test(self, test_loader):
        progress = ProgressMeter(["test_loss", "test_acc"])
        ckpt = torch.load('ckpt.pt')
        self.model.load_state_dict(ckpt['model_state_dict'])
        self.model.eval()

        pbar = tqdm(test_loader)
        pbar.set_description(f'TEST')
        with torch.no_grad():
            for idx, (inputs, targets) in enumerate(pbar):
                inputs, targets = inputs.to(self.device), targets.to(self.device)
                outputs = self.model(inputs)
                loss = self.criterion(outputs, targets)
                loss = loss.item()
                acc = accuracy(outputs, targets).item()
                progress.update([loss, acc], n=inputs.size(0))
                pbar.set_postfix(log=progress.log())

In [13]:
model = FNet(CONFIG.model_dim, CONFIG.hidden_dim, train_data.vocab_size, CONFIG.maxlen, CONFIG.n_layer, n_class=2)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters())
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [14]:
trainer = Trainer(model, criterion, optimizer, None, device)

In [15]:
for ep in range(CONFIG.epoch_size):
    print('=' * 100)
    trainer.train(train_loader, ep)
    trainer.validate(valid_loader, ep)



HBox(children=(FloatProgress(value=0.0, max=938.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=235.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=938.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=235.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=938.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=235.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=938.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=235.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=938.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=235.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=938.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=235.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=938.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=235.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=938.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=235.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=938.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=235.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=938.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=235.0), HTML(value='')))




In [16]:
trainer.test(test_loader)

HBox(children=(FloatProgress(value=0.0, max=391.0), HTML(value='')))




In [None]:
|