In [12]:
# standard data science
import numpy as np
import pandas as pd

# standard pytorch
import torch
import torch.nn as nn
import torch.optim as optim

# PyTorch data utilities
import torchvision
import utils
from resnet import resnet20, resnet32, resnet44

# logging + I/O
import sys, copy, os, shutil, time
from importlib import reload

# FOR NOTEBOOKS ONLY
from tqdm.notebook import tqdm
import torch.nn.functional as F

In [13]:
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

In [14]:
# reload our scripts
reload(utils)

# how many epochs are we training for? no more than 50. also what's our batch size?
epochs, batch_size = 100, 256

# command-line arguments
dataset = "FashionMNIST"
variant = 0 # int(sys.argv[2])
seed = 0 # int(sys.argv[3])

In [17]:
torch.cat([inputs, inputs, inputs], dim=1)

torch.Size([256, 3, 32, 32])

In [18]:
# load our data
trainloader, testloader, data_dim = utils.load_data(dataset, batch_size)

# set a seed, instantiate our model + define loss function, optimizer
torch.manual_seed(seed)
model = resnet44(); model.to(device)
loss_func = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5)

'''
# create a directory in our models folder for this run
if dataset not in os.listdir("models"):
    os.mkdir(f"models/{dataset}")
model_name = f"cnn_num-modules={num_modules}_seed={str(seed).zfill(3)}"
if model_name not in os.listdir(f"models/{dataset}"):
    os.mkdir(f"models/{dataset}/{model_name}")
'''

    
###### OUR TRAINING PIPELINE

# metrics to record for each epoch
train_losses, test_losses, train_accuracies, test_accuracies = [], [], [], []

# iterate through our epochs
for epoch in tqdm(range(epochs)):
    
    # initialize the RUNNING TRAINING LOSS list and TRAIN accuracy list
    running_loss, accuracy = [], [0, 0] # [# correct, total # of samples seen]
    
    # iterate through training data: train + collect train loss
    for data in tqdm(trainloader):
        
        # reset gradient + get our data for this batch
        optimizer.zero_grad()
        inputs, labels = data
        inputs = torch.cat([inputs, inputs, inputs], dim=1) # to make resnet-compatible!
        inputs, labels = inputs.to(device), labels.to(device)
        
        # forward prop, backward prop, make incremental step
        outputs = model(inputs); loss = loss_func(outputs, labels)
        loss.backward(); optimizer.step()

        # update our running_loss (TRAINING!)
        running_loss.append(loss.item())
        
         # calculate + record our train acuracy (TRAINING!)
        with torch.no_grad():

            # get our predictions with the current weights + count no. of correct
            _, predicted = torch.max(outputs.data, 1)
            accuracy[1] += labels.size(0)
            accuracy[0] += (predicted == labels).sum().item()
    
    # compute mean train loss across batches, also test accuracy + mean test loss across batches
    with torch.no_grad():
                
        # add our training loss + accuracy to our lists
        train_losses.append(np.mean(np.array(running_loss)))
        train_accuracies.append((accuracy[0] / accuracy[1]))

        # initialize the RUNNING TEST LOSS + test accuracy list
        running_test_loss, test_accuracy = [], [0, 0]

        # compute test set metrics
        for test_data in testloader:

            # make test predictions + record running loss
            test_images, test_labels = test_data
            test_images = torch.cat([test_images, test_images, test_images], dim=1) # RESNET COMPATIBILITY!
            test_images, test_labels = test_images.to(device), test_labels.to(device)
            test_outputs = model(test_images)
            test_loss = loss_func(test_outputs, test_labels)
            running_test_loss.append(float(test_loss))
            _, test_predicted = torch.max(test_outputs.data, 1)
            test_accuracy[1] += test_labels.size(0)
            test_accuracy[0] += (test_predicted == test_labels).sum().item()

        # add our test loss/accuracies to our lists
        test_losses.append(np.mean(np.array(running_test_loss)))
        test_accuracies.append((test_accuracy[0] / test_accuracy[1]))
        
        print(epoch, test_accuracies[-1])
        
        # save our weights at every epoch! models/{dataset}/{model_name}/
        torch.save(obj=model.state_dict(), f=f"{str(epoch).zfill(3)}.pth")
        
# at the very end, save our logs for this model
logs = pd.DataFrame(data=np.array([list(np.arange(len(train_losses))), train_losses, 
                                   test_losses, train_accuracies, test_accuracies]).T,
                    columns=["epoch", "train_loss", "test_loss", "train_acc", "test_acc"])
logs.to_csv(f"models/{dataset}/{model_name}/logs.csv", index=False)

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/235 [00:00<?, ?it/s]

0 0.8611


  0%|          | 0/235 [00:00<?, ?it/s]

1 0.8914


  0%|          | 0/235 [00:00<?, ?it/s]

2 0.9032


  0%|          | 0/235 [00:00<?, ?it/s]

3 0.9045


  0%|          | 0/235 [00:00<?, ?it/s]

4 0.9074


  0%|          | 0/235 [00:00<?, ?it/s]

5 0.915


  0%|          | 0/235 [00:00<?, ?it/s]

6 0.917


  0%|          | 0/235 [00:00<?, ?it/s]

7 0.9148


  0%|          | 0/235 [00:00<?, ?it/s]

8 0.9229


  0%|          | 0/235 [00:00<?, ?it/s]

9 0.9224


  0%|          | 0/235 [00:00<?, ?it/s]

10 0.922


  0%|          | 0/235 [00:00<?, ?it/s]

11 0.9263


  0%|          | 0/235 [00:00<?, ?it/s]

12 0.923


  0%|          | 0/235 [00:00<?, ?it/s]

13 0.9263


  0%|          | 0/235 [00:00<?, ?it/s]

14 0.9298


  0%|          | 0/235 [00:00<?, ?it/s]

15 0.9283


  0%|          | 0/235 [00:00<?, ?it/s]

16 0.9324


  0%|          | 0/235 [00:00<?, ?it/s]

17 0.9262


  0%|          | 0/235 [00:00<?, ?it/s]

18 0.9271


  0%|          | 0/235 [00:00<?, ?it/s]

19 0.9316


  0%|          | 0/235 [00:00<?, ?it/s]

20 0.9284


  0%|          | 0/235 [00:00<?, ?it/s]

21 0.9302


  0%|          | 0/235 [00:00<?, ?it/s]

22 0.9313


  0%|          | 0/235 [00:00<?, ?it/s]

23 0.9287


  0%|          | 0/235 [00:00<?, ?it/s]

24 0.9312


  0%|          | 0/235 [00:00<?, ?it/s]

25 0.9325


  0%|          | 0/235 [00:00<?, ?it/s]

26 0.933


  0%|          | 0/235 [00:00<?, ?it/s]

27 0.9357


  0%|          | 0/235 [00:00<?, ?it/s]

28 0.9322


  0%|          | 0/235 [00:00<?, ?it/s]

29 0.9318


  0%|          | 0/235 [00:00<?, ?it/s]

30 0.9324


  0%|          | 0/235 [00:00<?, ?it/s]

31 0.9365


  0%|          | 0/235 [00:00<?, ?it/s]

32 0.9359


  0%|          | 0/235 [00:00<?, ?it/s]

33 0.9288


  0%|          | 0/235 [00:00<?, ?it/s]

34 0.9357


  0%|          | 0/235 [00:00<?, ?it/s]

35 0.9379


  0%|          | 0/235 [00:00<?, ?it/s]

36 0.9347


  0%|          | 0/235 [00:00<?, ?it/s]

37 0.9391


  0%|          | 0/235 [00:00<?, ?it/s]

38 0.9362


  0%|          | 0/235 [00:00<?, ?it/s]

39 0.9347


  0%|          | 0/235 [00:00<?, ?it/s]

40 0.9341


  0%|          | 0/235 [00:00<?, ?it/s]

41 0.9334


  0%|          | 0/235 [00:00<?, ?it/s]

42 0.9353


  0%|          | 0/235 [00:00<?, ?it/s]

43 0.9372


  0%|          | 0/235 [00:00<?, ?it/s]

44 0.934


  0%|          | 0/235 [00:00<?, ?it/s]

45 0.9374


  0%|          | 0/235 [00:00<?, ?it/s]

46 0.9325


  0%|          | 0/235 [00:00<?, ?it/s]

47 0.936


  0%|          | 0/235 [00:00<?, ?it/s]

48 0.9344


  0%|          | 0/235 [00:00<?, ?it/s]

49 0.9381


  0%|          | 0/235 [00:00<?, ?it/s]

50 0.9363


  0%|          | 0/235 [00:00<?, ?it/s]

51 0.9375


  0%|          | 0/235 [00:00<?, ?it/s]

52 0.9367


  0%|          | 0/235 [00:00<?, ?it/s]

53 0.9353


  0%|          | 0/235 [00:00<?, ?it/s]

54 0.9371


  0%|          | 0/235 [00:00<?, ?it/s]

55 0.9362


  0%|          | 0/235 [00:00<?, ?it/s]

56 0.9338


  0%|          | 0/235 [00:00<?, ?it/s]

57 0.9364


  0%|          | 0/235 [00:00<?, ?it/s]

58 0.9385


  0%|          | 0/235 [00:00<?, ?it/s]

KeyboardInterrupt: 