In [1]:
# Installing required dependencies required for the assignment

!pip install lightning -qU
!pip install wandb -qU

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m819.0/819.0 kB[0m [31m14.3 MB/s[0m eta [36m0:00:00[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m65.5/65.5 kB[0m [31m3.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m4.6 MB/s[0m eta [36m0:00:00[0m0:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m2.5 MB/s[0m eta [36m0:00:00[0m0:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m6.9 MB/s[0m eta [36m0:00:00[0m0:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.3/56.3 MB[0m [31m30.1 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m127.9/127.9 MB[0m [31m13.0 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[

In [2]:
# Importing necessary libraries

import pandas as pd
import numpy as np
import torch 
from torch import nn

import lightning as L
from torch.utils.data import DataLoader, Dataset
import torch.nn.functional as F
import wandb

from lightning.pytorch.loggers import WandbLogger


In [3]:
# Using Wandb API key, login to wandb account
from kaggle_secrets import UserSecretsClient
api_key = UserSecretsClient().get_secret("wandb_api")

wandb.login(key=api_key)

[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mrohitrk06[0m ([33mrohitrk06-indian-institute-of-technology-madras[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [4]:
#Since Kaggle is used to train the model, we have upload the data at this address.
#In case you are excuting the code, make sure that you have correct directory path of the dataset
dataset_path = "/kaggle/input/dakshina-dataset-v1-0-hi/dakshina_dataset_v1.0_hi/lexicons"

In [5]:
#Creating a custom dataset class in pytorch to store the transliteration dataset.

class TransliterationDataset(Dataset):
    def __init__(self, dataframe, source_vocab, target_vocab):
        self.dataframe = dataframe
        self.source_vocab = source_vocab
        self.target_vocab = target_vocab

    def __len__(self):
        return len(self.dataframe)
    
    def __getitem__(self, idx):
        src_words = [ self.source_vocab[chr] if chr in self.source_vocab else self.source_vocab["<UNK>"] for chr in self.dataframe.iloc[idx]["source"]]

        # Similarly, for the target words        
        tgt_words = [self.target_vocab[chr] if chr in self.target_vocab else self.target_vocab["<UNK>"] for chr in self.dataframe.iloc[idx]["target"]]
        # Add <SOW> and <EOW> tokens to the target words
        tgt_words = [self.target_vocab["<SOW>"]] + tgt_words + [self.target_vocab["<EOW>"]]

        return torch.LongTensor(src_words), torch.LongTensor(tgt_words) 


In [6]:
#Since we are using lightning module, Let's define the Lightning data module to handle the dataset
class TrasnliterationDataModule(L.LightningDataModule):
    def __init__(self, data_dir, batch_size=32):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size

        
        # Load the train dataset
        self.train_df = pd.read_csv(
            f"{self.data_dir}/hi.translit.sampled.train.tsv",
            sep="\t",
            names=["target", "source", "attestations"],
            header=None,
            keep_default_na=False, na_values=[]
        )
        # Drop the attestations column
        self.train_df.drop(columns=["attestations"], inplace=True)

        # Let's load the dev set as well
        # We will use the dev set for validation
        self.dev_df = pd.read_csv(
            f"{self.data_dir}/hi.translit.sampled.dev.tsv",
            sep="\t",
            names=["target", "source", "attestations"],
            header=None,
            keep_default_na=False, na_values=[]
        )
        # Drop the attestations column
        self.dev_df.drop(columns=["attestations"], inplace=True)


        #Let's load the test set as well
        # We will use the test set for evaluation
        self.test_df = pd.read_csv(
            f"{self.data_dir}/hi.translit.sampled.test.tsv",
            sep="\t",
            names=["target", "source", "attestations"],
            header=None,
            keep_default_na=False, na_values=[]
        )
        # Drop the attestations column
        self.test_df.drop(columns=["attestations"], inplace=True)

        # Create vocabularies for source and target languages
        self.source_vocab, self.source_chr_to_idx, self.source_idx_to_char = self.build_vocab(self.train_df['source'].values)
        self.target_vocab, self.target_chr_to_idx, self.target_idx_to_char = self.build_vocab(self.train_df['target'].values)


    def prepare_data(self):
        '''
        According the the Lightning documentation, this method is used to download and prepare the data.
        In our case, we are not downloading any data, dataset can be found at the given data_dir path, but we are preparing the data
        '''
        
        self.train_dataset = self.create_dataset(self.train_df)
        self.dev_dataset = self.create_dataset(self.dev_df)
        self.test_dataset = self.create_dataset(self.test_df)

    def build_vocab(self, words):
        '''
        This method is used to build the vocab for the given data
        :param data: The data to build the vocab for
        :return: The vocab for the given data
        '''
        vocab = set()
        for word in words:
            for char in word:
                vocab.add(char)

        # Adding special tokens in the vocab.
        vocab.add("<UNK>")
        vocab.add("<EOW>")
        vocab.add("<SOW>")

        # Sort the vocab to get the same order every time
        vocab = sorted(vocab)
        
        chr_to_idx_map = { chr : idx+1 for idx, chr in enumerate(vocab) }
        idx_to_chr_map = { idx+1 : chr for idx, chr in enumerate(vocab) }

        chr_to_idx_map["<PAD>"] = 0
        idx_to_chr_map[0] = "<PAD>"

        vocab.append("<PAD>")
        
        return vocab, chr_to_idx_map, idx_to_chr_map

    def create_dataset(self, dataframe):
        '''
        This method is used to create the dataset for the given data
        :param dataframe: The dataframe to create the dataset for
        :return: The dataset for the given data
        '''
        return TransliterationDataset(dataframe, self.source_chr_to_idx, self.target_chr_to_idx)
    
    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size = self.batch_size, shuffle=True, collate_fn = self.collate_fn, num_workers=3)
    
    def val_dataloader(self):
        return DataLoader(self.dev_dataset, batch_size = self.batch_size, shuffle=False, collate_fn = self.collate_fn, num_workers=3)
    
    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size = self.batch_size, shuffle=False, collate_fn = self.collate_fn, num_workers=3)
    
    def collate_fn(self, batch):
        '''
        This method is used to collate the data into batches
        :param batch: The batch to collate
        :return: The collated batch
        '''
        src_words, tgt_words = zip(*batch)
        
        # Pad the source and target words
        src_words = torch.nn.utils.rnn.pad_sequence(src_words, batch_first=True, padding_value=0)
        tgt_words = torch.nn.utils.rnn.pad_sequence(tgt_words, batch_first=True, padding_value=0)

        return src_words, tgt_words

In [7]:
#In the below cells,we would define the encoder, decoder architecture, 

class Encoder(nn.Module):
    def __init__(self, cell_type, input_embedding_size, embedding_dimension, hidden_layer_size, num_layers, dropout):
        super(Encoder, self).__init__()
        self.hidden_layer_size = hidden_layer_size
        self.num_layers = num_layers
        self.dropout = dropout

        # Embedding layer
        self.embedding = nn.Embedding(input_embedding_size, embedding_dimension, padding_idx=0)

        # Here while defining the embedding layer, we have set the padding_idx to 0, which is the index of the <PAD> token in our vocab.
        # This means that the embedding layer will ignore the padding tokens when computing the embeddings.
        # The embedding layer will learn the embeddings for the input characters.

        # Encoder RNN cell
        self.rnn_cell = {
            "RNN": nn.RNN,
            "LSTM": nn.LSTM,
            "GRU": nn.GRU
        }.get(cell_type)
        if self.rnn_cell is None:
            raise ValueError("Invalid cell type. Choose 'RNN', 'LSTM' or 'GRU'.")
        self.rnn = self.rnn_cell(embedding_dimension, hidden_layer_size, num_layers, dropout=dropout, batch_first=True)

    def forward(self, x):
        input = self.embedding(x)
        output, hidden = self.rnn(input)
        return output, hidden

In [8]:
class AttentionModule(nn.Module):
    def __init__(self, hidden_size):
        super(AttentionModule, self).__init__()
        self.attn = nn.Linear(hidden_size*2,hidden_size)
        self.v = nn.Linear(hidden_size,1,bias=False)

    def forward(self,decoder_hidden,encoder_outputs, mask=None):
        #Here, dimension of decoder_hidden = (batch, 1, hidden)
        # dimension of encoder_outputs: (batch, src_len, hidden)
        src_len = encoder_outputs.size(1)
        decoder_hidden  = decoder_hidden.repeat(1,src_len,1) # this will make the dimension: (batch, src_len, hidden)

        energy = torch.tanh(self.attn(torch.cat((decoder_hidden,encoder_outputs), dim = 2)))
        attention = self.v(energy).squeeze(2)
        if mask is not None:
            attention = attention.masked_fill(mask == 0, torch.finfo(attention.dtype).min)
        return F.softmax(attention,dim=1)

In [9]:
class Decoder(nn.Module):
    def __init__(self, cell_type, output_embedding_size, embedding_dimension, hidden_layer_size, num_layers, dropout):
        super(Decoder, self).__init__()
        # Embedding layer
        self.embedding = nn.Embedding(output_embedding_size, embedding_dimension, padding_idx=0)

        rnn_cell = {
            "RNN": nn.RNN,
            "LSTM": nn.LSTM,
            "GRU": nn.GRU
        }.get(cell_type)
        if rnn_cell is None:
            raise ValueError("Invalid cell type. Choose 'RNN', 'LSTM' or 'GRU'.")
        
        self.rnn = rnn_cell(embedding_dimension + hidden_layer_size, hidden_layer_size, num_layers, dropout=dropout, batch_first=True)

        # Output layer
        self.fc = nn.Linear(hidden_layer_size * 2, output_embedding_size)

        #attention module
        self.attention = AttentionModule(hidden_layer_size)
        self.cell_type = cell_type
        
    def forward(self, x, hidden, encoder_outputs, mask = None):
        embedded = self.embedding(x)

        if self.cell_type == "LSTM":
            dec_hidden = hidden[0][-1].unsqueeze(1)
        else:
            dec_hidden = hidden[-1].unsqueeze(1)

        attn_weights = self.attention(dec_hidden, encoder_outputs, mask).unsqueeze(1)

        context = torch.bmm(attn_weights, encoder_outputs)

        rnn_input = torch.cat((embedded, context), dim = 2)

        output, hidden = self.rnn(rnn_input, hidden)
        output = output.squeeze(1)
        context = context.squeeze(1)

        output = self.fc(torch.cat((output,context),dim=1))
                
        return output.unsqueeze(1), hidden, attn_weights.squeeze(1)

In [10]:
class Seq2Seq(L.LightningModule):
    def __init__(self, input_embedding_size, output_embedding_size, embedding_dimension, 
                 hidden_layer_size, number_of_layers_encoder, number_of_layers_decoder, 
                 dropout, cell_type, learning_rate, teacher_forcing_ratio=0.5):
        super().__init__()
        self.save_hyperparameters()

        # Initialize encoder and decoder
        self.encoder = Encoder(cell_type, input_embedding_size, embedding_dimension,
                              hidden_layer_size, number_of_layers_encoder, dropout)
        self.decoder = Decoder(cell_type, output_embedding_size, embedding_dimension,
                              hidden_layer_size, number_of_layers_decoder, dropout)
        
        # Loss function ignoring padding
        self.criterion = nn.CrossEntropyLoss(ignore_index=0)
        self.learning_rate = learning_rate
        self.teacher_forcing_ratio = teacher_forcing_ratio

    def forward(self, src, tgt, teacher_forcing_ratio=None):
        batch_size, tgt_len = tgt.size()
        outputs = torch.zeros(batch_size, tgt_len-1, self.hparams.output_embedding_size)

        attn_weights_all = []
        # Encoder forward
        encoder_outputs, hidden = self.encoder(src)
        
        # Initialize decoder hidden state
        if self.hparams.cell_type == 'LSTM':
            decoder_hidden = (hidden[0][:self.hparams.number_of_layers_decoder],
                             hidden[1][:self.hparams.number_of_layers_decoder])
        else:
            decoder_hidden = hidden[:self.hparams.number_of_layers_decoder]
        
        decoder_input = tgt[:, 0].unsqueeze(1)  # Start with SOS token
        mask = (src!=0)
        # Decoder forward
        for t in range(tgt_len-1):
            decoder_output, decoder_hidden, attn_weights = self.decoder(decoder_input, decoder_hidden, encoder_outputs, mask)
            outputs[:, t] = decoder_output.squeeze(1)
            attn_weights_all.append(attn_weights.detach().cpu())
            # Teacher forcing
            tf_ratio = self.teacher_forcing_ratio if teacher_forcing_ratio is None else teacher_forcing_ratio
            teacher_force = torch.rand(1).item() < tf_ratio
            top1 = decoder_output.argmax(2)
            decoder_input = tgt[:, t+1].unsqueeze(1) if teacher_force else top1

        attn_weights_all = torch.stack(attn_weights_all, dim=1)
        return outputs, attn_weights_all

    def __shared_step(self, batch, batch_idx, stage):
        src, tgt = batch
        output,_ = self(src, tgt, teacher_forcing_ratio=0 if stage != 'train' else None)
        tgt = tgt.to(output.device)
        loss = self.criterion(output.reshape(-1, output.size(-1)), tgt[:, 1:].reshape(-1))
        self.log(f"{stage}_loss", loss)
        
        # Calculate metrics
        preds = output.argmax(2)
        non_pad = tgt[:, 1:] != 0
        correct = (preds == tgt[:, 1:]) & non_pad
        
        # Token-level accuracy
        token_acc = correct.sum().float() / non_pad.sum()
        self.log(f"{stage}_token_acc", token_acc, prog_bar=True)
        
        # Sequence-level accuracy
        seq_acc = ((preds == tgt[:, 1:]) | ~non_pad).all(dim=1).float().mean()
        self.log(f"{stage}_seq_acc", seq_acc, prog_bar=True)
        
        return loss

    def training_step(self, batch, batch_idx):
        return self.__shared_step(batch, batch_idx, "train")

    def validation_step(self, batch, batch_idx):
        return self.__shared_step(batch, batch_idx, "val")

    def test_step(self, batch, batch_idx):
        return self.__shared_step(batch, batch_idx, "test")

    def predict_step(self, batch, batch_idx):
        src, _ = batch  # Don't require target during prediction
        output,_ = self(src, torch.zeros_like(src), teacher_forcing_ratio=0)
        return output.argmax(2)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer


In [18]:
def create_sweep_name(config):
    return (f"{config.cell_type}_"
            f"encoder_nl_{config.num_layers}_"
            f"decoder_nl_{config.num_layers}_"
            f"embedd_dim_{config.embedd_dim}_"
            f"hidden_size_{config.hidden_size}"
            f"tf_ratio_{config.teacher_force_ratio}_"
            f"lr_{config.lr}_"
            f"dropout_{config.dropout}_"
            f"max_epoches_{config.max_epochs}_"
            f"batch_size_{config.batch_size}")

In [None]:
def main(config = None):
    wandb.init(project = "da6401_assignment3_v1",
               config = config)
    config = wandb.config
    wandb.run.name = create_sweep_name(config)

    wandb_logger = WandbLogger(project = "da6401_assignment3_v1",
                               log_model = True)

    data = TrasnliterationDataModule(dataset_path,batch_size = config.batch_size)
    
    
    model = Seq2Seq(len(data.source_vocab), len(data.target_vocab),embedding_dimension = config.embedd_dim,
                   hidden_layer_size = config.hidden_size, number_of_layers_encoder=config.num_layers, number_of_layers_decoder=config.num_layers, 
                     dropout=config.dropout, cell_type= config.cell_type, learning_rate=config.lr, teacher_forcing_ratio=config.teacher_force_ratio)
    # model = torch.compile(model)

    trainer = L.Trainer(
        logger = wandb_logger,
        # strategy='ddp_spawn',
        max_epochs = config.max_epochs,
        precision="16-mixed",
        # devices = 2,
    )

    trainer.fit(model, data) 
    wandb.finish()

In [None]:
sweep_config = {
    "name": "Hyperparameter Sweep for different rnn cells with attention",
    "method": "bayes",
    "metric": {
        "name": "val_loss",
        "goal": "minimize"
    },
    "parameters": {
        "embedd_dim": {
            "values": [128, 256, 512]
        },
        "hidden_size": {
            "values": [128, 256, 512]
        },
        "num_layers": {
            "values": [3, 5]
        },
        "dropout": {
            "values": [0.3,0.4]
        },
        "teacher_force_ratio": {
            "values": [0.1, 0.3, 0.5]
        },
        "cell_type": {
            "values": ["LSTM"]
        },
        "max_epochs":{
            "values":[10,15]
        },
        "lr": {
            "distribution": "log_uniform_values",
            "min": 1e-4,
            "max": 1e-3
        },
        "batch_size":{
            "values":[64,128,256]
        }
    },
    "early_terminate": {
        "type": "hyperband",
        "min_iter": 3,
        "max_iter": 20,
        "eta": 2
    }
}

In [None]:
sweep_id = wandb.sweep(sweep_config,project="da6401_assignment3_v1")

In [None]:
wandb.agent(sweep_id,main,count=30)

In [11]:
# After performing a hyperparameter sweep, 
# Let's use the best hyperparameters to train the model and evaluate the model on the test set.
batch_size = 128
embbed_dim = 512
hidden_size = 512
num_layers = 3
dropout = 0.4
cell_type = "RNN"
lr = 0.000417
teacher_force_ratio = 0.5
max_epochs = 15

data = TrasnliterationDataModule(dataset_path,batch_size = batch_size)
model = Seq2Seq(len(data.source_vocab), len(data.target_vocab), embedding_dimension = embbed_dim,
               hidden_layer_size= hidden_size, number_of_layers_encoder = num_layers,number_of_layers_decoder=num_layers, 
                     dropout=dropout, cell_type= cell_type, learning_rate=lr, teacher_forcing_ratio=teacher_force_ratio)

wandb_logger = WandbLogger(project = "da6401_assignment3",
                               log_model = True)

trainer = L.Trainer(
    logger = wandb_logger,
    # strategy='ddp_spawn',
    max_epochs = max_epochs,
    precision="16-mixed",
    # devices = 2,
)
trainer.fit(model, data)

INFO: Trainer will use only 1 of 2 GPUs because it is running inside an interactive / notebook environment. You may try to set `Trainer(devices=2)` but please note that multi-GPU inside interactive / notebook environments is considered experimental and unstable. Your mileage may vary.
INFO: Using 16bit Automatic Mixed Precision (AMP)
INFO: Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
INFO: GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO: HPU available: False, using: 0 HPUs


INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
INFO: 
  | Name      | Type             | Params | Mode 
-------------------------------------------------------
0 | encoder   | Encoder          | 1.6 M  | train
1 | decoder   | Decoder          | 2.5 M  | train
2 | criterion | CrossEntropyLoss | 0      | train
-------------------------------------------------------
4.1 M     Trainable params
0         Non-trainable params
4.1 M     Total params
16.231    Total estimated model params size (MB)
11        Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

INFO: `Trainer.fit` stopped: `max_epochs=15` reached.


In [12]:
trainer.test(model,data)

INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Testing: |          | 0/? [00:00<?, ?it/s]

[{'test_loss': 0.9642021059989929,
  'test_token_acc': 0.746080219745636,
  'test_seq_acc': 0.4033762812614441}]

In [15]:
# After performing a hyperparameter sweep, 
# Let's use the best hyperparameters to train the model and evaluate the model on the test set.
batch_size = 128
embbed_dim = 512
hidden_size = 512
num_layers = 3
dropout = 0.3
cell_type = "LSTM"
lr = 0.00094
teacher_force_ratio = 0.5
max_epochs = 15

data = TrasnliterationDataModule(dataset_path,batch_size = batch_size)
model = Seq2Seq(len(data.source_vocab), len(data.target_vocab), embedding_dimension = embbed_dim,
               hidden_layer_size= hidden_size, number_of_layers_encoder = num_layers,number_of_layers_decoder=num_layers, 
                     dropout=dropout, cell_type= cell_type, learning_rate=lr, teacher_forcing_ratio=teacher_force_ratio)

wandb_logger = WandbLogger(project = "da6401_assignment3",
                               log_model = True)

trainer = L.Trainer(
    logger = wandb_logger,
    # strategy='ddp_spawn',
    max_epochs = max_epochs,
    precision="16-mixed",
    # devices = 2,
)
trainer.fit(model, data)

INFO: Trainer will use only 1 of 2 GPUs because it is running inside an interactive / notebook environment. You may try to set `Trainer(devices=2)` but please note that multi-GPU inside interactive / notebook environments is considered experimental and unstable. Your mileage may vary.
INFO: Using 16bit Automatic Mixed Precision (AMP)
INFO: Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
INFO: GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO: HPU available: False, using: 0 HPUs
/usr/local/lib/python3.11/dist-packages/lightning/pytorch/loggers/wandb.py:397: There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`.
/usr/local/lib/python3.11/dist-packages/lightning/pytorch/callbacks/model_checkpoint.py:654: Chec

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

INFO: `Trainer.fit` stopped: `max_epochs=15` reached.


In [16]:
trainer.test(model,data)

INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Testing: |          | 0/? [00:00<?, ?it/s]

[{'test_loss': 1.0675084590911865,
  'test_token_acc': 0.7606055736541748,
  'test_seq_acc': 0.4280319809913635}]

In [17]:
def decode_sequence(tensor, chr2idx, idx2char):
    chars = []
    for idx in tensor:
        if idx.item() in [chr2idx['<EOW>'], chr2idx['<PAD>']]:
            break
        chars.append(idx2char.get(idx.item(), '<UNK>'))
    return ''.join(chars)


In [19]:
import os
import random


wandb.init(project="da6401_assignment3_v1", name="predictions_attention", job_type="test_evaluation")
data.prepare_data()
test_dataloader = data.test_dataloader()
model.eval()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

# Create predictions directory
os.makedirs('predictions_attention', exist_ok=True)


# For wandb logging
sample_table = wandb.Table(columns=["Source", "Target", "Prediction"])

# Test evaluation
predictions = []
with torch.no_grad(), open('predictions_attention/test_predictions_attention.csv', 'w') as f:
    f.write("Source,Target,Prediction\n")
    
    for batch in test_dataloader:
        src, tgt = batch
        src, tgt = src.to(device), tgt.to(device)
        outputs, _ = model(src, tgt, teacher_forcing_ratio=0)
        preds = outputs.argmax(dim=2)
        
        # Decode sequences
        for i in range(src.size(0)):
            source_str = decode_sequence(src[i], data.source_chr_to_idx, data.source_idx_to_char)
            target_str = decode_sequence(tgt[i][1:], data.target_chr_to_idx, data.target_idx_to_char)  # Skip <SOW>
            pred_str = decode_sequence(preds[i],  data.target_chr_to_idx, data.target_idx_to_char)
            
            f.write(f'"{source_str}","{target_str}","{pred_str}"\n')
            predictions.append((source_str, target_str, pred_str))

            # Add a few (e.g. 20) predictions to W&B table
            if len(sample_table.data) < 40 and random.random()<0.02:
                sample_table.add_data(source_str, target_str, pred_str)

# Log the table to wandb
wandb.log({"Sample Test Predictions": sample_table})
wandb.finish()

wandb.finish()

print("Predictions saved to predictions_attention/test_predictions.csv")

Predictions saved to predictions_attention/test_predictions.csv


In [23]:
import pandas as pd

# Load prediction files
vanilla_df = pd.read_csv("/kaggle/input/test-predictions-vanilla/test_predictions_vanilla.csv")
attention_df = pd.read_csv("/kaggle/working/predictions_attention/test_predictions_attention.csv")

# Mark which predictions are correct
vanilla_df["VanillaCorrect"] = vanilla_df["Target"] == vanilla_df["Prediction"]
attention_df["AttentionCorrect"] = attention_df["Target"] == attention_df["Prediction"]

# Merge the two DataFrames
combined = vanilla_df.copy()
combined["AttentionPrediction"] = attention_df["Prediction"]
combined["AttentionCorrect"] = attention_df["AttentionCorrect"]

# Filter: incorrect in vanilla but correct in attention
corrected_cases = combined[(combined["VanillaCorrect"] == False) & (combined["AttentionCorrect"] == True)]



# Print sample of corrections
print(f"\nTotal corrections made by attention model: {len(corrected_cases)}\n")
corrected_cases[["Source", "Target", "Prediction", "AttentionPrediction"]].head(20)



Total corrections made by attention model: 524



Unnamed: 0,Source,Target,Prediction,AttentionPrediction
0,ank,अंक,एंक,अंक
2,ankit,अंकित,आंकत,अंकित
7,ankor,अंकोर,एंकोर,अंकोर
9,angarak,अंगारक,अंगरक,अंगारक
23,ambaani,अंबानी,अमबानी,अंबानी
24,ambani,अंबानी,अमबानी,अंबानी
51,azhar,अजहर,अजार,अजहर
53,agnat,अज्ञात,अग्ञात,अज्ञात
59,atke,अटके,अटे,अटके
68,atharva,अथर्व,अठर्व,अथर्व


In [27]:
import pandas as pd
wandb.init(project="da6401_assignment3_v1", name="attention_vs_vanilla_analysis")

# Mark which predictions are correct
vanilla_df["VanillaCorrect"] = vanilla_df["Target"] == vanilla_df["Prediction"]
attention_df["AttentionCorrect"] = attention_df["Target"] == attention_df["Prediction"]

# Merge the two DataFrames
combined = vanilla_df.copy()
combined["VanillaPrediction"] = combined["Prediction"]  # rename for clarity
combined["AttentionPrediction"] = attention_df["Prediction"]
combined["AttentionCorrect"] = attention_df["AttentionCorrect"]

# Filter: incorrect in vanilla but correct in attention
corrected_cases = combined[(combined["VanillaCorrect"] == False) & (combined["AttentionCorrect"] == True)]

# Drop the old generic Prediction column if needed
corrected_cases = corrected_cases.drop(columns=["Prediction"])

# Create wandb table
table = wandb.Table(columns=["Source", "Target", "VanillaPrediction", "AttentionPrediction"])
for _, row in corrected_cases.iterrows():
    table.add_data(row["Source"], row["Target"], row["VanillaPrediction"], row["AttentionPrediction"])

# Log the table to wandb
wandb.log({"Corrected_Predictions": table})

# Finish the run
wandb.finish()

# Show corrected predictions
print(f"\nTotal corrections made by attention model: {len(corrected_cases)}\n")
corrected_cases[["Source", "Target", "VanillaPrediction", "AttentionPrediction"]].head(20)



Total corrections made by attention model: 524



Unnamed: 0,Source,Target,VanillaPrediction,AttentionPrediction
0,ank,अंक,एंक,अंक
2,ankit,अंकित,आंकत,अंकित
7,ankor,अंकोर,एंकोर,अंकोर
9,angarak,अंगारक,अंगरक,अंगारक
23,ambaani,अंबानी,अमबानी,अंबानी
24,ambani,अंबानी,अमबानी,अंबानी
51,azhar,अजहर,अजार,अजहर
53,agnat,अज्ञात,अग्ञात,अज्ञात
59,atke,अटके,अटे,अटके
68,atharva,अथर्व,अठर्व,अथर्व


In [39]:
!apt-get install fonts-noto fonts-noto-core fonts-noto-unhinted fonts-noto-ui-core fonts-noto-ui-extra fonts-noto-cjk


Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
fonts-noto-cjk is already the newest version (1:20220127+repack1-1).
fonts-noto-cjk set to manually installed.
fonts-noto-core is already the newest version (20201225-1build1).
fonts-noto-core set to manually installed.
fonts-noto-ui-core is already the newest version (20201225-1build1).
fonts-noto-ui-core set to manually installed.
fonts-noto is already the newest version (20201225-1build1).
fonts-noto-ui-extra is already the newest version (20201225-1build1).
fonts-noto-ui-extra set to manually installed.
fonts-noto-unhinted is already the newest version (20201225-1build1).
fonts-noto-unhinted set to manually installed.
0 upgraded, 0 newly installed, 0 to remove and 87 not upgraded.


In [42]:
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
plt.rcParams['font.family'] = [
    'Lohit Devanagari',
    'DejaVu Sans',           # Good for Latin
    'Noto Sans Devanagari',  # Good for Devanagari
    'Arial Unicode MS',      # Good fallback for both
    'sans-serif'
]
matplotlib.rcParams['font.family'] = [
    'Lohit Devanagari'
    'DejaVu Sans',           # Good for Latin
    'Noto Sans Devanagari',  # Good for Devanagari
    'Arial Unicode MS',      # Good fallback for both
    'sans-serif'
]

In [43]:
# import torch
# import matplotlib.pyplot as plt
# import wandb
# import random
# import os

wandb.init(project="da6401_assignment3_v1", name="attention_heatmaps", job_type="attention_viz")

# You must define these functions or use your own
def decode_sequence(tensor, chr_to_idx, idx_to_char):
    tokens = []
    for idx in tensor:
        idx = idx.item()
        if idx == 0:  # padding
            continue
        tokens.append(idx_to_char.get(idx, '?'))
    return tokens

# Create a directory to save optional images
os.makedirs('attention_heatmaps', exist_ok=True)

# Prepare model and device
model.eval()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

attention_images = []
sampled = 0
max_samples = 9

with torch.no_grad():
    for batch in data.test_dataloader():
        src, tgt = batch
        src, tgt = src.to(device), tgt.to(device)
        outputs, attention_weights_all = model(src, tgt, teacher_forcing_ratio=0)
        preds = outputs.argmax(dim=2)

        batch_size = src.size(0)
        for i in range(batch_size):
            if sampled >= max_samples:
                break

            src_tokens = decode_sequence(src[i], data.source_chr_to_idx, data.source_idx_to_char)
            tgt_tokens = decode_sequence(tgt[i][1:], data.target_chr_to_idx, data.target_idx_to_char)

            attn = attention_weights_all[i].cpu().numpy()[:len(tgt_tokens), :len(src_tokens)]

            fig, ax = plt.subplots(figsize=(6, 4))
            cax = ax.imshow(attn, cmap='viridis', aspect='auto')

            ax.set_xticks(range(len(src_tokens)))
            ax.set_yticks(range(len(tgt_tokens)))
            ax.set_xticklabels(src_tokens, rotation=90)
            ax.set_yticklabels(tgt_tokens)
            ax.set_xlabel("Source Tokens")
            ax.set_ylabel("Target Tokens")
            ax.set_title(f"Attention Heatmap {sampled+1}")

            fig.colorbar(cax, ax=ax)
            plt.tight_layout()

            attention_images.append(wandb.Image(fig, caption=f"{''.join(src_tokens)} → {''.join(tgt_tokens)}"))
            plt.close()
            sampled += 1

        if sampled >= max_samples:
            break

# Log to W&B
wandb.log({"Attention Heatmaps (3x3 Grid)": attention_images})
wandb.finish()


  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()


In [57]:
import os
import random
import torch
import wandb
import pandas as pd
import html  # To handle special characters and ensure proper encoding
from torch.utils.data import DataLoader

# === WandB Init ===
wandb.init(project="da6401_assignment3_v1", name="connectivity_attention", job_type="html_vis")
run_name = wandb.run.name

# === Setup ===
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device).eval()  # Make sure your model variable is defined and loaded
datamodule = TrasnliterationDataModule(data_dir=dataset_path)  # Replace path
datamodule.prepare_data()

# === Color Map ===
def get_shade_color(value):
    # Modify this function to return colors directly from the value
    colors = ['#00fa00', '#00f500', '#00eb00', '#00e000', '#00db00',
              '#00d100', '#00c700', '#00c200', '#00b800', '#00ad00',
              '#00a800', '#009e00', '#009400', '#008f00', '#008500',
              '#007500', '#007000', '#006600', '#006100', '#005c00',
              '#005200', '#004d00', '#004700', '#003d00', '#003800',
              '#003300', '#002900', '#002400', '#001f00', '#001400']
    idx = int((value * 100) / 5)
    return colors[min(max(idx, 0), len(colors)-1)]

# === Decode Sequences ===
def decode_sequence(tensor, idx_to_char):
    return [idx_to_char[idx.item()] for idx in tensor if idx.item() != 0]

# === HTML Builder ===
def create_file(text_colors, input_words, output_words, file_path):
    # Initialize HTML with style
    html_content = "<html><head>"
    
    # Adding the CSS for interactivity
    html_content += """
    <style>
        body {
            font-family: 'Devanagari', sans-serif;
        }
        .token { cursor: pointer; padding: 5px; margin: 2px; }
        .highlight { background-color: yellow; }
        .attention { background-color: lightblue; }
        .tooltip {
            visibility: hidden;
            position: absolute;
            background-color: black;
            color: white;
            text-align: center;
            border-radius: 5px;
            padding: 5px;
            font-size: 12px;
            z-index: 1;
        }
        .token:hover .tooltip {
            visibility: visible;
        }
    </style>
    """
    
    # Adding JavaScript for hover interaction
    html_content += """
    <script>
        function highlightToken(id) {
            var token = document.getElementById(id);
            token.classList.toggle('highlight');
        }

        function showTooltip(event, value) {
            var tooltip = document.getElementById('tooltip');
            tooltip.innerHTML = "Attention Weight: " + value.toFixed(2);
            tooltip.style.left = event.pageX + 5 + "px";
            tooltip.style.top = event.pageY + 5 + "px";
            tooltip.style.visibility = 'visible';
        }

        function hideTooltip() {
            var tooltip = document.getElementById('tooltip');
            tooltip.style.visibility = 'hidden';
        }
    </script>
    """
    
    html_content += "</head><body style='font-family: monospace;'>"
    
    # Adding the tooltip element
    html_content += '<div id="tooltip" class="tooltip"></div>'
    
    # Generating the table with highlighted tokens
    for k in range(len(output_words)):
        # Escape HTML characters (including Devanagari) to prevent rendering issues
        input_sentence = ''.join([html.escape(word) for word in input_words[k]])
        output_sentence = ''.join([html.escape(word) for word in output_words[k]])
        
        html_content += f"<h3>Sample {k+1}: {input_sentence} → {output_sentence}</h3><pre>"
        for i in range(len(output_words[k])):
            html_content += f"<b>{output_words[k][i]}</b>: "
            for j in range(len(input_words[k])):
                attention_value = text_colors[k][i][j]  # This is already a color string
                html_content += f"<span id='src_{k}_{j}' class='token' onmouseover='showTooltip(event, {attention_value})' onmouseout='hideTooltip()' style='background-color:{attention_value};' onclick='highlightToken(\"src_{k}_{j}\")'>{input_words[k][j]}</span> "
            html_content += "<br>"
        html_content += "</pre><hr>"
    
    html_content += "</body></html>"
    
    # Saving the HTML file with UTF-8 encoding to handle Devanagari characters
    out_path = os.path.join(file_path, "connectivity_interactive.html")
    with open(out_path, "w", encoding="utf-8") as f:
        f.write(html_content)
    
    return out_path


# === Inference for One Word ===
def run_inference_on_word(src_tensor, model, datamodule):
    src_tensor = src_tensor.unsqueeze(0).to(device)
    tgt_dummy = torch.zeros((1, 20), dtype=torch.long).to(device)
    with torch.no_grad():
        outputs, att_weights = model(src_tensor, tgt_dummy, teacher_forcing_ratio=0)
    
    pred_ids = outputs.argmax(dim=2).squeeze(0)
    att_weights = att_weights[0].cpu().numpy()[:len(pred_ids), :src_tensor.size(1)]

    src_tokens = decode_sequence(src_tensor[0], datamodule.source_idx_to_char)
    tgt_tokens = decode_sequence(pred_ids, datamodule.target_idx_to_char)

    return tgt_tokens, src_tokens, att_weights


# === Randomly Sample 3 Test Words ===
test_set = datamodule.test_dataset
samples = random.sample(range(len(test_set)), 3)

input_words = []
output_words = []
color_list = []

for idx in samples:
    src_tensor, _ = test_set[idx]  # get source
    out_toks, in_toks, att = run_inference_on_word(src_tensor, model, datamodule)

    text_colours = [[get_shade_color(att[i][j]) for j in range(len(in_toks))] for i in range(len(out_toks))]

    input_words.append(in_toks)
    output_words.append(out_toks)
    color_list.append(text_colours)

# === Save HTML ===
output_dir = os.path.join(os.getcwd(), "predictions_attention", run_name)
os.makedirs(output_dir, exist_ok=True)
html_path = create_file(color_list, input_words, output_words, output_dir)

# === Log to W&B ===
wandb.log({"custom_file": wandb.Html(open(html_path))})

wandb.finish()


In [58]:
import os
import torch
import wandb
from matplotlib import font_manager as fm
from IPython.display import display, HTML

# Initialize W&B
wandb.init(project="da6401_assignment3_v1", name="attention_connectivity", job_type="attention_html")

# Decode sequence
def decode_sequence(tensor, chr_to_idx, idx_to_char):
    tokens = []
    for idx in tensor:
        idx = idx.item()
        if idx == 0:
            continue
        tokens.append(idx_to_char.get(idx, '?'))
    return tokens

# Color mapping (dark green = strong attention)
def get_shade_color(value):
    colors = ['#00fa00', '#00f500',  '#00eb00', '#00e000',  '#00db00',  
              '#00d100',  '#00c700',  '#00c200', '#00b800',  '#00ad00',  
              '#00a800',  '#009e00',  '#009400', '#008f00',  '#008500',
              '#007500',  '#007000',  '#006600', '#006100',  '#005c00',  
              '#005200',  '#004d00',  '#004700', '#003d00',  '#003800',  
              '#003300',  '#002900',  '#002400',  '#001f00',  '#001400']
    value = int((value * 100) / 5)
    value = min(max(value, 0), len(colors) - 1)
    return colors[value]

# HTML generation
def create_file(text_colors, input_word, output_word, file_path=os.getcwd()):
    html_text = '''<html><body style="font-family:monospace;">'''
    for k in range(len(output_word)):
        html_text += f"<h4>Output {k+1}: {''.join(output_word[k])}</h4>"
        html_text += "<pre style='line-height: 2;'>"
        for i in range(len(output_word[k])):
            html_text += f"<b>{output_word[k][i]}</b>: "
            for j in range(len(input_word[k])):
                color = text_colors[k][i][j]
                html_text += f"<span style='background-color:{color};'>{input_word[k][j]}</span> "
            html_text += "<br>"
        html_text += "</pre><hr>"
    html_text += "</body></html>"

    fname = os.path.join(file_path, "connectivity.html")
    with open(fname, "w", encoding="utf-8") as f:
        f.write(html_text)
    return fname

# Inference model (based on your syntax)
def inference_model(input_str, rnn_type):
    model.eval()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Encode input
    src = torch.tensor([data.encode_source(input_str)], dtype=torch.long).to(device)

    # Dummy target just for passing shape (no teacher forcing)
    tgt = torch.zeros((1, 20), dtype=torch.long).to(device)  # adjust max_len if needed

    with torch.no_grad():
        outputs, attention_weights_all = model(src, tgt, teacher_forcing_ratio=0)

    preds = outputs.argmax(dim=2)[0]  # batch size = 1
    src_tokens = decode_sequence(src[0], data.source_chr_to_idx, data.source_idx_to_char)
    tgt_tokens = decode_sequence(preds, data.target_chr_to_idx, data.target_idx_to_char)
    attn = attention_weights_all[0].cpu().numpy()[:len(tgt_tokens), :len(src_tokens)]

    return tgt_tokens, src_tokens, outputs, attn

# Connectivity pipeline
def connectivity(input_words, rnn_type, file_path):
    color_list = []
    input_word_list = []
    output_word_list = []

    for word in input_words:
        output_word, input_word, _, att_w = inference_model(word, rnn_type)
        text_colours = []
        for i in range(len(output_word)):
            row = [get_shade_color(att_w[i][j]) for j in range(len(input_word))]
            text_colours.append(row)
        color_list.append(text_colours)
        input_word_list.append(input_word)
        output_word_list.append(output_word)

    html_file_path = create_file(color_list, input_word_list, output_word_list, file_path)
    return html_file_path

# === Run the full pipeline ===
output_dir = os.path.join(os.getcwd(), "predictions_attention", str(run_name))
os.makedirs(output_dir, exist_ok=True)

input_examples = ['anjali', 'underwear', 'agastya']  # Your test samples
html_file = connectivity(input_examples, rnn_type, output_dir)

# === Log HTML file to W&B as artifact ===
artifact = wandb.Artifact("attention_connectivity_html", type="visualization")
artifact.add_file(html_file)
wandb.log_artifact(artifact)

wandb.finish()


In [45]:
src_vocab

{'<EOW>': 1,
 '<SOW>': 2,
 '<UNK>': 3,
 'a': 4,
 'b': 5,
 'c': 6,
 'd': 7,
 'e': 8,
 'f': 9,
 'g': 10,
 'h': 11,
 'i': 12,
 'j': 13,
 'k': 14,
 'l': 15,
 'm': 16,
 'n': 17,
 'o': 18,
 'p': 19,
 'q': 20,
 'r': 21,
 's': 22,
 't': 23,
 'u': 24,
 'v': 25,
 'w': 26,
 'x': 27,
 'y': 28,
 'z': 29,
 '<PAD>': 0}

In [61]:
import os
import random
import torch
import wandb
import pandas as pd
from torch.utils.data import DataLoader

# === WandB Init ===
wandb.init(project="da6401_assignment3_v1", name="connectivity_attention", job_type="html_vis")
run_name = wandb.run.name

# === Setup ===
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device).eval()  # Make sure your model variable is defined and loaded
datamodule = TrasnliterationDataModule(data_dir=dataset_path)  # Replace path
datamodule.prepare_data()

# === Color Map ===
def get_shade_color(value):
    colors = ['#00fa00', '#00f500', '#00eb00', '#00e000', '#00db00',
              '#00d100', '#00c700', '#00c200', '#00b800', '#00ad00',
              '#00a800', '#009e00', '#009400', '#008f00', '#008500',
              '#007500', '#007000', '#006600', '#006100', '#005c00',
              '#005200', '#004d00', '#004700', '#003d00', '#003800',
              '#003300', '#002900', '#002400', '#001f00', '#001400']
    idx = int((value * 100) / 5)
    return colors[min(max(idx, 0), len(colors)-1)]

# === Decode Sequences ===
def decode_sequence(tensor, idx_to_char):
    return [idx_to_char[idx.item()] for idx in tensor if idx.item() != 0]

# === HTML Builder ===
def create_file(text_colors, input_words, output_words, file_path):
    html = "<html><body style='font-family: monospace;'>"
    for k in range(len(output_words)):
        html += f"<h3>Sample {k+1}: {''.join(input_words[k])} → {''.join(output_words[k])}</h3><pre>"
        for i in range(len(output_words[k])):
            html += f"<b>{output_words[k][i]}</b>: "
            for j in range(len(input_words[k])):
                html += f"<span style='background-color:{text_colors[k][i][j]};'>{input_words[k][j]}</span> "
            html += "<br>"
        html += "</pre><hr>"
    html += "</body></html>"
    
    out_path = os.path.join(file_path, "connectivity.html")
    with open(out_path, "w", encoding="utf-8") as f:
        f.write(html)
    return out_path

# === Inference for One Word ===
def run_inference_on_word(src_tensor, model, datamodule):
    src_tensor = src_tensor.unsqueeze(0).to(device)
    tgt_dummy = torch.zeros((1, 20), dtype=torch.long).to(device)
    with torch.no_grad():
        outputs, att_weights = model(src_tensor, tgt_dummy, teacher_forcing_ratio=0)
    
    pred_ids = outputs.argmax(dim=2).squeeze(0)
    att_weights = att_weights[0].cpu().numpy()[:len(pred_ids), :src_tensor.size(1)]

    src_tokens = decode_sequence(src_tensor[0], datamodule.source_idx_to_char)
    tgt_tokens = decode_sequence(pred_ids, datamodule.target_idx_to_char)

    return tgt_tokens, src_tokens, att_weights

# === Randomly Sample 3 Test Words ===
test_set = datamodule.test_dataset
samples = random.sample(range(len(test_set)), 3)

input_words = []
output_words = []
color_list = []

for idx in samples:
    src_tensor, _ = test_set[idx]  # get source
    out_toks, in_toks, att = run_inference_on_word(src_tensor, model, datamodule)

    text_colours = [[get_shade_color(att[i][j]) for j in range(len(in_toks))] for i in range(len(out_toks))]

    input_words.append(in_toks)
    output_words.append(out_toks)
    color_list.append(text_colours)

# === Save HTML ===
output_dir = os.path.join(os.getcwd(), "predictions_attention", run_name)
os.makedirs(output_dir, exist_ok=True)
html_path = create_file(color_list, input_words, output_words, output_dir)

# === Log to W&B ===
# wandb.log({"custom_file": wandb.Html(open(html_path))})
artifact = wandb.Artifact("connectivity_attention_html", type="visualization")
artifact.add_file(html_path)
wandb.log_artifact(artifact)
wandb.finish()


In [36]:
# # import torch
# # import matplotlib.pyplot as plt
# # import seaborn as sns
# # import wandb
# # import random
# # import os

# wandb.init(project="da6401_assignment3_v1", name="attention_heatmaps", job_type="attention_viz")

# # You must define these functions or use your own
# def decode_sequence(tensor, chr_to_idx, idx_to_char):
#     tokens = []
#     for idx in tensor:
#         idx = idx.item()
#         if idx == 0:  # padding
#             continue
#         tokens.append(idx_to_char.get(idx, '?'))
#     return tokens

# # Create a directory to save optional images
# os.makedirs('attention_heatmaps', exist_ok=True)

# # Prepare model and device
# model.eval()
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# model.to(device)

# attention_images = []
# sampled = 0
# max_samples = 9

# with torch.no_grad():
#     for batch in data.test_dataloader():
#         src, tgt = batch
#         src, tgt = src.to(device), tgt.to(device)
#         outputs, attention_weights_all = model(src, tgt, teacher_forcing_ratio=0)
#         preds = outputs.argmax(dim=2)

#         batch_size = src.size(0)
#         for i in range(batch_size):
#             if sampled >= max_samples:
#                 break

#             src_tokens = decode_sequence(src[i], data.source_chr_to_idx, data.source_idx_to_char)
#             tgt_tokens = decode_sequence(tgt[i][1:], data.target_chr_to_idx, data.target_idx_to_char)

#             attn = attention_weights_all[i].cpu().numpy()[:len(tgt_tokens), :len(src_tokens)]

#             fig, ax = plt.subplots(figsize=(6, 4))
#             sns.heatmap(attn, xticklabels=src_tokens, yticklabels=tgt_tokens, cmap="viridis", cbar=True, ax=ax)
#             ax.set_xlabel("Source Tokens")
#             ax.set_ylabel("Target Tokens")
#             ax.set_title(f"Attention Heatmap {sampled+1}")
#             plt.tight_layout()

#             attention_images.append(wandb.Image(fig, caption=f"{''.join(src_tokens)} → {''.join(tgt_tokens)}"))
#             plt.close()
#             sampled += 1

#         if sampled >= max_samples:
#             break

# # Log to W&B
# wandb.log({"Attention Heatmaps (3x3 Grid)": attention_images})
# wandb.finish()


  fig.canvas.draw()
  fig.canvas.draw()
  fig.canvas.draw()
  fig.canvas.draw()
  util.ensure_matplotlib_figure(data).savefig(buf, format=self.format)
  util.ensure_matplotlib_figure(data).savefig(buf, format=self.format)
  util.ensure_matplotlib_figure(data).savefig(buf, format=self.format)
  util.ensure_matplotlib_figure(data).savefig(buf, format=self.format)
  fig.canvas.draw()
  fig.canvas.draw()
  fig.canvas.draw()
  fig.canvas.draw()
  util.ensure_matplotlib_figure(data).savefig(buf, format=self.format)
  util.ensure_matplotlib_figure(data).savefig(buf, format=self.format)
  util.ensure_matplotlib_figure(data).savefig(buf, format=self.format)
  util.ensure_matplotlib_figure(data).savefig(buf, format=self.format)
  fig.canvas.draw()
  fig.canvas.draw()
  fig.canvas.draw()
  fig.canvas.draw()
  fig.canvas.draw()
  fig.canvas.draw()
  util.ensure_matplotlib_figure(data).savefig(buf, format=self.format)
  util.ensure_matplotlib_figure(data).savefig(buf, format=self.format)
  util.ens

In [34]:
# # import torch
# # import matplotlib.pyplot as plt
# # import os
# # import random
# # import wandb
# # import matplotlib.font_manager as fm
# # from io import BytesIO
# # from PIL import Image

# # # Load Devanagari font
# # font_path = "/usr/share/fonts/truetype/noto/NotoSansDevanagari-Regular.ttf"
# # if not os.path.exists(font_path):
# #     import urllib.request
# #     os.makedirs(os.path.dirname(font_path), exist_ok=True)
# #     url = "https://github.com/googlefonts/noto-fonts/blob/main/hinted/ttf/NotoSansDevanagari/NotoSansDevanagari-Regular.ttf?raw=true"
# #     urllib.request.urlretrieve(url, font_path)

# # devanagari_font = fm.FontProperties(fname=font_path)
# # plt.rcParams['font.family'] = devanagari_font.get_name()

# # Initialize wandb
# wandb.init(project="da6401_assignment3_v1", name="attention_heatmaps_plt", job_type="visualization")

# # Prepare model and data
# data.prepare_data()
# test_dataloader = data.test_dataloader()
# model.eval()
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# model.to(device)

# # Collect 9 attention maps
# samples_to_plot = []
# with torch.no_grad():
#     for batch in test_dataloader:
#         src, tgt = batch
#         src, tgt = src.to(device), tgt.to(device)
#         outputs, attn_weights_all = model(src, tgt, teacher_forcing_ratio=0)
#         preds = outputs.argmax(dim=2)

#         for i in range(src.size(0)):
#             if len(samples_to_plot) >= 9:
#                 break

#             src_str = decode_sequence(src[i], data.source_chr_to_idx, data.source_idx_to_char)
#             tgt_str = decode_sequence(tgt[i][1:], data.target_chr_to_idx, data.target_idx_to_char)
#             pred_str = decode_sequence(preds[i], data.target_chr_to_idx, data.target_idx_to_char)
#             attn_weights = attn_weights_all[i].cpu().numpy()

#             src_tokens = list(src_str)
#             tgt_tokens = list(pred_str)

#             attn_weights = attn_weights[:len(tgt_tokens), :len(src_tokens)]

#             samples_to_plot.append((src_tokens, tgt_tokens, attn_weights, src_str, tgt_str, pred_str))
#         if len(samples_to_plot) >= 9:
#             break

# # Plot using matplotlib
# fig, axes = plt.subplots(3, 3, figsize=(15, 12))
# for idx, (src_tokens, tgt_tokens, attn, src_str, tgt_str, pred_str) in enumerate(samples_to_plot):
#     ax = axes[idx // 3][idx % 3]
#     im = ax.imshow(attn, aspect='auto', cmap='viridis')
#     ax.set_xticks(range(len(src_tokens)))
#     ax.set_yticks(range(len(tgt_tokens)))
#     ax.set_xticklabels(src_tokens, fontproperties=devanagari_font, rotation=90, fontsize=8)
#     ax.set_yticklabels(tgt_tokens, fontproperties=devanagari_font, fontsize=8)
#     ax.set_title(f"Src: {src_str}\nTgt: {tgt_str}\nPred: {pred_str}", fontproperties=devanagari_font, fontsize=10)
#     fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)

# plt.tight_layout()

# # Log to wandb
# buf = BytesIO()
# plt.savefig(buf, format='png')
# buf.seek(0)
# image = Image.open(buf)
# wandb.log({"Attention Heatmaps (3x3 Grid)": wandb.Image(image)})

# plt.close()
# wandb.finish()


  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.savefig(buf, format='png')
  plt.savefig(buf, format='png')
  plt.savefig(buf, format='png')
  plt.savefig(buf, format='png')
  plt.savefig(buf, format='png')
  plt.savefig(buf, format='png')
  plt.savefig(buf, format='png')
  plt.savefig(buf, format='png')
  plt.savefig(buf, format='png')
  plt.savefig(buf, format='png')
  plt.savefig(buf, format='png')
  plt.savefig(buf, format='png')
  plt.savefig(buf, format='png')
  plt.savefig(buf, format='png')
  plt.savefig(buf, format='png')
  plt.savefig(buf, format='png')
  plt.savefig(buf, format='png')
  plt.savefig(buf, 