In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
!pip install torchmetrics
!pip install pytorch_lightning
import torchmetrics
import pytorch_lightning as pl
from torch import optim
import os
os.chdir('/content/drive/MyDrive/Colab Notebooks/NLP_HW2/')
#wiki_dataset = __import__('/content/drive/MyDrive/Colab Notebooks/NLP_HW2/dataset/wiki_dataset')
from dataset import wiki_dataset
from dataloader import wiki_dataloader
import pytorch_lightning.loggers as pl_loggers
import nltk
nltk.download('punkt')

Collecting torchmetrics
  Downloading torchmetrics-0.7.2-py3-none-any.whl (397 kB)
[?25l[K     |▉                               | 10 kB 27.5 MB/s eta 0:00:01[K     |█▋                              | 20 kB 7.6 MB/s eta 0:00:01[K     |██▌                             | 30 kB 7.1 MB/s eta 0:00:01[K     |███▎                            | 40 kB 6.7 MB/s eta 0:00:01[K     |████▏                           | 51 kB 4.0 MB/s eta 0:00:01[K     |█████                           | 61 kB 4.2 MB/s eta 0:00:01[K     |█████▊                          | 71 kB 4.1 MB/s eta 0:00:01[K     |██████▋                         | 81 kB 4.6 MB/s eta 0:00:01[K     |███████▍                        | 92 kB 4.9 MB/s eta 0:00:01[K     |████████▎                       | 102 kB 4.1 MB/s eta 0:00:01[K     |█████████                       | 112 kB 4.1 MB/s eta 0:00:01[K     |██████████                      | 122 kB 4.1 MB/s eta 0:00:01[K     |██████████▊                     | 133 kB 4.1 MB/s eta 0:0

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.


True

In [8]:
class LSTM1(pl.LightningModule):
    def __init__(self, n_vocab,
                 embedding_size,
                 hidden_size,
                 num_layers,
                 seq_size):
        super(LSTM1, self).__init__()
        self.seq_size = seq_size
        self.embedding_size=embedding_size
        self.hidden_size=hidden_size
        self.lstm = nn.LSTM(input_size = embedding_size,
                            hidden_size = hidden_size,
                            num_layers = num_layers,
                            batch_first=True, dropout=0.5)
        # nn.utils.clip_grad_norm_(self.lstm.parameters(), clip) - between backwards and optimizer step
        self.prev_state = None
        self.embed = nn.Embedding(n_vocab, embedding_size)
        self.loss = nn.CrossEntropyLoss()
        self.fc = nn.Linear(embedding_size, n_vocab) #transpose of embedding layer; need same weights but transposed

    def forward(self, x, prev_state):
        x = self.embed(x)
        #x = torch.flatten(x, start_dim=1)
        if self.prev_state is None:
            x , state = self.lstm(x)
        else:
            x, state = self.lstm(x, prev_state)
        x = x[:, -1, :]
        logits = self.fc(x) #logit from running x through linear layer

        return logits, state

    def configure_optimizers(self):
        return optim.SGD(self.parameters(), lr=1e-3)

    def training_step(self, batch, batch_idx):
        data, label = batch
        logits, self.prev_state = self.forward(data, self.prev_state)
        self.prev_state = [state.detach() for state in self.prev_state] #holding onto the numbers, not the gradient
        #need to move tensors back to gpu/cuda
        print(f'{logits.shape}, {label.shape}')
        # l2_norm = sum(p.pow(2.0).sum() for p in self.parameters()).item()
        # l1_norm = sum(p.abs().sum() for p in model.parameters()).item()
        loss = self.loss(logits, label)  # + l2_norm + l1_norm
        tensorboard_logs = {'loss': {'train': loss.detach()}}
        self.log("training loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        return {"loss": loss, "log": tensorboard_logs}

    def validation_step(self, batch, batch_idx):
        data, label = batch
        logits, self.prev_state = self.forward(data, self.prev_state)
        loss = self.loss(logits, label)
        tensorboard_logs = {'loss': {'val': loss.detach()}}
        self.log("validation loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        return {"loss": loss, "log": tensorboard_logs}

    def test_step(self, batch, batch_idx):
        data, label = batch
        logits, self.prev_state = self.forward(data, self.prev_state)
        loss = self.loss(logits, label)
        tensorboard_logs = {'loss': {'test': loss.detach()}}
        self.log("test loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        return {"loss": loss, "log": tensorboard_logs}



In [3]:

# Load datasets
train = wiki_dataset('./wiki.train.txt', training=True, token_map='create', window=30)
valid = wiki_dataset('./wiki.valid.txt', training=False, token_map=train.token_map, window=30)
test = wiki_dataset('./wiki.test.txt', training=False, token_map=train.token_map, window=30)
datasets = [train, valid, test]

In [4]:
# Load dataloader
dataloader = wiki_dataloader(datasets=datasets, batch_size=20)

In [5]:
# Make model and train
model = LSTM1(n_vocab=len(train.unique_tokens),
              num_layers=2,
              seq_size=30,
              embedding_size=100,
              hidden_size=100)

In [10]:
tb_logger = pl_loggers.TensorBoardLogger("./lightning_logs/", name="ff")
trainer = pl.Trainer(gradient_clip_val=0.5, logger=tb_logger, max_epochs=10, gpus=1)
trainer.fit(model, dataloader)
result = trainer.test(model, dataloader)
print(result)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
  f"DataModule.{name} has already been called, so it will not be called again. "
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type             | Params
-------------------------------------------
0 | lstm  | LSTM             | 161 K 
1 | embed | Embedding        | 2.9 M 
2 | loss  | CrossEntropyLoss | 0     
3 | fc    | Linear           | 2.9 M 
-------------------------------------------
6.0 M     Trainable params
0         Non-trainable params
6.0 M     Total params
23.884    Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

RuntimeError: ignored