In [None]:
import torch
import torch.nn as nn
from random import choice, randint

from torchvision import models, utils
from torchvision.datasets import ImageFolder, FakeData
from torchvision.transforms.v2 import Compose, Resize, ToImage, ToDtype

### Configure the Dataset

In [None]:
channel = 3
width = 224
height = 224
fake_images = FakeData(1000, (channel, width, height))

class RandomBackground(torch.nn.Module):
    def forward(self, img):
        # Pick a random background image from the fake dataset
        background = choice(fake_images)[0]

        # Pick a random location, but makesure at least he top left quarter of the image will be visible
        x, y = randint(0, background.width-img.width//2), randint(0, background.height-img.height//2)

        # Put the image of the card on top of the background with the top left corner at the randomly selected loaction
        background.paste(img, (x, y))
        return background


add_background = Compose([
    # resize the images to be half the height of the background image
    Resize((height * 3 // 4, width // 2)),

    # And a random background to the image
    RandomBackground(),

    # Convert the image to a tensor scaled between 0 and 1
    ToImage(),
    ToDtype(torch.float32, scale=True),
])

cards_with_background = ImageFolder('../data/bicycle_cards', transform=add_background)
train_loader = torch.utils.data.DataLoader(cards_with_background, batch_size=8, shuffle=True)
val_loader = torch.utils.data.DataLoader(cards_with_background, batch_size=8, shuffle=False)

### Configure Model

In [None]:
import lightning as L
import torch.optim as optim

class lightningClassifier(L.LightningModule):
    def __init__(self, model) -> None:
        super().__init__()
        self.model = model

    def training_step(self, batch, batch_idx):
        x, y = batch
        z = self.model(x)
        loss = nn.functional.cross_entropy(z, y)
        # self.log("my_loss", loss, prog_bar=True)
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        z = self.model(x)
        acc = (y == z.argmax(-1)).sum() / 52
        self.log("acc", acc, prog_bar=True)
        return acc
    
    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

model_ft = models.resnet18(weights='IMAGENET1K_V1')
num_ftrs = model_ft.fc.in_features
model_ft.fc = nn.Linear(num_ftrs, len(cards_with_background.classes))

classifier = lightningClassifier(model_ft)

### Train Model

In [None]:
trainer = L.Trainer(max_epochs=100)
trainer.fit(model=classifier, train_dataloaders=train_loader)

### Visualize Results

In [None]:
import matplotlib.pyplot as plt
import numpy as np

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
data_loader = torch.utils.data.DataLoader(cards_with_background, batch_size=24, shuffle=True)

def imshow(inp, title=None):
    """Display image for Tensor."""
    inp = inp.numpy().transpose((1, 2, 0))
    inp = np.clip(inp, 0, 1)
    plt.imshow(inp)
    if title is not None:
        plt.title(title)
    plt.pause(0.001)  # pause a bit so that plots are updated

def visualize_model(model, num_images=6):
    was_training = model.training
    model.eval()
    images_so_far = 0
    fig = plt.figure()

    with torch.no_grad():
        for i, (inputs, labels) in enumerate(val_loader):
            # inputs = inputs.to(device)
            # labels = labels.to(device)
            print(inputs.shape)
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)

            for j in range(inputs.size()[0]):
                images_so_far += 1
                ax = plt.subplot(num_images//2, 2, images_so_far)
                ax.axis('off')
                ax.set_title(f'predicted: {cards_with_background.classes[preds[j].item()]}')
                imshow(inputs.cpu().data[j])

                if images_so_far == num_images:
                    model.train(mode=was_training)
                    return
        model.train(mode=was_training)

inputs, classes = next(iter(val_loader))
# Make a grid from batch
out = utils.make_grid(inputs)
visualize_model(model_ft, 6)