In [7]:
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
from torch.optim.lr_scheduler import MultiStepLR
from torchvision import datasets, transforms
from facenet_pytorch import InceptionResnetV1, fixed_image_standardization, training
from tqdm import tqdm

In [8]:
batch_size = 32
epochs = 30
workers = 8
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
data_path = 'training_data'

In [9]:
transform = transforms.Compose([np.float32, transforms.ToTensor(), fixed_image_standardization])
dataset = datasets.ImageFolder(data_path, transform=transform)
label_dict = dataset.class_to_idx

In [10]:
model = InceptionResnetV1(pretrained='vggface2', classify=True, num_classes=len(label_dict)).to(device)
weights_path = 'code.pt'

try:
    model.load_state_dict(torch.load(weights_path))
    print(f'Loaded weights from {weights_path}')
except FileNotFoundError:
    print(f'No existing weights found at {weights_path}. Using pretrained weights.')

model.logits = nn.Sequential(nn.Linear(512, 1024),
                            nn.ReLU(),
                            nn.Linear(1024, 1024),
                            nn.ReLU(),
                            nn.Linear(1024, len(label_dict)),
                            nn.Dropout(p=0.2)
                            ).to(device)

No existing weights found at code.pt. Using pretrained weights.


In [11]:
print(model)

InceptionResnetV1(
  (conv2d_1a): BasicConv2d(
    (conv): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
    (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU()
  )
  (conv2d_2a): BasicConv2d(
    (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), bias=False)
    (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU()
  )
  (conv2d_2b): BasicConv2d(
    (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU()
  )
  (maxpool_3a): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2d_3b): BasicConv2d(
    (conv): Conv2d(64, 80, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (bn): BatchNorm2d(80, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU()
  )
  (conv2d_4a): 

In [12]:
optimizer = torch.optim.Adam(model.logits.parameters())
scheduler = MultiStepLR(optimizer, [5, 10])

img_idxs = np.arange(len(dataset))
np.random.shuffle(img_idxs)
train_idxs = img_idxs[:int(0.8 * len(img_idxs))]
val_idxs = img_idxs[int(0.8 * len(img_idxs)):]

train_loader = DataLoader(dataset,
                        num_workers=workers,
                        batch_size=batch_size,
                        sampler=SubsetRandomSampler(train_idxs))
valid_loader = DataLoader(dataset,
                        num_workers=workers,
                        batch_size=batch_size,
                        sampler=SubsetRandomSampler(val_idxs))

loss_fn = torch.nn.CrossEntropyLoss()
metrics = {'acc': training.accuracy}

optimizer.zero_grad()

for epoch in tqdm(range(epochs)):
    print('-'*60)
    print(f'Epoch {epoch + 1}/{epochs}:')
    model.train()
    training.pass_epoch(model=model,
                        loss_fn=loss_fn,
                        loader=train_loader,
                        optimizer=optimizer,
                        scheduler=scheduler,
                        batch_metrics=metrics,
                        show_running=True,
                        device=device)
    model.eval()
    training.pass_epoch(model=model,
                        loss_fn=loss_fn,
                        loader=valid_loader,
                        scheduler=scheduler,
                        batch_metrics=metrics,
                        show_running=True,
                        device=device)

  0%|          | 0/30 [00:00<?, ?it/s]

------------------------------------------------------------
Epoch 1/30:
Train |     1/1    | loss:    0.0000 | acc:    1.0000   
Valid |     1/1    | loss:    0.0000 | acc:    1.0000   


  3%|▎         | 1/30 [00:24<11:49, 24.46s/it]

------------------------------------------------------------
Epoch 2/30:
Train |     1/1    | loss:    0.0000 | acc:    1.0000   
Valid |     1/1    | loss:    0.0000 | acc:    1.0000   


  7%|▋         | 2/30 [00:31<06:37, 14.19s/it]

------------------------------------------------------------
Epoch 3/30:
Train |     1/1    | loss:    0.0000 | acc:    1.0000   
Valid |     1/1    | loss:    0.0000 | acc:    1.0000   


 10%|█         | 3/30 [00:38<04:56, 10.98s/it]

------------------------------------------------------------
Epoch 4/30:
Train |     1/1    | loss:    0.0000 | acc:    1.0000   
Valid |     1/1    | loss:    0.0000 | acc:    1.0000   


 13%|█▎        | 4/30 [00:45<04:06,  9.48s/it]

------------------------------------------------------------
Epoch 5/30:
Train |     1/1    | loss:    0.0000 | acc:    1.0000   
Valid |     1/1    | loss:    0.0000 | acc:    1.0000   


 17%|█▋        | 5/30 [00:53<03:37,  8.70s/it]

------------------------------------------------------------
Epoch 6/30:
Train |     1/1    | loss:    0.0000 | acc:    1.0000   
Valid |     1/1    | loss:    0.0000 | acc:    1.0000   


 20%|██        | 6/30 [01:00<03:16,  8.17s/it]

------------------------------------------------------------
Epoch 7/30:
Train |     1/1    | loss:    0.0000 | acc:    1.0000   
Valid |     1/1    | loss:    0.0000 | acc:    1.0000   


 23%|██▎       | 7/30 [01:07<03:00,  7.84s/it]

------------------------------------------------------------
Epoch 8/30:
Train |     1/1    | loss:    0.0000 | acc:    1.0000   
Valid |     1/1    | loss:    0.0000 | acc:    1.0000   


 27%|██▋       | 8/30 [01:14<02:45,  7.50s/it]

------------------------------------------------------------
Epoch 9/30:
Train |     1/1    | loss:    0.0000 | acc:    1.0000   
Valid |     1/1    | loss:    0.0000 | acc:    1.0000   


 30%|███       | 9/30 [01:21<02:34,  7.38s/it]

------------------------------------------------------------
Epoch 10/30:
Train |     1/1    | loss:    0.0000 | acc:    1.0000   
Valid |     1/1    | loss:    0.0000 | acc:    1.0000   


 33%|███▎      | 10/30 [01:28<02:25,  7.27s/it]

------------------------------------------------------------
Epoch 11/30:
Train |     1/1    | loss:    0.0000 | acc:    1.0000   
Valid |     1/1    | loss:    0.0000 | acc:    1.0000   


 37%|███▋      | 11/30 [01:34<02:12,  6.97s/it]

------------------------------------------------------------
Epoch 12/30:
Train |     1/1    | loss:    0.0000 | acc:    1.0000   
Valid |     1/1    | loss:    0.0000 | acc:    1.0000   


 40%|████      | 12/30 [01:41<02:06,  7.02s/it]

------------------------------------------------------------
Epoch 13/30:
Train |     1/1    | loss:    0.0000 | acc:    1.0000   
Valid |     1/1    | loss:    0.0000 | acc:    1.0000   


 43%|████▎     | 13/30 [01:48<01:58,  6.94s/it]

------------------------------------------------------------
Epoch 14/30:
Train |     1/1    | loss:    0.0000 | acc:    1.0000   
Valid |     1/1    | loss:    0.0000 | acc:    1.0000   


 47%|████▋     | 14/30 [01:54<01:48,  6.76s/it]

------------------------------------------------------------
Epoch 15/30:
Train |     1/1    | loss:    0.0000 | acc:    1.0000   
Valid |     1/1    | loss:    0.0000 | acc:    1.0000   


 50%|█████     | 15/30 [02:01<01:42,  6.85s/it]

------------------------------------------------------------
Epoch 16/30:
Train |     1/1    | loss:    0.0000 | acc:    1.0000   
Valid |     1/1    | loss:    0.0000 | acc:    1.0000   


 53%|█████▎    | 16/30 [02:08<01:35,  6.82s/it]

------------------------------------------------------------
Epoch 17/30:
Train |     1/1    | loss:    0.0000 | acc:    1.0000   
Valid |     1/1    | loss:    0.0000 | acc:    1.0000   


 57%|█████▋    | 17/30 [02:15<01:27,  6.73s/it]

------------------------------------------------------------
Epoch 18/30:
Train |     1/1    | loss:    0.0000 | acc:    1.0000   
Valid |     1/1    | loss:    0.0000 | acc:    1.0000   


 60%|██████    | 18/30 [02:21<01:20,  6.73s/it]

------------------------------------------------------------
Epoch 19/30:
Train |     1/1    | loss:    0.0000 | acc:    1.0000   
Valid |     1/1    | loss:    0.0000 | acc:    1.0000   


 63%|██████▎   | 19/30 [02:29<01:15,  6.86s/it]

------------------------------------------------------------
Epoch 20/30:
Train |     1/1    | loss:    0.0000 | acc:    1.0000   
Valid |     1/1    | loss:    0.0000 | acc:    1.0000   


 67%|██████▋   | 20/30 [02:35<01:07,  6.77s/it]

------------------------------------------------------------
Epoch 21/30:
Train |     1/1    | loss:    0.0000 | acc:    1.0000   
Valid |     1/1    | loss:    0.0000 | acc:    1.0000   


 70%|███████   | 21/30 [02:42<01:00,  6.73s/it]

------------------------------------------------------------
Epoch 22/30:
Train |     1/1    | loss:    0.0000 | acc:    1.0000   
Valid |     1/1    | loss:    0.0000 | acc:    1.0000   


 73%|███████▎  | 22/30 [02:48<00:53,  6.65s/it]

------------------------------------------------------------
Epoch 23/30:
Train |     1/1    | loss:    0.0000 | acc:    1.0000   
Valid |     1/1    | loss:    0.0000 | acc:    1.0000   


 77%|███████▋  | 23/30 [02:55<00:46,  6.65s/it]

------------------------------------------------------------
Epoch 24/30:
Train |     1/1    | loss:    0.0000 | acc:    1.0000   
Valid |     1/1    | loss:    0.0000 | acc:    1.0000   


 80%|████████  | 24/30 [03:02<00:40,  6.83s/it]

------------------------------------------------------------
Epoch 25/30:
Train |     1/1    | loss:    0.0000 | acc:    1.0000   
Valid |     1/1    | loss:    0.0000 | acc:    1.0000   


 83%|████████▎ | 25/30 [03:09<00:33,  6.78s/it]

------------------------------------------------------------
Epoch 26/30:
Train |     1/1    | loss:    0.0000 | acc:    1.0000   
Valid |     1/1    | loss:    0.0000 | acc:    1.0000   


 87%|████████▋ | 26/30 [03:16<00:27,  6.86s/it]

------------------------------------------------------------
Epoch 27/30:
Train |     1/1    | loss:    0.0000 | acc:    1.0000   
Valid |     1/1    | loss:    0.0000 | acc:    1.0000   


 90%|█████████ | 27/30 [03:22<00:20,  6.79s/it]

------------------------------------------------------------
Epoch 28/30:
Train |     1/1    | loss:    0.0000 | acc:    1.0000   
Valid |     1/1    | loss:    0.0000 | acc:    1.0000   


 93%|█████████▎| 28/30 [03:29<00:13,  6.71s/it]

------------------------------------------------------------
Epoch 29/30:
Train |     1/1    | loss:    0.0000 | acc:    1.0000   
Valid |     1/1    | loss:    0.0000 | acc:    1.0000   


 97%|█████████▋| 29/30 [03:35<00:06,  6.62s/it]

------------------------------------------------------------
Epoch 30/30:
Train |     1/1    | loss:    0.0000 | acc:    1.0000   
Valid |     1/1    | loss:    0.0000 | acc:    1.0000   


100%|██████████| 30/30 [03:42<00:00,  7.42s/it]


In [13]:
torch.save(model.state_dict(), 'code.pt')