In [2]:
import torch 
import torch.nn as nn
import torch.nn.functional as F

In [3]:
class Encoder(nn.Module):
    def __init__(self,input_dim,emb_dim,hidden_dim):
        super().__init__()
        self.embedding = nn.Embedding(input_dim,emb_dim)
        self.rnn=nn.GRU(emb_dim,hidden_dim)
    def forward(self,src):
        embedded = self.embedding(src)
        outputs,hidden = self.rnn(embedded)
        return hidden
    

In [4]:
class Decoder(nn.Module):
    def __init__(self,output_dim,emb_dim,hidden_dim):
        super().__init__()
        self.embedding= nn.Embedding(output_dim,emb_dim)
        self.rnn=nn.GRU(emb_dim,hidden_dim)
        self.fc=nn.Linear(hidden_dim,output_dim)

    def forward(self,input,hidden):
        input=input.unsqueeze(0)
        embedded=self.embedding(input)
        output,hidden=self.rnn(embedded,hidden)
        prediction=self.fc(output.squeeze(0))
        return prediction,hidden

In [6]:
class Seq2SeqModel(nn.Module):
    def __init__(self,encoder,decoder,device):
        super().__init__()
        self.encoder=encoder
        self.decoder=decoder
        self.device=device

    def forward(self,src,trg=None,max__len=10,teacher_forcing_ratio=0.5):
        batch_size-=src.shape[1]
        trg_vocab_size=self.decoder.fc.out_features
        outputs=[]

        hidden=self.encoder(src)

        input=torch.zeros(batch_size,dtype=torch.long).to(self.device)

        for t in range(max__len):
            output,hidden=self.decoder(input,hidden)
            top1=output.argmax(1)
            outputs.append(top1.unsqueeze(0))
            if trg is not None and t < trg.shape[0] and torch.rand(1).item() < teacher_forcing_ratio:
                input=trg[t]
            else:
                input=top1
        outputs=torch.cat(outputs,dim=0)
        return outputs
            

In [7]:
VOCAB_SIZE=10
EMB_DIM=8 
HID_DIOM=16
SEQ_LEN=5 
BATCH_SIZE=2

encoder=Encoder(VOCAB_SIZE,EMB_DIM,HID_DIOM)
decoder=Decoder(VOCAB_SIZE,EMB_DIM,HID_DIOM)
device=torch.device('cpu')
model=Seq2SeqModel(encoder,decoder,device).to(device)
src=torch.randint(1,VOCAB_SIZE,(SEQ_LEN,BATCH_SIZE)).to(device)
trg=torch.randint(1,VOCAB_SIZE,(SEQ_LEN,BATCH_SIZE)).to(device)

outputs=model(src,trg,max__len=SEQ_LEN,teacher_forcing_ratio=0.75)

print(src.T)

print(outputs.T)

print(trg.T)


UnboundLocalError: cannot access local variable 'batch_size' where it is not associated with a value