In [1]:
import numpy as np
import pandas as pd
import torch
import torchtext
from torch import nn
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from torch.nn.utils.rnn import pad_sequence
from torchtext.data.functional import numericalize_tokens_from_iterator
from torch.utils.data import random_split, DataLoader

import logging
import os
import random

In [176]:
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

formatter = logging.Formatter('%(levelname)s:%(name)s:%(message)s')
# filehandler = logging.FileHandler('my_log.log')
filehandler = logging.handlers.RotatingFileHandler('my_log.log', mode='a', maxBytes=20000, backupCount=3)
filehandler.setFormatter(formatter)

streamhandler = logging.StreamHandler()
streamhandler.setFormatter(formatter)

logger.addHandler(filehandler)
logger.addHandler(streamhandler)


In [177]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

In [178]:
df = pd.read_csv('artifacts/train_cleaned.csv')
df.head()

Unnamed: 0,tweet,sentiment,label
0,im getting on borderlands and i will murder yo...,Positive,3
1,I am coming to the borders and I will kill you...,Positive,3
2,im getting on borderlands and i will kill you ...,Positive,3
3,im coming on borderlands and i will murder you...,Positive,3
4,im getting on borderlands 2 and i will murder ...,Positive,3


In [179]:
len(df)

71656

In [180]:
df = df[:10000]
df

Unnamed: 0,tweet,sentiment,label
0,im getting on borderlands and i will murder yo...,Positive,3
1,I am coming to the borders and I will kill you...,Positive,3
2,im getting on borderlands and i will kill you ...,Positive,3
3,im coming on borderlands and i will murder you...,Positive,3
4,im getting on borderlands 2 and i will murder ...,Positive,3
...,...,...,...
9995,_. . The faster more efficient AMD Zen CPUs in...,Positive,3
9996,_.. The faster more efficient AMD Zen CPUs in ...,Positive,3
9997,_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ ...,Positive,3
9998,_.. The faster memory efficient AMD Zen CPUs i...,Positive,3


In [181]:
df['label'].unique()

array([3, 2, 1, 0], dtype=int64)

In [182]:
df_valid = pd.read_csv('artifacts/valid_cleaned.csv')
df_valid = df_valid[:10000]
df_valid 

Unnamed: 0,tweet,sentiment,label
0,I mentioned on Facebook that I was struggling ...,Irrelevant,0
1,BBC News - Amazon boss Jeff Bezos rejects clai...,Neutral,2
2,@Microsoft Why do I pay for WORD when it funct...,Negative,1
3,"CSGO matchmaking is so full of closet hacking,...",Negative,1
4,Now the President is slapping Americans in the...,Neutral,2
...,...,...,...
995,⭐️ Toronto is the arts and culture capital of ...,Irrelevant,0
996,tHIS IS ACTUALLY A GOOD MOVE TOT BRING MORE VI...,Irrelevant,0
997,Today sucked so it’s time to drink wine n play...,Positive,3
998,Bought a fraction of Microsoft today. Small wins.,Positive,3


In [183]:
eng_tokenizer = get_tokenizer("spacy")
eng_tokenizer



functools.partial(<function _spacy_tokenize at 0x000001E5189CCAF0>, spacy=<spacy.lang.en.English object at 0x000001E641C75480>)

In [184]:
def token_gen(text):
    for sentence in text:
        tokens = eng_tokenizer(sentence)
        yield tokens
        

In [185]:
vocab_size = 5000 #down from original vocab size 13k ish
max_length = 100

In [186]:
vocab = build_vocab_from_iterator(token_gen(df['tweet']), specials=['<SOS>', '<EOS>','<UNK>'], max_tokens=vocab_size)
vocab.set_default_index(vocab["<UNK>"])

In [187]:
len(vocab)

5000

In [188]:
vocab.get_stoi()

{'forget': 1464,
 'but': 38,
 'that': 25,
 'cheat': 2911,
 'Team': 3521,
 'bald': 2621,
 'a': 10,
 '!': 6,
 'XSX': 4632,
 'seems': 505,
 'her': 340,
 '<SOS>': 0,
 'hackers': 1645,
 'I': 5,
 'face': 728,
 'non': 2697,
 'Championship': 3199,
 'crashes': 870,
 '@YouTube': 719,
 'awesome': 383,
 'the': 4,
 'Nice': 681,
 'jewel': 3807,
 '100': 140,
 'Red': 2596,
 'much': 123,
 'USS': 4598,
 'in': 17,
 'customers': 871,
 'reasons': 1779,
 'mood': 1480,
 'probably': 260,
 'kicked': 2266,
 '/': 14,
 'monitors': 3846,
 '<EOS>': 1,
 'thread': 2314,
 'charge': 2909,
 "'m": 57,
 '.': 3,
 'advice': 2867,
 'happen': 685,
 'VR': 2013,
 'mix': 1766,
 'this': 20,
 'arcade': 2878,
 'taking': 831,
 'Miles': 4413,
 '<UNK>': 2,
 'DMs': 3222,
 ',': 7,
 'OR': 2590,
 'exciting': 232,
 'Always': 1001,
 'literal': 1654,
 'calendars': 2903,
 'TO': 414,
 'ball': 1448,
 'stunned': 3990,
 '1000': 4060,
 'i': 46,
 'tlkn': 4006,
 'counting': 2043,
 '10': 252,
 'credit': 958,
 'to': 8,
 'Americans': 3131,
 'or': 99,
 

In [189]:
token_gen(df['tweet'])

<generator object token_gen at 0x000001E641E1AD50>

In [190]:
sequence = numericalize_tokens_from_iterator(vocab, token_gen(df['tweet']))

count = 0
for ids in sequence:
    print(count, [num for num in ids])
    count += 1
    if count > 10:
        break

0 [46, 159, 240, 16, 101, 9, 46, 86, 2083, 21, 43, 7]
1 [5, 132, 407, 8, 4, 2, 9, 5, 86, 756, 21, 43, 7]
2 [46, 159, 240, 16, 101, 9, 46, 86, 756, 21, 43, 7]
3 [46, 159, 407, 16, 101, 9, 46, 86, 2083, 21, 43, 7]
4 [46, 159, 240, 16, 101, 55, 9, 46, 86, 2083, 21, 40, 43, 7]
5 [46, 159, 240, 227, 101, 9, 46, 89, 2083, 21, 43, 7]
6 [142, 5, 830, 10, 415, 290, 488, 238, 11, 102, 3, 3, 3, 164, 21, 47, 42, 175, 5, 132, 10, 2162, 199, 630, 9, 1233, 13, 72, 12, 26, 242, 433, 3, 142, 5, 799, 8, 184, 435, 10, 2, 11, 26, 294, 3, 3, 637, 13, 4, 935, 1886, 2, 4, 2, 5, 213, 834, 1422, 6, 2]
7 [142, 5, 830, 10, 1190, 12, 290, 387, 238, 11, 102, 18, 164, 21, 47, 42, 175, 25, 5, 57, 10, 526, 27, 34, 630, 9, 1233, 13, 72, 12, 26, 242, 433, 7, 5, 799, 8, 184, 10, 2, 11, 26, 294, 22, 637, 31, 4, 935, 2704, 957, 8, 4, 2, 5, 213, 834, 500, 102, 6, 323, 14, 2]
8 [142, 5, 830, 10, 415, 290, 387, 238, 11, 102, 18, 164, 21, 47, 42, 175, 5, 57, 10, 2162, 27, 34, 630, 9, 1233, 13, 72, 12, 26, 242, 433, 3]
9 [142,

In [191]:
example = numericalize_tokens_from_iterator(vocab, ["hi how are you", "what is your name?"])
list(next(example))

[1755, 46, 33, 1755, 3859, 883, 33, 10, 1024, 848, 33, 578, 3859, 381]

In [192]:
example = numericalize_tokens_from_iterator(vocab, ["hi how are you pouya", "what is your age?huh"])
list(next(example))

[1755,
 46,
 33,
 1755,
 3859,
 883,
 33,
 10,
 1024,
 848,
 33,
 578,
 3859,
 381,
 33,
 1272,
 3859,
 381,
 578,
 10]

In [193]:
sequence = numericalize_tokens_from_iterator(vocab, iterator=token_gen(df['tweet']))
sequence

<generator object numericalize_tokens_from_iterator at 0x000001E641E1BDF0>

In [194]:
len(df)

10000

In [195]:
logger.debug("welcome")

# option 1

In [196]:
# token_ids = []
# for i in range(len(df)):
#     token_id = vocab(eng_tokenizer(df['tweet'][i]))
#     token_ids.append(token_id)

    

In [197]:
# padded_text = pad_sequence([torch.tensor(x) for x in token_ids], batch_first=True)
# padded_text.shape

### option 2 iterator

In [198]:
# valid_token_ids = torch.tensor(token_ids) # this will throw an error, because all sequence are not of same length

sequence = numericalize_tokens_from_iterator(vocab, iterator=token_gen(df['tweet']))

token_ids = []
for i in range(len(df)):
    x = list(next(sequence))
#     if i % 1000 == 0:
#         logger.debug(f"{i}:{x}",)
    token_ids.append(x)
   
padded_text = pad_sequence([torch.tensor(x) for x in token_ids], padding_value=0, batch_first=True)
padded_text.shape



torch.Size([10000, 198])

In [199]:
padded_text = padded_text[:,:max_length]
padded_text.shape

torch.Size([10000, 100])

In [200]:
vocab(eng_tokenizer(df['tweet'][0]))

[46, 159, 240, 16, 101, 9, 46, 86, 2083, 21, 43, 7]

In [201]:
torch.tensor(vocab(eng_tokenizer('I will see you')))

tensor([ 5, 86, 64, 21])

In [202]:
vocab.get_stoi()

{'forget': 1464,
 'but': 38,
 'that': 25,
 'cheat': 2911,
 'Team': 3521,
 'bald': 2621,
 'a': 10,
 '!': 6,
 'XSX': 4632,
 'seems': 505,
 'her': 340,
 '<SOS>': 0,
 'hackers': 1645,
 'I': 5,
 'face': 728,
 'non': 2697,
 'Championship': 3199,
 'crashes': 870,
 '@YouTube': 719,
 'awesome': 383,
 'the': 4,
 'Nice': 681,
 'jewel': 3807,
 '100': 140,
 'Red': 2596,
 'much': 123,
 'USS': 4598,
 'in': 17,
 'customers': 871,
 'reasons': 1779,
 'mood': 1480,
 'probably': 260,
 'kicked': 2266,
 '/': 14,
 'monitors': 3846,
 '<EOS>': 1,
 'thread': 2314,
 'charge': 2909,
 "'m": 57,
 '.': 3,
 'advice': 2867,
 'happen': 685,
 'VR': 2013,
 'mix': 1766,
 'this': 20,
 'arcade': 2878,
 'taking': 831,
 'Miles': 4413,
 '<UNK>': 2,
 'DMs': 3222,
 ',': 7,
 'OR': 2590,
 'exciting': 232,
 'Always': 1001,
 'literal': 1654,
 'calendars': 2903,
 'TO': 414,
 'ball': 1448,
 'stunned': 3990,
 '1000': 4060,
 'i': 46,
 'tlkn': 4006,
 'counting': 2043,
 '10': 252,
 'credit': 958,
 'to': 8,
 'Americans': 3131,
 'or': 99,
 

In [203]:
{k:v for k,v in sorted(vocab.get_stoi().items(), key=lambda x:x[1])}

{'<SOS>': 0,
 '<EOS>': 1,
 '<UNK>': 2,
 '.': 3,
 'the': 4,
 'I': 5,
 '!': 6,
 ',': 7,
 'to': 8,
 'and': 9,
 'a': 10,
 'for': 11,
 'of': 12,
 'is': 13,
 '/': 14,
 'it': 15,
 'on': 16,
 'in': 17,
 '...': 18,
 '?': 19,
 'this': 20,
 'you': 21,
 '..': 22,
 '-': 23,
 'Amazon': 24,
 'that': 25,
 'my': 26,
 '@': 27,
 'with': 28,
 ':': 29,
 'so': 30,
 "'s": 31,
 'have': 32,
 ' ': 33,
 'Borderlands': 34,
 'game': 35,
 'be': 36,
 '3': 37,
 'but': 38,
 'are': 39,
 'me': 40,
 'The': 41,
 "n't": 42,
 'all': 43,
 'just': 44,
 'not': 45,
 'i': 46,
 'do': 47,
 'was': 48,
 '"': 49,
 'Overwatch': 50,
 'like': 51,
 'at': 52,
 '_': 53,
 'Xbox': 54,
 '2': 55,
 'from': 56,
 "'m": 57,
 'get': 58,
 'as': 59,
 'now': 60,
 'out': 61,
 'your': 62,
 'overwatch': 63,
 'see': 64,
 'really': 65,
 'if': 66,
 '…': 67,
 ')': 68,
 'has': 69,
 'It': 70,
 'play': 71,
 'one': 72,
 '(': 73,
 'an': 74,
 'we': 75,
 'more': 76,
 'time': 77,
 'about': 78,
 '>': 79,
 'love': 80,
 'up': 81,
 'they': 82,
 'some': 83,
 'Black': 84,

In [204]:
embed = torch.nn.Embedding(len(vocab), embedding_dim=5, padding_idx=0)
embed

Embedding(5000, 5, padding_idx=0)

In [205]:
df['tweet'][len(df)-1]

'_.. The faster to more energy efficient AMD Zen CPUs in operation both.. We either already had / have most incredible display visuals with current gen consoles..'

In [206]:
test_input = torch.tensor(vocab(eng_tokenizer(df['tweet'][0])))
test_input.shape, embed(test_input).shape

(torch.Size([12]), torch.Size([12, 5]))

In [207]:
padded_text.shape

torch.Size([10000, 100])

In [208]:
input_text = embed(padded_text)
input_text.shape

torch.Size([10000, 100, 5])

In [209]:
padded_text[1000]

tensor([ 430,   62,  737, 1037,    3,   34,   37,   11,  294,  407,    8, 1181,
           2, 2829, 1101,    2,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0])

In [210]:
embed(padded_text[1000]).shape

torch.Size([100, 5])

In [211]:
# from collections import Counter
# counter = Counter()
# for sentence in df['tweet'].values:
#     counter.update(eng_tokenizer(sentence))
# counter

In [212]:
# from torchtext.vocab import vocab
# vocab_regular = vocab(counter, min_freq=3, specials=['<UNK>'])
# vocab_regular.get_stoi()

In [213]:
label = df['label']
label = torch.tensor(label)
label.shape

torch.Size([10000])

In [214]:
label.unique()

tensor([0, 1, 2, 3])

In [215]:
num_classes = len(label.unique())
num_classes

4

In [216]:
class RNNClassifier(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_size):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        
        self.rnn = nn.RNN(embed_dim, hidden_size, batch_first=False)
        self.linear = nn.Linear(hidden_size, num_classes)
        
#         self.init_weights()
        
    def init_weights(self):
        initrange = 0.5
        self.embed.weight.data.uniform_(-initrange, initrange)
        self.rnn.weight_ih_l0.data.uniform_(-initrange, initrange)
        self.rnn.weight_hh_l0.data.uniform_(-initrange, initrange)
        self.linear.weight.data.uniform_(-initrange, initrange)
        self.linear.bias.data.zero_()
        
        
    def forward(self, x):
        x = self.embed(x)
        
        out, hidden = self.rnn(x)
        
        out = out[:, -1, :]
#         logger.info('rnn last output shape:',out.shape)
        
        out = self.linear(out)
#         out = out.squeeze(0)
        
        return out

In [217]:
VOCAB_SIZE = len(vocab)
EMBED_DIM = 100
HIDDEN_SIZE = 32
EPOCHS = 30

In [218]:
rnn = RNNClassifier(VOCAB_SIZE, EMBED_DIM, HIDDEN_SIZE).to(device)
rnn

RNNClassifier(
  (embed): Embedding(5000, 100, padding_idx=0)
  (rnn): RNN(100, 32)
  (linear): Linear(in_features=32, out_features=4, bias=True)
)

In [219]:
padded_text[0].shape

torch.Size([100])

In [220]:
rnn(padded_text[0].unsqueeze(0).to(device)).shape

torch.Size([1, 4])

In [221]:
torch.prod(torch.tensor(torch.randn(4, 2).shape))

tensor(8)

In [222]:
total_params = 0
for name, param in rnn.named_parameters():
    total_params += int(torch.prod(torch.tensor(param.shape)))
    logger.info(f'{name}:{param.shape} has {torch.prod(torch.tensor(param.shape))} parameters')
logger.info(f'total parameters: {total_params:0.2f}')

INFO:__main__:embed.weight:torch.Size([5000, 100]) has 500000 parameters
INFO:__main__:embed.weight:torch.Size([5000, 100]) has 500000 parameters
INFO:__main__:embed.weight:torch.Size([5000, 100]) has 500000 parameters
INFO:__main__:rnn.weight_ih_l0:torch.Size([32, 100]) has 3200 parameters
INFO:__main__:rnn.weight_ih_l0:torch.Size([32, 100]) has 3200 parameters
INFO:__main__:rnn.weight_ih_l0:torch.Size([32, 100]) has 3200 parameters
INFO:__main__:rnn.weight_hh_l0:torch.Size([32, 32]) has 1024 parameters
INFO:__main__:rnn.weight_hh_l0:torch.Size([32, 32]) has 1024 parameters
INFO:__main__:rnn.weight_hh_l0:torch.Size([32, 32]) has 1024 parameters
INFO:__main__:rnn.bias_ih_l0:torch.Size([32]) has 32 parameters
INFO:__main__:rnn.bias_ih_l0:torch.Size([32]) has 32 parameters
INFO:__main__:rnn.bias_ih_l0:torch.Size([32]) has 32 parameters
INFO:__main__:rnn.bias_hh_l0:torch.Size([32]) has 32 parameters
INFO:__main__:rnn.bias_hh_l0:torch.Size([32]) has 32 parameters
INFO:__main__:rnn.bias_hh_

In [223]:
type(padded_text)

torch.Tensor

In [224]:
rnn(padded_text.to(device)).shape

torch.Size([10000, 4])

### will try batch gradient descent

In [225]:
X, y = padded_text, label
len(X)

10000

In [226]:
x_train_size = int(0.8*len(X))
x_test_size = int(len(X) - 0.8*len(X))
x_test_size

2000

In [227]:
y_train_size = int(0.8*len(y))
y_test_size = int(len(X) - 0.8*len(y))

In [228]:
x_train, x_test = random_split(X, [x_train_size, x_test_size])
y_train, y_test = random_split(y, [y_train_size, y_test_size])
len(x_train), len(x_test), len(y_train), len(y_test)

(8000, 2000, 8000, 2000)

In [229]:
padded_text.shape

torch.Size([10000, 100])

In [230]:
# token_ids = []
# for i in range(len(df_valid)):
#     token_ids.append(vocab(eng_tokenizer(df_valid['tweet'][i])))
    
# token_ids
    

In [231]:
# padded_text_valid = pad_sequence([torch.tensor(x) for x in token_ids], batch_first=True)
# padded_text_valid.shape

In [232]:
# valid_labels = df_valid['label'].to_list()

In [233]:
vocab['<SOS>']

0

In [234]:
text_transform = lambda x: [vocab['<SOS>']] + vocab(eng_tokenizer(x)) + [vocab['<EOS>']]

In [235]:
text_transform("here is an example")

[0, 154, 13, 74, 2, 1]

In [236]:
def collate_batch(batch):
    label_list, text_list = [], []
    for _label, _text in batch:
        label_list.append(_label)
        text_list.append(torch.tensor(text_transform(_text)))
        
    return torch.tensor(label_list), pad_sequence(text_list)

In [237]:
df

Unnamed: 0,tweet,sentiment,label
0,im getting on borderlands and i will murder yo...,Positive,3
1,I am coming to the borders and I will kill you...,Positive,3
2,im getting on borderlands and i will kill you ...,Positive,3
3,im coming on borderlands and i will murder you...,Positive,3
4,im getting on borderlands 2 and i will murder ...,Positive,3
...,...,...,...
9995,_. . The faster more efficient AMD Zen CPUs in...,Positive,3
9996,_.. The faster more efficient AMD Zen CPUs in ...,Positive,3
9997,_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ ...,Positive,3
9998,_.. The faster memory efficient AMD Zen CPUs i...,Positive,3


In [238]:
train_dataset = [(x, y) for x,y in zip(df['label'], df['tweet'])]
test_dataset = [(x, y) for x,y in zip(df_valid['label'], df_valid['tweet'])]


In [239]:
train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True, collate_fn=collate_batch)
test_dataloader = DataLoader(test_dataset, batch_size=8, shuffle=True, collate_fn=collate_batch)
next(iter(train_dataloader))[1].shape, next(iter(test_dataloader))[1].shape[0]

(torch.Size([31, 8]), 37)

In [240]:
min([y.shape for x, y in train_dataloader]),max([y.shape for x, y in train_dataloader])


(torch.Size([16, 8]), torch.Size([200, 8]))

In [241]:
eng_tokenizer("this is test")

['this', 'is', 'test']

In [242]:
indices = [(idx, len(eng_tokenizer(i[1]))) for idx, i in enumerate(train_dataset)]
indices

[(0, 12),
 (1, 13),
 (2, 12),
 (3, 12),
 (4, 14),
 (5, 12),
 (6, 60),
 (7, 62),
 (8, 32),
 (9, 60),
 (10, 61),
 (11, 1),
 (12, 20),
 (13, 21),
 (14, 21),
 (15, 19),
 (16, 32),
 (17, 19),
 (18, 26),
 (19, 26),
 (20, 26),
 (21, 31),
 (22, 26),
 (23, 14),
 (24, 11),
 (25, 14),
 (26, 18),
 (27, 14),
 (28, 52),
 (29, 53),
 (30, 58),
 (31, 52),
 (32, 62),
 (33, 52),
 (34, 44),
 (35, 36),
 (36, 42),
 (37, 43),
 (38, 53),
 (39, 44),
 (40, 8),
 (41, 7),
 (42, 6),
 (43, 7),
 (44, 8),
 (45, 9),
 (46, 33),
 (47, 28),
 (48, 25),
 (49, 34),
 (50, 1),
 (51, 20),
 (52, 20),
 (53, 26),
 (54, 21),
 (55, 26),
 (56, 20),
 (57, 5),
 (58, 2),
 (59, 2),
 (60, 2),
 (61, 31),
 (62, 29),
 (63, 25),
 (64, 31),
 (65, 37),
 (66, 31),
 (67, 14),
 (68, 15),
 (69, 17),
 (70, 14),
 (71, 16),
 (72, 14),
 (73, 45),
 (74, 46),
 (75, 48),
 (76, 39),
 (77, 55),
 (78, 39),
 (79, 3),
 (80, 3),
 (81, 3),
 (82, 5),
 (83, 5),
 (84, 3),
 (85, 49),
 (86, 49),
 (87, 50),
 (88, 49),
 (89, 61),
 (90, 48),
 (91, 32),
 (92, 29),
 (93,

In [243]:
len(train_dataset)
batch_size = 8

In [244]:

indices = [(idx, len(eng_tokenizer(i[1]))) for idx, i in enumerate(train_dataset)]
random.shuffle(indices)

pooled_indices = []
for i in range(0, len(indices), batch_size*100):
    pooled_indices.extend(indices[i:i+batch_size*100])

# pooled_indices = [i[0] for i in pooled_indices]
# print(pooled_indices)

# for i in range(0, len(pooled_indices), batch_size):
#     print(pooled_indices[i:i+batch_size])
        

In [245]:
import random
def batch_sampler(dataset):
    indices = [(idx, len(eng_tokenizer(i[1]))) for idx, i in enumerate(dataset)]
    random.shuffle(indices)
    
    pooled_indices = []
    for i in range(0, len(train_dataset), batch_size*100):
        pooled_indices.extend(sorted(indices[i:i+batch_size*100], key=lambda x:x[1]))
        
    pooled_indices = [i[0] for i in pooled_indices]
    
    for i in range(0, len(pooled_indices), batch_size):
        yield pooled_indices[i:i+batch_size]
        


In [246]:
bucket_train_dataloader = DataLoader(train_dataset, batch_sampler=batch_sampler(train_dataset), collate_fn=collate_batch )
bucket_test_dataloader = DataLoader(train_dataset, batch_sampler=batch_sampler(test_dataset), collate_fn=collate_batch )

next(iter(bucket_train_dataloader))[1].shape,next(iter(bucket_test_dataloader))[1].shape

(torch.Size([3, 8]), torch.Size([50, 8]))

In [247]:
for y, x in bucket_train_dataloader:
    print(x.shape, y.shape)


torch.Size([3, 8]) torch.Size([8])
torch.Size([4, 8]) torch.Size([8])
torch.Size([4, 8]) torch.Size([8])
torch.Size([5, 8]) torch.Size([8])
torch.Size([5, 8]) torch.Size([8])
torch.Size([6, 8]) torch.Size([8])
torch.Size([6, 8]) torch.Size([8])
torch.Size([6, 8]) torch.Size([8])
torch.Size([7, 8]) torch.Size([8])
torch.Size([7, 8]) torch.Size([8])
torch.Size([7, 8]) torch.Size([8])
torch.Size([7, 8]) torch.Size([8])
torch.Size([8, 8]) torch.Size([8])
torch.Size([8, 8]) torch.Size([8])
torch.Size([9, 8]) torch.Size([8])
torch.Size([9, 8]) torch.Size([8])
torch.Size([9, 8]) torch.Size([8])
torch.Size([9, 8]) torch.Size([8])
torch.Size([10, 8]) torch.Size([8])
torch.Size([10, 8]) torch.Size([8])
torch.Size([11, 8]) torch.Size([8])
torch.Size([11, 8]) torch.Size([8])
torch.Size([11, 8]) torch.Size([8])
torch.Size([12, 8]) torch.Size([8])
torch.Size([12, 8]) torch.Size([8])
torch.Size([13, 8]) torch.Size([8])
torch.Size([13, 8]) torch.Size([8])
torch.Size([14, 8]) torch.Size([8])
torch.Size

torch.Size([25, 8]) torch.Size([8])
torch.Size([26, 8]) torch.Size([8])
torch.Size([27, 8]) torch.Size([8])
torch.Size([27, 8]) torch.Size([8])
torch.Size([27, 8]) torch.Size([8])
torch.Size([28, 8]) torch.Size([8])
torch.Size([28, 8]) torch.Size([8])
torch.Size([29, 8]) torch.Size([8])
torch.Size([30, 8]) torch.Size([8])
torch.Size([30, 8]) torch.Size([8])
torch.Size([31, 8]) torch.Size([8])
torch.Size([32, 8]) torch.Size([8])
torch.Size([33, 8]) torch.Size([8])
torch.Size([33, 8]) torch.Size([8])
torch.Size([34, 8]) torch.Size([8])
torch.Size([35, 8]) torch.Size([8])
torch.Size([36, 8]) torch.Size([8])
torch.Size([36, 8]) torch.Size([8])
torch.Size([38, 8]) torch.Size([8])
torch.Size([39, 8]) torch.Size([8])
torch.Size([40, 8]) torch.Size([8])
torch.Size([41, 8]) torch.Size([8])
torch.Size([42, 8]) torch.Size([8])
torch.Size([43, 8]) torch.Size([8])
torch.Size([45, 8]) torch.Size([8])
torch.Size([46, 8]) torch.Size([8])
torch.Size([46, 8]) torch.Size([8])
torch.Size([47, 8]) torch.Si

torch.Size([29, 8]) torch.Size([8])
torch.Size([29, 8]) torch.Size([8])
torch.Size([30, 8]) torch.Size([8])
torch.Size([30, 8]) torch.Size([8])
torch.Size([31, 8]) torch.Size([8])
torch.Size([31, 8]) torch.Size([8])
torch.Size([32, 8]) torch.Size([8])
torch.Size([33, 8]) torch.Size([8])
torch.Size([33, 8]) torch.Size([8])
torch.Size([34, 8]) torch.Size([8])
torch.Size([35, 8]) torch.Size([8])
torch.Size([36, 8]) torch.Size([8])
torch.Size([37, 8]) torch.Size([8])
torch.Size([37, 8]) torch.Size([8])
torch.Size([38, 8]) torch.Size([8])
torch.Size([39, 8]) torch.Size([8])
torch.Size([40, 8]) torch.Size([8])
torch.Size([41, 8]) torch.Size([8])
torch.Size([42, 8]) torch.Size([8])
torch.Size([43, 8]) torch.Size([8])
torch.Size([44, 8]) torch.Size([8])
torch.Size([46, 8]) torch.Size([8])
torch.Size([47, 8]) torch.Size([8])
torch.Size([48, 8]) torch.Size([8])
torch.Size([50, 8]) torch.Size([8])
torch.Size([52, 8]) torch.Size([8])
torch.Size([54, 8]) torch.Size([8])
torch.Size([55, 8]) torch.Si

torch.Size([35, 8]) torch.Size([8])
torch.Size([36, 8]) torch.Size([8])
torch.Size([37, 8]) torch.Size([8])
torch.Size([38, 8]) torch.Size([8])
torch.Size([40, 8]) torch.Size([8])
torch.Size([41, 8]) torch.Size([8])
torch.Size([43, 8]) torch.Size([8])
torch.Size([44, 8]) torch.Size([8])
torch.Size([45, 8]) torch.Size([8])
torch.Size([46, 8]) torch.Size([8])
torch.Size([48, 8]) torch.Size([8])
torch.Size([49, 8]) torch.Size([8])
torch.Size([51, 8]) torch.Size([8])
torch.Size([52, 8]) torch.Size([8])
torch.Size([53, 8]) torch.Size([8])
torch.Size([54, 8]) torch.Size([8])
torch.Size([56, 8]) torch.Size([8])
torch.Size([57, 8]) torch.Size([8])
torch.Size([59, 8]) torch.Size([8])
torch.Size([60, 8]) torch.Size([8])
torch.Size([62, 8]) torch.Size([8])
torch.Size([64, 8]) torch.Size([8])
torch.Size([68, 8]) torch.Size([8])
torch.Size([74, 8]) torch.Size([8])
torch.Size([200, 8]) torch.Size([8])
torch.Size([3, 8]) torch.Size([8])
torch.Size([3, 8]) torch.Size([8])
torch.Size([4, 8]) torch.Size

In [248]:
optimizer = torch.optim.Adam(rnn.parameters(), lr=0.001)
loss_fn = torch.nn.CrossEntropyLoss()

In [249]:
next(rnn.parameters())

Parameter containing:
tensor([[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.3573,  0.4489, -1.5330,  ..., -1.3240,  0.6332,  0.0051],
        [ 0.8510, -0.4441,  0.6427,  ...,  0.3839,  0.9075, -0.9272],
        ...,
        [ 0.0428, -0.6111, -0.1993,  ..., -1.0839,  1.4500, -1.1630],
        [ 0.9792, -1.3638, -0.5155,  ...,  0.2063,  0.7979, -0.0670],
        [-1.6107,  0.5769,  0.9368,  ..., -2.0328,  0.7715, -0.2605]],
       device='cuda:0', requires_grad=True)

In [250]:
len(train_dataset)/8

1250.0

In [251]:
logit = rnn(torch.randint(0,10, (60,8)).to(device))
logit.shape

torch.Size([60, 4])

In [252]:
next(iter(x_test)).shape

torch.Size([100])

In [253]:
rnn(torch.randint(1,10, (9,8)).to(device)).shape

torch.Size([9, 4])

In [254]:
# loss = loss_fn(logit, torch.tensor(3))
# loss

In [255]:
for epoch in range(EPOCHS):
    train_loss, train_acc = 0, 0
    rnn.train()
    
    for label, text in bucket_train_dataloader:
          
        
        label, text = label.to(device), text.to(device)
        y_logits = rnn(text)
        
        loss = loss_fn(y_logits, label)
        train_loss += loss.item()
        
        y_probs = torch.softmax(y_logits, dim=1)
        y_preds = torch.argmax(y_probs, dim=1)
        
        acc = torch.eq(y_preds==label).sum().item()
        train_acc += acc
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
             
        
    train_loss /= len(train_dataset)/8
    train_acc /= len(test_dataset)/8
        
    print(f'Epoch {epoch+1}/{EPOCHS}, Train Loss: {train_loss:.4f}, Train Accuracy: {train_acc:.4f}') 
    
    
    with torch.inference_mode():
        test_loss,test_acc = 0,0
        rnn.eval()
        for test_label, test_text in bucket_test_dataloader:
            
            test_label, test_text = test_label.to(device), test_text.to(device)
            
            y_logits = rnn(test_text)
            
            loss = loss_fn(y_logits, test_label)
            test_loss += loss.item()
            
            acc = torch.eq(y_preds==test_label).sum().item()
            test_acc += acc
            
    train_loss /= len(train_dataset)/8
    train_acc /= len(test_dataset)/8
    
    print(f'Epoch {epoch+1}/{EPOCHS}, Test Loss: {test_loss:.4f}, Test Accuracy: {test_acc:.4f}')
    print('--'*50)
        

Epoch 1/30, Train Loss: 0.0000, Train Accuracy: 0.0000


ValueError: Expected input batch_size (63) to match target batch_size (8).