### Agent Code - Code executed by a sweep agent

In [None]:
# importing required libraries for the notebook
import lightning as lt
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.callbacks import ModelCheckpoint
import torch
import wandb
from language import *
from dataset_dataloader import *
from encoder_decoder import *
from runner import Runner

In [None]:
# know the accelerator available - NOT USED as we have switched to lightning
device = ('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

#### Defining the source and target languages; loading data to create Language objects

In [None]:
# define the source and target languages
TARGET = 'tam'
SOURCE = 'eng'

In [None]:
# load all the available data and print sample counts for each set
x_train, y_train = load_data(TARGET, 'train')
x_valid, y_valid = load_data(TARGET, 'valid')
x_test, y_test = load_data(TARGET, 'test')

print(f'Number of train samples = {len(x_train)}')
print(f'Number of valid samples = {len(x_valid)}')
print(f'Number of test samples = {len(x_test)}')

In [None]:
# create language objects for storing vocabulary, index2sym and sym2index
SRC_LANG = Language(SOURCE)
TAR_LANG = Language(TARGET)

# creating vocabulary using train data only
SRC_LANG.create_vocabulary(*(x_train))
TAR_LANG.create_vocabulary(*(y_train))

# generate mappings from characters to numbers and vice versa
SRC_LANG.generate_mappings()
TAR_LANG.generate_mappings()

# print the source and target vocabularies
print(f'Source Vocabulary Size = {len(SRC_LANG.symbols)}')
print(f'Source Vocabulary = {SRC_LANG.symbols}')
print(f'Source Mapping {SRC_LANG.index2sym}')
print(f'Target Vocabulary Size = {len(TAR_LANG.symbols)}')
print(f'Target Vocabulary = {TAR_LANG.symbols}')
print(f'Target Mapping {TAR_LANG.index2sym}')

#### Sweep agent function definition

In [None]:
import wandb
wandb.login()

# Code executed by a sweep agent
def agent_code():
    wdbrun = wandb.init(project='cs6910-assignment3', entity='cs19b021')
    wconfig = wandb.config
    # rename run to identify config easily from it
    wdbrun.name = f'ereemb={wconfig["embedding_size"]}_layers={wconfig["number_of_layers"]}_hid={wconfig["hidden_size"]}'
    wdbrun.name += f'_cell={wconfig["cell"]}_bidirectional={wconfig["bidirectional"]}_dr={wconfig["dropout"]}'
    wdbrun.name += f'_itfr={wconfig["initial_tf_ratio"]}_bsize={wconfig["batch_size"]}_att={wconfig["attention"]}'
    wdbrun.name += f'_opt={wconfig["optimizer"]}_lr={wconfig["learning_rate"]}'
    # dictionary to pass to a model (instance of Runner Class)
    rdict = dict(
                SOURCE=SOURCE,
                TARGET=TARGET,
                src_lang=SRC_LANG,
                tar_lang=TAR_LANG,
                common_embed_size=wconfig['embedding_size'],
                common_num_layers=wconfig['number_of_layers'],
                common_hidden_size=wconfig['hidden_size'],
                common_cell_type=wconfig['cell'],
                init_tf_ratio= wconfig['initial_tf_ratio'],
                enc_bidirect=wconfig['bidirectional'],
                attention=wconfig['attention'],
                dropout=wconfig['dropout'],
                opt_name=wconfig['optimizer'],
                learning_rate=wconfig['learning_rate'],
                batch_size=wconfig['batch_size'] 
    )
    
    runner = Runner(**rdict)
    # early stop if val_acc does not improve by 0.001 = 0.1% for 5 epochs
    early_stop_callback = EarlyStopping(monitor="val_acc", min_delta=wconfig['min_delta_imp'], patience=wconfig['patience'], verbose=True, mode="max")
    # we checkpoint the model when val_acc improves in the working directory.
    chkCallback = ModelCheckpoint(dirpath='./', filename=f'{wdbrun.name}', monitor='val_acc', mode='max')
    trainer = lt.Trainer(min_epochs=wconfig['min_epochs'], max_epochs=wconfig['max_epochs'], callbacks=[chkCallback, early_stop_callback])
    trainer.fit(runner)
    # log the checkpoint on wandb so that we can test later by loading it directly
    artifact = wandb.Artifact(f'{wandb.run.name}_best_ckpt'.replace("=","-"), type='model')
    artifact.add_file(chkCallback.best_model_path)
    wandb.run.log_artifact(artifact)
    # finish run
    wdbrun.finish()

In [None]:
# start runs using configs from the sweep server at `sweep-id`; count=10 is set to avoid crossing kaggle usage limits
wandb.agent(sweep_id='sweep-id', project='cs6910-assignment3', entity='cs19b021', count=10, function=agent_code)