In [1]:
import torch
import torch.nn as nn
import torch.optim as optim

import torchvision.datasets as datasets
from torch.utils.data import DataLoader, WeightedRandomSampler

import utils
import networks

In [4]:
train_set = datasets.ImageFolder(root=utils.dirs['train'], transform=utils.transform['train'])
val_set = datasets.ImageFolder(root=utils.dirs['val'], transform=utils.transform['val'])

In [5]:
class_freq = torch.as_tensor(train_set.targets).bincount()
weight = 1 / class_freq
samples_weight = weight[train_set.targets]
sampler = WeightedRandomSampler(samples_weight, len(samples_weight), replacement=True)

In [None]:
resnet18 = networks.get_resnet18(pretrained=True, out_features=4)

In [8]:
utils.fit(
    epochs=10,
    model=resnet18,
    criterion=nn.CrossEntropyLoss(),
    optimizer=optim.Adam(resnet18.parameters(), lr=3e-5),
    train_dl=DataLoader(train_set, batch_size=16, sampler=sampler),
    valid_dl=DataLoader(val_set, batch_size=32)
)

Epoch [ 1/10]: 100%|██████████| 1313/1313 [04:28<00:00,  4.89it/s, acc=0.912, loss=0.12]
  0%|          | 0/1313 [00:00<?, ?it/s]

		Avg training loss: 0.244246	Avg validation loss: 0.056423
		val_loss decreased (inf --> 0.056423)  saving model...


Epoch [ 2/10]: 100%|██████████| 1313/1313 [04:28<00:00,  4.89it/s, acc=0.956, loss=0.0205]
  0%|          | 0/1313 [00:00<?, ?it/s]

		Avg training loss: 0.126862	Avg validation loss: 0.049716
		val_loss decreased (0.056423 --> 0.049716)  saving model...


Epoch [ 3/10]: 100%|██████████| 1313/1313 [04:28<00:00,  4.89it/s, acc=0.967, loss=0.0337]
  0%|          | 0/1313 [00:00<?, ?it/s]

		Avg training loss: 0.096389	Avg validation loss: 0.039880
		val_loss decreased (0.049716 --> 0.039880)  saving model...


Epoch [ 4/10]: 100%|██████████| 1313/1313 [04:29<00:00,  4.87it/s, acc=0.975, loss=0.00805]
  0%|          | 0/1313 [00:00<?, ?it/s]

		Avg training loss: 0.073555	Avg validation loss: 0.026281
		val_loss decreased (0.039880 --> 0.026281)  saving model...


Epoch [ 5/10]: 100%|██████████| 1313/1313 [04:29<00:00,  4.87it/s, acc=0.98, loss=0.0165]
Epoch [ 6/10]:   0%|          | 0/1313 [00:00<?, ?it/s]

		Avg training loss: 0.063953	Avg validation loss: 0.032789


Epoch [ 6/10]: 100%|██████████| 1313/1313 [04:29<00:00,  4.86it/s, acc=0.98, loss=0.026]
  0%|          | 0/1313 [00:00<?, ?it/s]

		Avg training loss: 0.058138	Avg validation loss: 0.043833


Epoch [ 7/10]: 100%|██████████| 1313/1313 [04:30<00:00,  4.85it/s, acc=0.984, loss=0.028]
  0%|          | 0/1313 [00:00<?, ?it/s]

		Avg training loss: 0.048760	Avg validation loss: 0.027824


Epoch [ 8/10]: 100%|██████████| 1313/1313 [04:30<00:00,  4.85it/s, acc=0.986, loss=0.00897]
  0%|          | 0/1313 [00:00<?, ?it/s]

		Avg training loss: 0.040949	Avg validation loss: 0.022712
		val_loss decreased (0.026281 --> 0.022712)  saving model...


Epoch [ 9/10]: 100%|██████████| 1313/1313 [04:30<00:00,  4.86it/s, acc=0.988, loss=0.000634]
  0%|          | 0/1313 [00:00<?, ?it/s]

		Avg training loss: 0.036118	Avg validation loss: 0.035958


Epoch [10/10]: 100%|██████████| 1313/1313 [04:29<00:00,  4.88it/s, acc=0.99, loss=0.0028]


		Avg training loss: 0.031628	Avg validation loss: 0.016979
		val_loss decreased (0.022712 --> 0.016979)  saving model...
