# Library

In [78]:
import numpy as np
import urllib.request
from typing import List
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import lightning as L
from tqdm import tqdm

In [5]:
# data download from link
urllib.request.urlretrieve('http://www.gutenberg.org/files/11/11-0.txt', filename='gutenberg.txt')

('gutenberg.txt', <http.client.HTTPMessage at 0x1fd3f4ed670>)

# Proeprocessing

In [3]:
with open('./gutenberg.txt') as file:
    sentences = []
    for sentence in tqdm(file):
        sentence = sentence.strip()
        sentence = sentence.lower()
        sentence = sentence.encode().decode('ascii', 'ignore')
        if len(sentence) > 0:
            sentences.append(sentence)

3384it [00:00, 846111.76it/s]


In [4]:
all_sentences = ' '.join(sentences)
print(f'전체 문자열 길이: {len(all_sentences)}')

전체 문자열 길이: 140323


In [5]:
char_vocab = sorted(list(set(all_sentences)))
vocab_size = len(char_vocab)
print(f'vocab의 수: {vocab_size}')

vocab의 수: 43


In [6]:
char_to_index = dict((char, index) for index, char in enumerate(char_vocab))

index_to_char = {}
for key, value in char_to_index.items():
    index_to_char.update({value: key})

In [7]:
char_to_index

{' ': 0,
 '!': 1,
 "'": 2,
 '(': 3,
 ')': 4,
 '*': 5,
 ',': 6,
 '-': 7,
 '.': 8,
 '0': 9,
 '3': 10,
 ':': 11,
 ';': 12,
 '?': 13,
 '[': 14,
 ']': 15,
 '_': 16,
 'a': 17,
 'b': 18,
 'c': 19,
 'd': 20,
 'e': 21,
 'f': 22,
 'g': 23,
 'h': 24,
 'i': 25,
 'j': 26,
 'k': 27,
 'l': 28,
 'm': 29,
 'n': 30,
 'o': 31,
 'p': 32,
 'q': 33,
 'r': 34,
 's': 35,
 't': 36,
 'u': 37,
 'v': 38,
 'w': 39,
 'x': 40,
 'y': 41,
 'z': 42}

In [8]:
index_to_char

{0: ' ',
 1: '!',
 2: "'",
 3: '(',
 4: ')',
 5: '*',
 6: ',',
 7: '-',
 8: '.',
 9: '0',
 10: '3',
 11: ':',
 12: ';',
 13: '?',
 14: '[',
 15: ']',
 16: '_',
 17: 'a',
 18: 'b',
 19: 'c',
 20: 'd',
 21: 'e',
 22: 'f',
 23: 'g',
 24: 'h',
 25: 'i',
 26: 'j',
 27: 'k',
 28: 'l',
 29: 'm',
 30: 'n',
 31: 'o',
 32: 'p',
 33: 'q',
 34: 'r',
 35: 's',
 36: 't',
 37: 'u',
 38: 'v',
 39: 'w',
 40: 'x',
 41: 'y',
 42: 'z'}

In [9]:
seq_len = 60

n_samples = int(len(all_sentences) / seq_len)
print(f'샘플 수: {n_samples}')

샘플 수: 2338


In [10]:
data_X = []
data_y = []

for i in tqdm(range(n_samples)):
    temp = all_sentences[i*seq_len:(i+1)*seq_len]

    temp_encoded = [char_to_index[char] for char in temp]
    data_X.append(temp_encoded)

    temp_y = all_sentences[i*seq_len+1:(i+1)*seq_len+1]

    temp_encoded_y = [char_to_index[char] for char in temp_y]
    data_y.append(temp_encoded_y)


100%|██████████| 2338/2338 [00:00<00:00, 179839.40it/s]


In [11]:
train_x, train_y = data_X[:1400], data_y[:1400]
valid_x, valid_y = data_X[1400:1900], data_y[1400:1900]
test_x, test_y = data_X[1900:], data_y[1900:]

# Data Module

In [51]:
class GutenbergDataset(Dataset):
    def __init__(self, X: List[int], y: List[int], vocab_size: int):
        self.X = torch.tensor(X, dtype=torch.long)
        self.y = torch.tensor(y, dtype=torch.long)
        self.vocab_size = vocab_size
    
    def __len__(self):
        return len(self.X)
    
    def __getitem__(self, index):
        X = self.X[index]
        y = self.y[index]

        return {
            'X': F.one_hot(X, num_classes=self.vocab_size),
            'y': y,
        }

In [196]:
class GutenbergDataModule(L.LightningDataModule):
    def __init__(self, batch_size: int):
        super().__init__()
        self.batch_size = batch_size
        self.save_hyperparameters()
    
    def prepare(
        self,
        train: Dataset,
        valid: Dataset,
        test: Dataset,
        ):
        self.train = train
        self.valid = valid
        self.test = test
    
    def setup(self, stage: str = None):
        if stage == 'fit':
            self.train_data = self.train
            self.valid_data = self.valid
        
        if stage == 'test':
            self.test_data = self.test
    
    def train_dataloader(self):
        return DataLoader(
            dataset=self.train_data,
            batch_size=self.batch_size,
            shuffle=False,
        )

    def val_dataloader(self):
        return DataLoader(
            dataset=self.valid_data,
            batch_size=self.batch_size,
            shuffle=False,
        )

    def test_dataloader(self):
        return DataLoader(
            dataset=self.test_data,
            batch_size=self.batch_size,
            shuffle=False,
        )

# LightningModule

In [197]:
class GutenbergLightningModule(L.LightningModule):
    def __init__(self, input_dim, hidden_dim, vocab_size, learning_rate: float = 2e-5):
        super().__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.vocab_size = vocab_size
        self.learning_rate = learning_rate

        self.lstm1 = nn.LSTM(input_dim, hidden_dim, batch_first=True)
        self.lstm2 = nn.LSTM(hidden_dim, hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, vocab_size)

        self.save_hyperparameters()

    def forward(self, x):       # (batch, seq_len, vocab_size)
        x, _ = self.lstm1(x)    # (batch, seq_len, hidden_dim)
        x, _ = self.lstm2(x)    # (batch, seq_len, hidden_dim)
        x = self.fc(x)          # (batch, seq_len, vocab_size)

        return x
    
    def training_step(self, batch, batch_idx):
        X = batch.get('X').float()
        y = batch.get('y')

        outputs = self(X)   # nn.Sequential([self.lstm1, self.lstm2, self.fc])(X)

        outputs = outputs.permute(0, 2, 1)  # (batch, vocab_size, seq_len)
        loss = F.cross_entropy(outputs, y)

        return loss

    def validation_step(self, batch, batch_idx):
        X = batch.get('X').float()
        y = batch.get('y')

        outputs = self(X)

        outputs = outputs.permute(0, 2, 1)
        loss = F.cross_entropy(outputs, y)

        return loss
    
    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr=self.learning_rate)

# Train

## Hyperparameters

In [199]:
input_dim = len(char_to_index)
hidden_dim = 256
vocab_size = len(char_to_index)
batch_size = 256
seq_len = 64

## Dataset

In [200]:
train_dataset = GutenbergDataset(train_x, train_y, vocab_size)
valid_dataset = GutenbergDataset(valid_x, valid_y, vocab_size)
test_dataset = GutenbergDataset(test_x, test_y, vocab_size)

gutenberg_data_module = GutenbergDataModule(batch_size)
gutenberg_data_module.prepare(train_dataset, valid_dataset, test_dataset)

## Model

In [201]:
model = GutenbergLightningModule(input_dim, hidden_dim, vocab_size)

## Train

In [202]:
trainer = L.Trainer(max_epochs=10)
trainer.fit(
    model=model,
    datamodule=gutenberg_data_module,
)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

  | Name  | Type   | Params | Mode 
-----------------------------------------
0 | lstm1 | LSTM   | 308 K  | train
1 | lstm2 | LSTM   | 526 K  | train
2 | fc    | Linear | 11.1 K | train
-----------------------------------------
845 K     Trainable params
0         Non-trainable params
845 K     Total params
3.382     Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

C:\Users\MJH\AppData\Roaming\Python\Python39\site-packages\lightning\pytorch\trainer\connectors\data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=31` in the `DataLoader` to improve performance.
C:\Users\MJH\AppData\Roaming\Python\Python39\site-packages\lightning\pytorch\trainer\connectors\data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=31` in the `DataLoader` to improve performance.
C:\Users\MJH\AppData\Roaming\Python\Python39\site-packages\lightning\pytorch\loops\fit_loop.py:298: The number of training batches (6) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=10` reached.


# Generation

In [203]:
def generate_sentence(model: L.LightningModule, max_length: int, vocab_size: int = vocab_size):
    idx = [np.random.randint(vocab_size)]

    sentences = [index_to_char.get(idx[-1])]
    print(f'{idx[-1]}번 문자 {sentences[-1]}로 예측 시작')

    X = torch.zeros((1, max_length, vocab_size))

    with torch.no_grad():
        for i in range(max_length):
            X[0][i][idx] = 1
            idx = model(X[:, :i+1, :])[0].argmax(axis=-1)
            idx = idx[-1].item()
            sentences.append(index_to_char.get(idx))
    
    return ''.join(sentences)

In [215]:
generate_sentence(model, 60)

8번 문자 .로 예측 시작


'.**                                                          '