# Setup 
- `TA TODO`: Setup based on your environment. Reach out to me if you face any issues. Also, any feedback/improvements to the setup process for the students based on your experience setting up here would be very appreciated! (PS: I am still figuring out the best way to do this.)

## Local 

The assignment is designed in a manner that you can do most of the work `implementation` locally. We would recommend that you pass all the tests locally using the `hw4_data_subset` we've provided before moving to a GPU runtime. To do this simply:
- Create a new conda environment with `conda create -n hw4 python=3.12.4`
- Activate the conda environment with `conda activate hw4`
- Install the dependencies with `pip install -r requirements.txt`
- Ensure that your notebook is in the same directory as the `handout` folder. (See the expected directory structure in the `README.md`)




## Colab (`TA TODO`)

### Step 1: Get Repo (TA-Only, will be handout for students)

- `INTERNAL TODO`: Need to switch this to handout upload for students.

In [None]:
GITHUB_USERNAME = "puru-samal"
REPO_NAME = "IDL-HW4"
BRANCH_NAME = "TA"
ACCESS_TOKEN = "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"
repo_url = f"https://{GITHUB_USERNAME}:{ACCESS_TOKEN}@github.com/{GITHUB_USERNAME}/{REPO_NAME}.git"
!git clone -b {BRANCH_NAME} {repo_url}
#!git clone {repo_url}

#### If I announce a new commit, please delete and re-clone the repo.

In [4]:
!rm -rf IDL-HW4/

### Step 2: Get Data
- `INTERNAL TODO`: Need to switch this download from kaggle.

In [None]:
!gdown 1-0e9Gnl4nm6wbIuE_Yxl2wRZI8yGxHm6 --output hw4_data.tar.gz
!tar -xf hw4_data.tar.gz
!rm -rf hw4_data.tar.gz
!du -h max-depth=3 hw4_data_kaggle/

### Step 3: Install Dependencies
- `NOTE`: Colab may prompt you to restart your runtime. Do so then proceed to the next step.

In [None]:
%pip install --no-deps -r IDL-HW4/colab_requirements.txt

### Step 4: Move to Project Directory
- `NOTE`: You may have to repeat this on restarting your runtime. You can do a `pwd` to check if you are in the right directory.
- `NOTE`: Your data directory should be one level up from your project directory. Keep this in mind when you are setting your `root` in the config file.

In [None]:
import os
os.chdir('IDL-HW4')
!ls

## PSC (`TA TODO`)

### Step 1: Preliminaries

- `Step 0:` ssh into Bridges2 with `ssh username@bridges2.psc.edu`
- `Step 1:` cd into your project directory with `cd $PROJECT`
- `Step 2:` Load the anaconda module with `module load anaconda3`
- `Step 3:` Activate the HW4 envirtonent that was created for you with `conda activate /jet/home/psamal/hw_envs/idl_hw4` (Make sure to deactivate any existing conda environment first with `conda deactivate`)
- `Step 4:` Get a compute node with `interact -p GPU-shared --gres=gpu:v100-32:1 -t 8:00:00`
- `Step 5:` Run `conda deactivate` if your conda environment was deactivated due to node allocation. Ensure you are in the HW4 environment.
- `Step 6:` Now follow your usual steps to start a jupyter notebook. For me this is:
  - Start a jupyter notebook with `jupyter notebook --no-browser --ip=0.0.0.0` 
  - On a separate terminal, start a tunnel with `ssh -L 8888:{hostname}:{port} bridges2.psc.edu -l username`
  - Select the appropriate kernel on VSCode: Kernel -> Select Another Kernel -> Existing Jupyter Server -> `http://127.0.0.1:{port}/tree?token={token}`
- `Step 7:` Now follow the instructions below.

### Step 2: Get Repo (TA-Only, will be handout for students)

In [None]:
GITHUB_USERNAME = "puru-samal"
REPO_NAME = "IDL-HW4"
BRANCH_NAME = "TA"
ACCESS_TOKEN = "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"
repo_url = f"https://{GITHUB_USERNAME}:{ACCESS_TOKEN}@github.com/{GITHUB_USERNAME}/{REPO_NAME}.git"
!git clone -b {BRANCH_NAME} {repo_url} # TA ONLY

#### If I announce a new commit, please delete and re-clone the repo.

In [3]:
!rm -rf IDL-HW4/

### Step 3: Move to Project Directory
- `NOTE`: You may have to repeat this on anytime you restart your runtime. You can do a `pwd` or `ls` to check if you are in the right directory.

In [None]:
import os
os.chdir('IDL-HW4')
!ls

### Step 4: Get Data
- `NOTE`: We are using `$LOCAL`: the scratch storage on local disk on the node running a job to store out data. Disk accesses are much faster than what you would get from `$PROJECT` storage, but `IT IS NOT PERSISTENT`. 
- `NOTE`: Make sure you have a node allocated to you with `interact -p GPU-shared --gres=gpu:v100-32:1 -t 8:00:00`
- Read more about it PSC File Spaces [here](https://www.psc.edu/resources/bridges-2/user-guide#file-spaces).

In [None]:
!gdown 1-0e9Gnl4nm6wbIuE_Yxl2wRZI8yGxHm6 --output $LOCAL/hw4_data.tar.gz
!ls $LOCAL/
!tar -xf $LOCAL/hw4_data.tar.gz -C $LOCAL/
!rm -rf $LOCAL/hw4_data.tar.gz
!du --max-depth=3 $LOCAL/

# Imports

In [None]:
from hw4lib.data import (
    H4Tokenizer,
    ASRDataset,
    verify_dataloader
)
from hw4lib.model import (
    DecoderOnlyTransformer,
    EncoderDecoderTransformer
)
from hw4lib.utils import (
    create_scheduler,
    create_optimizer,
    plot_lr_schedule
)
from hw4lib.trainers import (
    ASRTrainer,
    ProgressiveTrainer
)
from torch.utils.data import DataLoader
import yaml
import gc
import torch
from torchinfo import summary
import os
import json
import wandb
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# Implementations
- `TA TODO`: 
  - `MANDATORY`: Run these cells to verify that the testing works in your chosen environment. Lmk if it doesn't.
  - `OPTIONAL`: Do read through the implementations. Any feedback regarding them would be very appreciated!
- `NOTE`: All of these implementations have detailed specification, implementation details, and hints in their respective source files. Make sure to read all of them in their entirety to understand the implementation details!

## Dataset Implementation
- `TODO`: Implement the `ASRDataset` class in `hw4lib/data/asr_dataset.py`. 
- You will have to implement parts of `__init__` and completely implement the `__len__`, `__getitem__` and `collate_fn` methods. 
- `TODO`: Then run the cell below to check your implementation.


In [None]:
!python -m tests.test_dataset_asr


## Model Implementations

- `TODO`: Implement the `CrossAttentionLayer` class in `hw4lib/model/sublayers.py`.
- `TODO`: Implement the `CrossAttentionDecoderLayer` class in `hw4lib/model/decoder_layers.py`.
- `TODO`: Implement the `SelfAttentionEncoderLayer` class in `hw4lib/model/encoder_layers.py`. This will be mostly a copy-paste of the `SelfAttentionDecoderLayer` class in `hw4lib/model/decoder_layers.py` with one minor diffrence: it can attend to all positions in the input sequence.
- `TODO`: Implement the `EncoderDecoderTransformer` class in `hw4lib/model/transformers.py`.

### Transformer Sublayers
- `TODO`: Now, Implement the `CrossAttentionLayer` class in `hw4lib/model/sublayers.py`.
- `NOTE`: You should have already implemented the `SelfAttentionLayer`, and `FeedForwardLayer` classes in `hw4lib/model/sublayers.py`.
- `TODO`: Then run the cell below to check your implementation.

In [None]:
!python -m tests.test_sublayer_crossattention

### Transformer Cross-Attention Decoder Layer
- `TODO`: Implement the `CrossAttentionDecoderLayer` class in `hw4lib/model/decoder_layers.py`.
- `TODO`: Then run the cell below to check your implementation.


In [None]:
!python -m tests.test_decoderlayer_crossattention

### Transformer Self-Attention Encoder Layer
- `TODO`: Implement the `SelfAttentionEncoderLayer` class in `hw4lib/model/encoder_layers.py`.
- `TODO`: Then run the cell below to check your implementation.




In [None]:
!python -m tests.test_encoderlayer_selfattention

### Encoder-Decoder Transformer

- `TODO`: Implement the  `EncoderDecoderTransformer` class in `hw4lib/model/transformers.py`.
- `TODO`: Then run the cell below to check your implementation.

In [None]:
!python -m tests.test_transformer_encoder_decoder

## Decoding Implementations 
- `TODO`: We highly recommend you to implement the `generate_beam` method of the `SequenceGenerator` class in `hw4lib/decoding/sequence_generator.py`.
- `TODO`: Then run the cell below to check your implementation.
- `NOTE`: This is an optional but highly recommended task for `HW4P2` to ease the journey to high cutoffs!

In [None]:
!python -m tests.test_decoding --mode beam  

# Experiments
- Please keep an eye out for the `TA TODO`'s in the following cells.

## Config
- `TA TODO`: You can use the `config.yaml` file to set your config for your ablation study.
- `TA TODO`: Remember to change the `root` path!
- `NOTE`: For the values not provided in the ablation sheet, feel free to set as you see fit.
- `NOTE`: If warmup is enabled in `scheduler` section, the warmup phase will happen first before switching to the base scheduler.
- `NOTE`: `warmup` is currently not supported with `ReduceLROnPlateau` scheduler.

For our purposes, we define the following terms:
- Light SpecAug  : `5 freq_mask_width_range, 1 num_freq_mask, 20 time_mask_width_range, 1 num_time_mask`
- Medium SpecAug : `5 freq_mask_width_range, 2 num_freq_mask, 40 time_mask_width_range, 2 num_time_mask`
- Heavy SpecAug  : `5 freq_mask_width_range, 4 num_freq_mask, 40 time_mask_width_range, 4 num_time_mask`

- `IMPORTANT`: You are required to run 70 epochs in total. 


### Experimental 
- `NOTE`: There is one experimental setup for the optimizer configuration, i.e Pattern-matching to group parameters by their names and apply different learning rates to them. 
- Eg. `self_attn` will match all parameters containing `self_attn` in their names. 
- See `hw4lib/utils/create_optimizer.py` for more details. Again, experiment if you want with it but I am still testing it out. 
- Motivation is to use it to set lower learning rates for `self-attn` and `ffn` modules while initializing an Encoder-Decoder Transformer with weights from a pre-trained Decoder-Only Transformer.
- This is for Internal Testing only, wont be available for student use for simplicity. 

In [57]:
%%writefile config.yaml

Name                      : "Puru"

###### Tokenization ------------------------------------------------------------
tokenization:
  token_type                : "1k"       # [char, 1k, 5k, 10k]
  token_map :
      'char': 'hw4lib/data/tokenizer_jsons/tokenizer_char.json'
      '1k'  : 'hw4lib/data/tokenizer_jsons/tokenizer_1000.json'
      '5k'  : 'hw4lib/data/tokenizer_jsons/tokenizer_5000.json'
      '10k' : 'hw4lib/data/tokenizer_jsons/tokenizer_10000.json'

###### Dataset -----------------------------------------------------------------
data:
  root                 : "/local/hw4_data_kaggle/hw4p2_data"  # TODO: Set the root path of your data
  train_partition      : "train-clean-100"  # paired text-speech for ASR pre-training
  val_partition        : "dev-clean"        # paired text-speech for ASR pre-training
  test_partition       : "test-clean"       # paired text-speech for ASR pre-training
  subset               : 1.0                # Load a subset of the data (for debugging, testing, etc
  batch_size           : 32           #   
  NUM_WORKERS          : 2            # Set to 0 for CPU
  norm                 : 'global_mvn' # ['global_mvn', 'cepstral', 'none']
  num_feats            : 80

  ###### SpecAugment ---------------------------------------------------------------
  specaug                   : True  # TODO: Set to True if you want to use SpecAugment
  # Light  :  5, 1, 20, 1
  # Medium :  5, 2, 40, 2
  # Heavy  :  5, 4, 40, 4
  # Currently set to Light
  specaug_conf:
    apply_freq_mask         : True
    freq_mask_width_range   : 5
    num_freq_mask           : 2
    apply_time_mask         : True
    time_mask_width_range   : 40
    num_time_mask           : 2

###### Network Specs -------------------------------------------------------------
model: # Encoder-Decoder Transformer (HW4P2)
  # Speech embedding parameters
  input_dim: 80              # Speech feature dimension
  time_reduction: 2          # Time dimension downsampling factor
  reduction_method: 'lstm'   # The source_embedding reduction method ['lstm', 'conv', 'both']
  
  # Architecture parameters
  d_model: 384            # Model dimension
  num_encoder_layers: 10  # Number of encoder layers
  num_decoder_layers: 6  # Number of decoder layers
  num_encoder_heads: 8   # Number of encoder attention heads
  num_decoder_heads: 8   # Number of decoder attention heads
  d_ff_encoder: 768     # Feed-forward dimension for encoder
  d_ff_decoder: 1536     # Feed-forward dimension for decoder
  skip_encoder_pe: False # Whether to skip positional encoding for encoder
  skip_decoder_pe: False # Whether to skip positional encoding for decoder
  
  # Common parameters
  dropout: 0.0          # Dropout rate
  layer_drop_rate: 0.0  # Layer dropout rate
  weight_tying: False   # Whether to use weight tying
  
###### Common Training Parameters ------------------------------------------------
training:
  use_wandb                   : True
  wandb_run_id                : "none" # "none" or "run_id"
  resume                      : False
  epochs                      : 70
  gradient_accumulation_steps : 1
  wandb_project               : "S25-HW4P2-TA"

###### Loss ----------------------------------------------------------------------
loss: # Just good ol' CrossEntropy
  label_smoothing: 0.0
  ctc_weight: 0.3

###### Optimizer -----------------------------------------------------------------
optimizer:
  name: "adamw" # Options: sgd, adam, adamw
  lr: 0.0004  # Base learning rate

  # Common parameters
  weight_decay: 0.000001

  # Parameter groups
  # You can add more param groups as you want and set their learning rates and patterns
  param_groups:
    - name: self_attn
      patterns: []  # Will match all parameters containing "encoder"
      lr: 0.0002  # LR for self_attn
      layer_decay:
        enabled: False
        decay_rate: 0.8
    
    - name: ffn
      patterns: []
      lr: 0.0002  # LR for ffn
      layer_decay:
        enabled: False
        decay_rate: 0.8
  
  # Layer-wise learning rates
  layer_decay:
    enabled: False
    decay_rate: 0.75

  # SGD specific parameters
  sgd:
    momentum: 0.9
    nesterov: True
    dampening: 0

  # Adam specific parameters
  adam:
    betas: [0.9, 0.999]
    eps: 1.0e-8
    amsgrad: False

  # AdamW specific parameters
  adamw:
    betas: [0.9, 0.999]
    eps: 1.0e-8
    amsgrad: False

###### Scheduler -----------------------------------------------------------------
scheduler:
  name: "cosine"  # Options: reduce_lr, cosine, cosine_warm

  # ReduceLROnPlateau specific parameters
  reduce_lr:
    mode: "min"  # Options: min, max
    factor: 0.1  # Factor to reduce learning rate by
    patience: 10  # Number of epochs with no improvement after which LR will be reduced
    threshold: 0.0001  # Threshold for measuring the new optimum
    threshold_mode: "rel"  # Options: rel, abs
    cooldown: 0  # Number of epochs to wait before resuming normal operation
    min_lr: 0.0000001  # Minimum learning rate
    eps: 1e-8  # Minimal decay applied to lr

  # CosineAnnealingLR specific parameters
  cosine:
    T_max: 15  # Maximum number of iterations
    eta_min: 0.00001  # Minimum learning rate
    last_epoch: -1

  # CosineAnnealingWarmRestarts specific parameters
  cosine_warm:
    T_0: 10    # Number of iterations for the first restart
    T_mult: 10 # Factor increasing T_i after each restart
    eta_min: 0.0000001  # Minimum learning rate
    last_epoch: -1

  # Warmup parameters (can be used with any scheduler)
  warmup:
    enabled: True
    type: "exponential"  # Options: linear, exponential
    epochs: 10
    start_factor: 0.1
    end_factor: 1.0


Overwriting config.yaml


In [58]:
with open('config.yaml', 'r') as file:
    config = yaml.safe_load(file)

## Tokenizer

In [59]:
Tokenizer = H4Tokenizer(
    token_map  = config['tokenization']['token_map'], 
    token_type = config['tokenization']['token_type']
)

                          Tokenizer Configuration (1k)                          
--------------------------------------------------------------------------------
Vocabulary size:     1000

Special Tokens:
PAD:              0
UNK:              1
MASK:             2
SOS:              3
EOS:              4
BLANK:            5

Validation Example:
--------------------------------------------------------------------------------
Input text:  [SOS]HI DEEP LEARNERS[EOS]
Tokens:      ['[SOS]', 'H', 'I', 'ĠDE', 'EP', 'ĠLE', 'AR', 'N', 'ERS', '[EOS]']
Token IDs:   [3, 14, 15, 159, 290, 228, 71, 20, 214, 4]
Decoded:     [SOS]HI DEEP LEARNERS[EOS]


## Datasets

In [60]:
train_dataset = ASRDataset(
    partition=config['data']['train_partition'],
    config=config['data'],
    tokenizer=Tokenizer,
    isTrainPartition=True,
    global_stats=None  # Will compute stats from training data
)

# TODO: Get the computed global stats from training set
global_stats = None
if config['data']['norm'] == 'global_mvn':
    global_stats = (train_dataset.global_mean, train_dataset.global_std)
    print(f"Global stats computed from training set.")

val_dataset = ASRDataset(
    partition=config['data']['val_partition'],
    config=config['data'],
    tokenizer=Tokenizer,
    isTrainPartition=False,
    global_stats=global_stats
)

test_dataset = ASRDataset(
    partition=config['data']['test_partition'],
    config=config['data'],
    tokenizer=Tokenizer,
    isTrainPartition=False,
    global_stats=global_stats
)

gc.collect()

Loading data for train-clean-100 partition...


100%|████████████████████████████████████████████████████████████████████████████████████| 28539/28539 [00:51<00:00, 553.20it/s]


Global stats computed from training set.
Loading data for dev-clean partition...


100%|█████████████████████████████████████████████████████████████████████████████████████| 2703/2703 [00:01<00:00, 1963.99it/s]


Loading data for test-clean partition...


100%|█████████████████████████████████████████████████████████████████████████████████████| 2620/2620 [00:00<00:00, 3537.68it/s]


7413

## Dataloaders

In [61]:
train_loader    = DataLoader(
    dataset     = train_dataset,
    batch_size  = config['data']['batch_size'],
    shuffle     = True,
    num_workers = config['data']['NUM_WORKERS'] if device == 'cuda' else 0,
    pin_memory  = True,
    collate_fn  = train_dataset.collate_fn   
)

val_loader      = DataLoader(
    dataset     = val_dataset,
    batch_size  = config['data']['batch_size'],
    shuffle     = False,
    num_workers = config['data']['NUM_WORKERS'] if device == 'cuda' else 0,
    pin_memory  = True,
    collate_fn  = val_dataset.collate_fn   
)

test_loader     = DataLoader(
    dataset     = test_dataset,
    batch_size  = config['data']['batch_size'],
    shuffle     = False,
    num_workers = config['data']['NUM_WORKERS'] if device == 'cuda' else 0,
    pin_memory  = True,
    collate_fn  = test_dataset.collate_fn   
)

gc.collect()

0

### Dataloader Verification

In [62]:
verify_dataloader(train_loader)

             Dataloader Verification              
Dataloader Partition     : train-clean-100
--------------------------------------------------
Number of Batches        : 892
Batch Size               : 32
--------------------------------------------------
Checking shapes of the data...                    

Feature Shape            : [32, 2119, 80]
Shifted Transcript Shape : [32, 99]
Golden Transcript Shape  : [32, 99]
Feature Lengths Shape    : [32]
Transcript Lengths Shape : [32]
--------------------------------------------------
Max Feature Length       : 3066
Max Transcript Length    : 139
Avg. Chars per Token     : 3.13


In [63]:
verify_dataloader(val_loader)

             Dataloader Verification              
Dataloader Partition     : dev-clean
--------------------------------------------------
Number of Batches        : 85
Batch Size               : 32
--------------------------------------------------
Checking shapes of the data...                    

Feature Shape            : [32, 3676, 80]
Shifted Transcript Shape : [32, 132]
Golden Transcript Shape  : [32, 132]
Feature Lengths Shape    : [32]
Transcript Lengths Shape : [32]
--------------------------------------------------
Max Feature Length       : 4081
Max Transcript Length    : 179
Avg. Chars per Token     : 3.08


In [64]:
verify_dataloader(test_loader)

             Dataloader Verification              
Dataloader Partition     : test-clean
--------------------------------------------------
Number of Batches        : 82
Batch Size               : 32
--------------------------------------------------
Checking shapes of the data...                    

Feature Shape            : [32, 2099, 80]
Feature Lengths Shape    : [32]
--------------------------------------------------
Max Feature Length       : 4370
Max Transcript Length    : 0
Avg. Chars per Token     : 0.00


## Calculate Max Lengths
Calculating the maximum transcript length across your dataset is a crucial step when working with certain transformer models. 
-  We'll use sinusoidal positional encodings that must be precomputed up to a fixed maximum length.
- This maximum length is a hyperparameter that determines:
  - How long of a sequence your model can process
  - The size of your positional encoding matrix
  - Memory requirements during training and inference
- `Requirements`: For this assignment, ensure your positional encodings can accommodate at least the longest sequence in your dataset to prevent truncation. However, you can set this value higher if you anticipate using your languagemodel to work with longer sequences in future tasks (hint: this might be useful for P2! 😉).
- `NOTE`: We'll be using the same positional encoding matrix for all sequences in your dataset. Take this into account when setting your maximum length.

In [65]:
max_feat_len       = max(train_dataset.feat_max_len, val_dataset.feat_max_len, test_dataset.feat_max_len)
max_transcript_len = max(train_dataset.text_max_len, val_dataset.text_max_len, test_dataset.text_max_len)
max_len            = max(max_feat_len, max_transcript_len)

print("="*50)
print(f"{'Max Feature Length':<30} : {max_feat_len}")
print(f"{'Max Transcript Length':<30} : {max_transcript_len}")
print(f"{'Overall Max Length':<30} : {max_len}")
print("="*50)

Max Feature Length             : 4370
Max Transcript Length          : 179
Overall Max Length             : 4370


## Wandb

In [66]:
wandb.login(key="31888e0ba72a18d4a57ea02c19a9687bc4481f37") 

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /jet/home/psamal/.netrc


True

## Training 

You will have to do some minor in-filling for the `ASRTrainer` class in `hw4lib/trainers/asr_trainer.py` before you can use it.
- `TODO`: Fill in the `TODO`s in the `__init__`.
- `TODO`: Fill in the `TODO`s in the `_train_epoch`.
- `TODO`: Fill in the `TODO`s in the `recognize` method.
- `TODO`: Fill in the `TODO`s in the `_validate_epoch`.
- `TODO`: Fill in the `TODO`s in the `train` method.
- `TODO`: Fill in the `TODO`s in the `evaluate` method.

Every time you run the trainer, it will create a new directory in the `expts` folder with the following structure:
```
expts/
    └── {run_name}/
        ├── config.yaml
        ├── model_arch.txt
        ├── checkpoints/
        │   ├── checkpoint-best-metric-model.pth
        │   └── checkpoint-last-epoch-model.pth
        ├── attn/
        │   └── {attention visualizations}
        └── text/
            └── {generated text outputs}
```


### Training Strategy 1: Cold-Start Trainer

- `TA TODO`: Run this section if you are assigned the `Cold-Start` task. Nothing special here, just the standard training loop.

#### Model Load (Default)

In [None]:
model_config = config['model'].copy()
model_config.update({
    'max_len': max_len,
    'num_classes': Tokenizer.vocab_size
})

model = EncoderDecoderTransformer(**model_config)

# Get some inputs from the train dataloader
for batch in train_loader:
    padded_feats, padded_shifted, padded_golden, feat_lengths, transcript_lengths = batch
    break


model_stats = summary(model, input_data=[padded_feats, padded_shifted, feat_lengths, transcript_lengths])
print(model_stats)

#### Initialize Trainer
- `TA TODO`: Please change the run name to the run name that was assigned to you in the ablation sheet for easy referencing.
- `NOTE`: `optimizer` gets initialized in the `trainer` constructor based on the config.

If you need to reload the model from a checkpoint, you can do so by calling the `load_checkpoint` method.

```python
checkpoint_path = "path/to/checkpoint.pth"
trainer.load_checkpoint(checkpoint_path)
```


In [None]:
trainer = ASRTrainer(
    model=model,
    tokenizer=Tokenizer,
    config=config,
    run_name="Puru-Test-Cold-Start-PSC",
    config_file="config.yaml",
    device=device
)

checkpoint_path = "/ocean/projects/cis220031p/psamal/expts/Puru-Pretrained-Decoder-Test/checkpoints/checkpoint-best-metric-model.pth"
trainer.load_checkpoint(checkpoint_path)

#### Train
- `TA TODO`: You can set your epochs here or in the config. If you set in config, make sure you remove the epoch argument here.
- `NOTE`: A `scheduler` gets initialized in this call based on the config. 

In [None]:
trainer.train(train_loader, val_loader, epochs=60)

#### Evaluate

- `TA TODO`: There will be 3 sequential evaluations here: with greedy decoding and beam search decoding with beam sizes 10 and 20.
- `TA TODO`: Make sure you report the results for each of these cases.

In [None]:
with open("hw4p2_sol.json", "r") as f:
    solution = json.load(f)

results = trainer.evaluate(test_loader, solution, max_length=max_transcript_len)

# Cleanup (Will end wandb run)
trainer.cleanup()

### Training Strategy 2: Progressive Trainer

- `TA TODO`: Run this section if you are assigned the `Progressive-Train` task. This section is a bit more involved. Read carefully. Reach out if you require any clarifications.

In this mode of training, you will start with a model with only 1 encoder and 1 decoder layer, and then increase the number of layers after every pretrain iteration, optionally freezing the previous layers and scheduling regularization such as dropout  and label smoothing, which you will keep low or disabled initially and then later enable. Finally, you will unfreeze all layers and train the model.



#### Model Load (Default)

In [67]:
model_config = config['model'].copy()
model_config.update({
    'max_len': max_len,
    'num_classes': Tokenizer.vocab_size
})

model = EncoderDecoderTransformer(**model_config)

# Get some inputs from the train dataloader
for batch in train_loader:
    padded_feats, padded_shifted, padded_golden, feat_lengths, transcript_lengths = batch
    break


model_stats = summary(model, input_data=[padded_feats, padded_shifted, feat_lengths, transcript_lengths])
print(model_stats)

Layer (type:depth-idx)                             Output Shape              Param #
EncoderDecoderTransformer                          [32, 93, 1000]            --
├─SpeechEmbedding: 1-1                             [32, 1039, 384]           --
│    └─StackedBLSTMEmbedding: 2-1                  [32, 1039, 384]           --
│    │    └─LSTM: 3-1                              [52153, 384]              420,864
│    │    └─MaxPool1d: 3-2                         [32, 384, 1039]           --
│    │    └─LSTM: 3-3                              [26069, 384]              887,808
│    │    └─MaxPool1d: 3-4                         [32, 384, 1039]           --
│    │    └─Linear: 3-5                            [32, 1039, 384]           147,840
│    │    └─Dropout: 3-6                           [32, 1039, 384]           --
├─PositionalEncoding: 1-2                          [32, 1039, 384]           --
├─Dropout: 1-3                                     [32, 1039, 384]           --
├─ModuleList: 1-4   

#### Initialize Progressive Trainer
- `TA TODO`: Please change the run name to the run name that was assigned to you in the ablation sheet for easy referencing.
- `NOTE`: `optimizer` gets initialized in the `trainer` constructor based on the config.

If you need to reload the model from a checkpoint, you can do so by calling the `load_checkpoint` method.

```python
checkpoint_path = "path/to/checkpoint.pth"
trainer.load_checkpoint(checkpoint_path)
```

In [68]:
trainer = ProgressiveTrainer(
    model=model,
    tokenizer=Tokenizer,
    config=config,
    run_name="asr_1k_09",
    config_file="config.yaml",
    device=device
)

Using device: cuda

🔧 Configuring Optimizer:
├── Type: ADAMW
├── Base LR: 0.0004
├── Weight Decay: 1e-06
├── Parameter Groups:
│   ├── Group: self_attn
│   │   ├── LR: 0.0002
│   │   └── Patterns: []
│   ├── Group: ffn
│   │   ├── LR: 0.0002
│   │   └── Patterns: []
│   └── Default Group (unmatched parameters)
└── AdamW Specific:
    ├── Betas: [0.9, 0.999]
    ├── Epsilon: 1e-08
    └── AMSGrad: False


[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


#### `TA TODO`: Define your training stages
The `ProgressiveTrainer` class implements a curriculum learning approach where model complexity and regularization are gradually increased through defined training stages.

##### Stage Configuration

Each stage is defined as a dictionary with the following parameters:
```python
{
    'name': str,                        # Name of the training stage
    'epochs': int,                      # Number of epochs to train in this stage
    'encoder_active_layers': List[int], # Which encoder layers to use
    'decoder_active_layers': List[int], # Which decoder layers to use
    'encoder_freeze': List[bool],       # Whether to freeze each encoder layer
    'decoder_freeze': List[bool],       # Whether to freeze each decoder layer
    'dropout': float,                   # Dropout rate for this stage
    'label_smoothing': float,           # Label smoothing value
    'data_subset': float                # Fraction of training data to use (0.0-1.0)
}
```

It is best understood by an example. Here is a breakdown of the stages defined below for a model with 6 encoder and 6 decoder layers:

- `Initial (1 layers)`: 
   - This stage starts with a model with only 1 encoder and 1 decoder layer. 
   - No freezing or regularization is applied. 
   - It uses 20% of the training data.
- `2 layers`: 
   - This stage increases the number of layers to 2 for both the encoder and decoder. 
   - The previous layer (encoder layer 1 and decoder layer 1) are frozen. 
   - No regularization is applied. 
   - It uses 20% of the training data.
- `4 layers`: 
   - This stage increases the number of layers to 4 for both the encoder and decoder. 
   - The previous layers (encoder layers 1 and 2 and decoder layers 1 and 2) are frozen. 
   - Dropout is set to 0.05 and label smoothing is set to 0.0. 
   - It uses 20% of the training data.
- `All 6 layers`: 
   - This stage uses all 6 encoder and 6 decoder layers. 
   - The 4 previous layers are frozen and the last 2 layers are trained. 
   - Dropout is set to 0.1 and label smoothing is set to 0.0. 
   - It uses 20% of the training data.
- `Final (with label smoothing)`: 
   - This stage uses all 6 encoder and 6 decoder layers. 
   - All layers are unfrozen and trained. 
   - Dropout is set to 0.1 and label smoothing is set to 0.1. 
   - It uses 20% of the training data.    

`TA TODO`: Define your stages here. The design is left to you. Also, the active_layers do not have to be contiguous. For example, for stage 2, you could have had layer 1 and 6 as active layers. 

##### Important Notes
- Ensure `encoder_freeze` and `decoder_freeze` lists match the length of their respective `active_layers`
- `data_subset` should be between 0 and 1
- Stage transitions are handled automatically by the trainer
- The same optimizer and scheduler are used for all stages so keep that in mind while setting the learning rates and other parameters

In [69]:
## Example with a model with 6 encoder and 6 decoder layers
stages = [
            {
                'name': 'Initial (2 Encoder + 1 Decoder layers)',
                'epochs': 5,
                'encoder_active_layers': list(range(2)),  # layers 1 
                'decoder_active_layers': list(range(1)),  # layers 1
                'encoder_freeze': [False, False],
                'decoder_freeze': [False],
                'dropout': 0.0,
                'label_smoothing': 0.0,
                'data_subset': 0.2
            },
            {
                'name': '4 Encoder + 2 Decoder layers',
                'epochs': 5,
                'encoder_active_layers': list(range(4)),
                'decoder_active_layers': list(range(2)),
                'encoder_freeze': [True, True, False, False],
                'decoder_freeze': [True, False],
                'dropout': 0.0,
                'label_smoothing': 0.0,
                'data_subset': 0.2
            },
            {
                'name': '8 Encoder + 4 Decoder layers',
                'epochs': 5,
                'encoder_active_layers': list(range(8)),
                'decoder_active_layers': list(range(4)),
                'encoder_freeze': [True, True, True, True, False, False, False, False],
                'decoder_freeze': [True, True, False, False],
                'dropout': 0.05,
                'label_smoothing': 0.0,
                'data_subset': 0.2
            },
            {
                'name': '10 Encoder + 6 Decoder layers',
                'epochs': 5,
                'encoder_active_layers': list(range(10)),
                'decoder_active_layers': list(range(6)),
                'encoder_freeze': [True, True, True, True, True, True, True, True, False, False],
                'decoder_freeze': [True, True, True, True, False, False],
                'dropout': 0.1,
                'label_smoothing': 0.0,
                'data_subset': 0.2
            },
            {
                'name': 'Final (with label smoothing)',
                'epochs': 5,
                'encoder_active_layers': list(range(10)),
                'decoder_active_layers': list(range(6)),
                'encoder_freeze': [False, False, False, False, False, False, False, False, False, False],
                'decoder_freeze': [False, False, False, False, False, False],
                'dropout': 0.1,
                'label_smoothing': 0.1,
                'data_subset': 0.2
            }
        ]

`TA TODO`: You might want to revisit and change the settings of your optimizer and scheduler. 
- Just go back up, change the config optimizer and scheduler parameters, and return to this cell. 
- The same optimizer and scheduler are used for all stages so keep that in mind while setting the learning rates and other parameters. 
- The example below assumes that the same subset of the training data is used for all stages, this is the easiest way to do it. 
- I would not recommend having variable data subsets for each stage without understanding the `ProgressiveTrainer` and it's parent `ASRTrainer` classes.

In [70]:
# Create scheduler before progressive training
trainer.optimizer = create_optimizer(model, config['optimizer'])
subset_train_dataloader = trainer.get_subset_dataloader(train_loader, stages[0]['data_subset'])
trainer.scheduler = create_scheduler(trainer.optimizer, config['scheduler'], subset_train_dataloader, gradient_accumulation_steps=config['training']['gradient_accumulation_steps'])


🔧 Configuring Optimizer:
├── Type: ADAMW
├── Base LR: 0.0004
├── Weight Decay: 1e-06
├── Parameter Groups:
│   ├── Group: self_attn
│   │   ├── LR: 0.0002
│   │   └── Patterns: []
│   ├── Group: ffn
│   │   ├── LR: 0.0002
│   │   └── Patterns: []
│   └── Default Group (unmatched parameters)
└── AdamW Specific:
    ├── Betas: [0.9, 0.999]
    ├── Epsilon: 1e-08
    └── AMSGrad: False

📈 Configuring Learning Rate Scheduler:
├── Type: COSINE
├── Cosine Annealing Settings:
│   ├── T_max: 15 epochs (2685 steps)
│   └── Min LR: 1e-05
├── Warmup Settings:
│   ├── Duration: 10 epochs (1790 steps)
│   ├── Start Factor: 0.1
│   └── End Factor: 1.0


#### Train Progressively

In [56]:
trainer.progressive_train(train_loader, val_loader, stages)


             Starting Stage: Initial (2 Encoder + 1 Decoder layers)             

Configuration Details:
├── Data Subset: 20.0% of training data
├── Training Epochs: 5
├── Dropout: 0.0
├── Label Smoothing: 0.0
├── Encoder Layers:
│   ├── Layer 0: Trainable
│   ├── Layer 1: Trainable
├── Decoder Layers:
│   ├── Layer 0: Trainable
├── Frozen Parameters: 0
└── Trainable Parameters: 4,734,336


                                                                                                                                


📊 Metrics (Epoch 0):
├── TRAIN:
│   ├── ce_loss: 6.3148
│   ├── ctc_loss: 12.1573
│   ├── joint_loss: 9.9620
│   ├── perplexity_char: 7.5294
│   └── perplexity_token: 552.6990
└── VAL:
    ├── cer: 179.9351
    ├── wer: 289.5620
    └── word_dist: 145.6313
└── TRAINING:
    └── learning_rate: 0.000076


                                                                                                                                


📊 Metrics (Epoch 1):
├── TRAIN:
│   ├── ce_loss: 5.7497
│   ├── ctc_loss: 6.2098
│   ├── joint_loss: 7.6127
│   ├── perplexity_char: 6.2849
│   └── perplexity_token: 314.1002
└── VAL:
    ├── cer: 96.0998
    ├── wer: 178.7966
    └── word_dist: 77.7688
└── TRAINING:
    └── learning_rate: 0.000112


                                                                                                                                


📊 Metrics (Epoch 2):
├── TRAIN:
│   ├── ce_loss: 5.1317
│   ├── ctc_loss: 6.1621
│   ├── joint_loss: 6.9803
│   ├── perplexity_char: 5.1582
│   └── perplexity_token: 169.3071
└── VAL:
    ├── cer: 73.4322
    ├── wer: 99.7953
    └── word_dist: 59.4500
└── TRAINING:
    └── learning_rate: 0.000148


                                                                                                                                


📊 Metrics (Epoch 3):
├── TRAIN:
│   ├── ce_loss: 4.7906
│   ├── ctc_loss: 6.1225
│   ├── joint_loss: 6.6274
│   ├── perplexity_char: 4.6253
│   └── perplexity_token: 120.3789
└── VAL:
    ├── cer: 72.2892
    ├── wer: 109.1691
    └── word_dist: 58.5062
└── TRAINING:
    └── learning_rate: 0.000184


                                                                                                                                


📊 Metrics (Epoch 4):
├── TRAIN:
│   ├── ce_loss: 4.5998
│   ├── ctc_loss: 6.0631
│   ├── joint_loss: 6.4187
│   ├── perplexity_char: 4.3515
│   └── perplexity_token: 99.4610
└── VAL:
    ├── cer: 72.7139
    ├── wer: 96.2341
    └── word_dist: 58.8563
└── TRAINING:
    └── learning_rate: 0.000220

                  Starting Stage: 4 Encoder + 2 Decoder layers                  

Configuration Details:
├── Data Subset: 20.0% of training data
├── Training Epochs: 5
├── Dropout: 0.0
├── Label Smoothing: 0.0
├── Encoder Layers:
│   ├── Layer 0: Frozen
│   ├── Layer 1: Frozen
│   ├── Layer 2: Trainable
│   ├── Layer 3: Trainable
├── Decoder Layers:
│   ├── Layer 0: Frozen
│   ├── Layer 1: Trainable
├── Frozen Parameters: 4,734,336
└── Trainable Parameters: 4,734,336


                                                                                                                                


📊 Metrics (Epoch 5):
├── TRAIN:
│   ├── ce_loss: 4.5517
│   ├── ctc_loss: 6.0508
│   ├── joint_loss: 6.3670
│   ├── perplexity_char: 4.2852
│   └── perplexity_token: 94.7942
└── VAL:
    ├── cer: 75.3939
    ├── wer: 104.8301
    └── word_dist: 61.0125
└── TRAINING:
    └── learning_rate: 0.000256


                                                                                                                                


📊 Metrics (Epoch 6):
├── TRAIN:
│   ├── ce_loss: 4.4495
│   ├── ctc_loss: 6.0220
│   ├── joint_loss: 6.2561
│   ├── perplexity_char: 4.1474
│   └── perplexity_token: 85.5870
└── VAL:
    ├── cer: 73.7566
    ├── wer: 100.9005
    └── word_dist: 59.6875
└── TRAINING:
    └── learning_rate: 0.000292


                                                                                                                                


📊 Metrics (Epoch 7):
├── TRAIN:
│   ├── ce_loss: 4.3569
│   ├── ctc_loss: 5.9691
│   ├── joint_loss: 6.1476
│   ├── perplexity_char: 4.0264
│   └── perplexity_token: 78.0124
└── VAL:
    ├── cer: 72.5749
    ├── wer: 96.7253
    └── word_dist: 58.7562
└── TRAINING:
    └── learning_rate: 0.000328


                                                                                                                                


📊 Metrics (Epoch 8):
├── TRAIN:
│   ├── ce_loss: 4.2450
│   ├── ctc_loss: 5.8258
│   ├── joint_loss: 5.9928
│   ├── perplexity_char: 3.8850
│   └── perplexity_token: 69.7590
└── VAL:
    ├── cer: 79.0624
    ├── wer: 113.7945
    └── word_dist: 64.0188
└── TRAINING:
    └── learning_rate: 0.000364


                                                                                                                                


📊 Metrics (Epoch 9):
├── TRAIN:
│   ├── ce_loss: 4.1173
│   ├── ctc_loss: 5.3329
│   ├── joint_loss: 5.7172
│   ├── perplexity_char: 3.7295
│   └── perplexity_token: 61.3937
└── VAL:
    ├── cer: 73.3472
    ├── wer: 100.5731
    └── word_dist: 59.3750
└── TRAINING:
    └── learning_rate: 0.000400

                  Starting Stage: 8 Encoder + 4 Decoder layers                  

Configuration Details:
├── Data Subset: 20.0% of training data
├── Training Epochs: 5
├── Dropout: 0.05
├── Label Smoothing: 0.0
├── Encoder Layers:
│   ├── Layer 0: Frozen
│   ├── Layer 1: Frozen
│   ├── Layer 2: Frozen
│   ├── Layer 3: Frozen
│   ├── Layer 4: Trainable
│   ├── Layer 5: Trainable
│   ├── Layer 6: Trainable
│   ├── Layer 7: Trainable
├── Decoder Layers:
│   ├── Layer 0: Frozen
│   ├── Layer 1: Frozen
│   ├── Layer 2: Trainable
│   ├── Layer 3: Trainable
├── Frozen Parameters: 9,468,672
└── Trainable Parameters: 9,468,672


                                                                                                                                


📊 Metrics (Epoch 10):
├── TRAIN:
│   ├── ce_loss: 4.3103
│   ├── ctc_loss: 4.5608
│   ├── joint_loss: 5.6786
│   ├── perplexity_char: 3.9669
│   └── perplexity_token: 74.4658
└── VAL:
    ├── cer: 73.3550
    ├── wer: 95.6611
    └── word_dist: 59.3937
└── TRAINING:
    └── learning_rate: 0.000396


                                                                                                                                


📊 Metrics (Epoch 11):
├── TRAIN:
│   ├── ce_loss: 4.1412
│   ├── ctc_loss: 3.8875
│   ├── joint_loss: 5.3074
│   ├── perplexity_char: 3.7581
│   └── perplexity_token: 62.8774
└── VAL:
    ├── cer: 72.9533
    ├── wer: 100.6140
    └── word_dist: 59.0375
└── TRAINING:
    └── learning_rate: 0.000383


                                                                                                                                


📊 Metrics (Epoch 12):
├── TRAIN:
│   ├── ce_loss: 3.9692
│   ├── ctc_loss: 3.4727
│   ├── joint_loss: 5.0110
│   ├── perplexity_char: 3.5570
│   └── perplexity_token: 52.9398
└── VAL:
    ├── cer: 72.1347
    ├── wer: 98.4036
    └── word_dist: 58.3937
└── TRAINING:
    └── learning_rate: 0.000363


                                                                                                                                


📊 Metrics (Epoch 13):
├── TRAIN:
│   ├── ce_loss: 3.6521
│   ├── ctc_loss: 3.2021
│   ├── joint_loss: 4.6127
│   ├── perplexity_char: 3.2141
│   └── perplexity_token: 38.5554
└── VAL:
    ├── cer: 70.7136
    ├── wer: 99.5497
    └── word_dist: 57.2875
└── TRAINING:
    └── learning_rate: 0.000335


                                                                                                                                


📊 Metrics (Epoch 14):
├── TRAIN:
│   ├── ce_loss: 3.1466
│   ├── ctc_loss: 3.0071
│   ├── joint_loss: 4.0488
│   ├── perplexity_char: 2.7345
│   └── perplexity_token: 23.2571
└── VAL:
    ├── cer: 64.6046
    ├── wer: 88.6615
    └── word_dist: 52.3125
└── TRAINING:
    └── learning_rate: 0.000302

                 Starting Stage: 10 Encoder + 6 Decoder layers                  

Configuration Details:
├── Data Subset: 20.0% of training data
├── Training Epochs: 5
├── Dropout: 0.1
├── Label Smoothing: 0.0
├── Encoder Layers:
│   ├── Layer 0: Frozen
│   ├── Layer 1: Frozen
│   ├── Layer 2: Frozen
│   ├── Layer 3: Frozen
│   ├── Layer 4: Frozen
│   ├── Layer 5: Frozen
│   ├── Layer 6: Frozen
│   ├── Layer 7: Frozen
│   ├── Layer 8: Trainable
│   ├── Layer 9: Trainable
├── Decoder Layers:
│   ├── Layer 0: Frozen
│   ├── Layer 1: Frozen
│   ├── Layer 2: Frozen
│   ├── Layer 3: Frozen
│   ├── Layer 4: Trainable
│   ├── Layer 5: Trainable
├── Frozen Parameters: 18,937,344
└── Trainable Param

                                                                                                                                


📊 Metrics (Epoch 15):
├── TRAIN:
│   ├── ce_loss: 3.1709
│   ├── ctc_loss: 2.9915
│   ├── joint_loss: 4.0684
│   ├── perplexity_char: 2.7558
│   └── perplexity_token: 23.8294
└── VAL:
    ├── cer: 59.5690
    ├── wer: 82.7262
    └── word_dist: 48.2500
└── TRAINING:
    └── learning_rate: 0.000265


[Training ASR]:  15%|▏| 26/179 [00:22<02:07,  1.20it/s, acc_step=1/1, ce_loss=2.8994, ctc_loss=2.8700, joint_loss=3.7604, perple

#### Unfreeze all layers

In [19]:
for name, param in model.named_parameters():
    if param.requires_grad:
        param.requires_grad = True

#### Reload Optimizer and Scheduler
- `TA TODO`: You might want to revisit and change the settings of your optimizer and scheduler. 
- Just go back up, change the config optimizer and scheduler parameters, and return to this cell. 
- The same optimizer and scheduler are used for all stages so keep that in mind while setting the learning rates and other parameters. 

In [None]:
# Create scheduler before full training
trainer.optimizer = create_optimizer(model, config['optimizer'])
trainer.scheduler = create_scheduler(trainer.optimizer, config['scheduler'], train_loader, gradient_accumulation_steps=config['training']['gradient_accumulation_steps'])

#### Train Full
- `TA TODO`: You can set your epochs here or in the config. If you set in config, make sure you remove the epoch argument here.

In [None]:
trainer.train(train_loader, val_loader, epochs=40)

#### Evaluate

- `TA TODO`: There will be 3 sequential evaluations here: with greedy decoding and beam search decoding with beam sizes 10 and 20.
- `TA TODO`: Make sure you report the results for each of these cases.



In [None]:
with open("hw4p2_sol.json", "r") as f:
    solution = json.load(f)

results = trainer.evaluate(test_loader, solution, max_length=max_transcript_len)

# Cleanup (Will end wandb run)
trainer.cleanup()

### Training Strategy 3: Pretrained Decoder Initialized 

- `TA TODO`: Run this section if you are assigned the `Decoder-Initialized Train` task. This section is a bit more involved. Read carefully. Reach out if you require any clarifications.

In this mode of training, you will: 
- Initialize the Encoder-Decoder Transformer with the shared weights from a pretrained Decoder-Only Transformer (self-attn's, ffn's, etc). 
- You will then first freeze these pre-trained weights and train the encoder and just the cross-attention decoder layers on the ASR task. 
- After that, you will unfreeze all weights and train the model, optionally setting a lower learning rate for the pre-trained weights.

`NOTE`: You can get a bit adventurous if you'd like, for ex, combining this with the progressive training strategy. It will be well appreciated but is not required.

#### Decoder-Only Initialized Load

- `TA TODO`: Be sure to set the `decoder_checkpoint` below to the path of the `COMPATIBLE` decoder checkpoint you trained during your `HW4P1` ablation study. 







In [None]:
model_config = config['model'].copy()

# TODO: Set the path to the decoder checkpoint.
decoder_checkpoint = "/ocean/projects/cis220031p/psamal/expts/lm_char_02/checkpoints/checkpoint-best-metric-model.pth"
model_config.update({
    'max_len': max_len,
    'num_classes': Tokenizer.vocab_size
})

model, param_info = EncoderDecoderTransformer.from_pretrained_decoder(
    decoder_checkpoint_path=decoder_checkpoint,
    config=model_config,
)

#### Freeze Pre-trained Weights

In [16]:
transferred_params = [name for (name, _) in param_info['transferred']]
for name, param in model.named_parameters():
    if name in transferred_params:
        param.requires_grad = False

#### Initialize Trainer
- `TA TODO`: Please change the run name to the run name that was assigned to you in the ablation sheet for easy referencing.
- `NOTE`: `optimizer` gets initialized in the `trainer` constructor based on the config.

If you need to reload the model from a checkpoint, you can do so by calling the `load_checkpoint` method.

```python
checkpoint_path = "path/to/checkpoint.pth"
trainer.load_checkpoint(checkpoint_path)
```

In [None]:
trainer = ASRTrainer(
    model=model,
    tokenizer=Tokenizer,
    config=config,
    run_name="Puru-Pretrained-Decoder-Test",
    config_file="config.yaml",
    device=device
)

#### Train Encoder and Cross-Attention Decoder Layers with frozen pre-trained weights
- `TA TODO`: You can set your epochs here or in the config. If you set in config, make sure you remove the epoch argument here.
- `TA TODO`: You might want to revisit and change the settings of your optimizer and scheduler. 
- Just go back up, change the config optimizer and scheduler parameters, and return to this cell. 


In [None]:
trainer.train(train_loader, val_loader, epochs=20)

#### Unfreeze all weights

In [26]:
transferred_params = [name for (name, _) in param_info['transferred']]
for name, param in model.named_parameters():
    if name in transferred_params:
        param.requires_grad = True

# Check that all parameters are being trained
for name, param in model.named_parameters():
    if param.requires_grad:
        assert param.requires_grad

#### Reload Optimizer and Scheduler
- `TA TODO`: You might want to revisit and change the settings of your optimizer and scheduler. 
- Just go back up, change the config optimizer and scheduler parameters, and return to this cell. 
- `TA TODO`: We are creating two separate groups for the pre-trained and new parameters. This is because we will experiment with different learning rates for the pre-trained and new parameters. Set the `lr_factor` below based on your desired ratio of learning rates for the pre-trained and new parameters.

In [None]:
# Create diffrent groups for the pre-trained and new parameters
transfered_patterns = [name for (name, param) in param_info['transferred']]
new_patterns = [name for (name, param) in param_info['new']]
lr_factor = 0.1 # TODO: Set


optimizer_config = config['optimizer']
optimizer_config['param_groups'] = [
    {
        'name': 'transferred_params',
        'patterns': transfered_patterns,
        'lr': config['optimizer']['lr'] * lr_factor # TODO: Set
    },
    {
        'name': 'new_params',
        'patterns': new_patterns,
        'lr': config['optimizer']['lr']
    }
]
trainer.optimizer = create_optimizer(model, optimizer_config)
trainer.scheduler = create_scheduler(trainer.optimizer, config['scheduler'], train_loader)

#### Train Full
- `TA TODO`: You can set your epochs here or in the config. If you set in config, make sure you remove the epoch argument here.

In [None]:
trainer.train(train_loader, val_loader, epochs=30)

#### Evaluate



In [None]:
with open("hw4p2_sol.json", "r") as f:
    solution = json.load(f)

results = trainer.evaluate(test_loader, solution, max_length=max_transcript_len)

# Cleanup (Will end wandb run)
trainer.cleanup()

## Shallow Fusion Inference

Here you will use an external language model (i.e. the one you trained in HW4P1) to potentially improve the ASR performance of your model. On a high level, each step of beam search will involve a log-linear interpolation between the ASR model's logits and the language model's logits.

- `TA TODO`: Set the `path_to_lm_checkpoint` below to the path of the `COMPATIBLE` language model checkpoint you trained during your `HW4P1` ablation study. By compatible, we mean that the tokenization type should match the tokenizer used in the ASR model.


In [None]:
# TODO: Set the path to your best performing LM checkpoint from HW4P1 ablation study.
# NOTE: The tokenization type should match the tokenizer used in the ASR model.
path_to_lm_checkpoint = "/ocean/projects/cis220031p/psamal/expts/test-lm/checkpoints/checkpoint-best-metric-model.pth"

# Load the LM checkpoint
lm_dict = torch.load(path_to_lm_checkpoint, map_location=trainer.device, weights_only=True)

# Get the model config
lm_model_config = lm_dict['config']['model']
lm_max_len = lm_dict['model_state_dict']['positional_encoding.pe'].shape[1]
lm_model_config.update({
    'max_len': lm_max_len,
    'num_classes': Tokenizer.vocab_size
})

# Initialize the LM model
lm_model = DecoderOnlyTransformer(**lm_model_config)

# Get some inputs from the train dataloader
for batch in train_loader:
    padded_feats, padded_shifted, padded_golden, feat_lengths, transcript_lengths = batch
    break


model_stats = summary(lm_model, input_data=[padded_shifted, transcript_lengths])
print(model_stats)

lm_model.load_state_dict(lm_dict['model_state_dict'], strict=True)


### Inference
- `TA TODO`: Set the `lm_weight` below to determine the weight given to the LM model's predictions.

In [None]:
with open("hw4p2_sol.json", "r") as f:
    solution = json.load(f)

lm_weight = 0.5 # TODO: Set the weight for the LM model

# Define the recognition config: Beam search with width 10 + LM blending
recognition_config = {
    'num_batches': None,
    'temperature': 1.0,
    'repeat_penalty': 1.0,
    'lm_weight': lm_weight,
    'lm_model': lm_model,
    'beam_width': 10,
}

# Recognize with the shallow fusion config
config_name = "shallow-fusion"
print(f"Evaluating with {config_name} config")
results = trainer.recognize(test_loader, recognition_config, config_name=config_name, max_length=min(max_transcript_len, lm_max_len))
assert len(results) == len(solution)           


# Calculate metrics on full batch
generated = [r['generated'] for r in results]
metrics = trainer._calculate_asr_metrics(solution, generated)
# Log metrics
metrics = {
    f'test_{config_name}': metrics
}
trainer._log_metrics(metrics, trainer.current_epoch)

print("-"*50)
print(f"Config: {config_name}")
print(f"WER: {metrics['wer']:.2f}%")
print(f"CER: {metrics['cer']:.2f}%")
print(f"Word Distance: {metrics['word_dist']:.2f}")
print("-"*50)
trainer._save_generated_text(results, f'test_{config_name}_results')

# Cleanup (Will end wandb run)
trainer.cleanup()