<a href="https://colab.research.google.com/github/sandonli/Gecko-Binary-Classifier/blob/main/training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
xform = transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor()])
dataset_full = datasets.ImageFolder('/content/gdrive/MyDrive/geckos', transform=xform)

In [None]:
n_all = len(dataset_full)
n_train = int(0.8 * n_all)
n_test = n_all - n_train
rng = torch.Generator().manual_seed(2000)
dataset_train, dataset_test = torch.utils.data.random_split(dataset_full, [n_train, n_test], rng)
loader_train = torch.utils.data.DataLoader(dataset_train, batch_size=4, shuffle=True)
loader_test = torch.utils.data.DataLoader(dataset_test, batch_size=4, shuffle=True)

In [None]:
model = models.resnet18(pretrained=True)
model.fc = nn.Linear(model.fc.in_features, 2)
torch.nn.init.xavier_uniform_(model.fc.weight)
device = torch.device('cuda:0')
model = model.to(device)

In [None]:
criterion = nn.CrossEntropyLoss()

def run_test(model):
  nsamples_test = len(dataset_test)
  loss, correct = 0, 0
  model.eval()
  with torch.no_grad():
    for samples, labels in loader_test:
      samples = samples.to(device)
      labels = labels.to(device)
      outs = model(samples)
      loss += criterion(outs, labels)
      _, preds = torch.max(outs.detach(), 1)
      correct_mask = preds == labels
      correct += correct_mask.sum(0).item()
  return loss / nsamples_test, correct / nsamples_test

In [None]:
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)

In [None]:
def run_train(model, opt, sched):
  nsamples_train = len(dataset_train)
  loss_sofar, correct_sofar = 0, 0
  model.train()
  with torch.enable_grad():
    for samples, labels in loader_train:
      samples = samples.to(device)
      labels = labels.to(device)
      opt.zero_grad()
      outs = model(samples)
      _, preds = torch.max(outs.detach(), 1)
      loss = criterion(outs, labels)
      loss.backward()
      opt.step()
      loss_sofar += loss.item() * samples.size(0)
      correct_sofar += torch.sum(preds == labels.detach())
  sched.step()
  return loss_sofar / nsamples_train, correct_sofar / nsamples_train

def run_all(model, optimizer, scheduler, n_epochs):
  for epoch in range(n_epochs):
    loss_train, acc_train = run_train(model, optimizer, scheduler)
    loss_test, acc_test = run_test(model)
    print(f"epoch {epoch}: train loss {loss_train:.4f} acc { acc_train:.4f}, test loss {loss_test:.4f} acc {acc_test:.4f}")