# Task 1: Use LSTM for text classification. Sentiment classification using sample movie review dataset. 
- Represent words using 1-hot encoding
- Encode words as 1-hot encoded vectors and feed them to an LSTM

In [4]:
import torch
import torch.nn as nn
import numpy as np
import random
seed = 42
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)

class SimpleRNN(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, rnn_type='rnn'):
        super(SimpleRNN, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.rnn_type = rnn_type
        if rnn_type == 'gru':
            self.rnn = nn.GRU(input_size, hidden_size, num_layers, batch_first=True)
        elif rnn_type == 'lstm':
            self.rnn = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        elif rnn_type == "rnn":
            self.rnn = nn.RNN(input_size, hidden_size, num_layers, batch_first=True)
        else:
            raise ValueError("Invalid RNN type")
        # Output layer for binary classification (1 node)
        self.fc = nn.Linear(hidden_size, 1)

    def forward(self, x):
        # Initialize hidden state with zeros
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
        if self.rnn_type == "lstm":
            c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
            out, _ = self.rnn(x, (h0, c0))
        else:
            out, _ = self.rnn(x, h0)

        # Extract the last hidden state output
        out = out[:, -1, :]

        # Pass the last hidden state output to fully connected layer
        out = self.fc(out)
        # Apply sigmoid function to convert to a probability
        out = torch.sigmoid(out)
        return out

# Example usage:
input_size = 10   # Number of input features
hidden_size = 10  # Number of features in the hidden state
num_layers = 1    # Number of stacked RNN layers

# Create the RNN model instance
model = SimpleRNN(input_size, hidden_size, num_layers, rnn_type='lstm')


print(model)

In [29]:
from collections import Counter
# Example sentiment analysis data
texts = [
    "I loved the movie",  # Positive sentiment
    "I hated the movie",  # Negative sentiment
    "A great movie. Fantastic characters.",  # Positive sentiment
    "Poor plot, boring scenes.",  # Negative sentiment
]
Y = torch.tensor([1, 0, 1, 0], dtype=torch.float32)

# Create vocabulary
token_counter = Counter(token for text in texts for token in text.split())
vocab = {token: i for i, token in enumerate(token_counter)}
VOCAB_SIZE = len(vocab)
# One-Hot Encode Texts
def one_hot_encode(text, vocab):
    encoding = np.zeros((len(text.split()), len(vocab)), dtype=int)
    for i, token in enumerate(text.split()):
        if token in vocab:  # Check if token is in vocab
            encoding[i, vocab[token]] = 1
    return encoding

In [30]:
# Prepare data
encoded_texts = [one_hot_encode(text, vocab) for text in texts]
max_length = max(len(text) for text in encoded_texts)

# Pad the sequences and convert X to a tensor
padded_texts = [np.pad(text, ((0, max_length - len(text)), (0, 0)), 'constant') for text in encoded_texts]
X_padded = torch.tensor(padded_texts, dtype=torch.float32)  # Shape: [batch_size, max_length, VOCAB_SIZE]



In [31]:
X_padded.shape

torch.Size([4, 5, 14])

In [32]:
Y.shape

torch.Size([4])

In [23]:
input_size = VOCAB_SIZE
hidden_size = 50  
num_layers = 1
learning_rate = 0.01
num_epochs = 100
model = SimpleRNN(input_size, hidden_size, num_layers, rnn_type='lstm')

In [24]:
# pip install wandb
import wandb
wandb.login()



True

In [42]:
import wandb
def train(train_model, x, y, optimizer_name="adam", model_type = "RNN"):
    criterion = nn.BCELoss()
    lr = 0.01
    if optimizer_name == "adam":
        optimizer = torch.optim.Adam(train_model.parameters(), lr=lr)
    elif optimizer_name == "rmsprop":
        optimizer = torch.optim.RMSprop(train_model.parameters(), lr=lr)
    elif optimizer_name == "sgd":
        optimizer = torch.optim.SGD(train_model.parameters(), lr=lr)
    else:
        raise ValueError("Invalid optimizer.")
    wandb.init(project='dat550_rnn_text', reinit=True, name=f'run_{model_type}_{optimizer_name}')
    # Training loop (for demonstration, let's say 100 epochs)
    num_epochs = 100
    for epoch in range(num_epochs):
        pred = train_model(x)
        loss = criterion(pred.squeeze(), y.float())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        print(epoch, loss.item())
        wandb.log({"epoch": epoch, "loss": loss})


    print(f'Final Loss: {loss.item()}')
    wandb.finish()
    return model

In [43]:

train(model, x=X_padded, y=Y, model_type="lstm")


0 2.1154362173092522e-07
1 1.8160166348479834e-07
2 1.8146941727081867e-07
3 1.5154509469539335e-07
4 1.5143002940476435e-07
5 1.2152069928106357e-07
6 1.2141916272412345e-07
7 9.152162050440893e-08
8 9.143045076598355e-08
9 9.134417666700756e-08
10 9.126316768970355e-08
11 6.138522223864129e-08
12 6.131274687959376e-08
13 6.124383844507975e-08
14 6.117872430877469e-08
15 6.111746131409745e-08
16 3.1257727073352726e-08
17 3.1202191053125716e-08
18 3.1148818635529096e-08
19 3.109779811438784e-08
20 3.1049246729253355e-08
21 3.100320000726242e-08
22 3.095966150112872e-08
23 3.091859213100179e-08
24 3.0879917289894365e-08
25 3.084355526539184e-08
26 3.080940658151121e-08
27 3.077736110412843e-08
28 3.0747312251833137e-08
29 3.071915699592864e-08
30 3.069277099143619e-08
31 3.06680512096591e-08
32 3.064490528004171e-08
33 3.0623215963032635e-08
34 3.0602897993503575e-08
35 3.058386610632624e-08
36 7.637002696903039e-10
37 7.457945927491494e-10
38 7.279356561973316e-10
39 7.102199939268417e

0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
loss,█▇▅▄▄▃▃▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,99.0
loss,0.0


SimpleRNN(
  (rnn): LSTM(14, 50, batch_first=True)
  (fc): Linear(in_features=50, out_features=1, bias=True)
)

# Train LSTM using the SGD, Adam and RMSProp optimizers.

In [44]:
simple_lstm = SimpleRNN(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, rnn_type="lstm")
train(simple_lstm, x=X_padded, y=Y, model_type="lstm", optimizer_name="sgd")
train(simple_lstm, x=X_padded, y=Y, model_type="lstm", optimizer_name="adam")
train(simple_lstm, x=X_padded, y=Y, model_type="lstm", optimizer_name="rmsprop")

0 0.6960824728012085
1 0.6960626840591431
2 0.6960430145263672
3 0.6960232853889465
4 0.6960036754608154
5 0.6959840655326843
6 0.6959645748138428
7 0.6959450244903564
8 0.6959255337715149
9 0.6959060430526733
10 0.6958866715431213
11 0.6958673000335693
12 0.6958478093147278
13 0.6958285570144653
14 0.6958092451095581
15 0.6957899332046509
16 0.6957707405090332
17 0.6957516670227051
18 0.6957324743270874
19 0.6957133412361145
20 0.6956942081451416
21 0.6956751346588135
22 0.6956561803817749
23 0.6956372261047363
24 0.695618212223053
25 0.6955991983413696
26 0.695580244064331
27 0.695561408996582
28 0.695542573928833
29 0.695523738861084
30 0.6955049633979797
31 0.6954861879348755
32 0.695467472076416
33 0.6954487562179565
34 0.6954301595687866
35 0.6954114437103271
36 0.6953927874565125
37 0.6953742504119873
38 0.6953556537628174
39 0.6953371167182922
40 0.6953186392784119
41 0.6953001618385315
42 0.6952817440032959
43 0.6952633261680603
44 0.6952449083328247
45 0.6952265501022339
46 0

0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
loss,███▇▇▇▇▇▇▆▆▆▆▆▅▅▅▅▅▄▄▄▄▄▄▃▃▃▃▃▃▂▂▂▂▂▂▁▁▁

0,1
epoch,99.0
loss,0.69427


0 0.6942522525787354
1 0.6838241815567017
2 0.67153400182724
3 0.6529476642608643
4 0.6238479614257812
5 0.5792327523231506
6 0.5127093195915222
7 0.4224400222301483
8 0.3173147439956665
9 0.20733854174613953
10 0.11310096085071564
11 0.04737548530101776
12 0.01903168298304081
13 0.0082006910815835
14 0.0042136237025260925
15 0.002495621331036091
16 0.001624763011932373
17 0.0011288704117760062
18 0.0008224203484132886
19 0.0006213767919689417
20 0.0004834167193621397
21 0.00038535590283572674
22 0.00031367799965664744
23 0.00026005582185462117
24 0.00021916825789958239
25 0.00018744374392554164
26 0.00016245830920524895
27 0.00014252823893912137
28 0.00012645219976548105
29 0.0001133006953750737
30 0.00010251045023323968
31 9.34990675887093e-05
32 8.597419946454465e-05
33 7.962190284160897e-05
34 7.421668851748109e-05
35 6.95547932991758e-05
36 6.559822213603184e-05
37 6.214310269569978e-05
38 5.91148818784859e-05
39 5.6476830650353804e-05
40 5.419855369837023e-05
41 5.216540012042969

0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
loss,██▇▅▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,99.0
loss,3e-05


0 2.9370654374361038e-05
1 3.8123111724853516
2 5.704899787902832
3 1.0666062831878662
4 1.337188482284546
5 0.33512210845947266
6 0.2433239221572876
7 0.17142261564731598
8 0.097152940928936
9 0.050504766404628754
10 0.030599597841501236
11 0.023140035569667816
12 0.018418574705719948
13 0.013997294008731842
14 0.012892886996269226
15 0.012012801133096218
16 0.01124551985412836
17 0.010521749965846539
18 0.009757516905665398
19 0.00885146763175726
20 0.008013758808374405
21 0.007507922127842903
22 0.007119504734873772
23 0.006766036618500948
24 0.006415664218366146
25 0.0060475110076367855
26 0.005667844321578741
27 0.005323977209627628
28 0.005040121730417013
29 0.004789636004716158
30 0.004507836885750294
31 0.0041071041487157345
32 0.0039116558618843555
33 0.0037967958487570286
34 0.0036948174238204956
35 0.0036010455805808306
36 0.003513556206598878
37 0.003431259887292981
38 0.0033535375259816647
39 0.003279798896983266
40 0.003209623508155346
41 0.003142738714814186
42 0.0030787

0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
loss,▁█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,99.0
loss,0.0012


SimpleRNN(
  (rnn): LSTM(14, 50, batch_first=True)
  (fc): Linear(in_features=50, out_features=1, bias=True)
)

# Task 3: Implement a SimpleLSTMAttention which adds an attention layer on top of LSTM.

In [45]:
class LSTMWithAttention(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers):
        super(LSTMWithAttention, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        # Attention layer
        self.attention = nn.Linear(hidden_size, 1)
        self.fc = nn.Linear(hidden_size, 1)

    def forward(self, x):
        batch_size, seq_len, _ = x.size()
        h0 = torch.zeros(self.num_layers, batch_size, self.hidden_size).to(x.device)
        c0 = torch.zeros(self.num_layers, batch_size, self.hidden_size).to(x.device)
        
        # Forward propagate LSTM
        lstm_out, _ = self.lstm(x, (h0, c0))  # shape of lstm_out: [batch_size, seq_len, hidden_size]
        
        # Apply attention
        attention_weights = torch.softmax(self.attention(lstm_out).squeeze(-1), dim=1)  # shape: [batch_size, seq_len]
        # Multiply weights by LSTM outputs (element-wise)
        context_vector = torch.einsum('ij,ijk->ik', attention_weights, lstm_out)  # shape: [batch_size, hidden_size]

        # Pass the context vector through the linear layer
        out = self.fc(context_vector)
        return torch.sigmoid(out)

# Task 4: Log the learning curve to wandb and repeat all experiments with LSTMWithAttention

In [46]:
simple_lstm_attn = LSTMWithAttention(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers)
train(simple_lstm_attn, x=X_padded, y=Y, model_type="lstm_attn", optimizer_name="sgd")
train(model, x=X_padded, y=Y, model_type="lstm_attn", optimizer_name="adam")
train(model, x=X_padded, y=Y, model_type="lstm_attn", optimizer_name="rmsprop")

0 0.6953690052032471
1 0.6953569650650024
2 0.6953449249267578
3 0.6953328847885132
4 0.6953209638595581
5 0.6953089237213135
6 0.6952969431877136
7 0.6952850818634033
8 0.6952731609344482
9 0.6952612996101379
10 0.6952494382858276
11 0.6952375173568726
12 0.6952257752418518
13 0.695214033126831
14 0.6952022314071655
15 0.6951904892921448
16 0.695178747177124
17 0.6951671242713928
18 0.6951553821563721
19 0.6951436996459961
20 0.6951320767402649
21 0.6951204538345337
22 0.6951088905334473
23 0.6950972676277161
24 0.6950857639312744
25 0.6950742602348328
26 0.6950627565383911
27 0.6950512528419495
28 0.6950397491455078
29 0.6950283050537109
30 0.6950169205665588
31 0.695005476474762
32 0.6949940919876099
33 0.6949827075004578
34 0.6949713826179504
35 0.6949601173400879
36 0.6949487328529358
37 0.6949374675750732
38 0.6949261426925659
39 0.6949149966239929
40 0.6949037909507751
41 0.6948925256729126
42 0.6948813199996948
43 0.694870114326477
44 0.6948590278625488
45 0.6948478817939758
46

0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
loss,███▇▇▇▇▇▆▆▆▆▆▆▅▅▅▅▅▄▄▄▄▄▄▃▃▃▃▃▃▂▂▂▂▂▂▁▁▁

0,1
epoch,99.0
loss,0.69427


0 2.37337843822516e-10
1 2.3466201204414006e-10
2 2.3203186594322744e-10
3 2.2944735000862693e-10
4 2.2690724299501142e-10
5 2.244127800254958e-10
6 2.2196013083064514e-10
7 2.1955147422314525e-10
8 2.1718467302367372e-10
9 2.1486062928843808e-10
10 2.1257763604953794e-10
11 2.1033536024006594e-10
12 2.0813301082611702e-10
13 2.0597021310742036e-10
14 2.038466201392808e-10
15 2.0176069148725162e-10
16 1.997121218400011e-10
17 1.9770053649725838e-10
18 1.9572526932520873e-10
19 1.9378565419003735e-10
20 1.9188062250208304e-10
21 1.9001060447276785e-10
22 1.8817356006728403e-10
23 1.863688231518168e-10
24 1.8459722639363463e-10
25 1.82857146091564e-10
26 1.8114830468984877e-10
27 1.7946946706537403e-10
28 1.778204528068983e-10
29 1.7620034598042622e-10
30 1.7460899393029194e-10
31 1.7304553623365138e-10
32 1.7150950104571905e-10
33 1.700000695770143e-10
34 1.685164507936321e-10
35 1.6705853367326995e-10
36 1.6562586024893022e-10
37 1.642176811200713e-10
38 1.6283326076393934e-10
39 1.614

0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
loss,██▇▇▇▆▆▆▆▅▅▅▅▄▄▄▄▄▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁

0,1
epoch,99.0
loss,0.0


0 1.0746765977120987e-10
1 1.0691365154302801e-10
2 1.0636534014674126e-10
3 1.0582307252704481e-10
4 1.0528617561122999e-10
5 1.0475499634399199e-10
6 1.0422949309196738e-10
7 1.0370898584355359e-10
8 1.0319402971026292e-10
9 1.0268418754177944e-10
10 1.0217938301027019e-10
11 1.0167979652697667e-10
12 1.0118494930821953e-10
13 1.0069520217648176e-10
14 1.0021012492034131e-10
15 9.972966202864697e-11
16 9.925397309595851e-11
17 9.878299567223081e-11
18 9.831668118520653e-11
19 9.785445370669166e-11
20 9.739679895925946e-11
21 9.69433144870635e-11
22 9.64943541736929e-11
23 9.604909922966698e-11
24 9.560846558898106e-11
25 9.51716622177301e-11
26 9.473902218282149e-11
27 9.43103373174381e-11
28 9.38855521104287e-11
29 9.34648261563531e-11
30 9.304756964922944e-11
31 9.263430994499444e-11
32 9.222463764890776e-11
33 9.181871235552919e-11
34 9.14161316090123e-11
35 9.101722847626448e-11
36 9.062196826281621e-11
37 9.022980279604909e-11
38 8.984139820977788e-11
39 8.945635898705007e-11
40

0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
loss,██▇▇▇▇▇▆▆▆▆▅▅▅▅▅▄▄▄▄▄▄▃▃▃▃▃▃▂▂▂▂▂▂▂▂▁▁▁▁

0,1
epoch,99.0
loss,0.0


SimpleRNN(
  (rnn): LSTM(14, 50, batch_first=True)
  (fc): Linear(in_features=50, out_features=1, bias=True)
)

In [50]:
import wandb
# Register an account in https://wandb.ai/ and get your API key from https://wandb.ai/authorize
# wandb.login(key="add your key here")

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mvinays[0m ([33mfactiverse-ai[0m). Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /Users/vsetty/.netrc


True

In [56]:
wandb.init(project="dat550_rnn")