In [22]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [23]:
import numpy as np

Goals: Test the following

- nn.Module
- nn.Embedding
- nn.LSTM
- nn.Conv

### Testing the modules

In [43]:
device = torch.device("mps")

#### 1.1. nn.Embedding

Situation: We have sentences and would like to embed them 

Notes: 
- [docs](https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html)
- `num_embedding` is vocabulary size - how many tokens 
- `embedding_dim` is the size of each embedding vector - how many feature per token
- `nn.Embedding` shape will be of shape [BATCH_SIZE, SEQ_LENGTH, EMBEDDING_DIM]
    - 400 sentences
    - 20 tokens per sentence
    - each token is a 128 dim vector 

In [59]:
VOCAB_SIZE = 100
BATCH_SIZE = 400 # num_sentences
SEQ_LENGTH = 20

tokenized_sentences = torch.randint(VOCAB_SIZE, size=(BATCH_SIZE, SEQ_LENGTH), device=device).long()

In [60]:
EMBEDDING_DIM=128
embed = nn.Embedding(num_embeddings=VOCAB_SIZE, embedding_dim=128, padding_idx=0).to(device)

In [61]:
t = embed(tokenized_sentences)

In [62]:
t.shape

torch.Size([400, 20, 128])

### 1.2. nn.LSTM

Situation: LSTM on an embedding

In [None]:
Token IDs (integers)
    ↓
nn.Embedding
    ↓
Embedding vectors
    ↓
nn.LSTM
    ↓
Hidden states / outputs
    ↓
classifier

Notes:
- [docs](https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html)
- Expected input is [SEQ_LENGTH, BATCH_SIZE, EMBEDDING_DIM]
    - but we can use `batch_first=True`to get input and output as (batch, seq_len, feature)
- `outputs`: all hidden state at each time step => shape: (seq_len, batch, hidden_dim)
- `hn`: last hidden state for each sequence => shape: (1, batch, hidden_dim)
- `cn`: last hidden cell for each sequences => shape: (1, batch, hidden_dim) 

In [66]:
HIDDEN_DIM = 256
lstm = nn.LSTM(input_size=EMBEDDING_DIM, hidden_size=HIDDEN_DIM, batch_first=True).to(device)

In [75]:
outputs, (hn, cn) = lstm(t)

### 1.3. nn.Linear

Situation: We want to classify

Notes:
- [docs](https://pytorch.org/docs/stable/generated/torch.nn.Linear.html)
- `logits`: score per class for each sentence => shape: (batch, num_classes)
- `probs`: probability of each class for each sentence => shape: (batch, num_classes)
- `preds`: class prediction for each sentence => shape: (batch, num_classes)

In [79]:
NUM_CLASSES = 10

linear = nn.Linear(HIDDEN_DIM, NUM_CLASSES).to(device) # classifier

In [95]:
logits = linear(hn.squeeze(0)) # hn: (1, 400, 256) => hn.squeeze(0): (400, 256)

In [96]:
logits.shape

torch.Size([400, 10])

In [104]:
probs = torch.softmax(logits, dim=1)
preds = torch.argmax(logits, dim=1)

In [108]:
linear(outputs).shape

torch.Size([400, 20, 10])