In [None]:
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as tvt

In [None]:
import re
from pathlib import Path
from PIL import Image

In [None]:
device = "cuda:1" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")

In [None]:
class HW4Dataset(Dataset):
    LABELS = ['airplane', 'bus', 'cat', 'dog', 'pizza']
    def __init__(self, path, dataset) -> None:
        super().__init__()
        # define a folder
        self.folder = Path('/home/tam') / path
        self.filenames = [] # keep filename
        for filename in self.folder.iterdir():
            if re.findall(r'(\w+)-(\w+)-(\d+)', filename.stem)[0][0] == dataset:
                self.filenames.append(filename)
                
        self.augment = tvt.Compose([
            tvt.ToTensor(),
            tvt.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
            ])
        
    def __len__(self):
        
        return len(self.filenames)

    def __getitem__(self, index):
        filename = self.filenames[index]
        image = Image.open(filename)
        if image.mode != 'RGB':
            image = image.convert(mode='RGB')
        tensor = self.augment(image)
        label = re.findall(r'(\w+)-(\w+)-(\d+)', filename.stem)[0][1]
        label = self.LABELS.index(label)
        return tensor, label

In [None]:
dataset = HW4Dataset('git/ece60146/data/hw4_dataset', 'train')

In [None]:
dataset.filenames;

In [None]:
print(len(dataset))

In [None]:
dataloader = DataLoader(dataset, batch_size=8, shuffle=True)

In [None]:
print(len(dataloader))

In [None]:
next(iter(dataloader));

In [None]:
class HW4Net(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(3, 16, 3)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(16, 32, 3)
        self.fc1 = nn.Linear(6272, 64)
        self.fc2 = nn.Linear(64, 5)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = F.relu(self.conv2(x))
        x = self.pool(x)
        x = x.view(x.shape[0], -1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)

        return x
model = HW4Net().to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(
    model.parameters(), 
    lr=1e-3,
    betas=(0.9, 0.99)
)

In [None]:
model.train()
running_loss = 0
for batch, data in enumerate(dataloader):
    images = data[0].to(device)
    labels = data[1].to(device)
    prediction = model(images)
    loss = loss_fn(prediction, labels)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    running_loss += loss.item()
    if (batch+1) % 100 == 0:
        print(f'Batch {batch+1:4}: Loss = {running_loss/100:5.3f}')
        running_loss = 0