In [None]:
# Functions for training the model and calculating accuracy

def calculate_accuracy(model, dataloader, device):
    model.eval() # put in evaluation mode, turn off Dropout, BatchNorm uses learned statistics
    total_correct = 0
    total_images = 0
    with torch.no_grad():
        for data in dataloader:
            images, labels = data
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)

            _, predicted = torch.max(outputs.data, 1)
            total_images += labels.size(0)
            total_correct += (predicted == labels).sum().item()

    model_accuracy = total_correct / total_images * 100
    return model_accuracy

def train_model(model, train_loader, criterion, optimizer, scheduler, epochs, device):
    train_losses = []
    train_accuracies = []

    for epoch in range(1, epochs + 1):
        model.train()
        running_loss = 0.0
        epoch_time = time.time()

        for i, data in enumerate(train_loader, 0):
            inputs, labels = data

            inputs = inputs.to(device)
            labels = labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += loss.data.item()

        running_loss /= len(train_loader)
        train_losses.append(running_loss)
        train_accuracy = calculate_accuracy(model, train_loader, device)
        train_accuracies.append(train_accuracy)
        scheduler.step()

        log = "Epoch: {} | Loss: {:.4f} | Training accuracy: {:.3f}% ".format(epoch, running_loss, train_accuracy)
        epoch_time = time.time() - epoch_time
        log += "Epoch Time: {:.2f} secs".format(epoch_time)
        print(log)

    return train_losses, train_accuracies

In [None]:
# Function for choosing parameters to learn

def get_params_to_update(model, feature_extraction, layers_num=0):
  params_to_update = []
  print("Params to learn:")
  if (feature_extraction):
    # Feature extraction
    for name, param in model.named_parameters():
      if "fc" not in name: # Freeze all but last layer
          param.requires_grad = False
      else:
          param.requires_grad = True
      if param.requires_grad == True:
          params_to_update.append(param)
          print("\t",name)
  else:
    # Fine tuning wanted layers
    for param in model.parameters():
      param.requires_grad = False

    # Get all layer names
    param_names = [name for name, _ in model.named_parameters()]
    # Unfreeze the last wanted layers
    for name in param_names[-layers_num:]:
      param = dict(model.named_parameters())[name]
      param.requires_grad = True

    # Check the correct parameters are unfreezed
    params_to_update = []
    for name, param in model.named_parameters():
      if param.requires_grad:
        params_to_update.append(param)
        print("\t",name)

  return params_to_update

In [None]:
# Function for training the model
def train_and_evaluate_with_graph_model(model, train_loader, val_loader, params_to_update, epochs, weight_decay, learning_rate, device):
    criterion = torch.nn.CrossEntropyLoss()
    model.to(device)
    optimizer = torch.optim.Adam(params_to_update, lr=learning_rate, weight_decay=weight_decay)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
    train_losses, train_accuracies = train_model(model, train_loader, criterion, optimizer, scheduler, epochs, device)

    # Plotting
    epochs_range = range(1, epochs + 1)
    plt.figure(figsize=(12, 4))
    plt.subplot(1, 2, 1)
    plt.plot(epochs_range, train_losses, label='Training Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.title('Train Loss Curve')
    plt.locator_params(axis='x', integer=True, tight=True)
    plt.legend()
    plt.show()

    val_accuracy = calculate_accuracy(model, val_loader, device)
    print("Accuracy: {:.3f}%".format(val_accuracy))
    return val_accuracy

In [None]:
# Hyperparameters
batch_size = 1024
learning_rate = 0.001
epochs = 15
weight_decay = 1e-4

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
# DataLoaders
train_loader = DataLoader(trainset, batch_size, shuffle=True)
val_loader = DataLoader(valset, batch_size, shuffle=False)