<a href="https://colab.research.google.com/github/yingzibu/a_inhibitor_design/blob/main/examples/experiments/LSTM_text_generation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data

# load ascii text and covert to lowercase
filename = '/content/pg11.txt'
raw_text = open(filename, 'r', encoding='utf-8').read()
raw_text = raw_text.lower()

# create mapping of unique chars to integers
chars = sorted(list(set(raw_text)))
char_to_int = dict((c, i) for i, c in enumerate(chars))

# summarize the loaded data
n_chars = len(raw_text)
n_vocab = len(chars)
print("Total Characters: ", n_chars)
print("Total Vocab: ", n_vocab)

# prepare the dataset of input to output pairs encoded as integers
seq_length = 100
dataX = []
dataY = []
for i in range(0, n_chars - seq_length, 1):
    seq_in = raw_text[i:i + seq_length]
    seq_out = raw_text[i + seq_length]
    dataX.append([char_to_int[char] for char in seq_in])
    dataY.append(char_to_int[seq_out])
n_patterns = len(dataX)
print("Total Patterns: ", n_patterns)

# reshape X to be [samples, time steps, features]
X = torch.tensor(dataX, dtype=torch.float32).reshape(n_patterns, seq_length, 1)
X = X / float(n_vocab)
y = torch.tensor(dataY)

class CharModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.lstm = nn.LSTM(input_size=1, hidden_size=256, num_layers=2, batch_first=True, dropout=0.2)
        self.dropout = nn.Dropout(0.2)
        self.linear = nn.Linear(256, n_vocab)
    def forward(self, x):
        x, _ = self.lstm(x)
        # take only the last output
        x = x[:, -1, :]
        # produce output
        x = self.linear(self.dropout(x))
        return x

n_epochs = 40
batch_size = 128
model = CharModel()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)

optimizer = optim.Adam(model.parameters())
loss_fn = nn.CrossEntropyLoss(reduction="sum")
loader = data.DataLoader(data.TensorDataset(X, y), shuffle=True, batch_size=batch_size)

best_model = None
best_loss = np.inf
for epoch in range(n_epochs):
    model.train()
    for X_batch, y_batch in loader:
        y_pred = model(X_batch.to(device))
        loss = loss_fn(y_pred, y_batch.to(device))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    # Validation
    model.eval()
    loss = 0
    with torch.no_grad():
        for X_batch, y_batch in loader:
            y_pred = model(X_batch.to(device))
            loss += loss_fn(y_pred, y_batch.to(device))
        if loss < best_loss:
            best_loss = loss
            best_model = model.state_dict()
        print("Epoch %d: Cross-entropy: %.4f" % (epoch, loss))

torch.save([best_model, char_to_int], "single-char.pth")

# Generation using the trained model
best_model, char_to_int = torch.load("single-char.pth")
n_vocab = len(char_to_int)
int_to_char = dict((i, c) for c, i in char_to_int.items())
model.load_state_dict(best_model)



Total Characters:  164093
Total Vocab:  65
Total Patterns:  163993
Epoch 0: Cross-entropy: 437746.6250
Epoch 1: Cross-entropy: 404203.5312
Epoch 2: Cross-entropy: 377130.8125
Epoch 3: Cross-entropy: 358123.3125
Epoch 4: Cross-entropy: 339874.5312
Epoch 5: Cross-entropy: 326531.5625
Epoch 6: Cross-entropy: 312470.2500
Epoch 7: Cross-entropy: 310198.8750
Epoch 8: Cross-entropy: 300224.6562
Epoch 9: Cross-entropy: 295012.7500
Epoch 10: Cross-entropy: 282412.5000
Epoch 11: Cross-entropy: 274227.4688
Epoch 12: Cross-entropy: 278698.2188
Epoch 13: Cross-entropy: 265955.4062
Epoch 14: Cross-entropy: 257718.2812
Epoch 15: Cross-entropy: 254925.1406
Epoch 16: Cross-entropy: 251602.4062
Epoch 17: Cross-entropy: 245785.7031
Epoch 18: Cross-entropy: 246324.1875
Epoch 19: Cross-entropy: 239376.0625
Epoch 20: Cross-entropy: 239084.9219
Epoch 21: Cross-entropy: 231531.5781
Epoch 22: Cross-entropy: 229064.5938
Epoch 23: Cross-entropy: 227330.9531
Epoch 24: Cross-entropy: 221778.4688
Epoch 25: Cross-en

FileNotFoundError: ignored

In [6]:
# randomly generate a prompt
# filename = "wonderland.txt"
# seq_length = 100
# raw_text = open(filename, 'r', encoding='utf-8').read()
# raw_text = raw_text.lower()
start = np.random.randint(0, len(raw_text)-seq_length)
prompt = raw_text[start:start+seq_length]
pattern = [char_to_int[c] for c in prompt]

model.eval()
print('Prompt: "%s"' % prompt)
with torch.no_grad():
    for i in range(1000):
        # format input array of int into PyTorch tensor
        x = np.reshape(pattern, (1, len(pattern), 1)) / float(n_vocab)
        x = torch.tensor(x, dtype=torch.float32)
        # generate logits as output from the model
        prediction = model(x.to(device))
        # convert logits into one character
        index = int(prediction.argmax())
        # print(index)
        result = int_to_char[index]

        print(result, end="")
        # append the new character into the prompt for the next iteration
        pattern.append(index)
        pattern = pattern[1:]
print()
print("Done.")

Prompt: "eam of wonderland of long ago: and how she
would feel with all their simple sorrows, and find a plea"
ser to the teiesce. 
“what _ serpent!” said the mock turtle.

“wes, i thall say the thing ” said the mock turtle.

“wes, i thall say the thing ” said the mock turtle.

“wes, i’ll seter sareer the sea,” said the mock turtle.

“wes, i thall say the thing ” said the mock turtle. 
“wes, i thall say the thing ” said the mock turtle.

“wes, i’ll seter sareer the sea,” said the mock turtle.

“wes, i thall say the thing ” said the mock turtle. 
“wes, i thall say the thing ” said the mock turtle.

“wes, i’ll seter sareer the sea,” said the mock turtle.

“wes, i thall say the thing ” said the mock turtle. 
“wes, i thall say the thing ” said the mock turtle.

“wes, i’ll seter sareer the sea,” said the mock turtle.

“wes, i thall say the thing ” said the mock turtle. 
“wes, i thall say the thing ” said the mock turtle.

“wes, i’ll seter sareer the sea,” said the mock turtle.

“wes, i thal

In [None]:
! wget https://www.gutenberg.org/cache/epub/11/pg11.txt

--2023-08-24 01:23:08--  https://www.gutenberg.org/cache/epub/11/pg11.txt
Resolving www.gutenberg.org (www.gutenberg.org)... 152.19.134.47, 2610:28:3090:3000:0:bad:cafe:47
Connecting to www.gutenberg.org (www.gutenberg.org)|152.19.134.47|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 174580 (170K) [text/plain]
Saving to: ‘pg11.txt’


2023-08-24 01:23:09 (2.31 MB/s) - ‘pg11.txt’ saved [174580/174580]

