In [1]:
import torch
import pytorch_lightning as pl
from omegaconf import OmegaConf
from model import PnlpMixerSeqCls, PnlpMixerTokenCls
from mixer import FFFTrainFixed
from dataset import PnlpMixerDataModule
from run import PnlpMixerSeqClsTrainModule, PnlpMixerTokenClsTrainModule
import os
os.environ['CUDA_VISIBLE_DEVICES'] = ""
import time
import torch.nn as nn

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
torch.cuda.is_available()

False

In [3]:
cfg = OmegaConf.load('cfg/imdb_base.yml')

def get_module_cls(type: str): 
    if type == 'mtop': 
        return PnlpMixerTokenClsTrainModule
    if type == 'matis' or type == 'imdb': 
        return PnlpMixerSeqClsTrainModule

module_cls = get_module_cls(cfg.train.dataset_type)

In [4]:
checkpoint_path = 'to_cpp/ffft_256_4/model.ckpt'
orig_module = module_cls.load_from_checkpoint(
    checkpoint_path,
    optimizer_cfg=cfg.train.optimizer,
    model_cfg=cfg.model
)

{'num_mixers': 2, 'max_seq_len': 1024, 'hidden_dim': 256, 'mlp_hidden_dim': 256}


In [5]:
orig_module

PnlpMixerSeqClsTrainModule(
  (model): PnlpMixerSeqCls(
    (pnlp_mixer): PnlpMixer(
      (bottleneck): Linear(in_features=3072, out_features=256, bias=True)
      (mixer): Mixer(
        (mixers): Sequential(
          (0): MixerLayer(
            (layer_norm_1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
            (mlp_1): FFFTrainFixed(
              (linear_in): Linear(in_features=1024, out_features=31, bias=False)
              (linear_out): Linear(in_features=31, out_features=1024, bias=False)
              (activation): GELU()
            )
            (layer_norm_2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
            (mlp_2): FFFTrainFixed(
              (linear_in): Linear(in_features=256, out_features=31, bias=False)
              (linear_out): Linear(in_features=31, out_features=256, bias=False)
              (activation): GELU()
            )
          )
          (1): MixerLayer(
            (layer_norm_1): LayerNorm((256,), eps=1e-05, elementw

In [6]:
data_module = PnlpMixerDataModule(cfg.vocab, cfg.train, cfg.model.projection)
data_module.setup('test')

orig_module.eval()

test_dataloader = data_module.test_dataloader()

batch = next(iter(test_dataloader))

device = next(orig_module.parameters()).device
batch = {k: v.to(device) for k, v in batch.items()}

In [7]:
# import pickle
# with open("batch.pkl", "wb") as f:
#     pickle.dump(batch, f)

In [8]:
# import pickle
# with open("batch.pkl", "rb") as f:
#     batch = pickle.load(f)

In [9]:
batch["inputs"].shape

torch.Size([256, 1024, 3072])

In [10]:
with torch.no_grad():
    start = time.time()
    logits = orig_module.model(batch['inputs'])
    end = time.time()
    predictions = torch.argmax(logits, dim=1)

for i, (pred, target) in enumerate(zip(predictions, batch['targets'])):
    predicted_label = cfg.train.labels[pred.item()]
    true_label = cfg.train.labels[target.item()]
    # print(f"Sample {i}:")
    # print(f"  Predicted: {predicted_label}")
    # print(f"  Actual: {true_label}")
    # print()

accuracy = (predictions == batch['targets']).float().mean()
print(f"Batch Accuracy: {accuracy.item():.4f}")
print(f"Time: {end-start}")

TRAIN
torch.Size([256, 256, 1024])
AFTER RESHAPE
torch.Size([65536, 1024])
NEW_LOGITS
torch.Size([65536, 1024])
TRAIN
torch.Size([256, 1024, 256])
AFTER RESHAPE
torch.Size([262144, 256])
NEW_LOGITS
torch.Size([262144, 256])
TRAIN
torch.Size([256, 256, 1024])
AFTER RESHAPE
torch.Size([65536, 1024])
NEW_LOGITS
torch.Size([65536, 1024])
TRAIN
torch.Size([256, 1024, 256])
AFTER RESHAPE
torch.Size([262144, 256])
NEW_LOGITS
torch.Size([262144, 256])
Batch Accuracy: 0.8828
Time: 2.285902500152588


In [11]:
# import torch
# import torch.nn as nn
# from FFF import FFFInference
# from mixer import FFFTrainFixed

# def replace_fff_layers(module):
#     for name, child in module.named_children():
#         if isinstance(child, FFFTrainFixed):
#             new_layer = FFFInference(child)
#             setattr(module, name, new_layer)
#         elif isinstance(child, nn.Module):
#             replace_fff_layers(child)

In [12]:
import torch
import torch.nn as nn
from FFF import FFFInference
from mixer import FFFTrainFixed

def replace_fff_layers(module):
    for name, child in module.named_children():
        if isinstance(child, FFFTrainFixed):
            fixed = FFFTrainFixed(child.input_width, child.output_width, 4,)
            setattr(module, name, fixed)
        elif isinstance(child, nn.Module):
            replace_fff_layers(child)

In [13]:
mlp_1_1 = orig_module.model.pnlp_mixer.mixer.mixers[0].mlp_1
mlp_1_2 = orig_module.model.pnlp_mixer.mixer.mixers[0].mlp_2
mlp_2_1 = orig_module.model.pnlp_mixer.mixer.mixers[1].mlp_1
mlp_2_2 = orig_module.model.pnlp_mixer.mixer.mixers[1].mlp_2

new_mlp_1_1 = FFFInference(mlp_1_1)
new_mlp_1_2 = FFFInference(mlp_1_2)
new_mlp_2_1 = FFFInference(mlp_2_1)
new_mlp_2_2 = FFFInference(mlp_2_2)

train_mlp_1_1 = FFFTrainFixed(mlp_1_1.input_width, mlp_1_1.output_width, 4,)
train_mlp_1_2 = FFFTrainFixed(mlp_1_2.input_width, mlp_1_2.output_width, 4,)
train_mlp_2_1 = FFFTrainFixed(mlp_2_1.input_width, mlp_2_1.output_width, 4,)
train_mlp_2_2 = FFFTrainFixed(mlp_2_2.input_width, mlp_2_2.output_width, 4,)

orig_module.model.pnlp_mixer.mixer.mixers[0].mlp_1 = train_mlp_1_1
orig_module.model.pnlp_mixer.mixer.mixers[0].mlp_2 = train_mlp_1_2
orig_module.model.pnlp_mixer.mixer.mixers[1].mlp_1 = train_mlp_2_1
orig_module.model.pnlp_mixer.mixer.mixers[1].mlp_2 = train_mlp_2_2


In [14]:
with torch.no_grad():
    start = time.time()
    logits = orig_module.model(batch['inputs'])
    end = time.time()
    predictions = torch.argmax(logits, dim=1)

for i, (pred, target) in enumerate(zip(predictions, batch['targets'])):
    predicted_label = cfg.train.labels[pred.item()]
    true_label = cfg.train.labels[target.item()]
    # print(f"Sample {i}:")
    # print(f"  Predicted: {predicted_label}")
    # print(f"  Actual: {true_label}")
    # print()

accuracy = (predictions == batch['targets']).float().mean()
print(f"Batch Accuracy: {accuracy.item():.4f}")
print(f"Time: {end-start}")

TRAIN
torch.Size([256, 256, 1024])
AFTER RESHAPE
torch.Size([65536, 1024])
NEW_LOGITS
torch.Size([65536, 1024])
TRAIN
torch.Size([256, 1024, 256])
AFTER RESHAPE
torch.Size([262144, 256])
NEW_LOGITS
torch.Size([262144, 256])
TRAIN
torch.Size([256, 256, 1024])
AFTER RESHAPE
torch.Size([65536, 1024])
NEW_LOGITS
torch.Size([65536, 1024])
TRAIN
torch.Size([256, 1024, 256])
AFTER RESHAPE
torch.Size([262144, 256])
NEW_LOGITS
torch.Size([262144, 256])
Batch Accuracy: 0.1484
Time: 37.884618043899536
