In [1]:
# %load train.py
import time
import torch
import torch.nn as nn
from models import VGG16
from dataset import IMAGE_Dataset
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import transforms
from pathlib import Path
import copy

##REPRODUCIBILITY
torch.manual_seed(123)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

#args = parse_args()
#CUDA_DEVICES = args.cuda_devices
#DATASET_ROOT = args.path
CUDA_DEVICES = 0
DATASET_ROOT = './seg_train'

def train():
	data_transform = transforms.Compose([
		transforms.Resize((224,224)),
		transforms.ToTensor(),
		transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
	])
	#print(DATASET_ROOT)
	train_set = IMAGE_Dataset(Path(DATASET_ROOT), data_transform)
	data_loader = DataLoader(dataset=train_set, batch_size=10, shuffle=True, num_workers=1)
	#print(train_set.num_classes)
	model = VGG16(num_classes=train_set.num_classes)
	model = model.cuda(CUDA_DEVICES)
	model.train()

	best_model_params = copy.deepcopy(model.state_dict())
	best_acc = 0.0
	num_epochs = 20
	criterion = nn.CrossEntropyLoss()
	optimizer = torch.optim.SGD(params=model.parameters(), lr=0.01, momentum=0.9)

	for epoch in range(num_epochs):
		print(f'Epoch: {epoch + 1}/{num_epochs}')
		print('-' * len(f'Epoch: {epoch + 1}/{num_epochs}'))

		training_loss = 0.0
		training_corrects = 0

		for i, (inputs, labels) in enumerate(data_loader):
			inputs = Variable(inputs.cuda(CUDA_DEVICES))
			labels = Variable(labels.cuda(CUDA_DEVICES))			

			optimizer.zero_grad()

			outputs = model(inputs)
			_, preds = torch.max(outputs.data, 1)
			loss = criterion(outputs, labels)

			loss.backward()
			optimizer.step()

			training_loss += loss.item() * inputs.size(0)
			#revise loss.data[0]-->loss.item()
			training_corrects += torch.sum(preds == labels.data)
			#print(f'training_corrects: {training_corrects}')

		training_loss = training_loss / len(train_set)
		training_acc =training_corrects.double() /len(train_set)
		#print(training_acc.type())
		#print(f'training_corrects: {training_corrects}\tlen(train_set):{len(train_set)}\n')
		print(f'Training loss: {training_loss:.4f}\taccuracy: {training_acc:.4f}\n')

		if training_acc > best_acc:
			best_acc = training_acc
			best_model_params = copy.deepcopy(model.state_dict())

	model.load_state_dict(best_model_params)
	torch.save(model, f'model-{best_acc:.02f}-best_train_acc.pth')


if __name__ == '__main__':
    beginTime = time.time()
    train()
    endTime = time.time()
    print('Total time: {:5.2f}s'.format(endTime - beginTime))

Epoch: 1/20
-----------
Training loss: 1.5488	accuracy: 0.3236

Epoch: 2/20
-----------
Training loss: 1.0641	accuracy: 0.5672

Epoch: 3/20
-----------
Training loss: 0.9829	accuracy: 0.6032

Epoch: 4/20
-----------
Training loss: 0.9018	accuracy: 0.6415

Epoch: 5/20
-----------
Training loss: 0.8335	accuracy: 0.6764

Epoch: 6/20
-----------
Training loss: 0.8126	accuracy: 0.6852

Epoch: 7/20
-----------
Training loss: 0.7573	accuracy: 0.7129

Epoch: 8/20
-----------
Training loss: 0.6978	accuracy: 0.7329

Epoch: 9/20
-----------
Training loss: 0.6564	accuracy: 0.7570

Epoch: 10/20
------------
Training loss: 0.6082	accuracy: 0.7774

Epoch: 11/20
------------
Training loss: 0.5782	accuracy: 0.7880

Epoch: 12/20
------------
Training loss: 0.5498	accuracy: 0.7995

Epoch: 13/20
------------
Training loss: 0.5237	accuracy: 0.8092

Epoch: 14/20
------------
Training loss: 0.5336	accuracy: 0.8052

Epoch: 15/20
------------
Training loss: 0.5333	accuracy: 0.8078

Epoch: 16/20
------------
Tr