In [2]:
import fire
import torch
from torch import nn, optim
import numpy as np
from tqdm.notebook import tqdm
import os
import pandas as pd
import pickle as pkl
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression
from sklearn.linear_model import LogisticRegression

In [3]:
use_cuda = True

In [5]:
data_partition_path = 'dataset/lllt_split/0.pkl'

with open(data_partition_path, 'rb') as fin:
    data_partition = pkl.load(fin)

In [4]:
data_partition[0].keys()

dict_keys(['id', 'prompt', 'iter', 'X', 'Y', 'train'])

In [5]:
with open(data_partition_path, 'rb') as fin:
    test_partition = pkl.load(fin)

test_partition = [x for x in test_partition if not x['train']]

In [6]:
use_cuda = False
X_prompt_train, Y_prompt_train, X_prompt_test, Y_prompt_test = list(), list(), list(), list()

for dat in tqdm(data_partition):

    if dat['train']:
        X_, Y_ = X_prompt_train, Y_prompt_train
    else:
        X_, Y_ = X_prompt_test, Y_prompt_test
    
    X_.append(np.average(dat['prompt'], axis=0).reshape((1, -1)))
    Y_.append(dat['iter'])

X_prompt_train = np.concatenate(X_prompt_train, axis=0)
X_prompt_test = np.concatenate(X_prompt_test, axis=0)
Y_prompt_train, Y_prompt_test = np.array(Y_prompt_train), np.array(Y_prompt_test)

idx = np.random.permutation(len(X_prompt_train))
X_prompt_train = X_prompt_train[idx]
Y_prompt_train = Y_prompt_train[idx]

Y_prompt_train = np.clip(Y_prompt_train, 0, 512)
Y_prompt_test = np.clip(Y_prompt_test, 0, 512)

Y_prompt_train = np.digitize(Y_prompt_train, np.linspace(0, 512, 10))
Y_prompt_test = np.digitize(Y_prompt_test, np.linspace(0, 512, 10))

X_train_torch = torch.tensor(X_prompt_train, dtype=torch.float32)
Y_train_torch = torch.tensor(Y_prompt_train, dtype=torch.long)
X_test_torch = torch.tensor(X_prompt_test, dtype=torch.float32)
Y_test_torch = torch.tensor(Y_prompt_test, dtype=torch.long)

model = nn.Sequential(
    nn.Linear(X_train_torch.shape[1], 512),
    nn.ReLU(),
    nn.Linear(512, 10)
)

if torch.cuda.is_available() and use_cuda:
    model.to('cuda')
    X_train_torch = X_train_torch.to('cuda')
    Y_train_torch = Y_train_torch.to('cuda')
    X_test_torch = X_test_torch.to('cuda')
    Y_test_torch = Y_test_torch.to('cuda')

batch_size = 32
num_epochs = 30
initial_lr = 0.01
final_lr = 0.

# Define loss function (CE) and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=initial_lr, weight_decay=0.01)

# cosine decay of learning rate
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, num_epochs, eta_min=final_lr)

# Train the model
for epoch in range(num_epochs):  # number of epochs
    model.train()
    with tqdm(total=len(X_train_torch), desc=f"Epoch {epoch}") as pbar:
        for batch_idx in range(0, len(X_train_torch), batch_size):
            X_batch = X_train_torch[batch_idx:batch_idx+batch_size]
            Y_batch = Y_train_torch[batch_idx:batch_idx+batch_size]
            optimizer.zero_grad()
            Y_batch_pred = model(X_batch)
            loss = criterion(Y_batch_pred, Y_batch)
            loss.backward()
            optimizer.step()
            pbar.set_description(f"Epoch {epoch} Loss {loss.item()}")
            pbar.update(batch_size)

        lr_scheduler.step()
    
    model.eval()
    with torch.no_grad():
        Y_test_pred = model(X_test_torch)
        Y_test_pred = torch.argmax(Y_test_pred, dim=1)
        err = (Y_test_pred != Y_test_torch).sum().item() / len(Y_test_torch)
        print(f'Error: {err}, Epoch: {epoch}, lr: {optimizer.param_groups[0]["lr"]}')

  0%|          | 0/10000 [00:00<?, ?it/s]

Epoch 0:   0%|          | 0/8000 [00:00<?, ?it/s]

Error: 0.593, Epoch: 0, lr: 0.009972609476841367


Epoch 1:   0%|          | 0/8000 [00:00<?, ?it/s]

Error: 0.584, Epoch: 1, lr: 0.009890738003669028


Epoch 2:   0%|          | 0/8000 [00:00<?, ?it/s]

Error: 0.5675, Epoch: 2, lr: 0.009755282581475769


Epoch 3:   0%|          | 0/8000 [00:00<?, ?it/s]

Error: 0.568, Epoch: 3, lr: 0.009567727288213004


Epoch 4:   0%|          | 0/8000 [00:00<?, ?it/s]

Error: 0.548, Epoch: 4, lr: 0.009330127018922194


Epoch 5:   0%|          | 0/8000 [00:00<?, ?it/s]

Error: 0.539, Epoch: 5, lr: 0.009045084971874737


Epoch 6:   0%|          | 0/8000 [00:00<?, ?it/s]

Error: 0.535, Epoch: 6, lr: 0.00871572412738697


Epoch 7:   0%|          | 0/8000 [00:00<?, ?it/s]

Error: 0.5375, Epoch: 7, lr: 0.008345653031794291


Epoch 8:   0%|          | 0/8000 [00:00<?, ?it/s]

Error: 0.5335, Epoch: 8, lr: 0.007938926261462365


Epoch 9:   0%|          | 0/8000 [00:00<?, ?it/s]

Error: 0.5265, Epoch: 9, lr: 0.007499999999999999


Epoch 10:   0%|          | 0/8000 [00:00<?, ?it/s]

Error: 0.5235, Epoch: 10, lr: 0.007033683215379001


Epoch 11:   0%|          | 0/8000 [00:00<?, ?it/s]

Error: 0.5225, Epoch: 11, lr: 0.0065450849718747366


Epoch 12:   0%|          | 0/8000 [00:00<?, ?it/s]

Error: 0.53, Epoch: 12, lr: 0.006039558454088796


Epoch 13:   0%|          | 0/8000 [00:00<?, ?it/s]

Error: 0.5295, Epoch: 13, lr: 0.0055226423163382676


Epoch 14:   0%|          | 0/8000 [00:00<?, ?it/s]

Error: 0.5265, Epoch: 14, lr: 0.005000000000000001


Epoch 15:   0%|          | 0/8000 [00:00<?, ?it/s]

Error: 0.525, Epoch: 15, lr: 0.0044773576836617335


Epoch 16:   0%|          | 0/8000 [00:00<?, ?it/s]

Error: 0.5235, Epoch: 16, lr: 0.003960441545911203


Epoch 17:   0%|          | 0/8000 [00:00<?, ?it/s]

Error: 0.5205, Epoch: 17, lr: 0.003454915028125263


Epoch 18:   0%|          | 0/8000 [00:00<?, ?it/s]

Error: 0.5235, Epoch: 18, lr: 0.0029663167846209998


Epoch 19:   0%|          | 0/8000 [00:00<?, ?it/s]

Error: 0.5245, Epoch: 19, lr: 0.002500000000000001


Epoch 20:   0%|          | 0/8000 [00:00<?, ?it/s]

Error: 0.515, Epoch: 20, lr: 0.0020610737385376348


Epoch 21:   0%|          | 0/8000 [00:00<?, ?it/s]

Error: 0.515, Epoch: 21, lr: 0.0016543469682057104


Epoch 22:   0%|          | 0/8000 [00:00<?, ?it/s]

Error: 0.511, Epoch: 22, lr: 0.0012842758726130299


Epoch 23:   0%|          | 0/8000 [00:00<?, ?it/s]

Error: 0.5225, Epoch: 23, lr: 0.0009549150281252634


Epoch 24:   0%|          | 0/8000 [00:00<?, ?it/s]

Error: 0.5155, Epoch: 24, lr: 0.0006698729810778065


Epoch 25:   0%|          | 0/8000 [00:00<?, ?it/s]

Error: 0.5095, Epoch: 25, lr: 0.00043227271178699516


Epoch 26:   0%|          | 0/8000 [00:00<?, ?it/s]

Error: 0.5135, Epoch: 26, lr: 0.00024471741852423234


Epoch 27:   0%|          | 0/8000 [00:00<?, ?it/s]

Error: 0.5135, Epoch: 27, lr: 0.00010926199633097157


Epoch 28:   0%|          | 0/8000 [00:00<?, ?it/s]

Error: 0.513, Epoch: 28, lr: 2.7390523158632995e-05


Epoch 29:   0%|          | 0/8000 [00:00<?, ?it/s]

Error: 0.512, Epoch: 29, lr: 0.0


In [7]:
Y_prompt_test_pred = model(X_test_torch).cpu().detach().numpy()
for idx in range(Y_prompt_test_pred.shape[0]):
    test_partition[idx]['prompt_label_uniform_bin'] = Y_prompt_test[idx]
    test_partition[idx]['prompt_pred_uniform_bin'] = Y_prompt_test_pred[idx, :]

In [8]:
X_train, Y_train, X_test, Y_test = list(), list(), list(), list()

for dat in tqdm(data_partition):

    if dat['train']:
        X_, Y_ = X_train, Y_train
    else:
        X_, Y_ = X_test, Y_test

    X_.append(dat['X'])
    Y_.append(dat['Y'])

X_train = np.concatenate(X_train, axis=0)
X_test = np.concatenate(X_test, axis=0)
Y_train, Y_test = np.concatenate(Y_train), np.concatenate(Y_test)

idx = np.random.permutation(len(X_train))
X_train = X_train[idx]
Y_train = Y_train[idx]

Y_train = np.clip(Y_train, 0, 512)
Y_test = np.clip(Y_test, 0, 512)

Y_train = np.digitize(Y_train, np.linspace(0, 512, 10))
Y_test = np.digitize(Y_test, np.linspace(0, 512, 10))

X_train_torch = torch.tensor(X_train, dtype=torch.float32)
Y_train_torch = torch.tensor(Y_train, dtype=torch.long)
X_test_torch = torch.tensor(X_test, dtype=torch.float32)
Y_test_torch = torch.tensor(Y_test, dtype=torch.long)

model = nn.Sequential(
    nn.Linear(X_train_torch.shape[1], 512),
    nn.ReLU(),
    nn.Linear(512, 10)
)

if torch.cuda.is_available() and use_cuda:
    model.to('cuda')
    X_train_torch = X_train_torch.to('cuda')
    Y_train_torch = Y_train_torch.to('cuda')
    X_test_torch = X_test_torch.to('cuda')
    Y_test_torch = Y_test_torch.to('cuda')

batch_size = 10_000
num_epochs = 20
initial_lr = 0.01
final_lr = 0.

# Define loss function (CE) and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=initial_lr, weight_decay=0.01)

# cosine decay of learning rate
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, num_epochs, eta_min=final_lr)

# Train the model
for epoch in range(num_epochs):  # number of epochs
    model.train()
    with tqdm(total=len(X_train_torch), desc=f"Epoch {epoch}") as pbar:
        for batch_idx in range(0, len(X_train_torch), batch_size):
            X_batch = X_train_torch[batch_idx:batch_idx+batch_size]
            Y_batch = Y_train_torch[batch_idx:batch_idx+batch_size]
            optimizer.zero_grad()
            Y_batch_pred = model(X_batch)
            loss = criterion(Y_batch_pred, Y_batch)
            loss.backward()
            optimizer.step()
            pbar.set_description(f"Epoch {epoch} Loss {loss.item()}")
            pbar.update(batch_size)

        lr_scheduler.step()
    
    model.eval()
    with torch.no_grad():
        Y_test_pred = model(X_test_torch)
        Y_test_pred = torch.argmax(Y_test_pred, dim=1)
        err = (Y_test_pred != Y_test_torch).sum().item() / len(Y_test_torch)
        print(f'Error: {err}, Epoch: {epoch}, lr: {optimizer.param_groups[0]["lr"]}')

  0%|          | 0/10000 [00:00<?, ?it/s]

Epoch 0:   0%|          | 0/1688832 [00:00<?, ?it/s]

Error: 0.6350397824577121, Epoch: 0, lr: 0.009938441702975689


Epoch 1:   0%|          | 0/1688832 [00:00<?, ?it/s]

Error: 0.6250019795081319, Epoch: 1, lr: 0.009755282581475769


Epoch 2:   0%|          | 0/1688832 [00:00<?, ?it/s]

Error: 0.6204434550674277, Epoch: 2, lr: 0.00945503262094184


Epoch 3:   0%|          | 0/1688832 [00:00<?, ?it/s]

Error: 0.617875750233582, Epoch: 3, lr: 0.009045084971874739


Epoch 4:   0%|          | 0/1688832 [00:00<?, ?it/s]

Error: 0.6195996190295207, Epoch: 4, lr: 0.008535533905932738


Epoch 5:   0%|          | 0/1688832 [00:00<?, ?it/s]

Error: 0.6207511271884877, Epoch: 5, lr: 0.007938926261462366


Epoch 6:   0%|          | 0/1688832 [00:00<?, ?it/s]

Error: 0.6208167337437136, Epoch: 6, lr: 0.007269952498697735


Epoch 7:   0%|          | 0/1688832 [00:00<?, ?it/s]

Error: 0.6209977173443372, Epoch: 7, lr: 0.006545084971874738


Epoch 8:   0%|          | 0/1688832 [00:00<?, ?it/s]

Error: 0.6226039467998706, Epoch: 8, lr: 0.005782172325201155


Epoch 9:   0%|          | 0/1688832 [00:00<?, ?it/s]

Error: 0.6232554877621151, Epoch: 9, lr: 0.005


Epoch 10:   0%|          | 0/1688832 [00:00<?, ?it/s]

Error: 0.6239432254444844, Epoch: 10, lr: 0.004217827674798847


Epoch 11:   0%|          | 0/1688832 [00:00<?, ?it/s]

Error: 0.6251897499937787, Epoch: 11, lr: 0.0034549150281252636


Epoch 12:   0%|          | 0/1688832 [00:00<?, ?it/s]

Error: 0.6250585368833267, Epoch: 12, lr: 0.0027300475013022664


Epoch 13:   0%|          | 0/1688832 [00:00<?, ?it/s]

Error: 0.6256603073553998, Epoch: 13, lr: 0.0020610737385376348


Epoch 14:   0%|          | 0/1688832 [00:00<?, ?it/s]

Error: 0.6257983073508752, Epoch: 14, lr: 0.0014644660940672626


Epoch 15:   0%|          | 0/1688832 [00:00<?, ?it/s]

Error: 0.625843553251031, Epoch: 15, lr: 0.0009549150281252633


Epoch 16:   0%|          | 0/1688832 [00:00<?, ?it/s]

Error: 0.6256422089953374, Epoch: 16, lr: 0.0005449673790581611


Epoch 17:   0%|          | 0/1688832 [00:00<?, ?it/s]

Error: 0.6260154876716234, Epoch: 17, lr: 0.00024471741852423234


Epoch 18:   0%|          | 0/1688832 [00:00<?, ?it/s]

Error: 0.6264046024129638, Epoch: 18, lr: 6.15582970243117e-05


Epoch 19:   0%|          | 0/1688832 [00:00<?, ?it/s]

Error: 0.6266195204387043, Epoch: 19, lr: 0.0


In [22]:
Y_test_pred = model(X_test_torch).cpu().detach().numpy()
Y_test_buffer = Y_test
Y_test_pred_buffer = Y_test_pred

for dat in test_partition:

    sp = dat['X'].shape[0]
    Y_test_curr = Y_test_buffer[:sp]
    Y_test_pred_curr = Y_test_pred_buffer[:sp, :]

    dat['Y_label_uniform_bin'] = Y_test_curr
    dat['Y_pred_uniform_bin'] = Y_test_pred_curr

    Y_test_buffer = Y_test_buffer[sp:]
    Y_test_pred_buffer = Y_test_pred_buffer[sp:, :]

In [23]:
with open('eval/lllt_split0_testeval.pkl', 'wb') as fout:
    pkl.dump(test_partition, fout)

In [None]:
with open('eval/lllt_split0_testeval.pkl', 'rb') as fin:
    eval_result = pkl.load(fin)