In [4]:
from pathlib import Path
import torch
from torch import nn
import matplotlib.pyplot as plt
from torchvision.datasets import ImageFolder
from torchvision import transforms


device = 'cuda' if torch.cuda.is_available() else 'cpu'

data_root = Path('data')
data_transform = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor()
])

data = ImageFolder(data_root, transform=data_transform)
print(type(data))
print(data.classes)
print(data.class_to_idx)


train_dataset, test_dataset = torch.utils.data.random_split(data, [0.8, 0.2])


<class 'torchvision.datasets.folder.ImageFolder'>
['Cat', 'Dog']
{'Cat': 0, 'Dog': 1}


In [169]:
class tinyVGG(nn.Module):
    def __init__(self, in_shape, hidden_units, out_shape):
        super().__init__()
        self.first_conv = nn.Sequential(
            nn.Conv2d(in_channels=in_shape, 
                      out_channels=hidden_units,
                      kernel_size=3,
                      stride=1),
            nn.ReLU()
        )
        self.second_conv = nn.Sequential(
            nn.Conv2d(in_channels=hidden_units,
                      out_channels=hidden_units,
                      kernel_size=3,
                      stride=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2)
        )
        self.third_conv = nn.Sequential(
            nn.Conv2d(in_channels=hidden_units,
                      out_channels=hidden_units,
                      kernel_size=3,
                      stride=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=hidden_units,
                      out_channels=hidden_units,
                      kernel_size=3,
                      stride=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2)
        )
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(in_features=13 * 13 * hidden_units, 
                      out_features=out_shape)
        )

    def forward(self, x):
        return self.classifier((self.third_conv(self.second_conv(self.first_conv(x)))))

In [170]:
from torch.utils.data.dataloader import DataLoader

from torchinfo import summary

model = tinyVGG(3, 10, 1).to(device)

train_dataloader = DataLoader(train_dataset, 1, True)
test_dataloader = DataLoader(test_dataset, 1, False)


loss_fn = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(params=model.parameters(), lr=0.0001)
summary(model)



Layer (type:depth-idx)                   Param #
tinyVGG                                  --
├─Sequential: 1-1                        --
│    └─Conv2d: 2-1                       280
│    └─ReLU: 2-2                         --
├─Sequential: 1-2                        --
│    └─Conv2d: 2-3                       910
│    └─ReLU: 2-4                         --
│    └─MaxPool2d: 2-5                    --
├─Sequential: 1-3                        --
│    └─Conv2d: 2-6                       910
│    └─ReLU: 2-7                         --
│    └─Conv2d: 2-8                       910
│    └─ReLU: 2-9                         --
│    └─MaxPool2d: 2-10                   --
├─Sequential: 1-4                        --
│    └─Flatten: 2-11                     --
│    └─Linear: 2-12                      1,691
Total params: 4,701
Trainable params: 4,701
Non-trainable params: 0

In [171]:
import tqdm

model.train()
for epoch in tqdm.tqdm(range(20)):
    for n, (X, y) in enumerate(train_dataloader):
        X = X.to(device)
        y = torch.Tensor(y).to(device).type(torch.float32)
        pred_logits = model(X).squeeze(dim=0)

        loss = loss_fn(pred_logits, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

100%|██████████| 20/20 [23:52<00:00, 71.64s/it]


In [173]:
model.eval()
with torch.inference_mode():
    train_acc = 0
    for n, (X, y) in enumerate(test_dataloader):
        X = X.to(device)
        y = torch.Tensor(y).to(device).type(torch.float32)
        pred_logits = model(X)
        pred_class = torch.round(torch.sigmoid(pred_logits))
        train_acc += (pred_class==y).sum().item()/len(pred_logits)
        if n == 1000:
            break
    print(train_acc)

798.0


In [167]:
from pathlib import Path

models_path = Path('models')

model_name = '79acc.pth'

model_path = models_path / model_name
torch.save(obj=model.state_dict(), f=model_path)