In [1]:
from config import SimpleConvConfig
from models.simpleconv import SimpleConv
import torch

model_config = SimpleConvConfig(
    # Str to list of possible conditions
    mel_normalization=False,
    conditions={
        "study": [],
        "subject": [],
    },
    # Channels
    in_channels=208,
    out_channels=128,
    hidden_dim=384,
    dropout=0.2,
    initial_batch_norm=True,
    # Sensor layout settings
    layout_dim=2,
    layout_proj=True,
    layout_scaling="minmax",
    # Merger with spatial attn
    merger=False,
    merger_emb_type=None,
    merger_emb_dim=0,
    merger_channels=0,
    merger_dropout=False,
    merger_conditional=None,
    # Inital
    initial_linear=384,
    initial_depth=1,
    # Conditional layers
    conditional_layers=False,
    conditional_layers_dim=None,  # input or hidden_dim
    # Conv layer overall structure
    depth=6,
    kernel_size=3,
    growth=1.0,
    dilation_growth=2,
    dilation_period=5,
    glu=1,
    conv_dropout=0.2,
    dropout_input=0.2,
    batch_norm=True,
    half=True,
    cnn_pos_encoding=False,
    # Quantizer
    quantizer=False,
    num_codebooks=0,
    codebook_size=0,
    quantizer_commitment=0,
    quantizer_temp_init=0,
    quantizer_temp_min=0,
    quantizer_temp_decay=0,
    # Transformers Encoders
    transformer_input=None,
    transformer_encoder_emb=None,
    transformer_encoder_layers=0,
    transformer_encoder_heads=0,
    # Transformer Decoders
    transformer_decoder_emb=None,
    transformer_decoder_layers=0,
    transformer_decoder_heads=0,
    transformer_decoder_dim=0,
)

In [None]:
import typing as tp
import gc
import random
import time
from tqdm import tqdm
import typing as tp
import json
from torch.optim import AdamW
import os
import torch

from config import SimpleConvConfig, TrainingConfigV1
from train.training_session import TrainingSession
from models.whisper_alignment import WhisperAlignment

device = "cuda"


class TrainingSessionV1(TrainingSession):
    def __init__(
        self,
        config: TrainingConfigV1 = None,
        studies: tp.Dict[str, str] = None,
        data_path: str = "/home/ubuntu/brain-decoding/data",
        save_path: str = "/home/ubuntu/brain-decoding/saves",
        clear_cache: bool = False,
        max_cache_size: int = 100,
    ):
        """
        Initializes a training session with the provided configuration and data.
        This version deals with audio batches for Whisper latent alignment,
        architecture exploration, and dataset integration.

        Arguments:
            config -- The configuration for the training session.
            studies -- dict of studies, batch type. Partition policy determined in TrainingConfig
                    Batch type determines how to load data from study.

            data_path -- The path to the data directory.
            save_path -- The path to the directory where the model and logs will be saved.
            clear_cache -- Whether to clear the cache for the studies.
            max_cache_size -- The maximum number of stimulis in cache.
        """
        super().__init__(
            config=config,
            studies=studies,
            data_path=data_path,
            save_path=save_path,
            clear_cache=clear_cache,
            cache_enabled=True,
            max_cache_size=max_cache_size,
        )

        # MODEL
        self.model = WhisperAlignment(
            brain_module_config=config.brain_encoder_config,
            adalora_config=config.adalora_config,
            layers_to_align=config.latent_alignment_layers,
            use_compile=False,
        )

        self.optimizer = AdamW(
            self.model.parameters(),
            lr=config.learning_rate,
            weight_decay=config.weight_decay,
        )
        self.scaler = torch.amp.GradScaler(device=device)

        # NOTE: ADD LOSS DICTORARIES
        self.losses = {}

    # def train(
    #     self,
    #     device: str,
    #     buffer_size: int,
    #     num_workers: int,
    #     max_cache_size: int,
    #     current_epoch: int = 0,
    # ):
    #     self.model.encoder.update_and_allocate(step)

In [2]:
from peft import AdaLoraConfig

adalora_config = AdaLoraConfig(
    peft_type="ADALORA",
    task_type="CAUSAL_LM",  # or "SPEECH_RECOGNITION" if needed
    target_modules=["q_proj", "v_proj"],  # which modules to adapt
    init_r=12,  # initial rank
    target_r=8,  # final average rank
    tinit=100,  # begin rank updates after 100 steps
    tfinal=1000,  # finish rank updates by step 1000
    deltaT=100,  # re-allocate every 100 steps
    lora_alpha=32,
    lora_dropout=0.1,
    total_step=2000,  # your total training steps
)