In [308]:
import torch
from torchvision import datasets
import random
import torch.nn as nn
import numpy as np

In [329]:
device = 'gpu' if torch.cuda.is_available() else 'cpu'
device

'cpu'

In [625]:
def get_sequence(n_timesteps, limit_const=4):
    x = torch.tensor([random.random() for _ in range(n_timesteps)])
    limit = len(x)/limit_const
    y = torch.tensor([0 if i < limit else 1 for i in np.cumsum(x)])
    x = x.reshape(1, 1, n_timesteps)
    y = y.reshape(n_timesteps)
    return x, y

In [627]:
get_sequence(10)[1].shape

torch.Size([10])

In [793]:
class LSTM(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, num_classes):
        super().__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.lstm = nn.LSTM(
            input_size, hidden_size, num_layers, batch_first=True, bidirectional=True
        )
        self.fc = nn.Linear(hidden_size * 2 , num_classes)
        self.softmax  = nn.Sigmoid()
        
    def forward(self, x):
        h0 = torch.zeros(self.num_layers * 2 , 1, self.hidden_size).to(device)
        c0 = torch.zeros(self.num_layers * 2 , 1, self.hidden_size).to(device)
        out, (h,c) = self.lstm(x, (h0, c0))
        out = self.softmax(self.fc(out[:, -1 , :]))        
        return out , (h,c)

In [794]:
num_epochs = 1000
input_size =  10
hidden_size = 5
num_layers = 2
num_classes = 10

In [795]:
model = LSTM(input_size, hidden_size, num_layers, num_classes)
out, (h,c) = model(data)
h.shape, c.shape

(torch.Size([4, 1, 5]), torch.Size([4, 1, 5]))

In [796]:
criterion = nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [797]:
for epoch in range(num_epochs):
    data, targets = get_sequence(10)
    output, (h,c) = model(data)
    loss = criterion(output.reshape(10), targets.type(torch.float32).reshape(10))
    print(f"{epoch}/{num_epochs} , loss : {loss.item()}")
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()


0/1000 , loss : 0.703644335269928
1/1000 , loss : 0.6906753182411194
2/1000 , loss : 0.7064939141273499
3/1000 , loss : 0.6773227453231812
4/1000 , loss : 0.7227950096130371
5/1000 , loss : 0.6994580030441284
6/1000 , loss : 0.7039663195610046
7/1000 , loss : 0.7032570838928223
8/1000 , loss : 0.7045519351959229
9/1000 , loss : 0.6985527276992798
10/1000 , loss : 0.7176883220672607
11/1000 , loss : 0.697902500629425
12/1000 , loss : 0.7011364102363586
13/1000 , loss : 0.7010620832443237
14/1000 , loss : 0.694125771522522
15/1000 , loss : 0.6903647184371948
16/1000 , loss : 0.7135426998138428
17/1000 , loss : 0.7116078734397888
18/1000 , loss : 0.7110579013824463
19/1000 , loss : 0.6868621110916138
20/1000 , loss : 0.7092596292495728
21/1000 , loss : 0.6890504956245422
22/1000 , loss : 0.6847356557846069
23/1000 , loss : 0.6702001094818115
24/1000 , loss : 0.6884044408798218
25/1000 , loss : 0.6851843595504761
26/1000 , loss : 0.6854568123817444
27/1000 , loss : 0.6745012402534485
28/10

249/1000 , loss : 0.21871662139892578
250/1000 , loss : 0.2097748965024948
251/1000 , loss : 0.21207848191261292
252/1000 , loss : 0.20743532478809357
253/1000 , loss : 0.3053520917892456
254/1000 , loss : 0.2108294665813446
255/1000 , loss : 0.3136206269264221
256/1000 , loss : 0.32159027457237244
257/1000 , loss : 0.2265622913837433
258/1000 , loss : 0.6917799711227417
259/1000 , loss : 0.22715112566947937
260/1000 , loss : 0.3170347511768341
261/1000 , loss : 0.21417179703712463
262/1000 , loss : 0.3153170049190521
263/1000 , loss : 0.19899973273277283
264/1000 , loss : 0.47621506452560425
265/1000 , loss : 0.3088899552822113
266/1000 , loss : 0.7019287347793579
267/1000 , loss : 0.21403351426124573
268/1000 , loss : 0.29710879921913147
269/1000 , loss : 0.19996745884418488
270/1000 , loss : 0.22723062336444855
271/1000 , loss : 0.29953452944755554
272/1000 , loss : 0.20133984088897705
273/1000 , loss : 0.2190624475479126
274/1000 , loss : 0.191707581281662
275/1000 , loss : 0.21498

494/1000 , loss : 0.7537037134170532
495/1000 , loss : 0.268876850605011
496/1000 , loss : 0.16287373006343842
497/1000 , loss : 0.15277399122714996
498/1000 , loss : 0.16281087696552277
499/1000 , loss : 0.269034206867218
500/1000 , loss : 0.1533907949924469
501/1000 , loss : 0.1629982441663742
502/1000 , loss : 0.1512502282857895
503/1000 , loss : 0.26905012130737305
504/1000 , loss : 0.16181489825248718
505/1000 , loss : 0.2821763753890991
506/1000 , loss : 0.1500031054019928
507/1000 , loss : 0.16091704368591309
508/1000 , loss : 0.4926950931549072
509/1000 , loss : 0.15370938181877136
510/1000 , loss : 0.16257162392139435
511/1000 , loss : 1.084477186203003
512/1000 , loss : 0.16112975776195526
513/1000 , loss : 0.16139109432697296
514/1000 , loss : 0.1660228818655014
515/1000 , loss : 0.28339022397994995
516/1000 , loss : 0.28401249647140503
517/1000 , loss : 0.2678297758102417
518/1000 , loss : 0.16171717643737793
519/1000 , loss : 0.15067847073078156
520/1000 , loss : 0.1525563

758/1000 , loss : 0.1478336602449417
759/1000 , loss : 0.14727748930454254
760/1000 , loss : 0.26004886627197266
761/1000 , loss : 0.14621511101722717
762/1000 , loss : 0.1493508517742157
763/1000 , loss : 0.14555026590824127
764/1000 , loss : 0.27480432391166687
765/1000 , loss : 0.14701370894908905
766/1000 , loss : 0.2604275345802307
767/1000 , loss : 0.14504432678222656
768/1000 , loss : 0.1455933153629303
769/1000 , loss : 0.1484624147415161
770/1000 , loss : 0.14347201585769653
771/1000 , loss : 0.1446138322353363
772/1000 , loss : 0.1473298817873001
773/1000 , loss : 0.14440114796161652
774/1000 , loss : 0.2746252417564392
775/1000 , loss : 0.26124322414398193
776/1000 , loss : 0.142429918050766
777/1000 , loss : 0.1468619555234909
778/1000 , loss : 0.142147958278656
779/1000 , loss : 0.27451449632644653
780/1000 , loss : 0.14345195889472961
781/1000 , loss : 0.14291295409202576
782/1000 , loss : 0.27392998337745667
783/1000 , loss : 0.14145109057426453
784/1000 , loss : 0.26313

In [802]:
input_data = torch.round(model(data)[0]).reshape(10)

In [804]:
for i in range(len(targets)):
    print(f"Expected : {targets[i]} , Predicted : {input_data[i]}")

Expected : 0 , Predicted : 0.0
Expected : 0 , Predicted : 0.0
Expected : 0 , Predicted : 0.0
Expected : 0 , Predicted : 0.0
Expected : 0 , Predicted : 1.0
Expected : 0 , Predicted : 1.0
Expected : 1 , Predicted : 1.0
Expected : 1 , Predicted : 1.0
Expected : 1 , Predicted : 1.0
Expected : 1 , Predicted : 1.0
