In [1]:
# You need dataset (see dataset_download.txt)
# Create conda env for Python 3.7
# Run pip install -r requirements.txt
# Run pip3 install torch==1.9.0+cu111 torchvision==0.10.0+cu111 torchaudio==0.9.0 -f https://download.pytorch.org/whl/torch_stable.html
# Run:    cd to_cpp
# Run     pip install .
# Might have to do conda install mkl mkl-include

In [1]:
import torch
import pytorch_lightning as pl
from omegaconf import OmegaConf
from model import PnlpMixerSeqCls, PnlpMixerTokenCls
from mixer import FFFTrainFixed
from dataset import PnlpMixerDataModule
from run import PnlpMixerSeqClsTrainModule, PnlpMixerTokenClsTrainModule
import os
os.environ['CUDA_VISIBLE_DEVICES'] = ""
import time
import torch.nn as nn
from FFF import FFFInference
from mixer import MixerLayer
from omegaconf import OmegaConf

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
torch.cuda.is_available()

False

In [3]:
cfg = OmegaConf.load('cfg/imdb_base.yml')

In [4]:
checkpoint_path = 'to_cpp/ffft_256_4/model.ckpt'
orig_module = PnlpMixerSeqClsTrainModule.load_from_checkpoint(
    checkpoint_path,
    optimizer_cfg=cfg.train.optimizer,
    model_cfg=cfg.model
)
orig_module.to("cpu")

{'num_mixers': 2, 'max_seq_len': 1024, 'hidden_dim': 256, 'mlp_hidden_dim': 256}


PnlpMixerSeqClsTrainModule(
  (model): PnlpMixerSeqCls(
    (pnlp_mixer): PnlpMixer(
      (bottleneck): Linear(in_features=3072, out_features=256, bias=True)
      (mixer): Mixer(
        (mixers): Sequential(
          (0): MixerLayer(
            (layer_norm_1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
            (mlp_1): FFFTrainFixed(
              (linear_in): Linear(in_features=1024, out_features=31, bias=False)
              (linear_out): Linear(in_features=31, out_features=1024, bias=False)
              (activation): GELU()
            )
            (layer_norm_2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
            (mlp_2): FFFTrainFixed(
              (linear_in): Linear(in_features=256, out_features=31, bias=False)
              (linear_out): Linear(in_features=31, out_features=256, bias=False)
              (activation): GELU()
            )
          )
          (1): MixerLayer(
            (layer_norm_1): LayerNorm((256,), eps=1e-05, elementw

In [5]:
# Create a PnlpMixerDataModule instance
cfg.train.test_batch_size = 32
data_module = PnlpMixerDataModule(cfg.vocab, cfg.train, cfg.model.projection)

# Set up the data module for testing
data_module.setup('test')

# Get the test dataloader
test_dataloader = data_module.test_dataloader()

# Get the first batch from the test dataloader
batch_iterator = iter(test_dataloader)
batch = next(batch_iterator)

# Get the device of the original module
device = next(orig_module.parameters()).device
# Move each item in the batch to the device
new_batch = {}
for key, value in batch.items():
    new_batch[key] = value.to(device)

batch = new_batch



In [6]:
batch["inputs"].shape

torch.Size([32, 1024, 3072])

In [7]:
orig_module.eval()
print()




In [8]:
%%timeit
with torch.no_grad():
    logits = orig_module.model(batch['inputs'])

3.67 s ± 16.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [11]:
def create_pnlp_mixer_seq_cls(orig_module):

    bottleneck_cfg = OmegaConf.create({
        "in_features": orig_module.model.pnlp_mixer.bottleneck.in_features,
        "out_features": orig_module.model.pnlp_mixer.bottleneck.out_features,
        "window_size": (orig_module.model.pnlp_mixer.bottleneck.in_features // orig_module.model.pnlp_mixer.mixer.mixers[0].mlp_1.input_width - 1) // 2,
        "feature_size": orig_module.model.pnlp_mixer.mixer.mixers[0].mlp_1.input_width,
        "hidden_dim": orig_module.model.pnlp_mixer.bottleneck.out_features
    })

    mixer_cfg = OmegaConf.create({
        "num_mixers": len(orig_module.model.pnlp_mixer.mixer.mixers),
        "max_seq_len": orig_module.model.pnlp_mixer.mixer.mixers[0].mlp_1.input_width,
        "hidden_dim": orig_module.model.pnlp_mixer.mixer.mixers[0].mlp_2.input_width,
        "mlp_hidden_dim": orig_module.model.pnlp_mixer.mixer.mixers[0].mlp_2.output_width
    })

    sequence_cls_cfg = OmegaConf.create({
        "hidden_dim": orig_module.model.seq_cls.feature_proj.in_features,
        "proj_dim": orig_module.model.seq_cls.feature_proj.out_features,
        "num_classes": orig_module.model.seq_cls.cls_proj.out_features
    })

    new_model = PnlpMixerSeqCls(bottleneck_cfg, mixer_cfg, sequence_cls_cfg)

    new_model.load_state_dict(orig_module.model.state_dict())

    for layer in new_model.pnlp_mixer.mixer.mixers:
        layer.mlp_1 = FFFInference(layer.mlp_1)
        layer.mlp_2 = FFFInference(layer.mlp_2)

    new_module = PnlpMixerSeqClsTrainModule(
        optimizer_cfg=orig_module.optimizer_cfg,
        model_cfg=OmegaConf.create({
            "bottleneck": bottleneck_cfg,
            "mixer": mixer_cfg,
            "sequence_cls": sequence_cls_cfg
        })
    )
    new_module.model = new_model
    return new_module

new_module = create_pnlp_mixer_seq_cls(orig_module)

{'num_mixers': 2, 'max_seq_len': 1024, 'hidden_dim': 256, 'mlp_hidden_dim': 256}
{'num_mixers': 2, 'max_seq_len': 1024, 'hidden_dim': 256, 'mlp_hidden_dim': 256}


In [12]:
new_module.eval()
print()




In [14]:
%%timeit
with torch.no_grad():
    logits = new_module.model(batch['inputs'].to("cpu"))

3.46 s ± 80.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
