In [3]:

print("Running main.py")

## Imports
from lib.utils import *
from lib.models import MLP
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader
from tqdm import tqdm
import argparse
import os


## Load Data
X,y = load_raw_list([20])
X = X.flatten()[::10]
X = X.reshape(-1,500)
X_train,X_test,y_train,y_test = train_test_split(X,y,test_size=.2,shuffle=True,stratify=y,random_state=0)

Running main.py


In [None]:

y_train_one_hot = one_hot(y_train,num_classes=3).reshape(-1,3).float()
y_test_one_hot = one_hot(y_test,num_classes=3).reshape(-1,3).float()
train_dataloader = DataLoader(TensorDataset(X_train,y_train_one_hot), batch_size=256, shuffle=True)
test_dataloader = DataLoader(TensorDataset(X_test,y_test_one_hot), batch_size=256, shuffle=False)

device = 'cuda'

model = MLP()

model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters())
training_losses = []
testing_losses = []

# Get initial loss
total_loss = 0
for (X,y) in train_dataloader:
    X,y = X.to(device),y.to(device)
    logits = model(X)
    print(logits.dtype,y.dtype)
    loss = criterion(logits,y)
    total_loss += loss.item()
print("initial loss",total_loss/len(train_dataloader))


In [None]:

pbar = tqdm(range(args.epochs))
for epoch in pbar:
    training_loss = 0
    for (X,y) in train_dataloader:
        X,y = X.to(device),y.to(device)
        logits = model(X)
        loss = criterion(logits,y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        training_loss += loss.item()
    training_loss = training_loss/len(train_dataloader)
    training_losses.append(training_loss)
    model.eval()
    testing_loss = 0
    for (X,y) in test_dataloader:
        X,y = X.to(device), y.to(device)
        logits = model(X)
        loss = criterion(logits,y)
        testing_loss += loss.item()
    testing_loss = testing_loss/len(test_dataloader)
    testing_losses.append(testing_loss)
    pbar.set_description(f'\033[94mDev Loss: {training_loss:.4f}\033[93m Val Loss: {testing_loss:.2f}\033[0m')

plt.plot(training_losses)
plt.plot(testing_losses)
plt.savefig('loss.jpg')

from datetime import datetime
current_date = str(datetime.now()).replace(' ','_')
if not os.path.isdir('models'):
    os.system('mkdir models')
torch.save(model.state_dict(), f=f'models/{current_date}.pt')
torch.save(model.state_dict(), f=f'model.pt')
