In [1]:
import pandas as pd
import re
import math
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
from torch.nn.utils.rnn import pad_sequence

In [2]:
if torch.cuda.is_available():
    print("CUDA is available!")
    print(f"CUDA device count: {torch.cuda.device_count()}")
    print(f"Current CUDA device name: {torch.cuda.get_device_name(0)}")
else:
    print("CUDA is not available. PyTorch will use the CPU.")

CUDA is available!
CUDA device count: 1
Current CUDA device name: NVIDIA GeForce RTX 4080 SUPER


In [3]:
import sentencepiece as spm

from tqdm import tqdm
import random

file = "amazon_review.csv"

data = pd.read_csv(file).dropna(ignore_index=True)

data["reviewText"].to_csv("amazon_reviews.txt", index=False, header=False)
data["overall"] = data["overall"] -1

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

vocab_size = 16000
start_token = 16000
end_token = 16001

'''
spm.SentencePieceTrainer.train(
    input='amazon_reviews.txt',
    model_prefix='amazon_reviews',
    vocab_size=vocab_size,
    model_type='unigram',
    character_coverage=1.0
)
'''

tok = spm.SentencePieceProcessor(model_file='amazon_reviews.model')

print(tok.encode("This book is amazing!", out_type=int))

ls  = [len(tok.encode(i, out_type=int)) for i in data["reviewText"]]

data["lengths"] = ls

data_trunc = data[data["lengths"]<=80]

v = data_trunc["overall"].value_counts()

balanced_data = (
    data_trunc.groupby("overall")
      .sample(n=min(v), random_state=42)
      .reset_index(drop=True)
)

filter_ = 1


df_shuffled = balanced_data.sample(frac=filter_, random_state=42).reset_index(drop=True)

# Split the shuffled DataFrame
train_size = 0.8
train_df = df_shuffled.sample(frac=train_size, random_state=42).reset_index(drop=True)
test_df = df_shuffled.drop(train_df.index).reset_index(drop=True)

len_train = len(train_df)
len_test = len(test_df)

print("Train Size: ", len(train_df))
print("Test Size: ", len(test_df))

num_rows = 2000

n = [0, 1, 2, 3, 5, 7, 10]
def getEncoding(df, i):
    row_input = [16000] + tok.encode(df["reviewText"].iloc[i], out_type=int) + [16001]
    row_output = tok.encode(str(df["overall"].iloc[i]), out_type = int)
    return row_input, row_output

annos_x = {}
for i in n:
    annos_x[i]=[]
annos_y = []
for i in tqdm(range(num_rows)):
    k = random.randint(0,len(test_df) - 1)
    #input_fin = test_df["reviewText"][k]
    #output_fin = test_df["overall"][k]
    input_fin, output_fin = getEncoding(test_df, i)
    icl_df = test_df.drop(test_df.index[k])
    for j in n:
        s= icl_df.sample(n=j)
        in_ = []
        for shot in range(j):
            inp, outp = getEncoding(s, shot)
            in_.extend(inp)
            in_.extend(outp)
            
            print("in: ", inp)
            print(str(s["overall"].iloc[shot]))
            print("out: ", outp)
        in_.extend(input_fin)
        annos_x[j].append(in_)
    annos_y.extend(output_fin)





[109, 2071, 13, 828, 34]
Train Size:  15396
Test Size:  3849


  2%|▊                                       | 42/2000 [00:00<00:04, 419.82it/s]

in:  [16000, 266, 5, 5033, 75, 6, 23, 5, 397, 369, 57, 11, 25, 3, 1498, 364, 24, 9, 110, 2690, 25, 4, 264, 9, 86, 82, 11, 23, 790, 4, 69, 46, 1794, 11, 8, 103, 73, 7, 18, 141, 256, 5393, 65, 5, 1359, 727, 16001]
out:  [259, 3, 2517]
in:  [16000, 336, 38, 6, 22, 113, 603, 8, 120, 37, 2512, 3, 139, 49, 7, 8, 120, 1091, 25, 55, 136, 930, 494, 148, 3, 1145, 8, 229, 17, 24, 16, 225, 240, 7, 22, 19, 241, 121, 60, 42, 10, 221, 170, 3, 16001]
out:  [303, 3, 2517]
in:  [16000, 1803, 425, 56, 98, 50, 4, 27, 149, 19, 5, 161, 30, 3, 16001]
out:  [468, 3, 2517]
in:  [16000, 255, 148, 6, 55, 53, 117, 124, 16001]
out:  [468, 3, 2517]
in:  [16000, 4397, 46, 1637, 7291, 157, 126, 3, 1598, 10, 29, 27, 200, 1430, 71, 5, 73, 6, 37, 826, 11762, 14, 71, 873, 3, 3655, 665, 1766, 7286, 5, 98, 75, 33, 874, 5, 161, 3, 1511, 15, 65, 2445, 97, 24, 19, 2375, 3, 16001]
out:  [3210, 3, 2517]
in:  [16000, 38, 26, 274, 40, 6, 40, 49, 16001]
out:  [259, 3, 2517]
in:  [16000, 4, 318, 8, 381, 3, 2424, 10, 29, 145, 353, 6

  4%|█▋                                      | 86/2000 [00:00<00:04, 418.56it/s]

in:  [16000, 616, 10, 29, 2427, 34, 16001]
out:  [186, 3, 2517]
in:  [16000, 483, 61, 16001]
out:  [303, 3, 2517]
in:  [16000, 70, 13, 51, 106, 3, 478, 33, 53, 121, 271, 684, 5, 1674, 11, 65, 37, 4, 64, 9, 737, 14, 143, 16001]
out:  [259, 3, 2517]
in:  [16000, 4, 398, 5, 4907, 30, 3, 1057, 437, 186, 276, 53, 117, 3, 16001]
out:  [468, 3, 2517]
in:  [16000, 39, 30, 330, 7152, 7, 22, 1410, 10, 29, 145, 633, 3, 710, 438, 67, 380, 282, 17, 153, 276, 7152, 7, 22, 26, 38, 3, 39, 280, 1868, 4, 10, 76, 743, 3, 45, 284, 7152, 23, 66, 43, 3157, 67, 7152, 10, 16001]
out:  [3210, 3, 2517]
in:  [16000, 4, 660, 10, 29, 179, 21, 406, 6, 23, 22, 19, 87, 11, 5, 12, 2880, 2975, 5316, 8360, 84, 4, 58, 46, 9856, 681, 3, 16001]
out:  [303, 3, 2517]
in:  [16000, 269, 7, 38, 168, 16001]
out:  [303, 3, 2517]
in:  [16000, 45, 86, 82, 23, 19, 2070, 16001]
out:  [468, 3, 2517]
in:  [16000, 4, 10, 108, 299, 206, 35, 59, 50, 11, 1956, 101, 3, 4, 43, 21, 176, 56, 749, 10, 15, 11, 59, 7, 22, 10, 131, 114, 7377, 5, 6

  6%|██▍                                    | 128/2000 [00:00<00:04, 418.71it/s]

in:  [16000, 39, 148, 13, 869, 28, 65, 23, 5, 1019, 13, 672, 53, 103, 7, 212, 11, 92, 4, 219, 3, 706, 14, 10, 15, 556, 3, 4, 183, 14, 11, 8, 59, 41, 25, 14, 13, 3, 16001]
out:  [259, 3, 2517]
in:  [16000, 188, 74, 117, 124, 237, 4, 69, 150, 642, 10, 29, 30, 57, 3, 4, 10, 76, 8, 30, 438, 419, 23, 4, 41, 380, 3, 4, 81, 110, 380, 7, 14, 33, 117, 28, 65, 16001]
out:  [259, 3, 2517]
in:  [16000, 139, 49, 6, 23, 2339, 379, 180, 3, 266, 11, 103, 3116, 351, 3, 16001]
out:  [259, 3, 2517]
in:  [16000, 269, 26, 6, 411, 49, 50, 11, 18, 182, 10, 15, 194, 464, 3, 3022, 144, 90, 105, 3, 139, 636, 32, 210, 3, 16001]
out:  [303, 3, 2517]
in:  [16000, 45, 86, 7, 169, 38, 3, 16001]
out:  [303, 3, 2517]
in:  [16000, 336, 87, 239, 73, 34, 16001]
out:  [303, 3, 2517]
in:  [16000, 109, 13, 18, 335, 54, 17, 79, 15257, 6, 22, 157, 3146, 65, 51, 11, 177, 198, 101, 351, 16, 6392, 6, 11, 6202, 6190, 6, 71, 5, 1544, 6, 4, 2397, 1355, 79, 6, 22, 19, 40, 363, 7, 86, 51, 28, 18, 141, 3, 16001]
out:  [259, 3, 2517]
i

 13%|█████                                  | 258/2000 [00:00<00:04, 425.26it/s]

in:  [16000, 4, 88, 1730, 9, 5, 30, 803, 23, 22, 63, 8, 74, 78, 3, 16001]
out:  [186, 3, 2517]
in:  [16000, 47, 16001]
out:  [303, 3, 2517]
in:  [16000, 2529, 143, 346, 30, 398, 3, 16001]
out:  [186, 3, 2517]
in:  [16000, 1356, 442, 6, 68, 5, 93, 8, 74, 354, 23, 448, 21, 9, 1322, 1143, 171, 123, 35, 98, 211, 27, 16, 5, 518, 3, 16001]
out:  [259, 3, 2517]
in:  [16000, 181, 196, 23, 36, 2212, 11, 253, 339, 3, 2917, 141, 58, 1096, 16, 21, 7, 14, 10, 221, 1097, 9233, 3, 16001]
out:  [259, 3, 2517]
in:  [16000, 45, 80, 5, 464, 918, 3, 4, 169, 43, 22, 86, 43, 455, 656, 85, 908, 19, 128, 7, 22, 10, 131, 8, 74, 769, 56, 92, 4, 763, 3, 45, 10, 131, 172, 8, 120, 702, 6, 672, 56, 8, 74, 1016, 43, 4, 69, 448, 94, 8, 194, 1158, 17, 44, 3, 16001]
out:  [259, 3, 2517]
in:  [16000, 70, 10, 15, 8, 47, 158, 6, 23, 14, 496, 40, 78, 3, 16001]
out:  [468, 3, 2517]
in:  [16000, 45, 19, 167, 53, 78, 6, 23, 5, 346, 93, 35, 51, 6, 345, 4, 88, 250, 36, 1049, 506, 3, 16001]
out:  [468, 3, 2517]
in:  [16000, 717,

 17%|██████▋                                | 345/2000 [00:00<00:03, 425.07it/s]

in:  [16000, 70, 13, 675, 6922, 560, 348, 954, 6, 23, 5, 143, 17, 5, 245, 10349, 2294, 9, 5, 143, 17, 18, 919, 1129, 65, 8, 3400, 3, 70, 193, 82, 283, 4, 401, 14, 69, 27, 648, 77, 3, 16001]
out:  [3210, 3, 2517]
in:  [16000, 4, 137, 5, 156, 23, 22, 19, 140, 53, 78, 3, 16001]
out:  [468, 3, 2517]
in:  [16000, 1134, 233, 73, 6, 146, 35, 105, 16001]
out:  [303, 3, 2517]
in:  [16000, 96, 153, 19, 102, 82, 3, 272, 38, 97, 1913, 257, 371, 26, 23, 36, 53, 371, 97, 25, 15, 92, 5574, 163, 11, 3, 16001]
out:  [259, 3, 2517]
in:  [16000, 95, 969, 133, 5, 104, 5, 115, 1642, 347, 13, 5, 1386, 28, 5, 245, 275, 36, 8294, 51, 3, 16001]
out:  [186, 3, 2517]
in:  [16000, 1492, 474, 2995, 42, 72, 290, 320, 151, 2219, 7631, 4402, 16, 62, 126, 17, 291, 164, 4, 58, 256, 91, 584, 165, 16001]
out:  [3210, 3, 2517]
in:  [16000, 540, 6, 223, 17, 14, 35, 8, 1054, 547, 90, 5, 905, 13, 1711, 6, 193, 455, 7, 13, 455, 3, 16001]
out:  [468, 3, 2517]
in:  [16000, 4952, 403, 12372, 34, 4, 64, 9, 996, 32, 198, 30, 980, 

 22%|████████▍                              | 432/2000 [00:01<00:03, 427.26it/s]

in:  [16000, 6140, 9, 4206, 681, 152, 210, 3, 16001]
out:  [468, 3, 2517]
in:  [16000, 4, 162, 138, 698, 8, 952, 454, 10, 15, 283, 492, 6, 24, 63, 53, 156, 3, 4, 64, 9, 248, 3, 16001]
out:  [468, 3, 2517]
in:  [16000, 109, 158, 33, 53, 2110, 11, 18, 1033, 768, 5216, 271, 400, 1648, 77, 547, 94, 18, 744, 16, 8, 759, 71, 5, 2812, 17, 713, 812, 3, 39, 158, 316, 617, 10, 29, 53, 117, 6, 100, 71, 5, 161, 73, 203, 33, 8, 74, 8262, 71, 5, 297, 3, 716, 8, 1062, 26, 3, 70, 10, 15, 8, 241, 158, 6, 23, 36, 5, 135, 505, 11, 65, 3, 16001]
out:  [468, 3, 2517]
in:  [16000, 139, 220, 50, 6, 344, 19, 8, 120, 78, 23, 117, 191, 11, 8, 3824, 67, 1343, 11559, 6326, 348, 3, 1167, 286, 51, 3, 16001]
out:  [303, 3, 2517]
in:  [16000, 4725, 452, 3, 236, 24, 63, 400, 30, 1015, 6, 262, 5, 349, 13, 532, 174, 3, 16001]
out:  [468, 3, 2517]
in:  [16000, 82, 104, 11, 112, 89, 23, 5, 10473, 146, 53, 4479, 7, 4, 107, 36, 8, 1769, 800, 3, 16001]
out:  [468, 3, 2517]
in:  [16000, 83, 11, 498, 34, 16001]
out:  [303, 3, 

 26%|██████████                             | 518/2000 [00:01<00:03, 426.84it/s]

in:  [16000, 70, 10, 15, 8, 82, 104, 6, 23, 4, 27, 1102, 2966, 5, 73, 913, 3, 39, 143, 585, 17, 5, 335, 694, 13, 37, 414, 25, 14, 193, 43, 5, 1667, 694, 3, 16001]
out:  [186, 3, 2517]
in:  [16000, 147, 23, 1003, 6, 646, 4, 170, 8017, 15, 9, 801, 28, 5, 2634, 9, 189, 94, 2323, 77, 3, 1356, 14321, 7974, 3, 16001]
out:  [186, 3, 2517]
in:  [16000, 4, 64, 6286, 25, 24, 7120, 65, 16, 5, 378, 340, 5, 1093, 3, 725, 56, 25, 4, 102, 43, 14, 3, 16001]
out:  [186, 3, 2517]
in:  [16000, 96, 26, 65, 102, 51, 3, 147, 25, 203, 19, 344, 3, 39, 246, 13, 82, 53, 3, 3, 4, 69, 4208, 91, 165, 3, 16001]
out:  [303, 3, 2517]
in:  [16000, 4, 43, 5, 167, 5, 75, 193, 3, 181, 11, 449, 447, 141, 6, 55, 639, 379, 180, 3, 4, 209, 16, 110, 16, 5860, 29, 7, 22, 19, 217, 3, 16001]
out:  [186, 3, 2517]
in:  [16000, 4, 170, 9, 248, 31, 130, 11, 8, 199, 30, 4, 190, 36, 1505, 24, 19, 1512, 16, 2379, 276, 16001]
out:  [468, 3, 2517]
in:  [16000, 1567, 15, 43, 14, 13, 8, 533, 545, 923, 56, 4153, 16001]
out:  [468, 3, 2517]


 30%|███████████▊                           | 606/2000 [00:01<00:03, 430.67it/s]

in:  [16000, 79, 10010, 35, 278, 6, 695, 9, 65, 35, 10549, 3, 2436, 89, 6, 666, 10549, 130, 3, 3, 16001]
out:  [3210, 3, 2517]
in:  [16000, 4750, 159, 199, 56, 4, 105, 6, 23, 22, 19, 37, 220, 4, 58, 189, 21, 1116, 3, 95, 154, 54, 17, 2203, 1830, 929, 821, 733, 534, 101, 17, 3276, 423, 17, 797, 164, 3, 3, 2521, 15, 1925, 607, 6, 79, 14725, 320, 4880, 7656, 15, 6, 7, 915, 1751, 28, 5, 2541, 6, 2280, 6, 7, 482, 164, 79, 8011, 259, 74, 1638, 3, 16001]
out:  [303, 3, 2517]
in:  [16000, 96, 26, 484, 156, 6, 23, 5, 453, 2033, 15, 321, 18, 254, 7, 5, 1094, 3254, 77, 3, 224, 6, 7303, 6, 5, 26, 13, 8666, 3, 377, 6, 22, 19, 567, 35, 51, 3, 16001]
out:  [3210, 3, 2517]
in:  [16000, 4, 68, 5, 508, 537, 3, 16001]
out:  [259, 3, 2517]
in:  [16000, 608, 3253, 156, 28, 65, 6, 2852, 5, 1118, 26, 13, 5, 558, 3, 431, 138, 439, 26, 450, 73, 3, 16001]
out:  [186, 3, 2517]
in:  [16000, 2104, 5, 346, 246, 23, 613, 9, 189, 14, 6, 617, 10, 29, 8, 528, 4172, 3, 1481, 152, 8, 126, 8, 123, 21, 4, 944, 5, 703, 99, 

 35%|█████████████▌                         | 693/2000 [00:01<00:03, 427.41it/s]IOPub data rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_data_rate_limit`.

Current values:
ServerApp.iopub_data_rate_limit=1000000.0 (bytes/sec)
ServerApp.rate_limit_window=3.0 (secs)

 70%|██████████████████████████▌           | 1400/2000 [00:03<00:01, 440.12it/s]

in:  [16000, 139, 291, 311, 274, 1385, 3, 374, 72, 169, 239, 1797, 67, 3402, 28, 5, 3080, 3, 266, 92, 4, 105, 16, 8, 1182, 75, 3, 16001]
out:  [468, 3, 2517]
in:  [16000, 564, 18, 50, 19, 276, 418, 125, 7, 22, 26, 51, 3, 96, 19, 167, 53, 78, 43, 8, 381, 4, 1408, 1238, 77, 538, 5, 30, 5009, 15, 37, 121, 3, 16001]
out:  [3210, 3, 2517]
in:  [16000, 95, 273, 731, 54, 16, 360, 101, 6, 37, 4, 43, 92, 4, 43, 3, 16001]
out:  [303, 3, 2517]
in:  [16000, 45, 166, 78, 16001]
out:  [468, 3, 2517]
in:  [16000, 266, 657, 349, 6, 90, 88, 14, 330, 800, 10, 15, 156, 23, 388, 35, 8, 2599, 156, 3, 16001]
out:  [3210, 3, 2517]
in:  [16000, 12306, 1164, 6, 4, 190, 209, 16, 291, 506, 23, 1828, 8985, 364, 441, 16001]
out:  [186, 3, 2517]
in:  [16000, 272, 35, 105, 3, 7838, 16, 2076, 33, 195, 32, 21, 3, 16001]
out:  [303, 3, 2517]
in:  [16000, 1582, 4004, 973, 2658, 764, 11002, 16001]
out:  [3210, 3, 2517]
in:  [16000, 4, 88, 5, 466, 348, 4656, 149, 4, 361, 41, 32, 98, 50, 3, 39, 125, 26, 53, 184, 37, 4, 64,

 74%|████████████████████████████▎         | 1489/2000 [00:03<00:01, 435.24it/s]

in:  [16000, 31, 13, 8, 40, 2549, 106, 158, 3, 39, 2149, 3171, 113, 36, 2138, 702, 1137, 5, 730, 3, 39, 271, 13, 455, 7, 567, 3, 4, 69, 36, 183, 31, 130, 9, 612, 3, 16001]
out:  [3210, 3, 2517]
in:  [16000, 267, 332, 24, 11, 8, 886, 9, 41, 9, 59, 3, 45, 86, 38, 3, 39, 390, 320, 5, 868, 7, 1631, 33, 8, 74, 312, 680, 14, 40, 567, 9, 41, 3, 45, 63, 452, 3, 16001]
out:  [186, 3, 2517]
in:  [16000, 3426, 178, 3147, 4534, 106, 134, 1273, 79, 6085, 1894, 3, 860, 10, 29, 91, 14, 3, 16001]
out:  [3210, 3, 2517]
in:  [16000, 109, 13, 720, 8, 322, 2095, 784, 3, 4, 69, 145, 284, 14, 13, 1190, 56, 8, 432, 784, 3, 860, 10, 29, 80, 65, 346, 6, 4, 107, 140, 195, 32, 18, 7424, 210, 3, 4, 69, 284, 18, 115, 601, 13, 25, 22, 792, 14, 35, 322, 3, 16001]
out:  [468, 3, 2517]
in:  [16000, 380, 419, 2701, 146, 43, 380, 419, 747, 3, 1370, 53, 327, 9, 46, 49, 3, 1266, 7, 58, 936, 9, 91, 9476, 5854, 2630, 16, 880, 3, 16001]
out:  [3210, 3, 2517]
in:  [16000, 83, 153, 407, 2291, 115, 9, 878, 3366, 3, 16001]
out: 

 79%|█████████████████████████████▉        | 1577/2000 [00:03<00:00, 433.39it/s]

in:  [16000, 1440, 44, 88, 5, 30, 6, 930, 11, 5, 1124, 6, 353, 3643, 63, 62, 7, 8, 313, 276, 66, 3, 2801, 62, 30, 491, 34, 2951, 4395, 3592, 6, 28, 8, 254, 19, 2011, 6, 5, 1124, 13, 583, 3, 815, 5, 335, 433, 62, 30, 199, 3, 742, 372, 16001]
out:  [259, 3, 2517]
in:  [16000, 272, 13, 217, 7, 5, 297, 13, 484, 491, 7041, 7, 567, 56, 938, 297, 595, 6, 23, 4, 33, 838, 11, 338, 66, 1118, 7, 1826, 71, 5, 297, 3, 16001]
out:  [186, 3, 2517]
in:  [16000, 4461, 8, 59, 525, 3, 39, 542, 2029, 19, 212, 3, 45, 19, 35, 792, 35, 8, 371, 26, 3, 16001]
out:  [186, 3, 2517]
in:  [16000, 137, 8, 30, 331, 7, 140, 53, 78, 16001]
out:  [3210, 3, 2517]
in:  [16000, 39, 143, 402, 99, 671, 16, 5, 2149, 17, 5, 462, 3, 139, 567, 283, 36, 2055, 3, 16001]
out:  [3210, 3, 2517]
in:  [16000, 477, 683, 3542, 4, 398, 153, 25, 63, 4711, 159, 5962, 711, 3, 39, 4033, 1181, 5, 657, 30, 6, 23, 5, 1176, 1151, 33, 906, 346, 3, 543, 125, 33, 1766, 9, 466, 3, 188, 1552, 689, 33, 1766, 9, 1069, 3, 109, 13, 2444, 9, 933, 32, 6, 2

 83%|███████████████████████████████▋      | 1665/2000 [00:03<00:00, 432.84it/s]

in:  [16000, 39, 98, 359, 700, 10, 29, 5330, 6, 31, 104, 99, 8, 3124, 4175, 3, 237, 4, 111, 10, 29, 27, 110, 601, 32, 8855, 14, 913, 6, 7, 14, 10, 15, 36, 102, 25, 16, 13960, 17, 8, 4175, 3, 16001]
out:  [259, 3, 2517]
in:  [16000, 628, 11, 454, 13, 87, 6, 11, 472, 13, 9, 312, 124, 946, 146, 7, 4, 27, 387, 1381, 3, 1114, 5, 3675, 378, 6, 5, 10608, 15, 738, 5, 3422, 15, 26, 28, 192, 98, 930, 402, 16001]
out:  [468, 3, 2517]
in:  [16000, 4, 33, 37, 1421, 9, 41, 18, 836, 6, 37, 57, 665, 3133, 4, 579, 21, 16, 18, 141, 3, 3, 45, 516, 102, 47, 6, 7, 3, 4427, 560, 1851, 3, 4, 1362, 3934, 4, 106, 5, 3, 3208, 13, 28, 9, 210, 2924, 3, 4, 10, 771, 188, 12067, 1340, 293, 16001]
out:  [259, 3, 2517]
in:  [16000, 95, 1432, 3, 4, 88, 500, 1472, 3, 2282, 4580, 6, 58, 996, 3, 147, 5, 1049, 506, 93, 293, 139, 82, 310, 16001]
out:  [303, 3, 2517]
in:  [16000, 496, 199, 56, 98, 572, 16001]
out:  [468, 3, 2517]
in:  [16000, 535, 4, 2419, 675, 156, 6, 1781, 8, 120, 23, 262, 218, 986, 4, 27, 81, 17, 1471, 10

 88%|█████████████████████████████████▎    | 1753/2000 [00:04<00:00, 430.87it/s]

in:  [16000, 606, 156, 3, 16001]
out:  [468, 3, 2517]
in:  [16000, 139, 212, 7, 650, 3, 4, 10, 76, 5698, 7, 3370, 1055, 124, 10758, 14, 69, 26, 176, 3, 16001]
out:  [468, 3, 2517]
in:  [16000, 272, 199, 125, 30, 16001]
out:  [186, 3, 2517]
in:  [16000, 266, 47, 11, 1436, 32, 259, 1044, 212, 141, 3, 16001]
out:  [3210, 3, 2517]
in:  [16000, 259, 808, 4154, 808, 946, 26, 18, 3588, 213, 185, 182, 3, 70, 13, 62, 793, 7, 40, 668, 9, 26, 202, 870, 744, 345, 5, 327, 678, 1689, 3, 39, 890, 13, 1031, 7, 146, 18, 360, 774, 185, 265, 3, 39, 2431, 3495, 1250, 527, 17, 892, 57, 3, 16001]
out:  [468, 3, 2517]
in:  [16000, 181, 61, 6, 300, 6, 49, 6, 38, 26, 3, 16001]
out:  [303, 3, 2517]
in:  [16000, 542, 246, 190, 36, 921, 1217, 16001]
out:  [186, 3, 2517]
in:  [16000, 147, 14, 964, 83, 11, 482, 41, 3, 3, 3927, 7, 1883, 3, 3, 884, 322, 28, 5, 549, 3, 3, 1113, 557, 3, 3, 10906, 7, 3310, 3, 3, 5883, 119, 151, 66, 11, 5338, 3353, 2464, 16001]
out:  [303, 3, 2517]
in:  [16000, 181, 130, 10601, 34, 16001

 92%|██████████████████████████████████▉   | 1841/2000 [00:04<00:00, 432.23it/s]

in:  [16000, 9200, 159, 184, 1065, 5, 297, 17, 18, 1101, 10, 15, 254, 3, 175, 99, 8, 432, 711, 254, 3, 175, 286, 8, 432, 447, 75, 3, 45, 436, 43, 82, 211, 3, 95, 1101, 33, 429, 24, 216, 10, 29, 26, 3, 181, 697, 230, 6, 228, 248, 7, 116, 750, 3, 16001]
out:  [468, 3, 2517]
in:  [16000, 4, 81, 21, 35, 8, 589, 323, 3, 16001]
out:  [186, 3, 2517]
in:  [16000, 206, 19, 61, 208, 686, 807, 65, 8815, 42, 6, 5, 50, 19, 55, 8, 74, 205, 56, 105, 3, 188, 74, 66, 650, 173, 5, 355, 56, 4, 43, 23, 3056, 4, 58, 36, 5307, 21, 3, 16001]
out:  [259, 3, 2517]
in:  [16000, 2761, 119, 9, 91, 200, 54, 16001]
out:  [303, 3, 2517]
in:  [16000, 4, 107, 532, 429, 16, 31, 130, 3, 4, 81, 24, 44, 11, 18, 142, 35, 8, 323, 79, 3, 175, 299, 21, 114, 259, 348, 7, 28, 5, 154, 62, 5, 44, 1023, 894, 174, 906, 3, 39, 75, 13, 4216, 178, 3, 175, 172, 1600, 5, 2035, 520, 17, 202, 75, 9, 998, 14, 7, 894, 8, 793, 17, 3, 266, 8, 195, 1175, 16001]
out:  [3210, 3, 2517]
in:  [16000, 806, 9, 41, 6, 291, 311, 6, 55, 43, 90, 42, 337,

 96%|████████████████████████████████████▋ | 1928/2000 [00:04<00:00, 424.63it/s]

in:  [16000, 1076, 26, 35, 105, 7, 138, 319, 40, 1860, 3, 95, 182, 13, 40, 195, 32, 202, 50, 3, 16001]
out:  [259, 3, 2517]
in:  [16000, 96, 986, 26, 2237, 94, 192, 531, 4, 27, 287, 179, 3, 4, 207, 18, 640, 30, 6, 23, 5, 980, 542, 671, 3541, 980, 28, 18, 542, 7, 5, 986, 19, 675, 650, 71, 5, 615, 846, 22, 26, 217, 28, 297, 3, 45, 63, 1131, 461, 23, 4, 170, 9, 160, 77, 226, 9, 248, 21, 3, 16001]
out:  [468, 3, 2517]
in:  [16000, 45, 86, 82, 23, 19, 2070, 16001]
out:  [468, 3, 2517]
in:  [16000, 70, 10, 15, 136, 1509, 206, 208, 59, 148, 71, 8, 47, 89, 3, 4, 10, 221, 91, 66, 3, 4, 69, 342, 14, 273, 440, 97, 5, 168, 63, 8, 120, 1091, 43, 151, 910, 206, 215, 4, 27, 3, 16001]
out:  [259, 3, 2517]
in:  [16000, 9186, 64, 8, 610, 11223, 3, 4, 4017, 21, 16, 564, 5464, 11, 2167, 8463, 140, 190, 36, 318, 5, 11223, 547, 3, 1481, 4, 10588, 144, 564, 135, 28, 21, 274, 4017, 11, 186, 339, 3, 2090, 648, 34, 16001]
out:  [186, 3, 2517]
in:  [16000, 266, 289, 9, 30, 16001]
out:  [3210, 3, 2517]
in:  [1600

100%|██████████████████████████████████████| 2000/2000 [00:04<00:00, 429.51it/s]

in:  [16000, 1110, 117, 4, 72, 10, 29, 41, 21, 7, 22, 319, 1071, 16001]
out:  [3210, 3, 2517]
in:  [16000, 4, 27, 295, 282, 17, 5, 44, 16, 225, 240, 3, 45, 19, 18, 365, 150, 9, 367, 44, 3, 16001]
out:  [303, 3, 2517]
in:  [16000, 4, 68, 5, 50, 23, 2348, 25, 4, 201, 138, 8, 199, 30, 11, 5, 371, 26, 16001]
out:  [259, 3, 2517]
in:  [16000, 83, 44, 16001]
out:  [303, 3, 2517]
in:  [16000, 751, 13, 87, 3, 236, 42, 138, 231, 5, 50, 1138, 480, 85, 11635, 22, 58, 162, 26, 3, 4, 223, 112, 99, 114, 360, 240, 178, 3, 358, 555, 1933, 202, 98, 572, 9, 5821, 3, 39, 531, 1450, 8153, 7, 69, 4601, 28, 2029, 3, 96, 19, 8, 439, 26, 6, 36, 1118, 3, 16001]
out:  [303, 3, 2517]
in:  [16000, 139, 819, 9, 5, 62, 18, 559, 605, 65, 114, 360, 101, 446, 6, 23, 491, 634, 3, 4, 763, 5, 2868, 245, 177, 5, 98, 625, 6, 249, 4, 41, 14, 28, 18, 135, 549, 6, 7, 24, 19, 11087, 11, 515, 571, 3, 39, 12, 10897, 84, 30, 245, 275, 169, 55, 8, 120, 184, 6, 23, 5, 156, 30, 1430, 1226, 9, 46, 13001, 6, 7, 69, 376, 1863, 5608, 17




In [None]:
class GPTBlock(nn.Module):
    def __init__(self, emb_dim, num_heads, ff_dim, dropout=0.1):
        super().__init__()

        self.ln1 = nn.LayerNorm(emb_dim)
        self.attn = nn.MultiheadAttention(
            emb_dim, num_heads, dropout=dropout, batch_first=True
        )

        self.ln2 = nn.LayerNorm(emb_dim)
        self.ff = nn.Sequential(
            nn.Linear(emb_dim, ff_dim),
            nn.ReLU(),
            nn.Linear(ff_dim, emb_dim),
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, attn_mask=None, key_padding_mask=None):
        # self-attention only
        h = self.ln1(x)
        h, _ = self.attn(
            h, h, h,
            attn_mask=attn_mask,
            key_padding_mask=key_padding_mask
        )
        x = x + self.dropout(h)

        # feedforward
        h2 = self.ln2(x)
        h2 = self.ff(h2)
        x = x + self.dropout(h2)

        return x

class DecoderOnlyTransformer(nn.Module):
    def __init__(self, vocab_size, emb_dim=1024, num_heads=8,
                 num_layers=12, ff_dim=512, max_len=800, dropout=0.5):
        super().__init__()

        self.embedding = nn.Embedding(vocab_size, emb_dim, padding_idx=0)
        self.pos_embed = nn.Embedding(max_len, emb_dim)

        self.blocks = nn.ModuleList([
            GPTBlock(emb_dim, num_heads, ff_dim, dropout)
            for _ in range(num_layers)
        ])

        self.ln_f = nn.LayerNorm(emb_dim)
        self.fc = nn.Linear(emb_dim, 5)

    def forward(self, x, lengths):
        B, L = x.shape

        pos = torch.arange(L, device=x.device).unsqueeze(0)
        h = self.embedding(x) + self.pos_embed(pos)

        causal_mask = torch.triu(
            torch.ones(L, L, device=x.device), diagonal=1
        ).bool()

        pad_mask = (x == 0)

        for block in self.blocks:
            h = block(h, attn_mask=causal_mask, key_padding_mask=pad_mask)

        h = self.ln_f(h)

        idx = (lengths - 1).view(-1, 1, 1).expand(-1, 1, h.size(-1))
        last_hidden = h.gather(1, idx).squeeze(1)

        return self.fc(last_hidden)


def collate_fn(batch):
    sequences, labels = zip(*batch)

    lengths = torch.tensor([len(seq) for seq in sequences])

    padded = pad_sequence(
        sequences,
        batch_first=True, 
        padding_value=0
    )

    labels = torch.stack([l for l in labels])
    return padded, lengths, labels


class ReviewDataset(Dataset):
    def __init__(self, train_dframe, spm_model):
        self.df = train_dframe
        self.sp = spm.SentencePieceProcessor()
        self.sp.load(spm_model)

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        text = self.df.loc[idx, "reviewText"]
        label = torch.tensor(self.df.loc[idx, "overall"], dtype = torch.long)
        ids = self.sp.encode(text, out_type=int)
        ids_encoded = [start_token] + ids + [end_token]
        ids_encoded = torch.tensor(ids_encoded, dtype=torch.long)
        return ids_encoded, label


class ICLDataset(Dataset):
    def __init__(self, shot_dict, y, shots):
        self.ICL_shot = shot_dict[shots]
        self.y = y

    def __len__(self):
        return len(self.ICL_shot)

    def __getitem__(self, idx):
        text = self.ICL_shot[idx]
        label = torch.tensor(self.y[idx], dtype=torch.long)
        
        #label = self.df.loc[idx, "overall"].item()
        #one_hot_encoded_label = F.one_hot(label, num_classes=5).squeeze(0)
        #ids = self.sp.encode(text, out_type=int)
        ids = torch.tensor(text, dtype=torch.long)

        return ids, label

shot_ex = 10
icl_dataset = ICLDataset(annos_x, annos_y, shot_ex)
icl_loader = DataLoader(
    icl_dataset,
    batch_size=50,
    shuffle=True,
    collate_fn=collate_fn
)



train_dataset = ReviewDataset(train_df, "amazon_reviews.model")
train_loader = DataLoader(
    train_dataset,
    batch_size=50,
    shuffle=True,
    collate_fn=collate_fn
)

test_dataset = ReviewDataset(test_df, "amazon_reviews.model")
test_loader = DataLoader(
    test_dataset,
    batch_size=50,
    shuffle=True,
    collate_fn=collate_fn
)

max_len = 0
for padded, lengths, labels in icl_loader:
    batch_max = lengths.max().item()
    max_len = max(max_len, batch_max)
    
for padded, lengths, labels in train_loader:
    batch_max = lengths.max().item()
    max_len = max(max_len, batch_max)



for padded, lengths, labels in test_loader:
    batch_max = lengths.max().item()
    max_len = max(max_len, batch_max)


print("max_len", max_len)




In [None]:
# Loading Data

model = DecoderOnlyTransformer(vocab_size+2, max_len = max_len)
model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()
epochs = 15

total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Trainable parameters: {total_params}")


def train(model, dataloader, optimizer, criterion, epochs):
    model.train()
    track_loss = []
    total_loss = 0
    for i in range(epochs):
        model.train()
        total_loss = 0
        
        train_data_loader = tqdm(dataloader)
        
        for padded, lengths, labels in train_data_loader:
            padded = padded.to(device)
            lengths = lengths.to(device)
            labels = labels.to(device)
    
            optimizer.zero_grad()
            outputs = model(padded, lengths)


            loss = criterion(outputs, labels)
            track_loss.append(loss.item())
            avg = sum(track_loss[-10:])/10
    
            loss.backward()
            optimizer.step()
    
            total_loss += loss.item()
            train_data_loader.set_postfix(loss=avg)
            del outputs, loss
            torch.cuda.empty_cache()
        model.eval()
        total_val_loss = 0
        corr =0
        ll = 0

        with torch.no_grad():
            for padded, lengths, labels in test_loader:
                padded = padded.to(device)
                lengths = lengths.to(device)
                labels = labels.to(device)
                outputs = model(padded, lengths)


                loss = criterion(outputs, labels)

                v = torch.argmax(outputs, dim =1)
                delt = torch.sum(v==labels)
                ll+= len(v)
                corr += delt

                total_val_loss += loss.item()

        print("Epoch", i+1, "track loss: ", total_loss / len(train_data_loader), " val loss: ", total_val_loss / len(test_loader), "val acc: ", corr / len_test )
    return total_loss, track_loss

In [27]:
loss, tracked_loss = train(model, train_loader, optimizer, criterion, epochs)


100%|███████████████████████████████| 308/308 [00:32<00:00,  9.44it/s, loss=1.4]


Epoch 1 track loss:  1.4813482227263512  val loss:  1.3415442612264064 val acc:  tensor(0.4102, device='cuda:0')


100%|██████████████████████████████| 308/308 [00:32<00:00,  9.43it/s, loss=1.29]


Epoch 2 track loss:  1.2899153698961456  val loss:  1.1755781227892095 val acc:  tensor(0.4879, device='cuda:0')


100%|██████████████████████████████| 308/308 [00:32<00:00,  9.56it/s, loss=1.24]


Epoch 3 track loss:  1.195942890334439  val loss:  1.100661570375616 val acc:  tensor(0.5222, device='cuda:0')


100%|███████████████████████████████| 308/308 [00:32<00:00,  9.57it/s, loss=1.1]


Epoch 4 track loss:  1.1215855320552728  val loss:  0.9657062028909659 val acc:  tensor(0.5916, device='cuda:0')


100%|██████████████████████████████| 308/308 [00:32<00:00,  9.39it/s, loss=1.06]


Epoch 5 track loss:  1.0380210839695745  val loss:  0.9021536862695372 val acc:  tensor(0.6277, device='cuda:0')


100%|█████████████████████████████| 308/308 [00:32<00:00,  9.38it/s, loss=0.959]


Epoch 6 track loss:  0.9647899444227095  val loss:  0.8091903739161306 val acc:  tensor(0.6885, device='cuda:0')


100%|█████████████████████████████| 308/308 [00:32<00:00,  9.35it/s, loss=0.845]


Epoch 7 track loss:  0.884493437293288  val loss:  0.7802973726353089 val acc:  tensor(0.6994, device='cuda:0')


100%|█████████████████████████████| 308/308 [00:32<00:00,  9.39it/s, loss=0.784]


Epoch 8 track loss:  0.7898134664668666  val loss:  0.721585747870532 val acc:  tensor(0.7228, device='cuda:0')


100%|███████████████████████████████| 308/308 [00:32<00:00,  9.35it/s, loss=0.8]


Epoch 9 track loss:  0.709128240099201  val loss:  0.6030549205742873 val acc:  tensor(0.7864, device='cuda:0')


100%|█████████████████████████████| 308/308 [00:32<00:00,  9.40it/s, loss=0.673]


Epoch 10 track loss:  0.6416751087292448  val loss:  0.5789106128277717 val acc:  tensor(0.8137, device='cuda:0')


100%|█████████████████████████████| 308/308 [00:33<00:00,  9.28it/s, loss=0.501]


Epoch 11 track loss:  0.572035406607312  val loss:  0.572690664947807 val acc:  tensor(0.8288, device='cuda:0')


100%|█████████████████████████████| 308/308 [00:32<00:00,  9.41it/s, loss=0.517]


Epoch 12 track loss:  0.5218664079711035  val loss:  0.5722493951196794 val acc:  tensor(0.8264, device='cuda:0')


100%|█████████████████████████████| 308/308 [00:32<00:00,  9.36it/s, loss=0.532]


Epoch 13 track loss:  0.47456367400946553  val loss:  0.5642753854974524 val acc:  tensor(0.8452, device='cuda:0')


100%|█████████████████████████████| 308/308 [00:33<00:00,  9.31it/s, loss=0.545]


Epoch 14 track loss:  0.44322648267080256  val loss:  0.5428722953641569 val acc:  tensor(0.8478, device='cuda:0')


100%|█████████████████████████████| 308/308 [00:32<00:00,  9.35it/s, loss=0.497]


Epoch 15 track loss:  0.39783660820745803  val loss:  0.5318011021265736 val acc:  tensor(0.8605, device='cuda:0')


In [36]:
annos_x[1]

[[16000,
  4,
  72,
  10,
  29,
  41,
  24,
  44,
  3,
  45,
  19,
  167,
  53,
  767,
  3,
  4,
  605,
  21,
  547,
  3,
  45,
  19,
  82,
  44,
  23,
  22,
  111,
  10,
  29,
  26,
  3,
  16001,
  3210,
  3,
  2517,
  16000,
  751,
  352,
  121,
  199,
  745,
  9,
  98,
  572,
  17,
  5,
  161,
  30,
  3,
  3325,
  614,
  8,
  125,
  30,
  331,
  9,
  290,
  92,
  25,
  13,
  43,
  3,
  1885,
  7,
  704,
  301,
  217,
  3,
  6702,
  22,
  58,
  334,
  57,
  51,
  7,
  46,
  1262,
  172,
  3,
  16001],
 [16000,
  1507,
  7,
  40,
  49,
  134,
  43,
  351,
  28,
  1158,
  34,
  4,
  2943,
  41,
  5,
  161,
  30,
  116,
  1195,
  92,
  5,
  75,
  6,
  37,
  911,
  7,
  8,
  74,
  429,
  25,
  5,
  467,
  28,
  297,
  19,
  8,
  74,
  184,
  7,
  631,
  26,
  13,
  66,
  327,
  3,
  39,
  143,
  671,
  462,
  402,
  1075,
  7,
  146,
  38,
  134,
  116,
  1580,
  3,
  16001,
  186,
  3,
  2517,
  16000,
  45,
  19,
  375,
  49,
  7,
  1348,
  1029,
  2126,
  22,
  86,
  38,
  28,
  65,
 

In [28]:
torch.cuda.empty_cache()

In [29]:
model.eval()
total_val_loss = 0
corr =0
ll = 0

In [30]:
model.eval()

def predict(text):
    ids = torch.tensor(tok.encode(text, out_type=int)).unsqueeze(0)
    lengths = torch.tensor([len(ids)])
    ids = ids.to(device)
    lengths = lengths.to(device)
    with torch.no_grad():
        out = model(ids, lengths)
        ao = torch.argmax(out)
    return ao

#print("Prediction for <I love this book>", predict("I love this book"))

model.eval()
total_val_loss = 0
corr =0

with torch.no_grad():
    for padded, lengths, labels in test_loader:
        padded = padded.to(device)
        lengths = lengths.to(device)
        labels = labels.to(device)
        outputs = model(padded, lengths)


        loss = criterion(outputs, labels)

        v = torch.argmax(outputs, dim =1)
        delt = torch.sum(v==labels)
        corr += delt

        total_val_loss += loss.item()
        del outputs, loss
        torch.cuda.empty_cache()

print("val acc: ", corr / len_test )

val acc:  tensor(0.8605, device='cuda:0')


In [31]:

total_train_loss = 0
corr =0

with torch.no_grad():
    for padded, lengths, labels in train_loader:
        padded = padded.to(device)
        lengths = lengths.to(device)
        labels = labels.to(device)
        outputs = model(padded, lengths)


        loss = criterion(outputs, labels)

        v = torch.argmax(outputs, dim =1)
        delt = torch.sum(v==labels)
        corr += delt

        total_train_loss += loss.item()
        del outputs, loss
        torch.cuda.empty_cache()

print("train acc: ", corr / len_train)


train acc:  tensor(0.9472, device='cuda:0')


In [32]:
class ICLDataset(Dataset):
    def __init__(self, shot_dict, y, shots):
        self.ICL_shot = shot_dict[shots]
        self.y = y

    def __len__(self):
        return len(self.ICL_shot)

    def __getitem__(self, idx):
        text = self.ICL_shot[idx]
        label = torch.tensor(self.y[idx], dtype=torch.long)
        
        #label = self.df.loc[idx, "overall"].item()
        #one_hot_encoded_label = F.one_hot(label, num_classes=5).squeeze(0)
        #ids = self.sp.encode(text, out_type=int)
        ids = torch.tensor(text, dtype=torch.long)

        return ids, label



In [33]:
n

[0, 1, 2, 3, 5, 7, 10]

In [35]:
for shot_ex in n:
    icl_dataset = ICLDataset(annos_x, annos_y, shot_ex)
    icl_loader = DataLoader(
        icl_dataset,
        batch_size=50,
        shuffle=True,
        collate_fn=collate_fn
    )
    
    max_len = 0
    for padded, lengths, labels in icl_loader:
        batch_max = lengths.max().item()
        max_len = max(max_len, batch_max)
    
    print("max_len", max_len)
    
    """
    ll = 0
    corr = 0
    total_icl_loss = 0
    with torch.no_grad():
        for padded, lengths, labels in icl_loader:
            lengths = lengths.to(device)
            labels =  labels.to(device)
            padded = padded.to(device)
    
            outputs = model(padded, lengths)
            v = torch.argmax(outputs, dim = 1)
            corr += sum(v==labels)
            ll += len(v)
            loss = criterion(outputs, labels)
            total_icl_loss += loss.item()
    
    torch.cuda.empty_cache()
    avg_icl = total_icl_loss / len(icl_loader)
    icl_acc = corr / ll
    print(shot_ex, avg_icl, icl_acc)
    """

max_len 82
max_len 158
max_len 221
max_len 263
max_len 407
max_len 445
max_len 588


In [None]:
max_len = 0
for padded, lengths, labels in icl_loader:
    batch_max = lengths.max().item()
    max_len = max(max_len, batch_max)

print("max_len", max_len)


In [10]:
class ICLDataset(Dataset):
    def __init__(self, shot_dict, y, shots):
        self.ICL_shot = shot_dict[shots]
        self.y = y

    def __len__(self):
        return len(self.ICL_shot)

    def __getitem__(self, idx):
        text = self.ICL_shot[idx]
        label = torch.tensor(self.y[idx], dtype=torch.long)
        
        #label = self.df.loc[idx, "overall"].item()
        #one_hot_encoded_label = F.one_hot(label, num_classes=5).squeeze(0)
        #ids = self.sp.encode(text, out_type=int)
        ids = torch.tensor(text, dtype=torch.long)

        return ids, label



In [11]:
n

[0, 1, 2, 3, 5, 7, 10]