# 10. Encoder-Decoder Architecture

In [1]:
import torch 
import torch.nn as nn
from torch.utils import data
from torch.nn import functional as F

import os
import re
import collections

import math
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt

from d2l import torch as d2l

## Encoder-Decoder

To deal with inputs with **difference length**, we implement the following **encoder-decoder architecture** where the **encoder** takes a variable-length sequence as input and transform it into a **fix-shape encoded state**:

![](http://d2l.ai/_images/encoder-decoder.svg)

The **decoder** acts as a conditional language model that predicts the **subsequent token** in the target sequence by taking in the **encoded input** and the **leftwards context** of the target sequence.

## Encoder

In the **encoder interface**, we specify that the encoder takes **variable-length sequences** as input X. 

The implementation will be provided by any model that **inherits** this base Encoder class.

In [2]:
class Encoder(nn.Module):
    
    def __init__(self, **kwargs):
        super(Encoder, self).__init__(**kwargs)
        
    def forward(self, X, *args):
        raise NotImplemetedError

## Decoder

In the following **decoder interface**, we add an additional **init_state** function to convert the **encoder output (enc_outputs)** into the encoded state. 

Note that this step may require **extra inputs**, such as the valid length of the input. To generate a **variable-length sequence** token by token, every time the decoder may map an input (e.g. the **generated token at the previous time step**) and the **encoded state** into an output token at the current time step.

In [None]:
class Decoder(nn.Module):
    
    def __init__(self, **kwargs):
        super(Decoder, self).__init__(**kwargs)
        
    def init_state(self, enc_outputs, *args):
        raise NotImplementedError
        
    def forward(self, X, state):
        raise NotImplementedError

## Combination

In [None]:
class EncoderDecoder(nn.Module):
    
    def __init__(self, encoder, decoder, **kwargs):
        super(EncoderDecoder, self).__init__(**kwargs)
        self.encoder = encoder
        self.decoder = decoder
        
    def forward(self, enc_X, dec_X, *args):
        enc_outputs = self.encoder(enc_X, *args)
        dec_state = self.decoder.init_state(enc_outputs, *args)
        return self.decoder(dec_X, dec_state)