In [1]:
import numpy as np
import pandas as pd
import torch
import os
from torch import nn
import sys
from torch.utils.data import Dataset, DataLoader
import wandb
import regex as re
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence
import wandb
import lightning as pl
from pytorch_lightning import LightningModule
from pytorch_lightning.loggers import WandbLogger

In [2]:
train_path = "/home/user/Documents/Courses/dakshina_dataset_v1.0/ta/lexicons/ta.translit.sampled.train.tsv"
valid_path = "/home/user/Documents/Courses/dakshina_dataset_v1.0/ta/lexicons/ta.translit.sampled.dev.tsv"
test_path = "/home/user/Documents/Courses/dakshina_dataset_v1.0/ta/lexicons/ta.translit.sampled.test.tsv"

train_df = pd.read_csv(train_path, sep="\t", header=None, names=["native", "latin", 'n_annot'], encoding='utf-8')
valid_df = pd.read_csv(valid_path, sep="\t", header=None, names=["native", "latin", 'n_annot'], encoding='utf-8')
test_df = pd.read_csv(test_path, sep="\t", header=None, names=["native", "latin", 'n_annot'], encoding='utf-8')

train_df.head()

Unnamed: 0,native,latin,n_annot
0,ஃபியட்,fiat,2
1,ஃபியட்,phiyat,1
2,ஃபியட்,piyat,1
3,ஃபிரான்ஸ்,firaans,1
4,ஃபிரான்ஸ்,france,2


In [3]:
train_df = train_df[~train_df['latin'].isna()]
valid_df = valid_df[~valid_df['latin'].isna()]
test_df = test_df[~test_df['latin'].isna()]

In [4]:
class NativeTokenizer():
    def __init__(self, train_path, valid_path, test_path, special_tokens={'START': '<start>','END':'<end>', 'PAD':'<pad>'}):
        
        self.train_df = pd.read_csv(train_path, sep="\t", header=None, names=["native", "latin", 'n_annot'], encoding='utf-8')
        self.valid_df = pd.read_csv(valid_path, sep="\t", header=None, names=["native", "latin", 'n_annot'], encoding='utf-8')
        self.test_df = pd.read_csv(test_path, sep="\t", header=None, names=["native", "latin", 'n_annot'], encoding='utf-8')
        self.special_tokens = special_tokens
        # Build vocabulary
        self._build_vocab(add_special_tokens=True)
        
        # Id to token mapping
        self.id_to_latin = {i: char for i, char in enumerate(self.latin_vocab)}
        self.id_to_native = {i: char for i, char in enumerate(self.native_vocab)}

        self.latin_vocab_size = len(self.latin_vocab)
        self.nat_vocab_size = len(self.native_vocab)

    # Build vocabulary
    def _build_vocab(self, add_special_tokens=True):
        self.nat_set = set()
        self.latin_set = set()
        for lat, nat in zip(self.train_df['latin'], self.train_df['native']):
            nat_chars = re.findall(r'\X' , nat)
            try:
                lat_chars = list(lat)
            except:
                print(f"Invalid latin string: {lat}, skipping....")
            
            for char in nat_chars:
                self.nat_set.add(char)
            for char in lat_chars:
               self.latin_set.add(char.lower())
            
        self.nat_set = sorted(list(self.nat_set))
        self.latin_set = sorted(list(self.latin_set))
        
        if add_special_tokens:
            self.nat_set = list(self.special_tokens.values()) + self.nat_set
            self.latin_set = [self.special_tokens['PAD']] + self.latin_set   

        self.latin_vocab = {char: i for i, char in enumerate(self.latin_set)}
        self.native_vocab = {char: i for i, char in enumerate(self.nat_set)}

    def tokenize(self, text, lang='latin'):
        if type(text) != str:
            print("Invalid text:", text)
            print("Language must be a string, but got", type(text))
        if lang == 'latin':
            return [self.latin_vocab[char] for char in text]
        elif lang == 'native':
            return [self.native_vocab['<start>']] + [self.native_vocab[char] for char in re.findall('\X', text)] + [self.native_vocab['<end>']]
        else:
            raise ValueError("Language must be either 'latin' or 'native'.")




In [5]:
tokenizer = NativeTokenizer(train_path, valid_path, test_path)
print(f"Latin vocab size: {tokenizer.latin_vocab_size}")
print(f"Native vocab size: {tokenizer.nat_vocab_size}")

Invalid latin string: nan, skipping....
Invalid latin string: nan, skipping....
Invalid latin string: nan, skipping....
Latin vocab size: 27
Native vocab size: 253


In [6]:
class LatNatDataset(Dataset):
    def __init__(self, df, tokenizer):
        self.df = df
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        entry = self.df.iloc[idx]
        latin_word = entry['latin']
        native_word = entry['native']
               
        # Tokenize and convert to IDs
        #latin_ids = [self.tokenizer.latin_vocab[i] for i in latin_word]
        #native_ids = [self.tokenizer.native_vocab[i] for i in re.findall(r'\X' , native_word)]
        latin_ids = self.tokenizer.tokenize(latin_word, lang='latin')
        native_ids = self.tokenizer.tokenize(native_word, lang='native')


        return (torch.tensor(latin_ids),
            torch.tensor(native_ids))

    def collate_fn(self, batch):
        x,y = zip(*batch)
        x_len = [len(seq) for seq in x]
        y_len = [len(seq) for seq in y]

        padded_x = pad_sequence(x, batch_first=True, padding_value=self.tokenizer.latin_vocab['<pad>'])
        padded_y = pad_sequence(y, batch_first=True, padding_value=self.tokenizer.native_vocab['<pad>'])
        
        x_len, perm_idx = torch.tensor(x_len).sort(0, descending=True)
        padded_x = padded_x[perm_idx]

        y_len = torch.tensor(y_len).sort(0, descending=True)
        padded_y = padded_y[perm_idx]

        return padded_x, x_len, padded_y, y_len



In [7]:
train_dataset = LatNatDataset(train_df, tokenizer)
valid_dataset = LatNatDataset(valid_df, tokenizer)
test_dataset = LatNatDataset(test_df, tokenizer)

train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=train_dataset.collate_fn, num_workers=2)
valid_dataloader = DataLoader(valid_dataset, batch_size=32, shuffle=False, collate_fn=valid_dataset.collate_fn , num_workers=2)
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False, collate_fn=test_dataset.collate_fn, num_workers=2)

In [31]:
class Encoder(torch.nn.Module):
    def __init__(self, input_size, embedding_size, hidden_size, cell, num_layers, dropout, activation='tanh'):
        super(Encoder, self).__init__()
        self.hidden_size = hidden_size
        self.embedding = torch.nn.Embedding(num_embeddings=input_size, embedding_dim=embedding_size, )
        if cell =='rnn':
            self.rnn = torch.nn.RNN(input_size=embedding_size, hidden_size=hidden_size, batch_first=True, num_layers=num_layers, dropout=dropout, nonlinearity=activation)
        elif cell == 'LSTM':
            self.rnn = torch.nn.GRU(input_size=embedding_size, hidden_size=hidden_size, batch_first=True, num_layers=num_layers,dropout=dropout)
        elif cell == 'GRU':
            self.rnn = torch.nn.LSTM(input_size=embedding_size, hidden_size=hidden_size, batch_first=True, num_layers=num_layers, dropout=dropout)
    
    def forward(self, seq, seq_len):
        embedding = self.embedding(input=seq)
        packed = pack_padded_sequence(input=embedding, lengths=seq_len, batch_first=True, enforce_sorted=True)
        output, hidden = self.rnn(packed)
        output, _ = pad_packed_sequence(output, batch_first=True)
        return output, hidden

In [32]:
class Decoder(torch.nn.Module):
    def __init__(self, output_size, embedding_size, hidden_size, cell, num_layers, dropout, activation='tanh'):
        super(Decoder, self).__init__()
        self.hidden_size = hidden_size
        self.embedding = torch.nn.Embedding(num_embeddings=output_size, embedding_dim=embedding_size)
        if cell == 'rnn':
            self.rnn = torch.nn.RNN(input_size=embedding_size, hidden_size=hidden_size, batch_first=True, num_layers=num_layers, nonlinearity=activation, dropout=dropout)
        elif cell == 'LSTM':
            self.rnn = torch.nn.GRU(input_size=embedding_size, hidden_size=hidden_size, batch_first=True, num_layers=num_layers, dropout=dropout)
        elif cell == 'GRU':
            self.rnn = torch.nn.GRU(input_size=embedding_size, hidden_size=hidden_size, batch_first=True, num_layers=num_layers, dropout=dropout)
        self.out = torch.nn.Linear(hidden_size, output_size)
        self.softmax = torch.nn.LogSoftmax(dim=2)  

    def forward(self, input_step, hidden):
        # input_step: (batch_size, 1) [a single timestep]
        embedded = self.embedding(input_step)  # (batch_size, 1, hidden_size)

        rnn_output, hidden = self.rnn(embedded, hidden)  # output: (batch_size, 1, hidden_size)
        output = self.out(rnn_output)  # (batch_size, 1, output_size)
        return output, hidden


In [33]:
from torch import nn
def train(input_tensor, input_lengths, target_tensor, target_lengths, encoder, decoder, 
          encoder_optimizer, decoder_optimizer, criterion, max_target_len, teacher_forcing_ratio=0.5):
    special_tokens = {key: val for key, val in tokenizer.native_vocab.items() if key in ['<start>', '<end>', '<pad>']}

    encoder.train()
    decoder.train()
    
    encoder_optimizer.zero_grad()
    decoder_optimizer.zero_grad()

    batch_size = input_tensor.size(0)
    loss = 0

    _, encoder_hidden = encoder(input_tensor, input_lengths)

    # Prepare decoder input and hidden state

    decoder_input = target_tensor[:, :-1].detach().clone()
    decoder_target = target_tensor[:, 1:].detach().clone()

    #decoder_input = torch.tensor(([SOS_token]*batch_size)).unsqueeze(1)
    
    #print("decoder input", decoder_input.shape)

    
    decoder_hidden = encoder_hidden # directly use last hidden state from encoder
    # Feed the target as the next input
    for i in range(target_tensor.shape[1]-1):

        decoder_output, decoder_hidden = decoder(decoder_input[:, i].unsqueeze(1), decoder_hidden)
        #decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden)

        #decoder_input = decoder_output.argmax(dim=2)
        #print("decoder output: ", decoder_output.shape)
        loss += criterion(decoder_output.squeeze(1), decoder_target[:, i])
        # Append the new column
        if i == 0:
            preds = decoder_output.argmax(dim=2).cpu().numpy()
        else:
            preds = np.hstack((preds, decoder_output.argmax(dim=2).cpu().numpy()))

    mask = ~torch.isin(decoder_target[:,:-1], torch.tensor(list(special_tokens.values())))
    masked_preds = torch.tensor(preds[:, :-1]).masked_fill(~mask, -1)
    masked_targets = decoder_target[:, :-1].masked_fill(~mask, -1)

    exact_matches = (masked_preds == masked_targets).all(dim=1)
    accuracy = exact_matches.float().mean()
    words = "".join([tokenizer.id_to_native[i] for i in preds[0]])
    #print(decoder_target[0:1, :-1].tolist())
    truth = "".join([tokenizer.id_to_native[i] for i in decoder_target[0:1, :-1].tolist()[0]])
    #print(f"Pred: {words}, Truth: {truth}")
    
    loss.backward()
    encoder_optimizer.step()
    decoder_optimizer.step()

    return loss.item() / max_target_len, accuracy


In [10]:
def beam_search(decoder, encoder_outputs, beam_width, max_len, start_token, end_token):
    # Each beam stores (sequence, score)
    best_seqs = []
    for i in range(encoder_outputs.size(0)):
        beams = [(torch.tensor([start_token]), 0.0)]  # Start with start token

        for _ in range(max_len):
            all_candidates = []
            for seq, score in beams:
                if seq[-1].item() == end_token:
                    # Already ended; keep as is
                    all_candidates.append((seq, score))
                    continue
                
                # Get decoder output probabilities for last token
                input_token = seq[-1].unsqueeze(0).unsqueeze(0)  # shape (1,1)
                output_logits, hidden = decoder(input_token)
                log_probs = torch.log_softmax(output_logits, dim=-1).squeeze(0).squeeze(0)

                # Get top-k tokens and their log probabilities
                topk_log_probs, topk_tokens = torch.topk(log_probs, beam_width)

                for k in range(beam_width):
                    next_token = topk_tokens[k].unsqueeze(0)
                    new_seq = torch.cat([seq, next_token])
                    new_score = score + topk_log_probs[k].item()
                    all_candidates.append((new_seq, new_score))

            # Select top beam_width sequences
            beams = sorted(all_candidates, key=lambda x: x[1], reverse=True)[:beam_width]

            # Optional: break early if all beams ended with end_token
            if all(seq[-1].item() == end_token for seq, _ in beams):
                break

        # Return best sequence (highest score)
        best_seq = beams[0][0]
        best_seqs.append(best_seq)


In [36]:
INPUT_SIZE = tokenizer.latin_vocab_size
OUTPUT_SIZE = tokenizer.nat_vocab_size
EMBEDDING_SIZE = 128
HIDDEN_SIZE = 256
MAX_TARGET_LEN = 28  # Set this to the maximum length of your target sequences
SOS_token = tokenizer.native_vocab['<start>']
PAD_TOKEN = tokenizer.native_vocab['<pad>']
EOS_token = tokenizer.native_vocab['<end>']
encoder = Encoder(input_size=INPUT_SIZE, embedding_size= EMBEDDING_SIZE,hidden_size=HIDDEN_SIZE, cell='rnn', num_layers=2, dropout=0.1)
decoder = Decoder(output_size=OUTPUT_SIZE, embedding_size= EMBEDDING_SIZE,hidden_size=HIDDEN_SIZE, cell='rnn', num_layers=2, dropout=0.1)

# Optimizers and loss
encoder_optimizer = torch.optim.Adam(encoder.parameters(), lr=0.001)
decoder_optimizer = torch.optim.Adam(decoder.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.native_vocab['<pad>'])  # Assume PAD token is 0



def train_iters(encoder, decoder, n_epochs, training_data, encoder_optimizer, decoder_optimizer, 
                criterion, max_target_len, batch_size=32, teacher_forcing_ratio=0.5, print_every=1):
    
    for epoch in range(1, n_epochs + 1):
        print(f"Epoch {epoch}/{n_epochs}")
        total_loss = 0
        accuracy = []
        # Assuming training_data is a list of (input_tensor, input_length, target_tensor, target_length)
        for batch in training_data:
            input_tensor, input_lengths, target_tensor, target_lengths = batch
            loss, acc = train(input_tensor, input_lengths, target_tensor, target_lengths,
                         encoder, decoder, encoder_optimizer, decoder_optimizer,
                         criterion, max_target_len=max_target_len,
                         teacher_forcing_ratio=teacher_forcing_ratio)
            total_loss += loss

            accuracy.append(acc)

        if epoch % print_every == 0:
            avg_loss = total_loss / len(training_data)
            print(f"Epoch {epoch}/{n_epochs}, Loss: {avg_loss:.4f}")
            print(f"Epoch {epoch}/{n_epochs}, avg acc: {np.mean(accuracy):.4f}")
    



In [38]:
train_iters(
    encoder=encoder,
    decoder=decoder,
    n_epochs=10,
    training_data=train_dataloader,
    encoder_optimizer=encoder_optimizer,
    decoder_optimizer=decoder_optimizer,
    criterion=criterion,
    max_target_len=MAX_TARGET_LEN,
    batch_size=32,
    teacher_forcing_ratio=1,
    
    print_every=1
)

Epoch 1/10
Epoch 1/10, Loss: 0.7298
Epoch 1/10, avg acc: 0.0008
Epoch 2/10
Epoch 2/10, Loss: 0.5653
Epoch 2/10, avg acc: 0.0019
Epoch 3/10
Epoch 3/10, Loss: 0.4895
Epoch 3/10, avg acc: 0.0043
Epoch 4/10
Epoch 4/10, Loss: 0.4533
Epoch 4/10, avg acc: 0.0073
Epoch 5/10
Epoch 5/10, Loss: 0.4381
Epoch 5/10, avg acc: 0.0089
Epoch 6/10
Epoch 6/10, Loss: 0.4182
Epoch 6/10, avg acc: 0.0110
Epoch 7/10
Epoch 7/10, Loss: 0.3916
Epoch 7/10, avg acc: 0.0159
Epoch 8/10
Epoch 8/10, Loss: 0.3701
Epoch 8/10, avg acc: 0.0228
Epoch 9/10
Epoch 9/10, Loss: 0.3566
Epoch 9/10, avg acc: 0.0269
Epoch 10/10
Epoch 10/10, Loss: 0.3522
Epoch 10/10, avg acc: 0.0295


In [31]:
class RNN_light(pl.LightningModule):
    def __init__(self, input_sizes, embedding_size, hidden_size, cell, layers, dropout, activation, beam_size, optim, special_tokens, lr):
        super().__init__()
        self.optim = optim
        self.save_hyperparameters()
        self.beam_size = beam_size
        if layers == 1:
            print("Dropout is not applied for 1 layer")
            dropout = 0 
        self.encoder = Encoder(input_sizes[0], embedding_size, hidden_size, cell=cell, num_layers=layers, dropout=dropout, activation=activation)
        self.decoder = Decoder(input_sizes[1], embedding_size, hidden_size, cell=cell, num_layers=layers, dropout=dropout, activation=activation)
        self.loss_fn = torch.nn.CrossEntropyLoss(ignore_index=special_tokens['<pad>'])
        self.special_tokens = special_tokens   
        self.beam_size = beam_size 
    def forward(self, input_tensor=[], input_lengths=[], decoder_input=[], decoder_hidden= [], encoder=False):
        if encoder:
            _, decoder_hidden = self.encoder(input_tensor, input_lengths)
            print("decoder input: ", decoder_input.shape)
            decoder_output, decoder_hidden = self.decoder(decoder_input, decoder_hidden)

        else:
            decoder_output, decoder_hidden = self.decoder(decoder_input, decoder_hidden)
        return decoder_output, decoder_hidden

    
    def training_step(self, batch, batch_idx):
        input_tensor, input_lengths, target_tensor, target_lengths = batch
        decoder_input = target_tensor[:, :-1].detach().clone()
        decoder_target = target_tensor[:, 1:].detach().clone()
        loss = 0
        _, decoder_hidden = self.encoder(input_tensor, input_lengths)


        for i in range(target_tensor.shape[1]-1):
            if i ==0:
                # first step
                decoder_output, decoder_hidden = self.decoder(decoder_input[:, i].unsqueeze(1), decoder_hidden)
                #decoder_output, decoder_hidden = self(input_tensor = input_tensor, input_lengths=input_lengths, decoder_input = decoder_input[:, i].unsqueeze(1), encoder=True)
   
                loss += self.loss_fn(decoder_output.squeeze(1), decoder_target[:, i])
                preds = decoder_output.argmax(dim=2).cpu().numpy()
            else:
                # rest of the steps
                decoder_output, decoder_hidden = self.decoder(decoder_input[:, i].unsqueeze(1), decoder_hidden)
                #decoder_output, decoder_hidden = self(decoder_input=decoder_input[:, i].unsqueeze(1), decoder_hidden=decoder_hidden)
 
                loss += self.loss_fn(decoder_output.squeeze(1), decoder_target[:, i])
                preds = np.hstack((preds, decoder_output.argmax(dim=2).cpu().numpy()))
        
        # masking pad tokens and end tokens for accuracy calculation
        
        mask = ~torch.isin(decoder_target[:,:-1], torch.tensor(list(self.special_tokens.values())))

        masked_preds = torch.tensor(preds[:, :-1]).masked_fill(~mask, self.special_tokens['<pad>'])
        masked_targets = decoder_target[:, :-1].masked_fill(~mask, self.special_tokens['<pad>'])
        exact_matches = (masked_preds == masked_targets).all(dim=1)
        accuracy = exact_matches.float().mean()
 
        self.log("train loss", loss, on_step = False, on_epoch = True)
        self.log("train accuracy", accuracy, on_step = False, on_epoch = True)

        return loss

    def validation_step(self, batch, batch_idx):
        input_tensor, input_lengths, target_tensor, target_lengths = batch
        decoder_input = target_tensor[:, :-1].detach().clone()
        decoder_target_batch = target_tensor[:, 1:].detach().clone()
        loss = 0
        _, decoder_hidden_batch = self.encoder(input_tensor, input_lengths)


        best_seqs = []
        for i in range(len(batch)):
            beams = [(torch.tensor([self.special_tokens['<start>']]), 0.0)]  # Start with start token
            decoder_hidden = decoder_hidden_batch[:, i, :].unsqueeze(1)  
            decoder_target = decoder_target_batch[i].unsqueeze(0)
            for j in range(decoder_target.shape[1]-1):
                all_candidates = []
                for seq, score in beams:
                    if seq[-1].item() == self.special_tokens['<end>']:
                        all_candidates.append((seq, score))
                        continue
                    
                    
                    input_token = seq[-1].unsqueeze(0).unsqueeze(0)  # shape (1,1)
                    decoder_output, decoder_hidden = self(decoder_input=input_token, decoder_hidden=decoder_hidden[:,0,:].unsqueeze(1))
                    loss += self.loss_fn(decoder_output.squeeze(1), decoder_target[0, j].unsqueeze(0))

                    log_probs = torch.log_softmax(decoder_output, dim=-1).squeeze(0).squeeze(0)

                    # Get top-k tokens and their log probabilities
                    topk_log_probs, topk_tokens = torch.topk(log_probs, self.beam_size)

                    for k in range(self.beam_size):
                        next_token = topk_tokens[k].unsqueeze(0)
                        new_seq = torch.cat([seq, next_token])
                        new_score = score + topk_log_probs[k].item()
                        all_candidates.append((new_seq, new_score))

                # Select top beam_width sequences
                beams = sorted(all_candidates, key=lambda x: x[1], reverse=True)[:self.beam_size]

                # Optional: break early if all beams ended with end_token
                if all(seq[-1].item() == self.special_tokens['<end>'] for seq, _ in beams):
                    break

            # Return best sequence (highest score)
            best_seq = beams[0][0]
            best_seqs.append(best_seq)
        
                # Convert best_seqs to tensor

        preds = pad_sequence(best_seqs, batch_first=True, padding_value=self.special_tokens['<pad>'])
        print(preds.shape)
        #preds = torch.stack(best_seqs)
        pred_len = decoder_target_batch.shape[1]
        print(best_seqs[0])
        accuracy  = 0
        #preds = torch.tensor([seq + torch.tensor(self.special_tokens['<pad>'])*(pred_len - len(seq)) for seq in preds])
        #print("valdation")
        #print("preds shape", preds.shape)
        #print("decoder target shape", decoder_target.shape)

        #mask = ~torch.isin(decoder_target[:,:-1], torch.tensor(list(self.special_tokens.values())))
        #masked_preds = preds[:, :-1].masked_fill(~mask, -1)
        #masked_targets = decoder_target[:, :-1].masked_fill(~mask, -1)
        #exact_matches = (masked_preds == masked_targets).all(dim=1)
        #accuracy = exact_matches.float().mean()
    
        self.log("train loss", loss, on_step = False, on_epoch = True)
        self.log("train accuracy", accuracy, on_step = False, on_epoch = True)

        return loss

    def test_step(self, batch, batch_idx):
        input_tensor, input_lengths, target_tensor, target_lengths = batch
        decoder_input = target_tensor[:, :-1].detach().clone()
        decoder_target_batch = target_tensor[:, 1:].detach().clone()
        loss = 0
        _, decoder_hidden_batch = self.encoder(input_tensor, input_lengths)
        

        best_seqs = []
        for i in range(len(batch)):
            beams = [(torch.tensor([self.special_tokens['<start>']]), 0.0)]  # Start with start token
            decoder_hidden = decoder_hidden_batch[i].unsqueeze(0)
            decoder_target = decoder_target_batch[i].unsqueeze(0)
            for j in range(decoder_target.shape[1]-1):
                all_candidates = []
                for seq, score in beams:
                    if seq[-1].item() == self.special_tokens['<end>']:
                        all_candidates.append((seq, score))
                        continue
                    
                    
                    input_token = seq[-1].unsqueeze(0).unsqueeze(0)  # shape (1,1)
                    decoder_output, decoder_hidden = self(decoder_input=input_token, decoder_hidden=decoder_hidden[:,0,:].unsqueeze(1))
                    loss += self.loss_fn(decoder_output.squeeze(1), decoder_target[0, j].unsqueeze(0))

                    log_probs = torch.log_softmax(decoder_output, dim=-1).squeeze(0).squeeze(0)

                    # Get top-k tokens and their log probabilities
                    topk_log_probs, topk_tokens = torch.topk(log_probs, self.beam_size)

                    for k in range(self.beam_size):
                        next_token = topk_tokens[k].unsqueeze(0)
                        new_seq = torch.cat([seq, next_token])
                        new_score = score + topk_log_probs[k].item()
                        all_candidates.append((new_seq, new_score))

                # Select top beam_width sequences
                beams = sorted(all_candidates, key=lambda x: x[1], reverse=True)[:self.beam_size]

                # Optional: break early if all beams ended with end_token
                if all(seq[-1].item() == self.special_tokens['<end>'] for seq, _ in beams):
                    break

            # Return best sequence (highest score)
            best_seq = beams[0][0]
            best_seqs.append(best_seq)
        
                # Convert best_seqs to tensor

        preds = pad_sequence(best_seqs, batch_first=True, padding_value=self.special_tokens['<pad>'])

        #preds = torch.stack(best_seqs)
        pred_len = decoder_target_batch.shape[1]
        print(preds.shape)
        #preds = [seq + torch.tensor(self.special_tokens['<pad>'])*(pred_len - len(seq)) for seq in preds]

        mask = ~torch.isin(decoder_target_batch[:,:-1], torch.tensor(list(self.special_tokens.values())))
        masked_preds = preds[:, :-1].masked_fill(~mask, -1)
        masked_targets = decoder_target_batch[:, :-1].masked_fill(~mask, -1)
        exact_matches = (masked_preds == masked_targets).all(dim=1)
        accuracy = exact_matches.float().mean()
 
        self.log("train loss", loss, on_step = False, on_epoch = True)
        self.log("train accuracy", accuracy, on_step = False, on_epoch = True)

        return loss

    def predict_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        return logits

    def configure_optimizers(self):
        if self.optim == 'sgd':
            optimizer = torch.optim.SGD(self.parameters(), lr=self.hparams.lr, momentum=0.9)
        elif self.optim == 'adam':
            optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.lr)
        return optimizer
    


In [25]:
INPUT_SIZE = tokenizer.latin_vocab_size
OUTPUT_SIZE = tokenizer.nat_vocab_size
EMBEDDING_SIZE = 128
HIDDEN_SIZE = 256
MAX_TARGET_LEN = 28  # Set this to the maximum length of your target sequences
SOS_token = tokenizer.native_vocab['<start>']
PAD_TOKEN = tokenizer.native_vocab['<pad>']
EOS_token = tokenizer.native_vocab['<end>']

In [30]:
os.environ['WANDB_API_KEY'] = "761e2f0f9986fd2e6ee9f21ef44a2665e0bc8618"
wandb.login(key=os.getenv("WANDB_API_KEY"))
special_tokens = {key: val for key, val in tokenizer.native_vocab.items() if key in ['<start>', '<end>', '<pad>']}

model = RNN_light(
    input_sizes=(tokenizer.latin_vocab_size, tokenizer.nat_vocab_size),
    embedding_size=EMBEDDING_SIZE,
    hidden_size=HIDDEN_SIZE,
    cell='rnn',
    layers=3,
    dropout=0.1,
    activation='tanh',
    beam_size=3,
    optim='adam',
    special_tokens=special_tokens,
    lr=0.001
    
)
logger= WandbLogger(project= 'rnntest', name = "test",resume="never")
trainer = pl.Trainer(max_epochs=5,  accelerator="auto",logger=logger, profiler='simple',  precision="16-mixed",)
trainer.fit(model, train_dataloader,  valid_dataloader)
trainer.test(model, dataloaders=test_dataloader)
#trainer.save_checkpoint("trained_model.ckpt")

Using bfloat16 Automatic Mixed Precision (AMP)
Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

  | Name    | Type             | Params | Mode 
-----------------------------------------------------
0 | encoder | Encoder          | 365 K  | train
1 | decoder | Decoder          | 459 K  | train
2 | loss_fn | CrossEntropyLoss | 0      | train
-----------------------------------------------------
824 K     Trainable params
0         Non-trainable params
824 K     Total params
3.299     Total estimated model params size (MB)
9         Modules in train mode
0         Modules in eval mode


Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]torch.Size([4, 9])
tensor([  0, 175, 103,   0, 111,  90,  50,  86, 178])
Sanity Checking DataLoader 0:  50%|█████     | 1/2 [00:00<00:00, 19.83it/s]torch.Size([4, 10])
tensor([  0,  10, 155,   1])
Epoch 0: 100%|██████████| 2132/2132 [06:36<00:00,  5.37it/s, v_num=qrxt]   torch.Size([4, 5])
tensor([  0, 120,  93,  86,   1])
torch.Size([4, 6])
tensor([  0,   6, 133, 124, 146,   1])
torch.Size([4, 6])
tensor([  0, 216,  28,  15, 171,   1])
torch.Size([4, 6])
tensor([  0, 216,  28,  15, 204,   1])
torch.Size([4, 5])
tensor([ 0,  6, 93, 84,  1])
torch.Size([4, 5])
tensor([ 0,  6, 93, 84,  1])
torch.Size([4, 6])
tensor([  0,   6,  93,  84, 119,   1])
torch.Size([4, 6])
tensor([  0, 216,  28,  15,   1])
torch.Size([4, 6])
tensor([  0,   4,  84,  15, 207,   1])
torch.Size([4, 5])
tensor([ 0,  6, 93,  1])
torch.Size([4, 7])
tensor([  0, 120,  28,  15, 171,   1])
torch.Size([4, 5])
tensor([  0,   6, 106,  82,   1])
torch.Size([4,

`Trainer.fit` stopped: `max_epochs=5` reached.


Epoch 4: 100%|██████████| 2132/2132 [06:45<00:00,  5.26it/s, v_num=qrxt]


FIT Profiler Report

--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|  Action                                                                                                                                                         	|  Mean duration (s)	|  Num calls      	|  Total time (s) 	|  Percentage %   	|
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|  Total                                                                                                                                                          	|  -              	|  398991         	|

Testing DataLoader 0:   0%|          | 0/215 [00:00<?, ?it/s]

TypeError: 'tuple' object cannot be interpreted as an integer