In [1]:
!wget https://www.manythings.org/anki/rus-eng.zip
!unzip rus-eng.zip

--2023-03-25 11:08:53--  https://www.manythings.org/anki/rus-eng.zip
Resolving www.manythings.org (www.manythings.org)... 173.254.30.110
Connecting to www.manythings.org (www.manythings.org)|173.254.30.110|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 15374406 (15M) [application/zip]
Saving to: ‘rus-eng.zip’


2023-03-25 11:08:57 (5.48 MB/s) - ‘rus-eng.zip’ saved [15374406/15374406]

Archive:  rus-eng.zip
  inflating: rus.txt                 
  inflating: _about.txt              


In [2]:
import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks import LearningRateMonitor
import sys, os

sys.path.append(os.path.join(os.getcwd(), "./src"))

from models.rnn import Seq2SeqRNN
from data.datamodule import DataManager

In [3]:
eng_prefixes = (
    "i am ",
    "i m ",
    "he is",
    "he s ",
    "she is",
    "she s ",
    "you are",
    "you re ",
    "we are",
    "we re ",
    "they are",
    "they re ",
)

def filter_func(x):
    len_filter = lambda x: max(len(x[0].split(" ")), len(x[1].split(" "))) <= 5
    prefix_filter = lambda x: x[0].startswith(eng_prefixes)
    return len_filter(x) and prefix_filter(x)

config = {
    "batch_size": 256,          # <--- size of batch
    "num_workers": 16,          # <--- num cpu to use in dataloader
    "filter": filter_func,      # <--- callable obj to filter data  
    "filename": "./rus.txt",    # <--- path to file with sentneces
    "lang1": "en",              # <--- name of the first lang    
    "lang2": "ru",              # <--- name of the second lang
    "reverse": False,           # <--- direct or reverse order in pairs
    "train_size": 0.8,          # <--- ratio of data pairs to use in train
    "run_name": "tutorial",     # <--- run name to logger and checkpoints
    "quantile": 0.95,           # <--- (1 - quantile) longest sentences will be removed
}

In [4]:
# Data manager
dm = DataManager(config)
dm.prepare_data()
dm.setup()

input_lang_n_words = dm.input_lang_n_words
output_lang_n_words = dm.output_lang_n_words

Reading from file: 100%|██████████| 464010/464010 [00:05<00:00, 80902.12it/s]


In [5]:
model = Seq2SeqRNN(
    encoder_vocab_size=input_lang_n_words,
    encoder_embedding_size=256,
    decoder_embedding_size=256,
    decoder_output_size=output_lang_n_words,
    lr=1e-3,
    output_lang_index2word=dm.train_dataset.output_lang.index2word,
)

In [6]:
# TB Logger
logger = TensorBoardLogger("lightning_logs", name=config["run_name"])

# Callbacks
checkpoint_callback = ModelCheckpoint(
    save_top_k=3,
    monitor="val_loss",
    mode="min",
    dirpath="runs/{}/".format(config["run_name"]),
    filename="{epoch:02d}-{step:d}-{val_loss:.4f}",
    verbose=True,
    every_n_epochs=1,
)
lr_monitor = LearningRateMonitor(logging_interval="step")

# Initialize a Trainer
trainer = pl.Trainer(
    accelerator="cuda",
    devices=[2],
    precision=16,
    max_epochs=50,
    min_epochs=1,
    callbacks=[lr_monitor, checkpoint_callback],
    check_val_every_n_epoch=1,
    logger=logger,
    log_every_n_steps=1,
)

  rank_zero_warn(
Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [7]:
trainer.fit(model, dm)

Reading from file: 100%|██████████| 464010/464010 [00:05<00:00, 82091.54it/s]
You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
Missing logger folder: lightning_logs/tutorial
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2]

  | Name      | Type       | Params
-----------------------------------------
0 | encoder   | EncoderRNN | 1.1 M 
1 | decoder   | DecoderRNN | 3.3 M 
2 | criterion | NLLLoss    | 0     
-----------------------------------------
4.4 M     Trainable params
0         Non-trainable params
4.4 M     Total params
17.465    Total estimated model params size (MB)
2023-03-25 11:09:28.243083: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is

Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]проблема не в ---> чрезвычайно ответственность расстроены пиво
я жду обеда ---> завидую находчивый твердая олимпиаду
мы так счастливы ---> тактична делать худой шляпе
я на вашей ---> голодны спокойная пьёт помолвлена
я предвзят ---> завидую находчивый твердая олимпиаду
вы владельцы ---> строен запад запад гольф
мы не в ---> чрезвычайно ответственность расстроены пиво
завтра мы не ---> мальчики шарф завидуете культурист
ей семнадцать лет ---> чрезвычайно ответственность расстроены пиво
ты совершенно здорова ---> тупорылый стреляю океанограф ведёте
мы все счастливы ---> чрезвычайно ответственность расстроены пиво
он ленивый ---> завидую находчивый твердая олимпиаду
мы не такие ---> тупорылый ожидаем говорите говорите
я устал от ---> строен запад запад гольф
он оскорблён ---> влюблена сердиты усердный выживем
Epoch 0: 100%|██████████| 41/41 [00:05<00:00,  7.33it/s, v_num=0, train_bleu=0.000, train_loss=3.120]проблема не в 

Epoch 0, global step 41: 'val_loss' reached 3.12502 (best 3.12502), saving model to '/home/toomuch/somov/machine-translation/fork/pytorch-machine-translation/runs/tutorial/epoch=00-step=41-val_loss=3.1250.ckpt' as top 3


Epoch 1:  44%|████▍     | 18/41 [00:04<00:06,  3.77it/s, v_num=0, train_bleu=0.000, train_loss=2.960]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


Now you can see tensorboard logs in two ways: launch tensorboard extension in jupyter notebook or use CLI method.

CLI:
1. Tap `tensorboard --logdir=./lightning_logs --port=6006`
2. Forward selected port in the ssh connection if you are working remote else just open `localhost:6006` in the browser


Jupyter:
1. Load extension: `%load_ext tensorboard`
2. Launch built-in tensorboard: `%tensorboard --logdir=./lightning_logs`

In [None]:
%load_ext tensorboard
%tensorboard --logdir=./lightning_logs

# Hints

#### Load model from checkpoint
Using `self.save_hyperparameters` in `__init__` body of `pl.LightningModule` allows to load model in this way:
```python
import pytorch_lightning as pl
model = Seq2SeqRNN.load_from_checkpoint('checkpoint.ckpt')
```

Or you can load checkpoint in natural pytorch way:
```python
model = Seq2SeqRNN(*args, **kwargs)
model.load_state_dict(torch.load('checkpoint.ckpt')['state_dict'])
```

#### Add custom metrics to logger
https://lightning.ai/docs/pytorch/stable/extensions/logging.html


#### Enable grad accumulation
In this example accumulation will be the following: 
1. from 0 to 15th epoch accumulate 4 batches
2. from 15th to 25th epoch accumulate 2 batches
3. from 25th epoch accumulate 1 batch

```python
from pytorch_lightning.callbacks import GradientAccumulationScheduler

accumulator = GradientAccumulationScheduler(scheduling={0: 4, 15: 2, 25: 1})
trainer = pl.Trainer(
    ...
    callbacks=[..., accumulator],
    ...
)
```

#### Configure learning rate scheduler

```python
def configure_optimizers(self):
    optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
    lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        mode='min',
        factor=0.1,
        patience=0,
        threshold=1e-2,
        threshold_mode='rel',
        cooldown=0,
        min_lr=0,
        eps=1e-09,
        verbose=True
    )
    lr_dict = {
        "scheduler": lr_scheduler,
        "interval": "epoch",
        "frequency": 1,
        "monitor": "val_loss"
    }
    return [optimizer], [lr_dict]
```

#### Other logger: WandB
You can use famous weights&biases logger which is natively supports in pytorch-lightning:
https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.loggers.wandb.html#module-lightning.pytorch.loggers.wandb