In [None]:
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as tvt
from torchinfo import summary
from torchvision.ops import generalized_box_iou_loss

In [None]:
import re
import json
from pathlib import Path
from PIL import Image

In [None]:
device = "cuda:1" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")

In [None]:
class HW5Dataset(Dataset):
    '''
    This is a dataset class for this hw.
    '''
    LABELS = ['bus', 'cat', 'pizza'] # Labels for this task
    TARGET_SIZE = 256 # Image size
    def __init__(self, path, dataset) -> None:
        super().__init__()
        # Read meta data that stores ground truth bboxes and labels
        path = path / 'hw5_dataset'
        with open(path / 'metadata.json') as fp:
            self.meta = json.load(fp)
        self.image_folder = path / 'no_box' # Location for raw images
        self.filenames = [] # Keep filename
        for filename in self.image_folder.iterdir():
            if re.findall(r'(\w+)-(\d+)', filename.stem)[0][0] == dataset:
                self.filenames.append(filename)        
        self.augment = tvt.Compose([ 
            tvt.ToTensor(), # Convert to tensor
            tvt.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) # Normalize
            ])
        
    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, index):
        filename = self.filenames[index]
        image = Image.open(filename) # Load image from filename
        tensor = self.augment(image) # Apply transformation
        meta = self.meta[filename.stem] 
        label = self.LABELS.index(meta['label']) # Read label
        return {
            'filename': str(filename), # For debug
            'image': tensor,
            # It comes as [[x1, x2] [y1, y2]] -> transpose then flatten
            # So it will be [x1, y1, x2, y2]
            'bbox': torch.tensor(meta['bbox01'], dtype=torch.float).T.flatten(),
            'label': label
        }

In [None]:
dataset = HW5Dataset(Path('/home/tam/git/ece60146/data'), 'training')

In [None]:
dataloader = DataLoader(dataset, batch_size=8, shuffle=True)

In [None]:
class HW4Net(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(3, 16, 3)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(16, 32, 3)
        self.fc1 = nn.Linear(123008, 64)
        self.fc2 = nn.Linear(64, 3)
        self.fc3 = nn.Linear(64, 4)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = F.relu(self.conv2(x))
        x = self.pool(x)
        x = x.view(x.shape[0], -1)
        x = F.relu(self.fc1(x))
        xc = self.fc2(x)
        xb = self.fc3(x)

        return xc, xb
model = HW4Net().to(device)
loss_fn_c = nn.CrossEntropyLoss()
loss_fn_b = generalized_box_iou_loss
optimizer = torch.optim.Adam(
    model.parameters(), 
    lr=1e-3,
    betas=(0.9, 0.99)
)
summary(model, input_size=(8, 3, 256, 256))

In [None]:
model.train() # Set to training mode
total_loss_c = 0 # Classfication loss
total_loss_b = 0 # Regression loss
for _, data in enumerate(dataloader):
    images = data['image'].to(device)
    labels = data['label'].to(device)
    bboxes = data['bbox'].to(device)
    pred_labels, pred_boxes = model(images) # Get prediction
    loss_c = loss_fn_c(pred_labels, labels) # Calculate classification loss
    total_loss_c += loss_c.item()
    print(bboxes)
    print(pred_boxes)
    loss_b = loss_fn_b(pred_boxes, bboxes) # Calculate regression loss
    total_loss_b += loss_b.item()
    optimizer.zero_grad() # Reset gradient
    loss_c.backward(retain_graph=True) # First backprop need extra setting
    loss_b.backward()
    optimizer.step() # Update parameters
# Average loss over all batches
total_loss_c /= len(dataloader)
total_loss_b /= len(dataloader)

In [None]:
class HW5Net(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        model = [
            nn.ReflectionPad2d(3),
            nn.Conv2d(3, 16, 7),
            nn.BatchNorm2d(16),
            nn.ReLU(True)
        ]
        n_downsampling = 4
        for i in range(n_downsampling):
            mult = 2**i
            model.extend([
                nn.Conv2d(16*mult, 16*mult*2, 3, stride=2, padding=1),
                nn.BatchNorm2d(16*mult*2),
                nn.ReLU(True)
            ])
        self.model = nn.Sequential(*model)
    
    def forward(self, x):
        x = self.model(x)

model = HW5Net().to(device)
summary(model, input_size=(8, 3, 256, 256))

In [None]:
num_layers = len(list(model.parameters()))
print("\nThe number of layers in the model: %d\n\n" % num_layers)