In [1]:
from textgen.data.modules import SentenceCompletionIterableSplitFromStrings
from textgen.model.transformer import Transformer, TransformerModel
from textgen.generation import TransformerGreedySentenceGenerator, TransformerProbabilisticSentenceGenerator
from pytorch_lightning import Trainer
import torch
import os

In [2]:
train_text = """Lorem ipsum dolor sit amet, consectetur adipiscing elit. Duis in sapien a turpis ullamcorper iaculis. 
                Vivamus ut mi sed nisl maximus vehicula. Aliquam non augue eget enim tempus posuere ut id dui. 
                Pellentesque habitant morbi tristique senectus et netus et malesuada fames ac turpis egestas. 
                Mauris sed bibendum lorem. Nam laoreet nibh ac rutrum ultrices. 
                Fusce quis porttitor orci, a vestibulum justo. Maecenas varius sapien justo, id fringilla risus lobortis rhoncus. 
                Vestibulum nec consectetur arcu. Maecenas varius non diam eu laoreet.
                Quisque eu mattis magna. Donec rhoncus ante sit amet augue egestas imperdiet. 
                Nulla facilisi. Vestibulum eget orci sed ante euismod tempus sit amet eu ante. 
                Cras vestibulum eget eros non rhoncus. Curabitur sit amet nisl leo. 
                Mauris egestas sapien tristique sagittis ultricies. Cras cursus quam id dui cursus, nec tincidunt diam fermentum. 
                Cras a lectus lorem. Etiam vitae tincidunt odio. Aliquam sapien ex, eleifend at ante non, gravida vestibulum eros.
                Morbi et elementum lacus. Morbi vel sagittis urna. Sed eu scelerisque libero. 
                Suspendisse volutpat iaculis velit vel finibus. Pellentesque eleifend lacus ac auctor ultricies. 
                Vivamus fermentum pulvinar viverra. Pellentesque luctus facilisis lacinia. 
                Sed efficitur, tellus nec rutrum cursus, neque magna sollicitudin enim, quis efficitur massa justo et urna."""
val_text =   """Etiam eget ipsum tincidunt, lobortis metus id, eleifend arcu. Pellentesque consectetur placerat quam ut varius. 
                Proin porta, elit et volutpat accumsan, quam est eleifend nisi, in tincidunt lacus ipsum quis mauris. 
                Nulla condimentum eu nibh sed ornare."""
test_text =  """Etiam ligula velit, molestie et semper vel, finibus id sem. 
                Donec pharetra nisl erat, nec scelerisque neque congue non. Praesent blandit elit sed ipsum porttitor maximus. 
                Mauris quis ipsum mollis, ullamcorper ante sit amet, efficitur risus. 
                Maecenas ipsum arcu, aliquam ut orci non, laoreet porta risus. 
                Nullam euismod blandit libero, quis consectetur est lacinia sed. Quisque consectetur ante ut orci iaculis feugiat. 
                Mauris blandit, dui vel maximus luctus, ligula eros tempus mi, vitae dapibus dui elit in purus. 
                Maecenas sed euismod orci. Nunc auctor sit amet diam et malesuada. 
                Vestibulum eget lorem dapibus eros pellentesque rutrum. Aliquam gravida tellus vel porttitor tincidunt. 
                Vivamus libero lectus, egestas vitae gravida a, lobortis feugiat leo."""

In [3]:
max_length = 20
batch_size = 64
num_workers = os.cpu_count()

In [4]:
dm = SentenceCompletionIterableSplitFromStrings(train_text, val_text, test_text, 
                                                max_length=max_length, batch_size=batch_size, num_workers=num_workers)
vocab_size = len(dm.train_dataset.sentences.tokens)

In [5]:
d_model = 64
d_ff = 16
num_heads = 4
num_layers = 4
drop_out_rate = 0.01

In [6]:
t = Transformer(TransformerModel(vocab_size, vocab_size, d_model, d_ff, num_heads, num_layers, drop_out_rate, max_length),
                dm.train_dataset.pad_id, lr=0.01)

In [7]:
trainer = Trainer(max_epochs=30, gpus=int(torch.cuda.is_available()))

GPU available: True, used: True
TPU available: False, using: 0 TPU cores


In [8]:
trainer.fit(t, dm)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name   | Type             | Params
--------------------------------------------
0 | model  | TransformerModel | 248 K 
1 | loss   | NLLLoss          | 0     
2 | metric | Accuracy         | 0     
--------------------------------------------
248 K     Trainable params
0         Non-trainable params
248 K     Total params
0.993     Total estimated model params size (MB)


Epoch 0: : 8it [00:00, 11.29it/s, loss=3.44, v_num=8]                 
Validating: 0it [00:00, ?it/s][A
Validating: 0it [00:00, ?it/s][A
Epoch 0: : 10it [00:00, 10.20it/s, loss=3.44, v_num=8]
Epoch 0: : 17it [00:01, 15.10it/s, loss=3.44, v_num=8]
Epoch 1: : 8it [00:00, 11.97it/s, loss=2.71, v_num=8] 
Validating: 0it [00:00, ?it/s][A
Validating: 0it [00:00, ?it/s][A
Epoch 1: : 12it [00:00, 12.65it/s, loss=2.71, v_num=8]
Epoch 1: : 17it [00:01, 15.94it/s, loss=2.71, v_num=8]
Epoch 2: : 8it [00:00, 12.30it/s, loss=1.88, v_num=8] 
Validating: 0it [00:00, ?it/s][A
Validating: 0it [00:00, ?it/s][A
Epoch 2: : 12it [00:00, 12.44it/s, loss=1.88, v_num=8]
Epoch 2: : 17it [00:01, 15.84it/s, loss=1.88, v_num=8]
Epoch 3: : 8it [00:00, 12.11it/s, loss=1.21, v_num=8] 
Validating: 0it [00:00, ?it/s][A
Validating: 0it [00:00, ?it/s][A
Epoch 3: : 12it [00:00, 12.46it/s, loss=1.21, v_num=8]
Epoch 3: : 17it [00:01, 15.88it/s, loss=1.21, v_num=8]
Epoch 4: : 8it [00:00, 11.93it/s, loss=0.797, v_num=

In [9]:
trainer.test()

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: 8it [00:00, 18.64it/s]
--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_loss': 3.9968996047973633, 'test_metric': 0.25229737162590027}
--------------------------------------------------------------------------------


[{'test_metric': 0.25229737162590027, 'test_loss': 3.9968996047973633}]

In [10]:
t.eval()
g_greedy = TransformerGreedySentenceGenerator(t, dm.train_dataset)
g_probabilistic = TransformerProbabilisticSentenceGenerator(t, dm.train_dataset)

In [11]:
g_greedy.generate("Lorem ipsum")

'Lorem ipsum.'

In [12]:
g_probabilistic.generate("Lorem ipsum")

'Lorem ipsum nisl rutrum pulvinar non a.'