# Read the data

In [28]:
!pip install lightning wandb

[0m

In [29]:
!WANDB_API_KEY=8c780297be240a84f5c8b7d669cb158839b2637a

In [1]:
import pandas as pd
import torch 
from torch import nn
from torch.utils.data import Dataset, DataLoader
import numpy as np
import lightning as pl
from pytorch_lightning.loggers import WandbLogger
import random
import wandb

In [31]:
!WANDB_API_KEY=8c780297be240a84f5c8b7d669cb158839b2637a wandb login

[34m[1mwandb[0m: Currently logged in as: [33mcs20b075[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [2]:
path = "aksharantar_sampled/hin"
train_path = path + "/hin_train.csv"
valid_path = path + "/hin_valid.csv"
test_path = path + "/hin_test.csv"

In [3]:
def get_data(path):
    dataset = pd.read_csv(path, header=None)
    dataset = dataset.values
    input = dataset[:, 0]
    output = dataset[:, 1]
    return input, output

In [4]:
train_dataset = get_data(train_path)
val_dataset = get_data(valid_path)

In [5]:
def convert_word_to_tensor(word, lang):
    lang_to_int = {'SOS': 0, 'EOS': 1, 'PAD': 2}
    if lang == 'eng':
        lang_to_int.update({chr(i): i-94 for i in range(97, 123)})
    elif lang == 'hin':
        lang_to_int.update({chr(i): i-2300 for i in range(2304, 2432)})
    
    a = [lang_to_int['SOS']]

    for i in word:
        a.append(lang_to_int[i])

    a.append(lang_to_int['EOS'])
    if len(a) < 24:
        a.extend([lang_to_int['PAD']]*(24-len(a)))
    
    return torch.tensor(a)

In [6]:
class AksharantarDataset(Dataset):
    def __init__(self, dataset):
        super().__init__()
        self.dataset = dataset
        self.input = dataset[0]
        self.output = dataset[1]
        mask = np.array([len(elem) < 21 for elem in self.input]) & np.array([len(elem) < 21 for elem in self.output])
        self.input = self.input[mask]
        self.output = self.output[mask]
        self.len = len(self.input)
    
    def __getitem__(self, index):
        return convert_word_to_tensor(self.input[index], 'eng'), convert_word_to_tensor(self.output[index], 'hin')
    
    def __len__(self):
        return self.len

In [7]:
class CustomDataModule(pl.LightningDataModule):
    def __init__(self, dataset, val_dataset, batch_size=32):
        super().__init__()
        self.dataset = train_dataset
        self.val_dataset = val_dataset
        self.batch_size = batch_size

    def train_dataloader(self):
        dataset = AksharantarDataset(self.dataset)
        return DataLoader(dataset, batch_size=self.batch_size, num_workers=2)
    def val_dataloader(self):
        dataset = AksharantarDataset(self.val_dataset)
        return DataLoader(dataset, batch_size=self.batch_size, num_workers=2)

In [8]:
train_loader = CustomDataModule(train_dataset, val_dataset, 32)
# val_loader = CustomDataModule(val_dataset, 32)

# Encoder model

In [9]:
class Encoder(pl.LightningModule):
    def __init__(self, input_size, hidden_size, cell_type, num_layers=1, dropout=0, bidirectional=False):
        super().__init__()
        self.hidden_size = hidden_size
        self.embedding = nn.Embedding(input_size, hidden_size)
        self.cell_type = cell_type
        if cell_type == 'LSTM':
            self.rnn = nn.LSTM
        elif cell_type == 'GRU':
            self.rnn = nn.GRU
        else:
            self.rnn = nn.RNN
        self.direction = 2 if bidirectional else 1
        self.first_cell = self.rnn(hidden_size, hidden_size, bidirectional=bidirectional, batch_first=True)
        self.rnns = nn.ModuleList([self.rnn(hidden_size*self.direction, hidden_size, bidirectional=bidirectional, batch_first=True)]*(num_layers-1))
        self.num_layers = num_layers

    def forward(self, input, hidden):
        embedded = self.embedding(input)
        # embedded = embedded.view(1, 1, -1)
        output = embedded
        output, hidden = self.first_cell(output, hidden)
        for i in range(self.num_layers-1):
            output, hidden = self.rnns[i](output, hidden)
        return output, hidden

    def init_hidden(self):
        if self.cell_type == 'LSTM':
            return torch.zeros(self.direction, self.hidden_size), torch.zeros(self.direction, self.hidden_size)
        return torch.zeros(self.direction, self.hidden_size, device=self.device)

# Decoder

In [10]:
class Decoder(pl.LightningModule):
    def __init__(self, output_size, hidden_size, cell_type, num_layers=1, bidirectional=False, dropout=0):
        super().__init__()
        self.hidden_size = hidden_size
        self.embedding = nn.Embedding(output_size, hidden_size)
        if cell_type == 'LSTM':
            self.cell_type = nn.LSTM
        elif cell_type == 'GRU':
            self.cell_type = nn.GRU
        else:
            self.cell_type = nn.RNN
        self.first_cell = self.cell_type(hidden_size, hidden_size, bidirectional=bidirectional, batch_first=True)
        self.direction = 2 if bidirectional else 1
        self.rnns= nn.ModuleList([self.cell_type(hidden_size*self.direction, hidden_size, bidirectional=bidirectional, batch_first=True)]*(num_layers-1))
        self.out = nn.Linear(hidden_size*self.direction, output_size)
        self.softmax = nn.LogSoftmax(dim=-1)
        self.num_layers = num_layers

    def forward(self, input, hidden):
        output = self.embedding(input)
        output = nn.functional.relu(output)
        output, hidden = self.first_cell(output, hidden)
        for i in range(self.num_layers-1):
            output, hidden = self.rnns[i](output, hidden)
        linear_output = self.out(output)
        output = self.softmax(self.out(output))
        if output.shape[0] == 1:
            output = output.squeeze(0)
        return output, hidden

# Seq2seq model

In [11]:
class Seq2seq(pl.LightningModule):
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder.to(self.device)
        self.decoder = decoder.to(self.device)

    def forward(self, input):
        
        self.encoder = self.encoder.to(self.device)
        self.decoder = self.decoder.to(self.device)
        
        batched = True if len(input.shape) > 1 else False
        if not batched:
            input = input.unsqueeze(0)
            target = target.unsqueeze(0)
        input = input.to(self.device)
        target = target.to(self.device)
        batch_size = input.shape[0]
        input_length = input.shape[1]
        target_length = target.shape[1]

        encoder_hidden = None
        encoder_hidden_outputs = torch.zeros(input_length, self.encoder.direction, batch_size, self.encoder.hidden_size, device=self.device)
        encoder_output_gate = torch.zeros(input_length, self.encoder.direction, batch_size, self.encoder.hidden_size, device=self.device)
        if self.encoder.cell_type == 'LSTM':
            a, b = [torch.zeros(self.encoder.direction, batch_size, self.encoder.hidden_size)]*2
            encoder_hidden = a.to(self.device), b.to(self.device)
        else:
            encoder_hidden = torch.zeros(self.encoder.direction, batch_size, self.encoder.hidden_size).to(self.device)
        for i in range(input_length):
            # print(input[:, i].shape, encoder_hidden.shape)
            _, encoder_hidden_out = self.encoder(input[:, i].unsqueeze(1), encoder_hidden)
            if self.encoder.cell_type == 'LSTM':
                encoder_hidden_outputs[i] = encoder_hidden_out[0]
                encoder_output_gate[i] = encoder_hidden_out[1]
            else:
                encoder_hidden_outputs[i] = encoder_hidden_out
        if self.encoder.cell_type == 'LSTM':
            decoder_hidden = encoder_hidden_outputs[-1], encoder_output_gate[-1]
        else:
            decoder_hidden = encoder_hidden_outputs[-1]
        decoder_input = target[:, 0].unsqueeze(1)
        for j in range(target_length):
            decoder_output, decoder_hidden = self.decoder(decoder_input, decoder_hidden)
            decoder_input = decoder_output.argmax(dim=-1)
        output_sequence = torch.tensor(output_sequence, device=self.device)
        if not batched:
            output_sequence = output_sequence.squeeze(0)
        return output_sequence
        
    def training_step(self, batch, batch_idx):
        input, target = batch
        
        self.encoder = self.encoder.to(self.device)
        self.decoder = self.decoder.to(self.device)
        
        batched = True if len(input.shape) > 1 else False
        if not batched:
            input = input.unsqueeze(0)
            target = target.unsqueeze(0)
        input = input.to(self.device)
        target = target.to(self.device)
        batch_size = input.shape[0]
        input_length = input.shape[1]
        target_length = target.shape[1]

        encoder_hidden = None
        encoder_hidden_outputs = torch.zeros(input_length, self.encoder.direction, batch_size, self.encoder.hidden_size, device=self.device)
        encoder_output_gate = torch.zeros(input_length, self.encoder.direction, batch_size, self.encoder.hidden_size, device=self.device)
        if self.encoder.cell_type == 'LSTM':
            a, b = [torch.zeros(self.encoder.direction, batch_size, self.encoder.hidden_size)]*2
            encoder_hidden = a.to(self.device), b.to(self.device)
        else:
            encoder_hidden = torch.zeros(self.encoder.direction, batch_size, self.encoder.hidden_size).to(self.device)
        for i in range(input_length):
            # print(input[:, i].shape, encoder_hidden.shape)
            _, encoder_hidden_out = self.encoder(input[:, i].unsqueeze(1), encoder_hidden)
            if self.encoder.cell_type == 'LSTM':
                encoder_hidden_outputs[i] = encoder_hidden_out[0]
                encoder_output_gate[i] = encoder_hidden_out[1]
            else:
                encoder_hidden_outputs[i] = encoder_hidden_out
        loss = 0
        correct_words = 0
        if self.encoder.cell_type == 'LSTM':
            decoder_hidden = encoder_hidden_outputs[-1], encoder_output_gate[-1]
        else:
            decoder_hidden = encoder_hidden_outputs[-1]
        if random.random() < 0.5: 
            decoder_input = target[:, 0].unsqueeze(1)
            correct = None
            for j in range(target_length):
                decoder_output, decoder_hidden = self.decoder(decoder_input, decoder_hidden)
                squeezed_output = decoder_output.squeeze(1)
                for i in range(batch_size):
                    loss += nn.functional.nll_loss(squeezed_output[i], target[i, j])
                decoder_input = target[:, j].unsqueeze(1)
                if correct is None:
                    correct = decoder_output.argmax(dim=-1) == target[:, j]
                else:
                    correct = (decoder_output.argmax(dim=-1) == target[:, j]) & correct
            correct_words = correct.sum()

        else:
            decoder_input = target[:, 0].unsqueeze(1)
            correct = None
            for j in range(target_length):
                decoder_output, decoder_hidden = self.decoder(decoder_input, decoder_hidden)
                squeezed_output = decoder_output.squeeze(1)
                for i in range(batch_size):
                    loss += nn.functional.nll_loss(squeezed_output[i], target[i, j])
                decoder_input = decoder_output.argmax(dim=-1)
                if correct is None:
                    correct = decoder_input == target[:, j]
                else:
                    correct = (decoder_input == target[:, j]) & correct
            correct_words = correct.sum()

        # print(correct_words, batch_size, correct_words/batch_size)
        reported_loss = loss / (batch_size * target_length)
        self.log('train_loss', reported_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        self.log('train_acc', correct_words/batch_size, on_step=True, on_epoch=True, logger=True, prog_bar=True)
        return loss
    def validation_step(self, batch, batch_idx):
        input, target = batch
        
        self.encoder = self.encoder.to(self.device)
        self.decoder = self.decoder.to(self.device)
        
        batched = True if len(input.shape) > 1 else False
        if not batched:
            input = input.unsqueeze(0)
            target = target.unsqueeze(0)
        input = input.to(self.device)
        target = target.to(self.device)
        batch_size = input.shape[0]
        input_length = input.shape[1]
        target_length = target.shape[1]

        encoder_hidden = None
        encoder_hidden_outputs = torch.zeros(input_length, self.encoder.direction, batch_size, self.encoder.hidden_size, device=self.device)
        encoder_output_gate = torch.zeros(input_length, self.encoder.direction, batch_size, self.encoder.hidden_size, device=self.device)
        if self.encoder.cell_type == 'LSTM':
            a, b = [torch.zeros(self.encoder.direction, batch_size, self.encoder.hidden_size)]*2
            encoder_hidden = a.to(self.device), b.to(self.device)
        else:
            encoder_hidden = torch.zeros(self.encoder.direction, batch_size, self.encoder.hidden_size).to(self.device)
        for i in range(input_length):
            # print(input[:, i].shape, encoder_hidden.shape)
            _, encoder_hidden_out = self.encoder(input[:, i].unsqueeze(1), encoder_hidden)
            if self.encoder.cell_type == 'LSTM':
                encoder_hidden_outputs[i] = encoder_hidden_out[0]
                encoder_output_gate[i] = encoder_hidden_out[1]
            else:
                encoder_hidden_outputs[i] = encoder_hidden_out
        loss = 0
        correct_words = 0
        if self.encoder.cell_type == 'LSTM':
            decoder_hidden = encoder_hidden_outputs[-1], encoder_output_gate[-1]
        else:
            decoder_hidden = encoder_hidden_outputs[-1]
        decoder_input = target[:, 0].unsqueeze(1)
        correct = None
        for j in range(target_length):
            decoder_output, decoder_hidden = self.decoder(decoder_input, decoder_hidden)
            squeezed_output = decoder_output.squeeze(1)
            for i in range(batch_size):
                loss += nn.functional.nll_loss(squeezed_output[i], target[i, j])
            decoder_input = decoder_output.argmax(dim=-1)
            if correct is None:
                correct = decoder_input == target[:, j]
            else:
                correct = (decoder_input == target[:, j]) & correct
        correct_words = correct.sum()

        # for i in range(batch_size):
        #     if self.encoder.cell_type == 'LSTM':
        #         decoder_hidden = encoder_hidden_outputs[i].view(self.decoder.direction, -1), encoder_output_gate[i].view(self.decoder.direction, -1)
        #     else:
        #         decoder_hidden = encoder_hidden_outputs[i].view(self.decoder.direction, -1)
        #     decoder_input = target[i, 0].unsqueeze(0)
        #     correct = True
        #     for j in range(target_length):
        #         decoder_output, decoder_hidden = self.decoder(decoder_input, decoder_hidden)
        #         loss += nn.functional.nll_loss(decoder_output, target[i, j])
        #         decoder_input = torch.tensor([decoder_output.argmax().item()]).to(self.device)
        #         if correct and target[i, j]!= decoder_output.argmax().item():
        #             correct = False
        #     if correct:
        #         correct_words  += 1
        reported_loss = loss / (batch_size * target_length)
        self.log('val_loss', reported_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        self.log('val_acc', correct_words/batch_size, on_epoch=True, prog_bar=True, logger=True)
        return loss
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.001)

In [12]:
encoder = Encoder(30, 64, cell_type="LSTM", num_layers=2, bidirectional=True, dropout=0.1)
decoder = Decoder(150, 64, cell_type="LSTM", num_layers=3, bidirectional=True)
model = Seq2seq(encoder, decoder)

In [66]:
trainer = pl.Trainer(max_epochs=1)
trainer.fit(model, train_loader)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name    | Type    | Params
------------------------------------
0 | encoder | Encoder | 167 K 
1 | decoder | Decoder | 194 K 
------------------------------------
362 K     Trainable params
0         Non-trainable params
362 K     Total params
1.451     Total estimated model params size (MB)


Epoch 0:   0%|          | 0/1599 [13:38<?, ?it/s] 4.61it/s, v_num=186, train_loss_step=1.220, train_acc_step=0.000]
Epoch 0:   0%|          | 2/1599 [12:46<170:02:08, 383.30s/it, v_num=175, train_loss_step=4.800, train_acc_step=0.000]
Epoch 0:   0%|          | 2/1599 [12:14<162:51:33, 367.12s/it, v_num=176, train_loss_step=4.780, train_acc_step=0.000]
Epoch 0:   0%|          | 0/1599 [09:47<?, ?it/s]
Epoch 0:   0%|          | 1/1599 [08:56<238:04:40, 536.35s/it, v_num=178, train_loss_step=4.980, train_acc_step=0.000]
Epoch 0:   0%|          | 0/1599 [05:13<?, ?it/s]
Epoch 0:   0%|          | 0/1599 [04:55<?, ?it/s]
Epoch 0: 100%|██████████| 1599/1599 [06:32<00:00,  4.07it/s, v_num=186, train_loss_step=1.060, train_acc_step=0.000, val_loss_step=1.150, val_loss_epoch=1.200, val_acc=0.000, train_loss_epoch=1.310, train_acc_epoch=0.000]

`Trainer.fit` stopped: `max_epochs=1` reached.


Epoch 0: 100%|██████████| 1599/1599 [06:32<00:00,  4.07it/s, v_num=186, train_loss_step=1.060, train_acc_step=0.000, val_loss_step=1.150, val_loss_epoch=1.200, val_acc=0.000, train_loss_epoch=1.310, train_acc_epoch=0.000]


In [16]:
def convert_tensor_to_word(tensor, lang):
    int_to_lang = {0: 'SOS', 1: 'EOS', 2: 'PAD'}
    if lang == 'eng':
        int_to_lang.update({i-94: chr(i) for i in range(97, 123)})
    elif lang == 'hin':
        int_to_lang.update({i-2300: chr(i) for i in range(2304, 2432)})
    
    word = ''
    for i in tensor:
        word += int_to_lang[i.item()]
    return word

In [25]:
convert_tensor_to_word(model(convert_word_to_tensor('gharelu', 'eng')), 'hin')

'SOSSOSघरेलूEOSPADPADPADPADPADPADPADPADPADPADPADPADPADPADPADPADPAD'

In [26]:
sweep_config = {
    'method': 'bayes',
    'metric': {
        'name': 'val_loss',
        'goal': 'minimize'
    },
    'parameters': {
        'hidden_size': {
            'values': [64, 128, 256],
        },
        'encoder_num_layers': {
            'values': [1, 2, 3],
        },
        'decoder_num_layers': {
            'values': [1, 2, 3],
        },
        'bidirectional': {
            'values': [True, False],
        },
        'cell_type': {
            'values': ['LSTM', 'GRU'],
        },
    }
}

In [31]:
def sweep_fn():
    wandb.init()
    config = wandb.config
    dropout_val = 0
    encoder=Encoder(30, config.hidden_size, config.cell_type, num_layers=config.encoder_num_layers, bidirectional=config.bidirectional)
    decoder = Decoder(150, config.hidden_size, config.cell_type, num_layers=config.decoder_num_layers, bidirectional=config.bidirectional)
    model = Seq2seq(encoder, decoder)
    logger = WandbLogger(project='CS6910 Assignment 3', entity='cs20b075')
    trainer = pl.Trainer(max_epochs=5, precision=16, logger=logger)
    trainer.fit(model, train_loader)

In [28]:
wandb.login(key="8c780297be240a84f5c8b7d669cb158839b2637a")

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mcs20b075[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /home/sooraj/.netrc


True

In [None]:
sweep_id = wandb.sweep(sweep=sweep_config, project="CS6910 Assignment 3")
wandb.agent(sweep_id=sweep_id, function=sweep_fn, count=10)

In [None]:
wandb.agent(sweep_id="1aw4o8ik", function=sweep_fn, count=10, project="CS6910 Assignment 3")

In [None]:
wandb.finish()

# Adding attention to the Seq2Seq model

In [97]:
class AttnDecoder(pl.LightningModule):
    def __init__(self, output_size, hidden_size, attention_size, cell_type, num_layers=1, bidirectional=False, dropout=0):
        super().__init__()
        self.hidden_size = hidden_size
        self.embedding = nn.Embedding(output_size, hidden_size)
        if cell_type == 'LSTM':
            self.cell_type = nn.LSTM
        elif cell_type == 'GRU':
            self.cell_type = nn.GRU
        else:
            self.cell_type = nn.RNN
        self.first_cell = self.cell_type(hidden_size, hidden_size, bidirectional=bidirectional, batch_first=True)
        self.direction = 2 if bidirectional else 1
        self.rnns= nn.ModuleList([self.cell_type(hidden_size*self.direction, hidden_size, bidirectional=bidirectional, batch_first=True)]*(num_layers-1))
        self.out = nn.Linear(hidden_size*self.direction, output_size)
        self.softmax = nn.LogSoftmax(dim=-1)
        self.num_layers = num_layers

        self.Uattn = nn.Linear(hidden_size*self.direction, attention_size)
        self.Wattn = nn.Linear(hidden_size*self.direction, attention_size)
        self.Vattn = nn.Linear(attention_size, 1)

        self.attn_combine = nn.Linear(hidden_size + hidden_size*self.direction, hidden_size)

    def forward(self, input, hidden, encoder_outputs):
        # print("Am in the decoder")
        # print("Printing the shapes of everything here:")
        # print("Input shape:", input.shape)
        # print("Hidden shape:", hidden.shape)
        # print("Encoder outputs shape:", encoder_outputs.shape)
        encoder_outputs_flat = encoder_outputs.transpose(1, 2).flatten(2)
        hidden_flat = None
        if self.cell_type == nn.LSTM:
            hidden_flat = hidden[0].transpose(0, 1).flatten(1)
        else:
            hidden_flat = hidden.transpose(0, 1).flatten(1)
        # print("Flattened shapes:", encoder_outputs_flat.shape, hidden_flat.shape)
        # print("HIdden shapes:", self.Uattn.shape, self.Wattn.shape, self.Vattn.shape)
        encoder_part = self.Uattn(encoder_outputs_flat)
        # print("got past Uattn", encoder_outputs_flat.shape, encoder_part.shape)
        decoder_part = self.Wattn(hidden_flat.repeat(encoder_outputs.shape[0], 1, 1))
        # print("got past Wattn", decoder_part.shape)
        # ejt = torch.tanh(self.Uattn(encoder_outputs) + self.Wattn(hidden[0].repeat(1, encoder_outputs.shape[1], 1)))
        ejt = torch.tanh(encoder_part + decoder_part)
        at = self.Vattn(ejt).squeeze(-1)
        # print(at.shape)
        at = nn.functional.softmax(at, dim=0)
        at = at.transpose(0, 1).unsqueeze(1)
        # print("Attention", at.shape)
        encoder_outputs_flat = encoder_outputs_flat.transpose(0, 1)
        # print("Encoder outputs flat", encoder_outputs_flat.shape)
        context = torch.bmm(at, encoder_outputs_flat).squeeze(1)
        # print("Context", context.shape)
        
        output = self.embedding(input)
        # print("Output", output.shape)
        output = nn.functional.relu(output)
        output = torch.cat((output.squeeze(1), context), dim=-1).unsqueeze(1)
        output = self.attn_combine(output)
        output, hidden = self.first_cell(output, hidden)
        for i in range(self.num_layers-1):
            output, hidden = self.rnns[i](output, hidden)
        linear_output = self.out(output)
        output = self.softmax(self.out(output))
        if output.shape[0] == 1:
            output = output.squeeze(0)
        return output, hidden

In [98]:
class AttnSeq2seq(pl.LightningModule):
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder.to(self.device)
        self.decoder = decoder.to(self.device)

    def forward(self, input):
        
        self.encoder = self.encoder.to(self.device)
        self.decoder = self.decoder.to(self.device)
        
        batched = True if len(input.shape) > 1 else False
        if not batched:
            input = input.unsqueeze(0)
            target = target.unsqueeze(0)
        input = input.to(self.device)
        target = target.to(self.device)
        batch_size = input.shape[0]
        input_length = input.shape[1]
        target_length = target.shape[1]

        encoder_hidden = None
        encoder_hidden_outputs = torch.zeros(input_length, self.encoder.direction, batch_size, self.encoder.hidden_size, device=self.device)
        encoder_output_gate = torch.zeros(input_length, self.encoder.direction, batch_size, self.encoder.hidden_size, device=self.device)
        if self.encoder.cell_type == 'LSTM':
            a, b = [torch.zeros(self.encoder.direction, batch_size, self.encoder.hidden_size)]*2
            encoder_hidden = a.to(self.device), b.to(self.device)
        else:
            encoder_hidden = torch.zeros(self.encoder.direction, batch_size, self.encoder.hidden_size).to(self.device)
        for i in range(input_length):
            # print(input[:, i].shape, encoder_hidden.shape)
            _, encoder_hidden_out = self.encoder(input[:, i].unsqueeze(1), encoder_hidden)
            if self.encoder.cell_type == 'LSTM':
                encoder_hidden_outputs[i] = encoder_hidden_out[0]
                encoder_output_gate[i] = encoder_hidden_out[1]
            else:
                encoder_hidden_outputs[i] = encoder_hidden_out
        if self.encoder.cell_type == 'LSTM':
            decoder_hidden = encoder_hidden_outputs[-1], encoder_output_gate[-1]
        else:
            decoder_hidden = encoder_hidden_outputs[-1]
        decoder_input = target[:, 0].unsqueeze(1)
        for j in range(target_length):
            decoder_output, decoder_hidden = self.decoder(decoder_input, decoder_hidden, encoder_hidden_outputs)
            decoder_input = decoder_output.argmax(dim=-1)
        output_sequence = torch.tensor(output_sequence, device=self.device)
        if not batched:
            output_sequence = output_sequence.squeeze(0)
        return output_sequence
        
    def training_step(self, batch, batch_idx):
        input, target = batch
        
        self.encoder = self.encoder.to(self.device)
        self.decoder = self.decoder.to(self.device)
        
        batched = True if len(input.shape) > 1 else False
        if not batched:
            input = input.unsqueeze(0)
            target = target.unsqueeze(0)
        input = input.to(self.device)
        target = target.to(self.device)
        batch_size = input.shape[0]
        input_length = input.shape[1]
        target_length = target.shape[1]

        encoder_hidden = None
        encoder_hidden_outputs = torch.zeros(input_length, self.encoder.direction, batch_size, self.encoder.hidden_size, device=self.device)
        encoder_output_gate = torch.zeros(input_length, self.encoder.direction, batch_size, self.encoder.hidden_size, device=self.device)
        if self.encoder.cell_type == 'LSTM':
            a, b = [torch.zeros(self.encoder.direction, batch_size, self.encoder.hidden_size)]*2
            encoder_hidden = a.to(self.device), b.to(self.device)
        else:
            encoder_hidden = torch.zeros(self.encoder.direction, batch_size, self.encoder.hidden_size).to(self.device)
        for i in range(input_length):
            # print(input[:, i].shape, encoder_hidden.shape)
            _, encoder_hidden_out = self.encoder(input[:, i].unsqueeze(1), encoder_hidden)
            if self.encoder.cell_type == 'LSTM':
                encoder_hidden_outputs[i] = encoder_hidden_out[0]
                encoder_output_gate[i] = encoder_hidden_out[1]
            else:
                encoder_hidden_outputs[i] = encoder_hidden_out
        loss = 0
        correct_words = 0
        if self.encoder.cell_type == 'LSTM':
            decoder_hidden = encoder_hidden_outputs[-1], encoder_output_gate[-1]
        else:
            decoder_hidden = encoder_hidden_outputs[-1]
        if random.random() < 0.5: 
            decoder_input = target[:, 0].unsqueeze(1)
            correct = None
            for j in range(target_length):
                decoder_output, decoder_hidden = self.decoder(decoder_input, decoder_hidden, encoder_hidden_outputs)
                squeezed_output = decoder_output.squeeze(1)
                for i in range(batch_size):
                    loss += nn.functional.nll_loss(squeezed_output[i], target[i, j])
                decoder_input = target[:, j].unsqueeze(1)
                if correct is None:
                    correct = decoder_output.argmax(dim=-1) == target[:, j]
                else:
                    correct = (decoder_output.argmax(dim=-1) == target[:, j]) & correct
            correct_words = correct.sum()

        else:
            decoder_input = target[:, 0].unsqueeze(1)
            correct = None
            for j in range(target_length):
                decoder_output, decoder_hidden = self.decoder(decoder_input, decoder_hidden, encoder_hidden_outputs)
                squeezed_output = decoder_output.squeeze(1)
                for i in range(batch_size):
                    loss += nn.functional.nll_loss(squeezed_output[i], target[i, j])
                decoder_input = decoder_output.argmax(dim=-1)
                if correct is None:
                    correct = decoder_input == target[:, j]
                else:
                    correct = (decoder_input == target[:, j]) & correct
            correct_words = correct.sum()

        # print(correct_words, batch_size, correct_words/batch_size)
        reported_loss = loss / (batch_size * target_length)
        self.log('train_loss', reported_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        self.log('train_acc', correct_words/batch_size, on_step=True, on_epoch=True, logger=True, prog_bar=True)
        return loss
    def validation_step(self, batch, batch_idx):
        input, target = batch
        
        self.encoder = self.encoder.to(self.device)
        self.decoder = self.decoder.to(self.device)
        
        batched = True if len(input.shape) > 1 else False
        if not batched:
            input = input.unsqueeze(0)
            target = target.unsqueeze(0)
        input = input.to(self.device)
        target = target.to(self.device)
        batch_size = input.shape[0]
        input_length = input.shape[1]
        target_length = target.shape[1]

        encoder_hidden = None
        encoder_hidden_outputs = torch.zeros(input_length, self.encoder.direction, batch_size, self.encoder.hidden_size, device=self.device)
        encoder_output_gate = torch.zeros(input_length, self.encoder.direction, batch_size, self.encoder.hidden_size, device=self.device)
        if self.encoder.cell_type == 'LSTM':
            a, b = [torch.zeros(self.encoder.direction, batch_size, self.encoder.hidden_size)]*2
            encoder_hidden = a.to(self.device), b.to(self.device)
        else:
            encoder_hidden = torch.zeros(self.encoder.direction, batch_size, self.encoder.hidden_size).to(self.device)
        for i in range(input_length):
            # print(input[:, i].shape, encoder_hidden.shape)
            _, encoder_hidden_out = self.encoder(input[:, i].unsqueeze(1), encoder_hidden)
            if self.encoder.cell_type == 'LSTM':
                encoder_hidden_outputs[i] = encoder_hidden_out[0]
                encoder_output_gate[i] = encoder_hidden_out[1]
            else:
                encoder_hidden_outputs[i] = encoder_hidden_out
        loss = 0
        correct_words = 0
        if self.encoder.cell_type == 'LSTM':
            decoder_hidden = encoder_hidden_outputs[-1], encoder_output_gate[-1]
        else:
            decoder_hidden = encoder_hidden_outputs[-1]
        decoder_input = target[:, 0].unsqueeze(1)
        correct = None
        for j in range(target_length):
            decoder_output, decoder_hidden = self.decoder(decoder_input, decoder_hidden, encoder_hidden_outputs)
            squeezed_output = decoder_output.squeeze(1)
            for i in range(batch_size):
                loss += nn.functional.nll_loss(squeezed_output[i], target[i, j])
            decoder_input = decoder_output.argmax(dim=-1)
            if correct is None:
                correct = decoder_input == target[:, j]
            else:
                correct = (decoder_input == target[:, j]) & correct
        correct_words = correct.sum()

        # for i in range(batch_size):
        #     if self.encoder.cell_type == 'LSTM':
        #         decoder_hidden = encoder_hidden_outputs[i].view(self.decoder.direction, -1), encoder_output_gate[i].view(self.decoder.direction, -1)
        #     else:
        #         decoder_hidden = encoder_hidden_outputs[i].view(self.decoder.direction, -1)
        #     decoder_input = target[i, 0].unsqueeze(0)
        #     correct = True
        #     for j in range(target_length):
        #         decoder_output, decoder_hidden = self.decoder(decoder_input, decoder_hidden)
        #         loss += nn.functional.nll_loss(decoder_output, target[i, j])
        #         decoder_input = torch.tensor([decoder_output.argmax().item()]).to(self.device)
        #         if correct and target[i, j]!= decoder_output.argmax().item():
        #             correct = False
        #     if correct:
        #         correct_words  += 1
        reported_loss = loss / (batch_size * target_length)
        self.log('val_loss', reported_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        self.log('val_acc', correct_words/batch_size, on_epoch=True, prog_bar=True, logger=True)
        return loss
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.001)

In [99]:
encoder = Encoder(30, 64, cell_type="LSTM", num_layers=1, bidirectional=True, dropout=0.1)
decoder = AttnDecoder(150, 64, 72, cell_type="LSTM", num_layers=1, bidirectional=True)
model = AttnSeq2seq(encoder, decoder)

In [100]:
trainer = pl.Trainer(max_epochs=1)
trainer.fit(model, train_loader)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name    | Type        | Params
----------------------------------------
0 | encoder | Encoder     | 68.5 K
1 | decoder | AttnDecoder | 126 K 
----------------------------------------
194 K     Trainable params
0         Non-trainable params
194 K     Total params
0.780     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(


                                                                           

  rank_zero_warn(


Epoch 0:   0%|          | 1/1599 [09:35<255:20:40, 575.24s/it, v_num=213, train_loss_step=5.100, train_acc_step=0.000]
Epoch 0: 100%|██████████| 1599/1599 [04:33<00:00,  5.84it/s, v_num=216, train_loss_step=1.150, train_acc_step=0.000, val_loss_step=0.970, val_loss_epoch=0.994, val_acc=0.000489, train_loss_epoch=1.220, train_acc_epoch=0.000]

`Trainer.fit` stopped: `max_epochs=1` reached.


Epoch 0: 100%|██████████| 1599/1599 [04:33<00:00,  5.84it/s, v_num=216, train_loss_step=1.150, train_acc_step=0.000, val_loss_step=0.970, val_loss_epoch=0.994, val_acc=0.000489, train_loss_epoch=1.220, train_acc_epoch=0.000]


In [106]:
sweep_attn_config = {
    'method': 'bayes',
    'metric': {
        'name': 'val_loss',
        'goal': 'minimize'
    },
    'parameters': {
        'hidden_size': {
            'values': [64, 128, 256],
        },
        'encoder_num_layers': {
            'values': [1, 2, 3],
        },
        'bidirectional': {
            'values': [True, False],
        },
        'cell_type': {
            'values': ['LSTM', 'GRU'],
        },
    }
}

In [107]:
def sweep_attn_fn():
    wandb.init()
    config = wandb.config
    dropout_val = 0
    encoder=Encoder(30, config.hidden_size, config.cell_type, num_layers=config.encoder_num_layers, bidirectional=config.bidirectional)
    decoder = AttnDecoder(150, config.hidden_size, 64, config.cell_type, num_layers=1, bidirectional=config.bidirectional)
    model = AttnSeq2seq(encoder, decoder)
    logger = WandbLogger(project='CS6910 Assignment 3', entity='cs20b075')
    trainer = pl.Trainer(max_epochs=5, precision=16, logger=logger)
    trainer.fit(model, train_loader)

In [108]:
sweep_id = wandb.sweep(sweep=sweep_attn_config, project="CS6910 Assignment 3")
wandb.agent(sweep_id=sweep_id, function=sweep_attn_fn, count=10)

Create sweep with ID: 1t3u8y0a
Sweep URL: https://wandb.ai/cs20b075/CS6910%20Assignment%203/sweeps/1t3u8y0a


[34m[1mwandb[0m: Agent Starting Run: ai7hogtt with config:
[34m[1mwandb[0m: 	bidirectional: False
[34m[1mwandb[0m: 	cell_type: LSTM
[34m[1mwandb[0m: 	encoder_num_layers: 2
[34m[1mwandb[0m: 	hidden_size: 256
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


  rank_zero_warn(
  rank_zero_warn(
  rank_zero_warn(
Using bfloat16 Automatic Mixed Precision (AMP)
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name    | Type        | Params
----------------------------------------
0 | encoder | Encoder     | 1.1 M 
1 | decoder | AttnDecoder | 767 K 
----------------------------------------
1.8 M     Trainable params
0         Non-trainable params
1.8 M     Total params
7.312     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(


                                                                           

  rank_zero_warn(


Epoch 0:  10%|▉         | 155/1599 [01:14<11:37,  2.07it/s, v_num=ogtt, train_loss_step=1.090, train_acc_step=0.000]