In [1]:
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.loggers import TensorBoardLogger
import pytorch_lightning as pl
import torch
from iterativenn.utils.DataModules import MNISTRepeatedSequenceDataModule
from iterativenn.nn_modules.Sequential2D import Sequential2D
from iterativenn.lit_modules.MNISTMLP import MNISTModel
from iterativenn.lit_modules import IteratedModel
from iterativenn.utils.logger_factory import LoggerFacade
import logging
import warnings

In [2]:
global_max_epochs = 3
global_optimizer = 'SGD'

In [3]:
def factory_run(cfg, name):
    log_name = name
    logger = TensorBoardLogger("outputs", name=log_name, version='main')
    logger = LoggerFacade(logger, 'tensorboard', 'info')
    sequential2D = Sequential2D.from_config(cfg["sequential2D"])
    callbacks = IteratedModel.ConfigCallbacks(cfg["callbacks"])
    model = IteratedModel.IteratedModel(sequential2D,
                                        callbacks,
                                        optimizer=global_optimizer)
    data_module = MNISTRepeatedSequenceDataModule(min_copies=2, max_copies=2, seed=1234)
    # This can be used to remove all of the extra outut from the training
    logging.getLogger("pytorch_lightning").setLevel(logging.ERROR)
    # Initialize a trainer
    trainer = Trainer(
        accelerator='auto',
        devices=1 if torch.cuda.is_available() else None,  # limiting got iPython runs
        max_epochs=global_max_epochs,
        log_every_n_steps=1,
        enable_progress_bar=False,
        logger=logger,
    )

    with torch.no_grad():
        data_module.prepare_data()
        data_module.setup('fit')
        batch = next(iter(data_module.train_dataloader()))
        loss = model.training_step(batch, 0, do_logging=False)
        print(f"loss before training: {loss}")

    with warnings.catch_warnings():
        # There are warning that I dont' care about at this moment and are not relevant to the example.
        warnings.simplefilter("ignore")
        trainer.fit(model, data_module)

    with torch.no_grad():
        data_module.prepare_data()
        data_module.setup('fit')
        batch = next(iter(data_module.train_dataloader()))
        loss = model.training_step(batch, 0, do_logging=False)
        print(f"loss after training: {loss}")

    return sequential2D

In [4]:
cfg = {
    "sequential2D": {
        "in_features_list": [28*28, 110, 10], 
        "out_features_list": [28*28, 110, 10], 

        "block_types": [
            [None, 'Linear', None],
            [None, None, 'Linear'],
            [None, None, None],
        ],
        "block_kwargs": [
            [None, None, None],
            [None, None, None],
            [None, None, None],
        ]
    },
    "callbacks": {
        "loss": {
            "func": "CrossEntropyLoss",
            "idx_list" : range(28*28+110, 28*28+110+10),
            "sequence_position": 'last',
        },
        "initialization": {
            "func": "zeros",
            "size": 28*28+110+10,
        },
        "data": {
            "func": "insert",
            "idx_list": range(28*28),
            "flatten_input": True,            
        },
        "output": {
            "func": "max",
            "idx_list" : range(28*28+110, 28*28+110+10)
        },
        "optimization": {
            "func": "customLR",
            "block_lr": [
                [0.0, 0.02, 0.0],
                [0.0, 0.0, 0.02],
                [0.0, 0.0, 0.0],
        ],
        },
    }
}

bigrat0 = factory_run(cfg, "factory_MLP_tmp")

loss before training: 1.1714972257614136
loss after training: 1.1306164264678955


In [5]:
rat1 = factory_run(cfg, "factory_MLP")

loss before training: 1.1547160148620605
loss after training: 1.11774742603302


In [6]:
torch.save(rat1, "previous_model.pt")
rat1 = torch.load("previous_model.pt")
rat1_copy = torch.load("previous_model.pt")

In [7]:
default_block_kwargs = {'block_type':'W', 'initialization_type':'G=0.0,0.0', 'trainable':True, 'bias':False}


cfg = {
    "sequential2D": {
        "in_features_list": [28*28+100+10, 10], 
        "out_features_list": [28*28+100+10, 10], 
        "block_types": [
            ['Module', 'MaskedLinear.from_description'],
            ['MaskedLinear.from_description', 'MaskedLinear.from_description'],
        ],

        "block_kwargs": [
            [{'module':rat1}, default_block_kwargs],
            [default_block_kwargs, default_block_kwargs],
        ],
    },
    "callbacks": {
        "loss": {
            "func": "CrossEntropyLoss",
            "idx_list" : range(28*28+100, 28*28+100+10),
            "sequence_position": 'last',
        },
        "initialization": {
            "func": "zeros",
            "size": 28*28+100+10+10,
        },
        "data": {
            "func": "insert",
            "idx_list": range(28*28),
            "flatten_input": True,
        },
        "output": {
            "func": "max",
            "idx_list" : range(28*28+100, 28*28+100+10)
        },
    }
}

In [8]:
rat2 = factory_run(cfg, "grow_MLP_rat2")

AssertionError: 

In [None]:
default_block_kwargs = {'block_type':'W', 'initialization_type':'G=0.0,0.0', 'trainable':True, 'bias':False}

cfg = {
    "sequential2D": {
        "in_features_list": [28*28, 100, 10, 10],
        "out_features_list": [28*28, 100, 10, 10], 
        
        "block_types": [
            [None, 'Module', 'Linear', 'MaskedLinear.from_description'],
            [None, None, 'Module', 'MaskedLinear.from_description'],
            [None, None, None, 'MaskedLinear.from_description'],
            ['MaskedLinear.from_description', 'MaskedLinear.from_description', 'MaskedLinear.from_description', 'MaskedLinear.from_description']
        ],
        "block_kwargs": [
            [None, {'module': rat1_copy.blocks[f'(0, 1)']}, default_block_kwargs,  default_block_kwargs],
            [None, None, {'module': rat1_copy.blocks[f'(1, 2)']}, default_block_kwargs],
            [None, None, None, default_block_kwargs],
            [default_block_kwargs, default_block_kwargs, default_block_kwargs, default_block_kwargs]
        ]
        
    },
    "callbacks": {
        "loss": {
            "func": "CrossEntropyLoss",
            "idx_list" : range(28*28+100, 28*28+100+10),
            "sequence_position": 'last',
        },
        "initialization": {
            "func": "zeros",
            "size": 28*28+100+10+10,
        },
        "data": {
            "func": "insert",
            "idx_list": range(28*28),
            "flatten_input": True,
        },
        "output": {
            "func": "max",
            "idx_list" : range(28*28+100, 28*28+100+10)
        },
    }

}

In [None]:
rat3 = factory_run(cfg, "grow_MLP_rat3")

In [None]:
rat1_1 = torch.load("previous_model.pt")

default_block_kwargs = {'block_type':'W', 'initialization_type':'G=0.0,0.1', 'trainable':True, 'bias':False}

cfg = {
    "sequential2D": {
        "in_features_list": [28*28, 100, 10, 10],
        "out_features_list": [28*28, 100, 10, 10], 
        
        "block_types": [
            [None, 'Module', None, 'MaskedLinear.from_description'],
            [None, None, 'Module', 'MaskedLinear.from_description'],
            [None, None, None, 'MaskedLinear.from_description'],
            ['MaskedLinear.from_description', 'MaskedLinear.from_description', 'MaskedLinear.from_description', 'MaskedLinear.from_description']
        ],
        "block_kwargs": [
            [None, {'module': rat1_1.blocks[f'(0, 1)']}, default_block_kwargs,  default_block_kwargs],
            [None, None, {'module': rat1_1.blocks[f'(1, 2)']}, default_block_kwargs],
            [None, None, None, default_block_kwargs],
            [default_block_kwargs, default_block_kwargs, default_block_kwargs, default_block_kwargs]
        ]
        
    },
    "callbacks": {
        "loss": {
            "func": "CrossEntropyLoss",
            "idx_list" : range(28*28+100, 28*28+100+10),
            "sequence_position": 'last',
        },
        "initialization": {
            "func": "zeros",
            "size": 28*28+100+10+10,
        },
        "data": {
            "func": "insert",
            "idx_list": range(28*28),
            "flatten_input": True,
        },
        "output": {
            "func": "max",
            "idx_list" : range(28*28+100, 28*28+100+10)
        },
    }

}

rat3_1 = factory_run(cfg, "grow_MLP_rat3_1")

In [None]:
rat2_1 = torch.load("previous_model.pt")
default_block_kwargs = {'block_type':'W', 'initialization_type':'G=0.0,0.1', 'trainable':True, 'bias':False}


cfg = {
    "sequential2D": {
        "in_features_list": [28*28+100+10, 10], 
        "out_features_list": [28*28+100+10, 10], 
        "block_types": [
            ['Module', 'MaskedLinear.from_description'],
            ['MaskedLinear.from_description', 'MaskedLinear.from_description'],
        ],

        "block_kwargs": [
            [{'module':rat2_1}, default_block_kwargs],
            [default_block_kwargs, default_block_kwargs],
        ],
    },
    "callbacks": {
        "loss": {
            "func": "CrossEntropyLoss",
            "idx_list" : range(28*28+100, 28*28+100+10),
            "sequence_position": 'last',
        },
        "initialization": {
            "func": "zeros",
            "size": 28*28+100+10+10,
        },
        "data": {
            "func": "insert",
            "idx_list": range(28*28),
            "flatten_input": True,
        },
        "output": {
            "func": "max",
            "idx_list" : range(28*28+100, 28*28+100+10)
        },
    }
}
rat2_1 = factory_run(cfg, "grow_MLP_rat2_1")

In [None]:
rat2_1_skip = torch.load("previous_model.pt")

default_block_kwargs = {'block_type':'W', 'initialization_type':'G=0.0,0.1', 'trainable':True, 'bias':False}

cfg = {
    "sequential2D": {
        "in_features_list": [28*28, 100, 10, 10],
        "out_features_list": [28*28, 100, 10, 10], 
        
        "block_types": [
            [None, 'Module', 'MaskedLinear.from_description', 'MaskedLinear.from_description'],
            [None, None, 'Module', 'MaskedLinear.from_description'],
            [None, None, None, 'MaskedLinear.from_description'],
            ['MaskedLinear.from_description', 'MaskedLinear.from_description', 'MaskedLinear.from_description', 'MaskedLinear.from_description']
        ],
        "block_kwargs": [
            [None, {'module': rat2_1_skip.blocks[f'(0, 1)']}, default_block_kwargs,  default_block_kwargs],
            [None, None, {'module': rat2_1_skip.blocks[f'(1, 2)']}, default_block_kwargs],
            [None, None, None, default_block_kwargs],
            [default_block_kwargs, default_block_kwargs, default_block_kwargs, default_block_kwargs]
        ]
        
    },
    "callbacks": {
        "loss": {
            "func": "CrossEntropyLoss",
            "idx_list" : range(28*28+100, 28*28+100+10),
            "sequence_position": 'last',
        },
        "initialization": {
            "func": "zeros",
            "size": 28*28+100+10+10,
        },
        "data": {
            "func": "insert",
            "idx_list": range(28*28),
            "flatten_input": True,
        },
        "output": {
            "func": "max",
            "idx_list" : range(28*28+100, 28*28+100+10)
        },
    }

}


rat2_1_1 = factory_run(cfg, "grow_MLP_rat2_1_skip")