<a href="https://colab.research.google.com/github/vishal-burman/PyTorch-Architectures/blob/master/modeling_FNet/test_sample_FNet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
! nvidia-smi

In [None]:
! pip install datasets
! pip install transformers

In [None]:
! rm -rf PyTorch-Architectures/
! git clone https://github.com/vishal-burman/PyTorch-Architectures.git
%cd PyTorch-Architectures/

In [2]:
import time
from datasets import load_dataset
import torch
from toolkit.custom_dataset import DataLoaderTextClassification
from toolkit.metrics import compute_accuracy
from transformers import DistilBertTokenizer
from modeling_FNet.model import FNetClassify
from modeling_FNet.config import FNetConfig

In [3]:
# HyperParameter section
MAX_INP_LEN = 32
BS = 64
LR = 3e-5
EPOCHS = 5

In [4]:
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')

In [None]:
train_loader = DataLoaderTextClassification(tokenizer=tokenizer, max_input_length=MAX_INP_LEN).return_dataloader(batch_size=BS, shuffle=True)
valid_loader = DataLoaderTextClassification(tokenizer=tokenizer, max_input_length=MAX_INP_LEN, train=False).return_dataloader(batch_size=BS, shuffle=False)

In [6]:
print('Length of Train Loader: ', len(train_loader))
print('Length of Valid Loader: ', len(valid_loader))

# Sanity check train_loaders:
for sample in train_loader:
  assert sample['input_ids'].size(0) == sample['labels'].size(0)
  break

Length of Train Loader:  1053
Length of Valid Loader:  14


In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
config = FNetConfig()
model = FNetClassify(config)
model.to(device)

In [8]:
params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print('Trainable Parameters: ', params)

Trainable Parameters:  49640706


In [9]:
optimizer = torch.optim.AdamW(model.parameters(), lr=LR)

In [10]:
# Sanity check model forward pass
model.eval()
with torch.set_grad_enabled(False):
  input_ids = sample['input_ids'].to(device)
  labels = sample['labels'].to(device)
  outputs = model(input_ids=input_ids, labels=labels)
  print(outputs[0].item(), outputs[1].shape)

64.25020599365234 torch.Size([64, 2])


In [11]:
start_time = time.time()
for epoch in range(EPOCHS):
  model.train()
  for idx, sample in enumerate(train_loader):
    input_ids = sample['input_ids'].to(device)
    labels = sample['labels'].to(device)
    outputs = model(input_ids=input_ids, labels=labels)

    optimizer.zero_grad()
    loss = outputs[0]
    loss.backward()
    optimizer.step()

    if idx % 500 == 0:
      print('Epoch: %04d/%04d || Batch: %04d/%04d || Loss: %.2f' % (epoch+1,
                                                                    EPOCHS,
                                                                    idx,
                                                                    len(train_loader),
                                                                    loss.item()))
  model.eval()
  with torch.set_grad_enabled(False):
    train_acc = compute_accuracy(model, train_loader, device)
    valid_acc = compute_accuracy(model, valid_loader, device)
  print('Train Accuracy: %.2f%% || Valid Accuracy: %.2f%%' % (train_acc.item(),
                                                              valid_acc.item()))
  epoch_elapsed_time = (time.time() - start_time) / 60
  print('Epoch Elapsed Time: %.2f min' % (epoch_elapsed_time))
total_training_time = (time.time() - start_time) / 60
print('Total Training Time: %.2f min' % (total_training_time))

Epoch: 0001/0005 || Batch: 0000/1053 || Loss: 69.94
Epoch: 0001/0005 || Batch: 0500/1053 || Loss: 27.37
Epoch: 0001/0005 || Batch: 1000/1053 || Loss: 12.73
Train Accuracy: 59.08% || Valid Accuracy: 55.62%
Epoch Elapsed Time: 2.96 min
Epoch: 0002/0005 || Batch: 0000/1053 || Loss: 14.11
Epoch: 0002/0005 || Batch: 0500/1053 || Loss: 3.01
Epoch: 0002/0005 || Batch: 1000/1053 || Loss: 0.75
Train Accuracy: 58.54% || Valid Accuracy: 54.13%
Epoch Elapsed Time: 5.93 min
Epoch: 0003/0005 || Batch: 0000/1053 || Loss: 1.27
Epoch: 0003/0005 || Batch: 0500/1053 || Loss: 0.71
Epoch: 0003/0005 || Batch: 1000/1053 || Loss: 0.71
Train Accuracy: 59.47% || Valid Accuracy: 56.88%
Epoch Elapsed Time: 8.91 min
Epoch: 0004/0005 || Batch: 0000/1053 || Loss: 0.68
Epoch: 0004/0005 || Batch: 0500/1053 || Loss: 0.65
Epoch: 0004/0005 || Batch: 1000/1053 || Loss: 0.72
Train Accuracy: 58.03% || Valid Accuracy: 54.01%
Epoch Elapsed Time: 11.89 min
Epoch: 0005/0005 || Batch: 0000/1053 || Loss: 0.66
Epoch: 0005/0005 || 