In [None]:
from google.colab import drive
drive.mount("/content/drive")
import os
import sys
from datetime import datetime

drive_project_root = "/content/drive/MyDrive/#fastcampus"
sys.path.append(drive_project_root)
!pip install -r "/content/drive/MyDrive/#fastcampus/requirements.txt"

In [None]:
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Select the Runtime > "Change runtime type" menu to enable a GPU accelerator, ')
  print('and then re-execute this cell.')
else:
  print(gpu_info)

In [None]:
# For data loading.
from typing import List
from typing import Dict
from typing import Union
from typing import Any
from typing import Optional
from typing import Iterable
from abc import abstractmethod
from abc import ABC
from datetime import datetime
from functools import partial
from collections import Counter
from collections import OrderedDict
import random
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
import torch
from torch import nn
from torch import optim
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader
import pytorch_lightning as pl
from pprint import pprint

from torchtext import data
from torchtext import datasets
from torchtext.datasets import Multi30k
from torchtext.data.utils import get_tokenizer
from torchtext.data.functional import to_map_style_dataset
from torchtext.vocab import Vocab, build_vocab_from_iterator, vocab
import spacy

# For configuration
from omegaconf import DictConfig
from omegaconf import OmegaConf
import hydra
from hydra.core.config_store import ConfigStore

# For logger
from torch.utils.tensorboard import SummaryWriter
import wandb
os.environ["WANDB_START_METHOD"]="thread"

In [None]:
from data_utils import dataset_split
from config_utils import flatten_dict
from config_utils import register_config
from config_utils import configure_optimizers_from_cfg
from config_utils import get_loggers
from config_utils import get_callbacks
from custom_math import softmax

In [None]:
# download eng/d data.
!python -m spacy download en
!python -m spacy download en_core_web_sm
!python -m spacy download de
!python -m spacy download de_core_news_sm

In [None]:
# practice data first go to dataconfig

# data configs
data_spacy_de_en_cfg = {
    "name": "spacy_de_en",
    "data_root": os.path.join(os.getcwd(), "data"),
    "tokenizer": "spacy",
    "src_lang": "de",
    "tgt_lang": "en",
    "src_index": 0,
    "tgt_index": 1,
    "vocab": {
        "special_symbol2index": {
            # Define special symbols and indices
            # UNK_IDX, PAD_IDX, BOS_IDX, EOS_IDX = 0, 1, 2, 3
            # Make sure the tokens are in order of their indices to properly insert them in vocab
            '<unk>': 0,
            '<pad>': 1,
            '<bos>': 2,
            '<eos>': 3,
        },
        "special_first": True,
        "min_freq": 2
    }
}

data_cfg = OmegaConf.create(data_spacy_de_en_cfg)



In [None]:
# get dataset
# data_root = os.path.join(os.getcwd(), "data")

train_data, valid_data, test_data = Multi30k(data_cfg.data_root)

test_data = to_map_style_dataset(test_data)


In [None]:
# Create source and target language tokenizer. Make sure to install the dependencies.
# pip install -U spacy
# python -m spacy download en_core_web_sm
# python -m spacy download de_core_news_sm

def get_token_transform(data_cfg: DictConfig) -> dict:
    token_transform = {}
    token_transform[data_cfg.src_lang] = get_tokenizer(
        data_cfg.tokenizer, language=data_cfg.src_lang
    )
    token_transform[data_cfg.tgt_lang] = get_tokenizer(
        data_cfg.tokenizer, language=data_cfg.tgt_lang
    )
    return token_transform

token_transform = get_token_transform(data_cfg)

In [None]:
# helper function to yield list of tokens
def yield_tokens(
    data_iter: Iterable, lang: str, lang2index: Dict[str, int]
) -> List[str]:

    for data_sample in data_iter:
        yield token_transform[lang](data_sample[lang2index[lang]])

def get_vocab_transform(data_cfg: DictConfig) -> dict:
    vocab_transform = {}
    for ln in [data_cfg.src_lang, data_cfg.tgt_lang]:
        # Training data Iterator
        train_iter = Multi30k(
            split='train', language_pair=(data_cfg.src_lang, data_cfg.tgt_lang)
        )
        # Create torchtext's Vocab object
        vocab_transform[ln] = build_vocab_from_iterator(
            yield_tokens(
                train_iter,
                ln,
                {
                    data_cfg.src_lang: data_cfg.src_index,
                    data_cfg.tgt_lang: data_cfg.tgt_index
                }
            ),
            min_freq=data_cfg.vocab.min_freq,
            specials=list(data_cfg.vocab.special_symbol2index.keys()),
            special_first=data_cfg.vocab.special_first,
        )

    # Set UNK_IDX as the default index. This index is returned when the token is not found.
    # If not set, it throws RuntimeError when the queried token is not found in the Vocabulary.
    for ln in [data_cfg.src_lang, data_cfg.tgt_lang]:
        vocab_transform[ln].set_default_index(
            data_cfg.vocab.special_symbol2index["<unk>"]
        )
    return vocab_transform

vocab_transform = get_vocab_transform(data_cfg)

In [None]:
print(vocab_transform["de"]["<unk>"])
print(vocab_transform["en"]["<unk>"])
print(vocab_transform["en"]["hello"], vocab_transform["en"]["world"])

In [None]:
# helper function to club together sequential operations
def sequential_transforms(*transforms):
    def func(txt_input):
        for transform in transforms:
            txt_input = transform(txt_input)
        return txt_input
    return func

# function to add BOS/EOS and create tensor for input sequence indices
def tensor_transform(token_ids: List[int], bos_index: int, eos_index: int):
    return torch.cat((torch.tensor([bos_index]),
                      torch.tensor(token_ids),
                      torch.tensor([eos_index])))
    
# src and tgt language text transforms to convert raw strings into tensors indices
def get_text_transform(data_cfg):
    text_transform = {}
    for ln in [data_cfg.src_lang, data_cfg.tgt_lang]:
        text_transform[ln] = sequential_transforms(
            token_transform[ln], #Tokenization
            vocab_transform[ln], #Numericalization
            partial(
                tensor_transform,
                bos_index=data_cfg.vocab.special_symbol2index["<bos>"],
                eos_index=data_cfg.vocab.special_symbol2index["<eos>"],
            )
        ) # Add BOS/EOS and create tensor
    return text_transform

text_transform = get_text_transform(data_cfg)

In [None]:
print(text_transform["en"]("hello"))
print(text_transform["en"]("hello,"))
print(text_transform["en"]("hello, how"))
print(text_transform["en"]("hello, how are you ?"))

In [None]:
# function to collate data samples into batch tesors
def collate_fn(batch, data_cfg: DictConfig):
    src_batch, tgt_batch = [], []
    for src_sample, tgt_sample in batch:
        src_batch.append(text_transform[data_cfg.src_lang](src_sample.rstrip("\n")))
        tgt_batch.append(text_transform[data_cfg.tgt_lang](tgt_sample.rstrip("\n")))

    src_batch = pad_sequence(src_batch, padding_value=data_cfg.vocab.special_symbol2index["<pad>"])
    tgt_batch = pad_sequence(tgt_batch, padding_value=data_cfg.vocab.special_symbol2index["<pad>"])
    return src_batch, tgt_batch

def get_collate_fn(cfg: DictConfig):
    return partial(collate_fn, data_cfg=cfg.data)

def get_multi30k_dataloader(
    split_mode: str, language_pair, batch_size: int, collate_fn
):
    iter = Multi30k(split=split_mode, language_pair=language_pair)
    dataset = to_map_style_dataset(iter)
    dataloader = torch.utils.data.DataLoader(
        dataset, batch_size=batch_size, collate_fn=collate_fn
    )
    return dataloader


In [None]:
SEED = 1234

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

In [None]:
def _text_postprocessing(res: List[str]) -> str:
    if "<eos>" in res:
        res = res[:res.index("<eos>")]
    if "<pad>" in res:
        res = res[:res.index("<pad>")]
    res = " ".join(res).replace("<bos>", "")
    return res

class BaseTranslateLightningModule(pl.LightningModule):
    def __init__(self, cfg: DictConfig):
        super().__init__()
        self.cfg = cfg
        self.loss_function = torch.nn.CrossEntropyLoss(
            ignore_index=cfg.data.vocab.special_symbol2index["<pad>"]
        )
    
    def configure_optimizers(self):
        self._optimizers, self._schedulers = configure_optimizers_from_cfg(
            self.cfg, self
        )
        return self._optimizers, self._schedulers
    
    @abstractmethod
    def forward(x):
        raise NotImplementedError()
    
    def _forward(self, src, tgt, mode: str, teacher_forcing_ratio: float = 0.5):
        
        assert mode in ["train", "val", "test"]
        
        # get predictions
        tgt_inputs = tgt[:-1, :] # delete ends...
        outputs = self(src, tgt_inputs, teacher_forcing_ratio=teacher_forcing_ratio)
        tgt_outputs = tgt[1:, :]  # remove start tokens..

        loss = self.loss_function(
            outputs.reshape(-1, outputs.shape[-1]),
            tgt_outputs.reshape(-1)
        )

        logs_detail = {
            f"{mode}_src": src,
            f"{mode}_tgt": tgt,
            f"{mode}_results": outputs,
        }

        if mode in ["val", "test"]:
            _, tgt_results = torch.max(outputs, dim=2)
        
            src_texts = []
            tgt_texts = []
            res_texts = []

            for src_i in torch.transpose(src, 0, 1).detach().cpu().numpy().tolist():
                res = vocab_transform[self.cfg.data.src_lang].lookup_tokens(src_i)
                src_texts.append(_text_postprocessing(res))
            
            for tgt_i in torch.transpose(tgt, 0, 1).detach().cpu().numpy().tolist():
                res = vocab_transform[self.cfg.data.tgt_lang].lookup_tokens(tgt_i)
                tgt_texts.append(_text_postprocessing(res))

            for tgt_res_i in torch.transpose(tgt_results, 0, 1).detach().cpu().numpy().tolist():
                res = vocab_transform[cfg.data.tgt_lang].lookup_tokens(tgt_res_i)
                res_texts.append(_text_postprocessing(res))

            text_result_summary = {
                f"{mode}_src_text": src_texts,
                f"{mode}_tgt_text": tgt_texts,
                f"{mode}_results_text": res_texts,
            }
            print(f"{self.global_step} step: \n src_text: {src_texts[0]}, \n tgt_text: {tgt_texts[0]}, \n result_text:{res_texts[0]}")
            logs_detail.update(text_result_summary)

        return {f"{mode}_loss": loss}, logs_detail
    
    def training_step(self, batch, batch_idx):
        src, tgt = batch[0], batch[1]

        logs, logs_detail = self._forward(src, tgt, "train", self.cfg.model.teacher_forcing_ratio)
        self.log_dict(logs)
        logs["loss"] = logs["train_loss"]
        return logs
    
    def validation_step(self, batch, batch_idx):
        src, tgt = batch[0], batch[1]
        logs, logs_detail = self._forward(src, tgt, "val", 0.0)
        self.log_dict(logs)
        logs["loss"] = logs["val_loss"]
        logs.update(logs_detail)
        
        return logs
    
    def test_step(self, batch, batch_idx):
        src, tgt = batch[0], batch[1]
        logs, logs_detail = self._forward(images, labels, "test", 0.0)
        self.log_dict(logs)
        logs["loss"] = logs["test_loss"]
        logs.update(logs_detail)
        # wandb_logger, tensorboard_logger = self.logger.experiment
        # wandb_logger.log(logs_detail)
        # self.log_dict(logs)
        return logs


In [None]:
# weight initialization
def init_weights(model: Union[nn.Module, pl.LightningModule]):
    for name, param in model.named_parameters():
        nn.init.uniform_(param.data, -0.08, 0.08)


In [None]:
class LSTMEncoder(nn.Module):
    def __init__(
        self,
        input_dim: int,
        embed_dim: int,
        hidden_dim: int,
        n_layers: int,
        dropout: float
    ):
        super().__init__()
        
        self.hidden_dim = hidden_dim
        self.n_layers = n_layers
        
        self.embedding = nn.Embedding(input_dim, embed_dim)
        self.rnn = nn.LSTM(embed_dim, hidden_dim, n_layers, dropout = dropout)
        
        self.dropout = nn.Dropout(dropout)

        self.apply(init_weights)
        
    def forward(self, src):
        # src = [src len, batch size]
        embedded = self.dropout(self.embedding(src))
        
        # embedded = [src len, batch size, emb dim]
        outputs, (hidden, cell) = self.rnn(embedded)
        
        # outputs = [src len, batch size, hid dim * n directions]
        # hidden = [n layers * n directions, batch size, hid dim]
        # cell = [n layers * n directions, batch size, hid dim]
        
        # outputs are always from the top hidden layer
        return hidden, cell


class LSTMDecoder(nn.Module):
    def __init__(
        self,
        output_dim: int,
        embed_dim: int,
        hidden_dim: int,
        n_layers: int,
        dropout: float,
    ):
        super().__init__()
        
        self.output_dim = output_dim
        self.hidden_dim = hidden_dim
        self.n_layers = n_layers
        
        self.embedding = nn.Embedding(output_dim, embed_dim)
        self.rnn = nn.LSTM(embed_dim, hidden_dim, n_layers, dropout = dropout)
        self.fc_out = nn.Linear(hidden_dim, output_dim)
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, input, hidden, cell):        
        # input = [batch size]
        # hidden = [n layers * n directions, batch size, hid dim]
        # cell = [n layers * n directions, batch size, hid dim]

        # n directions in the decoder will both always be 1, therefore:
        # hidden = [n layers, batch size, hid dim]
        # context = [n layers, batch size, hid dim]
        
        input = input.unsqueeze(0)
        
        # input = [1, batch size]
        embedded = self.dropout(self.embedding(input))
        
        # embedded = [1, batch size, emb dim]
        output, (hidden, cell) = self.rnn(embedded, (hidden, cell))
        
        # output = [seq len, batch size, hid dim * n directions]
        # hidden = [n layers * n directions, batch size, hid dim]
        # cell = [n layers * n directions, batch size, hid dim]
        
        # seq len and n directions will always be 1 in the decoder, therefore:
        # output = [1, batch size, hid dim]
        # hidden = [n layers, batch size, hid dim]
        # cell = [n layers, batch size, hid dim]
        
        prediction = self.fc_out(output.squeeze(0))
        # prediction = [batch size, output dim]

        return prediction, hidden, cell


class LSTMSeq2Seq(BaseTranslateLightningModule):
    def __init__(self, cfg: DictConfig):
        super().__init__(cfg)
        # encoder, decoder, device
        
        self.encoder = LSTMEncoder(**cfg.model.enc)
        self.decoder = LSTMDecoder(**cfg.model.dec)
        # self.device = device
        
        assert self.encoder.hidden_dim == self.decoder.hidden_dim, \
            "Hidden dimensions of encoder and decoder must be equal!"
        assert self.encoder.n_layers == self.decoder.n_layers, \
            "Encoder and decoder must have equal number of layers!"
        
        self.apply(init_weights)
        
    def forward(self, src, trg, teacher_forcing_ratio = 0.5):
        
        # src = [src len, batch size]
        # trg = [trg len, batch size]
        # teacher_forcing_ratio is probability to use teacher forcing
        # e.g. if teacher_forcing_ratio is 0.75 we use ground-truth inputs 75% of the time
        
        batch_size = trg.shape[1]
        trg_len = trg.shape[0]
        trg_vocab_size = self.decoder.output_dim
        
        #tensor to store decoder outputs
        outputs = torch.zeros(trg_len, batch_size, trg_vocab_size).to(self.device)
        
        #last hidden state of the encoder is used as the initial hidden state of the decoder
        hidden, cell = self.encoder(src)
        
        #first input to the decoder is the <sos> tokens
        input = trg[0,:]
        
        for t in range(1, trg_len):
            
            #insert input token embedding, previous hidden and previous cell states
            #receive output tensor (predictions) and new hidden and cell states
            output, hidden, cell = self.decoder(input, hidden, cell)
            
            #place predictions in a tensor holding predictions for each token
            outputs[t] = output
            
            #decide if we are going to use teacher forcing or not
            teacher_force = random.random() < teacher_forcing_ratio
            
            #get the highest predicted token from our predictions
            top1 = output.argmax(1) 
            
            #if teacher forcing, use actual next token as next input
            #if not, use predicted token
            input = trg[t] if teacher_force else top1
        
        return outputs


In [None]:

# data configs
data_spacy_de_en_cfg = {
    "name": "spacy_de_en",
    "data_root": os.path.join(os.getcwd(), "data"),
    "tokenizer": "spacy",
    "src_lang": "de",
    "tgt_lang": "en",
    "src_index": 0,
    "tgt_index": 1,
    "vocab": {
        "special_symbol2index": {
            # Define special symbols and indices
            # UNK_IDX, PAD_IDX, BOS_IDX, EOS_IDX = 0, 1, 2, 3
            # Make sure the tokens are in order of their indices to properly insert them in vocab
            '<unk>': 0,
            '<pad>': 1,
            '<bos>': 2,
            '<eos>': 3,
        },
        "special_first": True,
        "min_freq": 2
    }
}

data_cfg = OmegaConf.create(data_spacy_de_en_cfg)

# get dataset
# data_root = os.path.join(os.getcwd(), "data")
train_data, valid_data, test_data = Multi30k(data_cfg.data_root)

token_transform = get_token_transform(data_cfg)
vocab_transform = get_vocab_transform(data_cfg)

In [None]:
# model configs
model_translate_lstm_seq2seq_cfg = {
    "name": "LSTMSeq2Seq",
    "out_dim": len(vocab_transform[data_cfg.src_lang]),
    "enc": {
        "input_dim": len(vocab_transform[data_cfg.src_lang]),
        "embed_dim": 256,
        "hidden_dim": 256,
        "n_layers": 2,
        "dropout": 0.5,
    },
    "dec": {
        "output_dim": len(vocab_transform[data_cfg.tgt_lang]),
        "embed_dim": 256,
        "hidden_dim": 256,
        "n_layers": 2,
        "dropout": 0.5,
    },
    "teacher_forcing_ratio": 0.5
}

# optimizer configs
opt_cfg = {
    "optimizers": [
        {
            "name": "RAdam",
            "kwargs": {
                "lr": 1e-3,
            }
        }
    ],
    "lr_schedulers": [
        {
            "name": None,
            "kwargs": {
                "warmup_end_steps": 1000
            }
        },
    ]
}

_merged_cfg_presets = {
    "LSTM_seq2seq_de_en_translate": {
        "opt": opt_cfg,
        "data": data_spacy_de_en_cfg,
        "model": model_translate_lstm_seq2seq_cfg,
    },
}

# clear config instance first
hydra.core.global_hydra.GlobalHydra.instance().clear()

# register preset configs
register_config(_merged_cfg_presets)

# initialize & make config
## select mode here ##
# .................. #
hydra.initialize(config_path=None)
cfg = hydra.compose("LSTM_seq2seq_de_en_translate")

# override some cfg
run_name = f"{datetime.now().isoformat(timespec='seconds')}-{cfg.model.name}-{cfg.data.name}"

# Define other train configs & log_configs 
# Merge configs into one & register it to Hydra.
project_root_dir = os.path.join(
    drive_project_root, "runs", "de_en_translate"
)
save_dir = os.path.join(project_root_dir, run_name)
run_root_dir = os.path.join(project_root_dir, run_name)

train_cfg = {
    "train_batch_size": 128,
    "val_batch_size": 32,
    "test_batch_size": 32,
    "train_val_split": [0.9, 0.1],
    "run_root_dir": run_root_dir,
    "trainer_kwargs": {
        "accelerator": "dp",
        "gpus": "0",
        "max_epochs": 50,
        "val_check_interval": 1.0,
        "log_every_n_steps": 100,
        "flush_logs_every_n_steps": 100,
    }
}

# logger config
log_cfg = {
    "loggers": {
        "WandbLogger": {
            "project": "fastcampus_de_en_translate_tutorials",
            "name": run_name,
            "tags": ["fastcampus_de_en_translate_tutorials"],
            "save_dir": run_root_dir,
        },
        "TensorBoardLogger": {
            "save_dir": project_root_dir,
            "name": run_name,
            "log_graph": True,
        }
    },
    "callbacks": {
        "ModelCheckpoint": {
            "save_top_k": 3,
            "monitor": "val_loss",
            "mode": "min",
            "verbose": True,
            "dirpath": os.path.join(run_root_dir, "weights"),
            "filename": "{epoch}-{val_loss:.3f}",
        },
        "EarlyStopping": {
            "monitor": "val_loss",
            "mode": "min",
            "patience": 3,
            "verbose": True,
        }
    }
}

# unlock config & set train_cfg & log_cfg
OmegaConf.set_struct(cfg, False)
cfg.train = train_cfg 
cfg.log = log_cfg

# lock config
OmegaConf.set_struct(cfg, True)
print(OmegaConf.to_yaml(cfg))

In [None]:
# dataloader def
train_dataloader = get_multi30k_dataloader(
    "train",
    (cfg.data.src_lang, cfg.data.tgt_lang),
    cfg.train.train_batch_size,
    collate_fn=get_collate_fn(cfg)
)
val_dataloader = get_multi30k_dataloader(
    "valid",
    (cfg.data.src_lang, cfg.data.tgt_lang),
    cfg.train.val_batch_size,
    collate_fn=get_collate_fn(cfg)
)
test_dataloader = get_multi30k_dataloader(
    "test",
    (cfg.data.src_lang, cfg.data.tgt_lang),
    cfg.train.test_batch_size,
    collate_fn=get_collate_fn(cfg)
)

In [None]:
# model definition
def get_pl_model(cfg: DictConfig, checkpoint_path: Optional[str] = None):
    if cfg.model.name == "LSTMSeq2Seq":
        model = LSTMSeq2Seq(cfg)
    else:
        raise NotImplementedError("not implemented model")
    
    if checkpoint_path is not None:
        model = model.load_from_checkpoint(checkpoint_path=checkpoint_path)
    return model

model = None
model = get_pl_model(cfg)
print(model)

In [None]:
# pytorch-lightning trainer def
logger = get_loggers(cfg)
callbacks = get_callbacks(cfg)

trainer = pl.Trainer(
    callbacks=callbacks,
    logger=logger,
    default_root_dir=cfg.train.run_root_dir,
    num_sanity_val_steps=2,
    **cfg.train.trainer_kwargs,
)

In [None]:
trainer.fit(model, train_dataloader, val_dataloader)

In [None]:
def post_evaluate(cfg, model, src, device: str = "cpu"):
    model.eval()

    src_texts = []
    tgt_texts = []
    res_texts = []
    for i, (src, tgt) in enumerate(test_dataloader):
        src = src.to(device)
        tgt = tgt.to(device)

        logits = model(src, tgt, teacher_forcing_ratio=0.0)
        _, tgt_results = torch.max(logits, dim=2)
        
        # Convert to [seq_size, batch, ...] -> [batch, seq_size, ...]
        src = torch.transpose(src, 0, 1)
        tgt = torch.transpose(tgt, 0, 1)
        tgt_results = torch.transpose(tgt_results, 0, 1)

        for src_i in src.detach().cpu().numpy().tolist():
            res = vocab_transform[cfg.data.src_lang].lookup_tokens(src_i)
            if "<eos>" in res:
                res = res[:res.index("<eos>")]
            if "<pad>" in res:
                res = res[:res.index("<pad>")]
            res = " ".join(res).replace("<bos>", "")
            src_texts.append(res)
        
        for tgt_i in tgt.detach().cpu().numpy().tolist():
            res = vocab_transform[cfg.data.tgt_lang].lookup_tokens(tgt_i)
            if "<eos>" in res:
                res = res[:res.index("<eos>")]
            if "<pad>" in res:
                res = res[:res.index("<pad>")]
            res = " ".join(res).replace("<bos>", "")
            tgt_texts.append(res)
        
        print(tgt_results)
        for tgt_res_i in tgt_results.detach().cpu().numpy().tolist():
            res = vocab_transform[cfg.data.tgt_lang].lookup_tokens(tgt_res_i)

            if "<eos>" in res:
                res = res[:res.index("<eos>")]
            if "<pad>" in res:
                res = res[:res.index("<pad>")]
            
            res = " ".join(res).replace("<bos>", "")
            res_texts.append(res)
        
        # # print(src_texts[-1])
        # # print(tgt_texts[-1])
        # # print(res_texts[-1])

        # print('check::', res_texts)
        # from pprint import pprint
        # pprint(list(zip(src_texts, tgt_texts, res_texts)))
        # break

    assert len(src_texts) == len(tgt_texts) and len(tgt_texts) == len(res_texts)
    return list(zip(src_texts, tgt_texts, res_texts))


checkpoint_path = os.path.join(project_root_dir, "2021-08-16T10:06:43-LSTMSeq2Seq-spacy_de_en", "weights", "epoch=22-val_loss=3.856-val_acc=0.000.ckpt")
model = model.load_from_checkpoint(cfg=cfg, checkpoint_path=checkpoint_path)
evaluate(cfg, model, test_dataloader)

In [None]:

def train_epoch(model, optimizer):
    model.train()
    losses = 0
    train_iter = Multi30k(
        split='train', language_pair=(SRC_LANG, TGT_LANG)
    )
    train_dataloader = torch.utils.data.DataLoader(
        train_iter, batch_size=BATCH_SIZE, collate_fn=collate_fn
    )

    for src, tgt in train_dataloader:
        src = src.to(DEVICE)
        tgt = tgt.to(DEVICE)

        tgt_input = tgt[:-1, :]

        logits = model(src, tgt_input)

        # src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input)

        # logits = model(src, tgt_input, src_mask, tgt_mask,src_padding_mask, tgt_padding_mask, src_padding_mask)

        optimizer.zero_grad()

        tgt_out = tgt[1:, :]

        loss = loss_fn(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))
        loss.backward()

        optimizer.step()
        losses += loss.item()

    return losses / len(train_dataloader)

def evaluate(model):
    model.eval()
    losses = 0

    val_iter = Multi30k(split='valid', language_pair=(SRC_LANG, TGT_LANG))
    val_dataloader = DataLoader(
        val_iter, batch_size=BATCH_SIZE, collate_fn=collate_fn
    )

    for src, tgt in val_dataloader:
        src = src.to(DEVICE)
        tgt = tgt.to(DEVICE)

        tgt_input = tgt[:-1, :]

        # src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input)

        logits = model(src, tgt_input)
        # logits = model(src, tgt_input, src_mask, tgt_mask,src_padding_mask, tgt_padding_mask, src_padding_mask)

        tgt_out = tgt[1:, :]

        loss = loss_fn(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))
        losses += loss.item()

    return losses / len(val_dataloader)

In [None]:
i = 0
for i, epoch in enumerate(tqdm(range(1, NUM_EPOCHS+1), position=0, leave=True, desc=f"epoch_{i}:")):
    start_time = timer()
    train_loss = train_epoch(model, optimizer)
    end_time = timer()
    val_loss = evaluate(model)
    print((f"Epoch: {epoch}, Train loss: {train_loss:.3f}, Val loss: {val_loss:.3f}, "f"Epoch time = {(end_time - start_time):.3f}s"))

    print(translate(model, "Eine Gruppe von Menschen steht vor einem Iglu ."))