In [None]:
import torch
import pytorch_lightning as pl
from omegaconf import OmegaConf
from model import PnlpMixerSeqCls, PnlpMixerTokenCls
from mixer import FFFTrainFixed
from dataset import PnlpMixerDataModule
from run import PnlpMixerSeqClsTrainModule, PnlpMixerTokenClsTrainModule
import os
os.environ['CUDA_VISIBLE_DEVICES'] = ""
import time
import torch.nn as nn

In [None]:
cfg = OmegaConf.load('cfg/imdb_base.yml')

In [None]:
checkpoint_path = 'to_cpp/ffft_256_4/model.ckpt'
orig_module = PnlpMixerSeqClsTrainModule.load_from_checkpoint(
    checkpoint_path,
    optimizer_cfg=cfg.train.optimizer,
    model_cfg=cfg.model
)

In [None]:
import copy
from FFF import FFFInference
new_module = copy.deepcopy(orig_module)
mlp_1_1 = orig_module.model.pnlp_mixer.mixer.mixers[0].mlp_1
train_mlp_1_1 = FFFTrainFixed(mlp_1_1.input_width, mlp_1_1.output_width, 4,)
mlp_1_2 = orig_module.model.pnlp_mixer.mixer.mixers[0].mlp_2
train_mlp_1_2 = FFFTrainFixed(mlp_1_2.input_width, mlp_1_2.output_width, 4,)

new_module.model.pnlp_mixer.mixer.mixers[0].mlp_1 = FFFInference(copy.deepcopy(train_mlp_1_1))
new_module.model.pnlp_mixer.mixer.mixers[0].mlp_2 = FFFInference(copy.deepcopy(train_mlp_1_2))

In [None]:
class ThreeLayerFFFT(nn.Module):
    def __init__(self, input_size, hidden_size, num_blocks):
        super(ThreeLayerFFFT, self).__init__()
        self.layer1 = FFFTrainFixed(input_size, hidden_size, num_blocks)
        self.layer2 = FFFTrainFixed(input_size, hidden_size, num_blocks)
        self.layer3 = FFFTrainFixed(input_size, hidden_size, num_blocks)

    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        return x
    
three_layer_ffft = ThreeLayerFFFT(256, 128, 4)

in_t = torch.rand(1024)

three_layer_ffft.eval()

with torch.no_grad():
    start = time.time()
    res = three_layer_ffft(in_t)
    end = time.time()
print("původní random three_layer_ffft", end-start)

single_layer = FFFInference(three_layer_ffft.layer2)
three_layer_ffft.layer2 = single_layer

with torch.no_grad():
    start = time.time()
    res = three_layer_ffft(in_t)
    end = time.time()
print("původní ffft znovu", end-start)