In [None]:
import os
from torch import optim, nn, utils, Tensor
import torch
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
from torch.utils.data import Dataset, DataLoader
import pytorch_lightning as pl
from pytorch_lightning.callbacks import LearningRateMonitor
from leap_sd import LEAPBuffer
from hflayers import HopfieldLayer

In [None]:
size_in = 128
size_out = 500000
amount_data = 1

In [None]:
def get_datamodule(batch_size: int):
    class FakeDataset(Dataset):
        def __init__(self, amount):
            self.amount = amount
            self.x = torch.zeros(amount, size_in).uniform_(0, 1)
            self.y = torch.zeros(amount, size_out).uniform_(0, 1)

        def __getitem__(self, index):
            return self.x[index, ...], self.y[index, ...]
        
        def __len__(self):
            return self.amount

    class ImageWeights(pl.LightningDataModule):
        def __init__(self, batch_size: int):
            super().__init__()
            self.num_workers = 16
            self.batch_size = batch_size
            
        def prepare_data(self):
            pass

        def setup(self, stage):
            pass
            
        def train_dataloader(self):
            dataset = FakeDataset(amount_data)
            return DataLoader(dataset, num_workers = self.num_workers, batch_size = self.batch_size)

        def teardown(self, stage):
            # clean up after fit or test
            # called on every process in DDP
            pass
    
    dm = ImageWeights(batch_size = batch_size)
    
    return dm

@torch.no_grad()
def set_lookup_weights(hopfield, loader):
    Z = None
    for x, _ in loader:
        if Z is None:
            Z = x
        else:
            Z = torch.cat((Z, x), dim=0)
    Z = Z.unsqueeze(0)
    print("set_lookup_weights > X", Z.shape)
    hopfield.lookup_weights[:] = Z

class TestModel(pl.LightningModule):
    def __init__(self, size_in: int, size_out: int, hidden_size: int, learning_rate: float):
        super().__init__()
        self.buf = HopfieldLayer(
            input_size=size_in,
            output_size=size_out,
            hidden_size=hidden_size,
            num_heads=4,
            quantity=amount_data,
            scaling=8.0,
            dropout=0,
            lookup_weights_as_separated=True,
            lookup_targets_as_trainable=False,
            # do not pre-process layer input
            # normalize_stored_pattern=False,
            # normalize_stored_pattern_affine=False,
            # normalize_state_pattern=False,
            # normalize_state_pattern_affine=False,
            # normalize_pattern_projection=False,
            # normalize_pattern_projection_affine=False,
        )
        self.criterion = torch.nn.L1Loss()
        self.learning_rate = learning_rate

    def training_step(self, batch, batch_idx):
        x, y = batch
        result = self.buf(x.unsqueeze(1)).squeeze(1)
        loss = self.criterion(result, y)
        # Logging to TensorBoard (if installed) by default
        self.log("train_loss", loss)
        cur_lr = self.trainer.optimizers[0].param_groups[0]['lr']
        self.log("lr", cur_lr, prog_bar=True, on_step=True)
        return loss

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=self.learning_rate)
        scheduler = {
            "scheduler": torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience = 5),
            "monitor": "train_loss",
            "interval": "epoch"
        }
        return [optimizer], [scheduler]

def train():
    torch.autograd.set_detect_anomaly(True)
    torch.set_float32_matmul_precision('medium')

    hidden_size = 20
    model = TestModel(size_in, size_out, hidden_size, 1e-3)
    dm = get_datamodule(10)
    lr_monitor = LearningRateMonitor(logging_interval='step')
    from pytorch_lightning.loggers import WandbLogger
    set_lookup_weights(model.buf, dm.train_dataloader())
    trainer = pl.Trainer(auto_lr_find=True, devices=1, accelerator="gpu", callbacks = [lr_monitor], log_every_n_steps=2, max_epochs=100)
    # trainer = pl.Trainer(devices=1, accelerator="gpu", logger = WandbLogger(project="LEAP_Lora_BufferTest"), callbacks = [lr_monitor], log_every_n_steps=2, max_epochs=1000)
    # trainer.tune(model, dm)
    trainer.fit(model, dm)

In [None]:
train()