In [1]:
import torch
import pytorch_lightning as pl
from omegaconf import OmegaConf
from model import PnlpMixerSeqCls, PnlpMixerTokenCls
from mixer import FFFTrainFixed
from dataset import PnlpMixerDataModule
from run import PnlpMixerSeqClsTrainModule, PnlpMixerTokenClsTrainModule
import os
os.environ['CUDA_VISIBLE_DEVICES'] = ""
import time
import torch.nn as nn

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
torch.cuda.is_available()

False

In [3]:
cfg = OmegaConf.load('cfg/imdb_base.yml')

def get_module_cls(type: str): 
    if type == 'mtop': 
        return PnlpMixerTokenClsTrainModule
    if type == 'matis' or type == 'imdb': 
        return PnlpMixerSeqClsTrainModule

module_cls = get_module_cls(cfg.train.dataset_type)

In [4]:
checkpoint_path = 'to_cpp/ffft_256_4/model.ckpt'
orig_module = module_cls.load_from_checkpoint(
    checkpoint_path,
    optimizer_cfg=cfg.train.optimizer,
    model_cfg=cfg.model
)

{'num_mixers': 2, 'max_seq_len': 1024, 'hidden_dim': 256, 'mlp_hidden_dim': 256}


In [5]:
data_module = PnlpMixerDataModule(cfg.vocab, cfg.train, cfg.model.projection)
data_module.setup('test')

orig_module.eval()

test_dataloader = data_module.test_dataloader()

batch = next(iter(test_dataloader))

device = next(orig_module.parameters()).device
batch = {k: v.to(device) for k, v in batch.items()}

In [7]:
# import pickle
# with open("batch.pkl", "wb") as f:
#     pickle.dump(batch, f)

In [8]:
# import pickle
# with open("batch.pkl", "rb") as f:
#     batch = pickle.load(f)

In [6]:
with torch.no_grad():
    start = time.time()
    logits = orig_module.model(batch['inputs'])
    end = time.time()
    predictions = torch.argmax(logits, dim=1)

for i, (pred, target) in enumerate(zip(predictions, batch['targets'])):
    predicted_label = cfg.train.labels[pred.item()]
    true_label = cfg.train.labels[target.item()]
    # print(f"Sample {i}:")
    # print(f"  Predicted: {predicted_label}")
    # print(f"  Actual: {true_label}")
    # print()

accuracy = (predictions == batch['targets']).float().mean()
print(f"Batch Accuracy: {accuracy.item():.4f}")
print(f"Time: {end-start}")

TRAIN
torch.Size([256, 256, 1024])
AFTER RESHAPE
torch.Size([65536, 1024])
NEW_LOGITS
torch.Size([65536, 1024])
TRAIN
torch.Size([256, 1024, 256])
AFTER RESHAPE
torch.Size([262144, 256])
NEW_LOGITS
torch.Size([262144, 256])
TRAIN
torch.Size([256, 256, 1024])
AFTER RESHAPE
torch.Size([65536, 1024])
NEW_LOGITS
torch.Size([65536, 1024])
TRAIN
torch.Size([256, 1024, 256])
AFTER RESHAPE
torch.Size([262144, 256])
NEW_LOGITS
torch.Size([262144, 256])
Batch Accuracy: 0.8828
Time: 2.0069186687469482


In [11]:
# import torch
# import torch.nn as nn
# from FFF import FFFInference
# from mixer import FFFTrainFixed

# def replace_fff_layers(module):
#     for name, child in module.named_children():
#         if isinstance(child, FFFTrainFixed):
#             new_layer = FFFInference(child)
#             setattr(module, name, new_layer)
#         elif isinstance(child, nn.Module):
#             replace_fff_layers(child)

In [12]:
import torch
import torch.nn as nn
from FFF import FFFInference
from mixer import FFFTrainFixed

def replace_fff_layers(module):
    for name, child in module.named_children():
        if isinstance(child, FFFTrainFixed):
            fixed = FFFTrainFixed(child.input_width, child.output_width, 4,)
            setattr(module, name, fixed)
        elif isinstance(child, nn.Module):
            replace_fff_layers(child)

In [13]:
mlp_1_1 = orig_module.model.pnlp_mixer.mixer.mixers[0].mlp_1
mlp_1_2 = orig_module.model.pnlp_mixer.mixer.mixers[0].mlp_2
mlp_2_1 = orig_module.model.pnlp_mixer.mixer.mixers[1].mlp_1
mlp_2_2 = orig_module.model.pnlp_mixer.mixer.mixers[1].mlp_2

new_mlp_1_1 = FFFInference(mlp_1_1)
new_mlp_1_2 = FFFInference(mlp_1_2)
new_mlp_2_1 = FFFInference(mlp_2_1)
new_mlp_2_2 = FFFInference(mlp_2_2)

train_mlp_1_1 = FFFTrainFixed(mlp_1_1.input_width, mlp_1_1.output_width, 4,)
train_mlp_1_2 = FFFTrainFixed(mlp_1_2.input_width, mlp_1_2.output_width, 4,)
train_mlp_2_1 = FFFTrainFixed(mlp_2_1.input_width, mlp_2_1.output_width, 4,)
train_mlp_2_2 = FFFTrainFixed(mlp_2_2.input_width, mlp_2_2.output_width, 4,)

orig_module.model.pnlp_mixer.mixer.mixers[0].mlp_1 = train_mlp_1_1
orig_module.model.pnlp_mixer.mixer.mixers[0].mlp_2 = train_mlp_1_2
orig_module.model.pnlp_mixer.mixer.mixers[1].mlp_1 = train_mlp_2_1
orig_module.model.pnlp_mixer.mixer.mixers[1].mlp_2 = train_mlp_2_2


In [14]:
with torch.no_grad():
    start = time.time()
    logits = orig_module.model(batch['inputs'])
    end = time.time()
    predictions = torch.argmax(logits, dim=1)

for i, (pred, target) in enumerate(zip(predictions, batch['targets'])):
    predicted_label = cfg.train.labels[pred.item()]
    true_label = cfg.train.labels[target.item()]
    # print(f"Sample {i}:")
    # print(f"  Predicted: {predicted_label}")
    # print(f"  Actual: {true_label}")
    # print()

accuracy = (predictions == batch['targets']).float().mean()
print(f"Batch Accuracy: {accuracy.item():.4f}")
print(f"Time: {end-start}")

TRAIN
torch.Size([256, 256, 1024])
AFTER RESHAPE
torch.Size([65536, 1024])
NEW_LOGITS
torch.Size([65536, 1024])
TRAIN
torch.Size([256, 1024, 256])
AFTER RESHAPE
torch.Size([262144, 256])
NEW_LOGITS
torch.Size([262144, 256])
TRAIN
torch.Size([256, 256, 1024])
AFTER RESHAPE
torch.Size([65536, 1024])
NEW_LOGITS
torch.Size([65536, 1024])
TRAIN
torch.Size([256, 1024, 256])
AFTER RESHAPE
torch.Size([262144, 256])
NEW_LOGITS
torch.Size([262144, 256])
Batch Accuracy: 0.1484
Time: 37.884618043899536


In [8]:
from FFF import FFFInference
from mixer import MixerLayer
import torch.nn as nn
from omegaconf import OmegaConf

class PnlpMixerSeqClsTrainModule(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model

def create_pnlp_mixer_seq_cls(orig_module):

    bottleneck_cfg = OmegaConf.create({
        "window_size": (orig_module.model.pnlp_mixer.bottleneck.in_features // orig_module.model.pnlp_mixer.mixer.mixers[0].mlp_1.input_width - 1) // 2,
        "feature_size": orig_module.model.pnlp_mixer.mixer.mixers[0].mlp_1.input_width,
        "hidden_dim": orig_module.model.pnlp_mixer.bottleneck.out_features
    })

    mixer_cfg = OmegaConf.create({
        "num_mixers": len(orig_module.model.pnlp_mixer.mixer.mixers),
        "max_seq_len": orig_module.model.pnlp_mixer.mixer.mixers[0].mlp_1.input_width,
        "hidden_dim": orig_module.model.pnlp_mixer.mixer.mixers[0].mlp_2.input_width,
        "mlp_hidden_dim": orig_module.model.pnlp_mixer.mixer.mixers[0].mlp_2.output_width
    })

    seq_cls_cfg = OmegaConf.create({
        "hidden_dim": orig_module.model.seq_cls.feature_proj.in_features,
        "proj_dim": orig_module.model.seq_cls.feature_proj.out_features,
        "num_classes": orig_module.model.seq_cls.cls_proj.out_features
    })

    new_model = PnlpMixerSeqCls(
        bottleneck_cfg=bottleneck_cfg,
        mixer_cfg=mixer_cfg,
        seq_cls_cfg=seq_cls_cfg
    )

    new_model.pnlp_mixer.bottleneck = orig_module.model.pnlp_mixer.bottleneck

    new_mixer_layers = []
    for orig_layer in orig_module.model.pnlp_mixer.mixer.mixers:
        new_layer = MixerLayer(
            max_seq_len=mixer_cfg.max_seq_len,
            hidden_dim=mixer_cfg.hidden_dim,
            channel_hidden_dim=mixer_cfg.mlp_hidden_dim,
            seq_hidden_dim=mixer_cfg.max_seq_len
        )
        new_layer.layer_norm_1 = orig_layer.layer_norm_1
        new_layer.mlp_1 = FFFInference(orig_layer.mlp_1)
        new_layer.layer_norm_2 = orig_layer.layer_norm_2
        new_layer.mlp_2 = FFFInference(orig_layer.mlp_2)
        new_mixer_layers.append(new_layer)

    new_model.pnlp_mixer.mixer.mixers = nn.Sequential(*new_mixer_layers)

    new_model.seq_cls = orig_module.model.seq_cls

    new_train_module = PnlpMixerSeqClsTrainModule(new_model)

    return new_train_module

# Usage:
new_module = create_pnlp_mixer_seq_cls(orig_module)

{'num_mixers': 2, 'max_seq_len': 1024, 'hidden_dim': 256, 'mlp_hidden_dim': 256}


In [23]:
import pickle
with open("batch.pkl", "rb") as f:
    batch = pickle.load(f)

In [10]:
orig_module.eval()

with torch.no_grad():
    start = time.time()
    logits = orig_module.model(batch['inputs'])
    end = time.time()
    predictions = torch.argmax(logits, dim=1)

for i, (pred, target) in enumerate(zip(predictions, batch['targets'])):
    predicted_label = cfg.train.labels[pred.item()]
    true_label = cfg.train.labels[target.item()]
    # print(f"Sample {i}:")
    # print(f"  Predicted: {predicted_label}")
    # print(f"  Actual: {true_label}")
    # print()

accuracy = (predictions == batch['targets']).float().mean()
print(f"Batch Accuracy: {accuracy.item():.4f}")
print(f"Time: {end-start}")

TRAIN
torch.Size([256, 256, 1024])
AFTER RESHAPE
torch.Size([65536, 1024])
NEW_LOGITS
torch.Size([65536, 1024])
TRAIN
torch.Size([256, 1024, 256])
AFTER RESHAPE
torch.Size([262144, 256])
NEW_LOGITS
torch.Size([262144, 256])
TRAIN
torch.Size([256, 256, 1024])
AFTER RESHAPE
torch.Size([65536, 1024])
NEW_LOGITS
torch.Size([65536, 1024])
TRAIN
torch.Size([256, 1024, 256])
AFTER RESHAPE
torch.Size([262144, 256])
NEW_LOGITS
torch.Size([262144, 256])
Batch Accuracy: 0.8828
Time: 27.705037117004395
