In [10]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader

import pickle as pk

BATCH_SIZE = 32
EPOCHS = 5
VALIDATION_SPLIT = 0.02

train_data_path = "train_data.pk"

In [2]:
import pickle as pk

train_data = pk.load(open(train_data_path, "rb"))
X_train = train_data["X"]
X_train_oh = train_data["X_oh"]

alphabet_size = train_data["alphabet_size"]

len(X_train), len(X_train_oh)

(400000, 400000)

In [3]:
X_train, X_val, X_train_oh, X_val_oh = train_test_split(X_train, X_train_oh, test_size=VALIDATION_SPLIT)

In [4]:
class MyModel(nn.Module):
    def __init__(self, alphabet_size: int, embedding_dim: int, hidden_dim: int):
        super().__init__()

        self.embedding = nn.Embedding(alphabet_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=False)
        self.linear = nn.Linear(hidden_dim, hidden_dim)
        self.output = nn.Linear(hidden_dim, alphabet_size)
        
    def forward(self, x: torch.Tensor):
        x = self.embedding(x)
        x = self.lstm(x)[0]
        x = F.relu(x)
        x = F.relu(self.linear(x))
        x = F.softmax(self.output(x))
        return x

class SequenceDataset(Dataset):
    def __init__(self, X_train, X_train_oh):
        super().__init__()
        self.X_train = X_train
        self.X_train_oh = X_train_oh
        
    def __len__(self):
        return len(self.X_train)

    def __getitem__(self, idx):
        return torch.tensor(self.X_train[idx]), torch.tensor(self.X_train_oh[idx])

def get_model(input_shape, alphabet_size):
    OUTPUT_DIM = alphabet_size # sigmoid output

    input_layer = Input(shape=input_shape)
    x = Embedding(alphabet_size, 20)(input_layer)
    x = LSTM(64)(x)
    x = Dense(64, activation="relu")(x)
    x = Dense(OUTPUT_DIM, activation="softmax")(x)
    model = Model(input_layer, x)

    model.compile(
        loss="categorical_crossentropy",
        optimizer="adam",
        metrics=["categorical_crossentropy"]
    )
    return model

In [5]:
model = MyModel(alphabet_size, 20, 64)

In [6]:
train_set = SequenceDataset(X_train, X_train_oh)
train_dataloader = DataLoader(train_set, batch_size=1, shuffle=True)

In [13]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
loss_fn = nn.CrossEntropyLoss()

running_loss = 0.
last_loss = 0.
divisor = 0.

for i in range(EPOCHS):
    print("Epoch: ", i)
    #optimizer.zero_grad()
    for j, (x, x_oh) in enumerate(train_dataloader):
        optimizer.zero_grad()
        #print(list(x.size()), list(x_oh.size()))
        #break
        x = torch.permute(x, dims=[1,0])
        x_oh = torch.permute(x_oh, dims=[1,0, 2])
        
        outputs = model(x)
        loss = loss_fn(torch.squeeze(outputs), torch.squeeze(x_oh))
        loss.backward()
        
        #break

        # Adjust learning weights
        optimizer.step()

        # Gather data and report
        running_loss += loss.item()
        divisor += float( list(x.size())[1] )
        #if j % 32 == 0:
        #    optimizer.step()
        #    optimizer.zero_grad()
        if j % 1000 == 0:
            last_loss = running_loss / divisor # loss per batch
            print('  batch {} loss: {}'.format(j, last_loss))
            running_loss = 0.
            divisor = 0.

Epoch:  0
  batch 0 loss: 4.253195762634277


  x = F.softmax(self.output(x))


  batch 1000 loss: 4.146942563524034
  batch 2000 loss: 4.149731180683407
  batch 3000 loss: 4.151008807408341
  batch 4000 loss: 4.144310755690301
  batch 5000 loss: 4.143533850642614
  batch 6000 loss: 4.149032811476766
  batch 7000 loss: 4.15057930670868
  batch 8000 loss: 4.143346159605184
  batch 9000 loss: 4.1491986294334575
  batch 10000 loss: 4.1400208547669175
  batch 11000 loss: 4.140362903287863
  batch 12000 loss: 4.14102695297467
  batch 13000 loss: 4.146002909694655
  batch 14000 loss: 4.14032620810824
  batch 15000 loss: 4.150905581521444
  batch 16000 loss: 4.146035178802957
  batch 17000 loss: 4.154010558908638
  batch 18000 loss: 4.142508838471031
  batch 19000 loss: 4.146937599236313
  batch 20000 loss: 4.150960569770776
  batch 21000 loss: 4.150744295973952
  batch 22000 loss: 4.146199816968664
  batch 23000 loss: 4.142450999781576
  batch 24000 loss: 4.142107960464858
  batch 25000 loss: 4.136277543129467
  batch 26000 loss: 4.152200092281713
  batch 27000 loss: 4.

KeyboardInterrupt: 

In [None]:
model.save("model.keras")