# Task 1: Add the sigmoid layer in the forward pass before returning the output.

In [43]:
import torch
import torch.nn as nn
import numpy as np
import random
seed = 42
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)

class SimpleRNN(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, rnn_type='rnn'):
        super(SimpleRNN, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.rnn_type = rnn_type
        if rnn_type == 'gru':
            self.rnn = nn.GRU(input_size, hidden_size, num_layers, batch_first=True)
        elif rnn_type == 'lstm':
            self.rnn = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        elif rnn_type == "rnn":
            self.rnn = nn.RNN(input_size, hidden_size, num_layers, batch_first=True)
        else:
            raise ValueError("Invalid RNN type")
        # Output layer for binary classification (1 node)
        self.fc = nn.Linear(hidden_size, 1)

    def forward(self, x):
        # Initialize hidden state with zeros
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
        if self.rnn_type == "lstm":
            c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
            out, _ = self.rnn(x, (h0, c0))
        else:
            out, _ = self.rnn(x, h0)

        # Extract the last hidden state output
        out = out[:, -1, :]

        # Pass the last hidden state output to fully connected layer
        out = self.fc(out)
        # Apply sigmoid function to convert to a probability
        out = torch.sigmoid(out)
        return out

# Example usage:
input_size = 10   # Number of input features
hidden_size = 10  # Number of features in the hidden state
num_layers = 1    # Number of stacked RNN layers

# Create the RNN model instance
model = SimpleRNN(input_size, hidden_size, num_layers, rnn_type='rnn')


print(model)

SimpleRNN(
  (rnn): RNN(10, 10, batch_first=True)
  (fc): Linear(in_features=10, out_features=1, bias=True)
)


# Task 2: Train the SimpleRNN using the Adam optimizer 
- Hint:  you can copy the code from pytorch_neural_network_solution.ipynb

In [44]:
import numpy as np
num_features = 10
num_instances = 1000
sequence_length = 100
x = torch.rand(num_instances, sequence_length, num_features)
x.shape


torch.Size([1000, 100, 10])

In [45]:
y = torch.tensor(np.random.randint(low=0, high=1, size = num_instances).reshape(-1, 1))
y.shape

torch.Size([1000, 1])

In [57]:
import wandb
def train(train_model, optimizer_name="adam", model_type = "RNN"):
    criterion = nn.BCELoss()
    lr = 0.01
    if optimizer_name == "adam":
        optimizer = torch.optim.Adam(train_model.parameters(), lr=lr)
    elif optimizer_name == "rmsprop":
        optimizer = torch.optim.RMSprop(train_model.parameters(), lr=lr)
    elif optimizer_name == "sgd":
        optimizer = torch.optim.SGD(train_model.parameters(), lr=lr)
    else:
        raise ValueError("Invalid optimizer.")
    wandb.init(project='dat550_rnn', reinit=True, name=f'run_{model_type}_{optimizer_name}')
    # Training loop (for demonstration, let's say 100 epochs)
    num_epochs = 100
    for epoch in range(num_epochs):
        pred = train_model(x)
        loss = criterion(pred, y.float())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        print(epoch, loss.item())
        wandb.log({"epoch": epoch, "loss": loss})


    print(f'Final Loss: {loss.item()}')
    wandb.finish()
    return model

# Task 3: Change the SimpleRNN to use the GRU units instead of RNN units and trin.

In [59]:
input_size = 10   # Number of input features
hidden_size = 10  # Number of features in the hidden state
num_layers = 1    # Number of stacked RNN layers

simple_rnn = SimpleRNN(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, rnn_type="rnn")
train(simple_rnn, model_type="rnn")
simple_gru = SimpleRNN(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, rnn_type="gru")
train(simple_gru, model_type="gru")
simple_lstm = SimpleRNN(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, rnn_type="lstm")
train(simple_lstm, model_type="lstm")


0 0.5159330368041992
1 0.4592374861240387
2 0.4143798053264618
3 0.37853148579597473
4 0.34742528200149536
5 0.31895264983177185
6 0.29213947057724
7 0.2666225731372833
8 0.24241903424263
9 0.21971337497234344
10 0.19869738817214966
11 0.17948637902736664
12 0.16209867596626282
13 0.14647482335567474
14 0.1325090378522873
15 0.12007423490285873
16 0.10903366655111313
17 0.0992443859577179
18 0.09056036174297333
19 0.0828392505645752
20 0.07595061510801315
21 0.06978095322847366
22 0.0642346516251564
23 0.05923245847225189
24 0.05470894277095795
25 0.05060996487736702
26 0.04689041152596474
27 0.04351217299699783
28 0.040442511439323425
29 0.0376526340842247
30 0.03511669114232063
31 0.032810986042022705
32 0.030713604763150215
33 0.0288042351603508
34 0.027064133435487747
35 0.025476135313510895
36 0.0240247193723917
37 0.02269596979022026
38 0.021477488800883293
39 0.02035832591354847
40 0.019328786060214043
41 0.01838032901287079
42 0.017505383118987083
43 0.016697237268090248
44 0.0

0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
loss,█▇▅▅▄▃▃▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,99.0
loss,0.00443


0 0.7958732843399048
1 0.732130229473114
2 0.6741362810134888
3 0.6202055215835571
4 0.5685790777206421
5 0.5180537700653076
6 0.46821895241737366
7 0.4193700850009918
8 0.372270792722702
9 0.3278682827949524
10 0.2870347201824188
11 0.2503839135169983
12 0.21818475425243378
13 0.19037209451198578
14 0.16663019359111786
15 0.14650258421897888
16 0.12948653101921082
17 0.11509449034929276
18 0.10288606584072113
19 0.09248106926679611
20 0.08356184512376785
21 0.07586889714002609
22 0.06919321417808533
23 0.06336737424135208
24 0.058257028460502625
25 0.053753722459077835
26 0.04976905509829521
27 0.0462302640080452
28 0.04307682067155838
29 0.040257979184389114
30 0.03773084282875061
31 0.03545888885855675
32 0.03341086581349373
33 0.03155994415283203
34 0.02988293021917343
35 0.02835976332426071
36 0.026972996070981026
37 0.02570742927491665
38 0.024549802765250206
39 0.023488493636250496
40 0.02251330204308033
41 0.02161525934934616
42 0.020786475390195847
43 0.020019972696900368
44 0

0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
loss,█▇▆▅▃▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,99.0
loss,0.0069


0 0.8326718807220459
1 0.7734530568122864
2 0.7237474918365479
3 0.6793004870414734
4 0.6376040577888489
5 0.5973783135414124
6 0.5580827593803406
7 0.5194371938705444
8 0.4813520014286041
9 0.443863183259964
10 0.4070146381855011
11 0.37075528502464294
12 0.3349562883377075
13 0.299589604139328
14 0.2649940252304077
15 0.2320137470960617
16 0.2017279416322708
17 0.17487186193466187
18 0.15157173573970795
19 0.13158152997493744
20 0.11455937474966049
21 0.10014895349740982
22 0.08797892183065414
23 0.07768184691667557
24 0.06893030554056168
25 0.061456192284822464
26 0.055047355592250824
27 0.0495343841612339
28 0.04477847367525101
29 0.040663547813892365
30 0.037091612815856934
31 0.0339798778295517
32 0.03125860169529915
33 0.028869198635220528
34 0.026762597262859344
35 0.024897728115320206
36 0.023240182548761368
37 0.021761108189821243
38 0.02043626457452774
39 0.019245220348238945
40 0.018170710653066635
41 0.017198098823428154
42 0.016314946115016937
43 0.015510626137256622
44 0

0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
loss,█▇▆▅▄▄▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,99.0
loss,0.00434


SimpleRNN(
  (rnn): RNN(10, 10, batch_first=True)
  (fc): Linear(in_features=10, out_features=1, bias=True)
)

In [60]:
simple_lstm = SimpleRNN(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, rnn_type="lstm")
train(simple_lstm, model_type="lstm", optimizer_name="sgd")
train(simple_lstm, model_type="lstm", optimizer_name="adam")
train(simple_lstm, model_type="lstm", optimizer_name="rmsprop")

0 0.5151149034500122
1 0.512709379196167
2 0.5103204846382141
3 0.5079483985900879
4 0.5055927634239197
5 0.5032533407211304
6 0.50093013048172
7 0.4986228346824646
8 0.4963313043117523
9 0.4940554201602936
10 0.49179500341415405
11 0.4895499646663666
12 0.4873200058937073
13 0.48510509729385376
14 0.4829050600528717
15 0.4807196855545044
16 0.47854891419410706
17 0.47639256715774536
18 0.4742504954338074
19 0.47212257981300354
20 0.4700086712837219
21 0.4679085910320282
22 0.4658223092556
23 0.46374958753585815
24 0.46169033646583557
25 0.4596444368362427
26 0.4576117992401123
27 0.45559218525886536
28 0.4535856246948242
29 0.4515919089317322
30 0.4496108889579773
31 0.4476425051689148
32 0.4456866979598999
33 0.4437432289123535
34 0.44181200861930847
35 0.4398930072784424
36 0.4379860460758209
37 0.43609100580215454
38 0.43420788645744324
39 0.4323364496231079
40 0.43047669529914856
41 0.42862841486930847
42 0.42679160833358765
43 0.42496612668037415
44 0.4231518805027008
45 0.421348

0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
loss,███▇▇▇▇▆▆▆▆▆▅▅▅▅▅▄▄▄▄▄▄▃▃▃▃▃▃▂▂▂▂▂▂▂▁▁▁▁

0,1
epoch,99.0
loss,0.33824


0 0.3369270861148834
1 0.30315300822257996
2 0.2687419354915619
3 0.23379497230052948
4 0.19974103569984436
5 0.16836576163768768
6 0.14096172153949738
7 0.11787916719913483
8 0.09882819652557373
9 0.08329785615205765
10 0.07070167362689972
11 0.060474976897239685
12 0.05214311182498932
13 0.04532453417778015
14 0.03971137851476669
15 0.035056304186582565
16 0.031164858490228653
17 0.02788684144616127
18 0.025106431916356087
19 0.0227332916110754
20 0.020695911720395088
21 0.018937058746814728
22 0.017410511150956154
23 0.016078738495707512
24 0.014911111444234848
25 0.013882533647119999
26 0.012972339987754822
27 0.012163433246314526
28 0.011441599577665329
29 0.010794967412948608
30 0.010213565081357956
31 0.009688973426818848
32 0.009214062243700027
33 0.00878275278955698
34 0.008389844559133053
35 0.00803087092936039
36 0.007701979484409094
37 0.007399834226816893
38 0.007121535018086433
39 0.006864554714411497
40 0.006626687478274107
41 0.006405987776815891
42 0.006200750824064016

0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
loss,█▇▄▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,99.0
loss,0.00232


0 0.0022935240995138884
1 0.000688446220010519
2 0.00045197681174613535
3 0.0003501910832710564
4 0.000288977287709713
5 0.00024720310466364026
6 0.00021654876763932407
7 0.0001929496502270922
8 0.00017414698959328234
9 0.00015877214900683612
10 0.0001459411869291216
11 0.0001350554812233895
12 0.0001256935647688806
13 0.00011754941078834236
14 0.00011039504897780716
15 0.00010405679495306686
16 9.840004349825904e-05
17 9.331846376881003e-05
18 8.872717444319278e-05
19 8.455726492684335e-05
20 8.075240475591272e-05
21 7.72659623180516e-05
22 7.405901124002412e-05
23 7.109880243660882e-05
24 6.835759268142283e-05
25 6.581153866136447e-05
26 6.344034045469016e-05
27 6.122639024397358e-05
28 5.915436486247927e-05
29 5.721085472032428e-05
30 5.5384167353622615e-05
31 5.3664003644371405e-05
32 5.2041221351828426e-05
33 5.0507656851550564e-05
34 4.905612877337262e-05
35 4.768016151501797e-05
36 4.637398888007738e-05
37 4.513233943725936e-05
38 4.395056748762727e-05
39 4.282437657820992e-05
4

0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
loss,█▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,99.0
loss,2e-05


SimpleRNN(
  (rnn): RNN(10, 10, batch_first=True)
  (fc): Linear(in_features=10, out_features=1, bias=True)
)

# Task 4: Implement a SimpleLSTM using the LSTM unit. Note you also need to include the c0 in the forward pass.

In [48]:
simple_gru = SimpleRNN(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, rnn_type="lstm")
train(simple_gru)

0 0.79059898853302
1 0.7437970042228699
2 0.6992879509925842
3 0.6513519883155823
4 0.5989393591880798
5 0.5431981682777405
6 0.48646876215934753
7 0.4317106306552887
8 0.38138681650161743
9 0.3363927900791168
10 0.2963874340057373
11 0.26092633605003357
12 0.22980144619941711
13 0.20278501510620117
14 0.17947590351104736
15 0.15936899185180664
16 0.14197690784931183
17 0.12689155340194702
18 0.11378461867570877
19 0.1023854985833168
20 0.09246223419904709
21 0.08381124585866928
22 0.07625295221805573
23 0.06963000446557999
24 0.06380578130483627
25 0.05866270139813423
26 0.054100245237350464
27 0.05003276839852333
28 0.04638771340250969
29 0.04310378432273865
30 0.04012973606586456
31 0.037423331290483475
32 0.03495060279965401
33 0.032685115933418274
34 0.030606793239712715
35 0.028700310736894608
36 0.026953045278787613
37 0.02535315789282322
38 0.023888586089015007
39 0.022547060623764992
40 0.021316934376955032
41 0.02018791250884533
42 0.019151214510202408
43 0.01819920726120472


SimpleRNN(
  (rnn): RNN(10, 10, batch_first=True)
  (fc): Linear(in_features=10, out_features=1, bias=True)
)

# Task 5: Log the learning curve to wandb and repeat all experiments

In [49]:
pip install wandb

Defaulting to user installation because normal site-packages is not writeable
You should consider upgrading via the '/Library/Developer/CommandLineTools/usr/bin/python3 -m pip install --upgrade pip' command.[0m
Note: you may need to restart the kernel to use updated packages.


In [50]:
import wandb
# Register an account in https://wandb.ai/ and get your API key from https://wandb.ai/authorize
# wandb.login(key="add your key here")

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mvinays[0m ([33mfactiverse-ai[0m). Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /Users/vsetty/.netrc


True

In [56]:
wandb.init(project="dat550_rnn")