In [None]:
import os
import torch
import random
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
from PIL import Image
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset

In [None]:
label_mapping = {
    'climbing': 0,
    'diving': 1,
    'fishing': 2,
    'racing': 3,
    'throwing': 4,
    'pole vaulting': 5
}

In [None]:
def select_random_ints(l_bound, u_bound, quantity):
    return set(random.sample(range(l_bound, u_bound + 1), quantity))

In [None]:
train_dir = './data/train'
test_dir = './data/test'

validation_image_paths = []
validation_image_labels = []
train_image_paths = []
train_image_labels = []
test_image_paths = []
test_image_labels = []

In [None]:
climbing_validation_nums = select_random_ints(0, 325, 65)
diving_validation_nums = select_random_ints(0, 519, 104)
fishing_validation_nums = select_random_ints(0, 162, 32)
racing_validation_nums = select_random_ints(0, 335, 67)
throwing_validation_nums = select_random_ints(0, 316, 63)
vaulting_validation_nums = select_random_ints(0, 278, 55)

for num in climbing_validation_nums:
    filename = 'climbing_' + str(num) + '.jpg'
    image_path = os.path.join(train_dir, filename)
    label = 'climbing'

    validation_image_paths.append(image_path)
    validation_image_labels.append(label_mapping[label])

for num in diving_validation_nums:
    filename = 'diving_' + str(num) + '.jpg'
    image_path = os.path.join(train_dir, filename)
    label = 'diving'

    validation_image_paths.append(image_path)
    validation_image_labels.append(label_mapping[label])

for num in fishing_validation_nums:
    filename = 'fishing_' + str(num) + '.jpg'
    image_path = os.path.join(train_dir, filename)
    label = 'fishing'

    validation_image_paths.append(image_path)
    validation_image_labels.append(label_mapping[label])

for num in racing_validation_nums:
    filename = 'racing_' + str(num) + '.jpg'
    image_path = os.path.join(train_dir, filename)
    label = 'racing'

    validation_image_paths.append(image_path)
    validation_image_labels.append(label_mapping[label])

for num in throwing_validation_nums:
    filename = 'throwing_' + str(num) + '.jpg'
    image_path = os.path.join(train_dir, filename)
    label = 'throwing'

    validation_image_paths.append(image_path)
    validation_image_labels.append(label_mapping[label])

for num in vaulting_validation_nums:
    filename = 'pole vaulting_' + str(num) + '.jpg'
    image_path = os.path.join(train_dir, filename)
    label = 'pole vaulting'

    validation_image_paths.append(image_path)
    validation_image_labels.append(label_mapping[label])

In [None]:
for filename in os.listdir(train_dir):
    image_path = os.path.join(train_dir, filename)
    label = filename.split('_')[0]
    number = int(filename.split('_')[1].split('.')[0])

    if (
        (label == 'climbing' and number not in climbing_validation_nums) or
        (label == 'diving' and number not in diving_validation_nums) or
        (label == 'fishing' and number not in fishing_validation_nums) or
        (label == 'racing' and number not in racing_validation_nums) or
        (label == 'throwing' and number not in throwing_validation_nums) or
        (label == 'pole vaulting' and number not in vaulting_validation_nums)
    ):
        train_image_paths.append(image_path)
        train_image_labels.append(label_mapping[label])

In [None]:
for filename in os.listdir(test_dir):
    image_path = os.path.join(test_dir, filename)
    label = filename.split('_')[0]

    test_image_paths.append(image_path)
    test_image_labels.append(label_mapping[label])

In [None]:
transformer = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.404], std=[0.229, 0.224, 0.225])
])

In [None]:
class BARImageDataset(Dataset):
    def __init__(self, image_paths, image_labels, transform):
        self.image_paths = image_paths
        self.image_labels = image_labels
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        image = Image.open(image_path).convert('RGB')
        
        if self.transform:
            image = self.transform(image)
        
        label = self.image_labels[idx]
        return image, label

In [None]:
train_dataset = BARImageDataset(train_image_paths, train_image_labels, transformer)
validation_dataset = BARImageDataset(validation_image_paths, validation_image_labels, transformer)
test_dataset = BARImageDataset(test_image_paths, test_image_labels, transformer)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
validation_loader = DataLoader(validation_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1, downsample=None):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)

        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)

        self.downsample = downsample

    def forward(self, x):
        identity = x

        if self.downsample is not None:
            identity = self.downsample(x)

        out = self.conv1(x)
        out = self.bn1(out)
        out = F.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        out += identity
        out = F.relu(out)

        return out

In [None]:
class ResNet18(nn.Module):
    def __init__(self, num_classes=6):
        super().__init__()
        self.in_channels = 64
        
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        
        self.layer1 = self._make_layer(64, 2)
        self.layer2 = self._make_layer(128, 2, stride=2)
        self.layer3 = self._make_layer(256, 2, stride=2)
        self.layer4 = self._make_layer(512, 2, stride=2)
        
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512, num_classes)

    def _make_layer(self, out_channels, blocks, stride=1):
        downsample = None
        if stride != 1 or self.in_channels != out_channels:
            downsample = nn.Sequential(
                nn.Conv2d(self.in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )
        
        layers = []
        layers.append(ResidualBlock(self.in_channels, out_channels, stride, downsample))
        self.in_channels = out_channels
        for _ in range(1, blocks):
            layers.append(ResidualBlock(out_channels, out_channels))
        
        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)

        return x

In [None]:
class GradCAM:
    def __init__(self, model, target_layer_name):
        self.model = model
        self.target_layer = dict([*model.named_modules()])[target_layer_name]
        self.gradients = None
        self.activations = None

        self.target_layer.register_forward_hook(self._forward_hook)
        self.target_layer.register_backward_hook(self._backward_hook)

    def _forward_hook(self, module, input, output):
        self.activations = output

    def _backward_hook(self, module, grad_input, grad_output):
        self.gradients = grad_output[0]

    def generate_heatmap(self, class_idx):
        weights = torch.mean(self.gradients, dim=[2, 3])

        weights = weights[:, :, None, None]

        heatmap = torch.sum(weights * self.activations, dim=1).squeeze()

        heatmap = F.relu(heatmap)
        heatmap -= heatmap.min()
        heatmap /= heatmap.max()
        return heatmap.cpu().detach().numpy()


In [None]:
def overlay_heatmap(heatmap, image_path, alpha=0.5):
    original_image = Image.open(image_path).convert("RGB")

    heatmap_resized = Image.fromarray(np.uint8(heatmap * 255)).resize(original_image.size, Image.BILINEAR)

    heatmap_colored = plt.cm.jet(np.array(heatmap_resized) / 255.0)[:, :, :3] * 255
    heatmap_colored = Image.fromarray(np.uint8(heatmap_colored))

    overlayed_image = Image.blend(original_image, heatmap_colored, alpha=alpha)
    return overlayed_image

In [None]:
def generate_and_visualize_heatmap(image_path, model, target_layer_name, class_idx=None):
    model.eval()

    transformer = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.404], std=[0.229, 0.224, 0.225])
    ])
    image = Image.open(image_path).convert("RGB")
    input_tensor = transformer(image).unsqueeze(0).to(device)

    grad_cam = GradCAM(model, target_layer_name)

    outputs = model(input_tensor)
    if class_idx is None:
        class_idx = torch.argmax(outputs, dim=1).item()

    model.zero_grad()
    outputs[0, class_idx].backward()

    heatmap = grad_cam.generate_heatmap(class_idx)

    overlayed_image = overlay_heatmap(heatmap, image_path)

    plt.figure(figsize=(10, 5))
    plt.subplot(1, 2, 1)
    plt.title("Original Image")
    plt.imshow(image)
    plt.axis("off")

    plt.subplot(1, 2, 2)
    plt.title("Heat Map Overlay")
    plt.imshow(overlayed_image)
    plt.axis("off")
    plt.show()

In [None]:
model = ResNet18(num_classes=6)

optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

criterion = nn.CrossEntropyLoss()

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

In [None]:
num_epochs = 100

for epoch in range(num_epochs):
    print(f"Epoch {epoch+1}/{num_epochs}")

    model.train()

    running_loss = 0.0
    correct = 0
    total = 0

    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()

        outputs = model(images)

        loss = criterion(outputs, labels)

        loss.backward()
        optimizer.step()

        running_loss += loss.item()

        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    epoch_loss = running_loss / len(train_loader)
    epoch_acc = 100 * correct / total

    print(f"Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.2f}%")


In [None]:
model.eval()

val_running_loss = 0.0
val_correct = 0
val_total = 0

with torch.no_grad():
    for images, labels in validation_loader:
        images, labels = images.to(device), labels.to(device)

        outputs = model(images)
        loss = criterion(outputs, labels)

        val_running_loss += loss.item()

        _, predicted = torch.max(outputs, 1)
        val_total += labels.size(0)
        val_correct += (predicted == labels).sum().item()

val_loss = val_running_loss / len(validation_loader)
val_acc = 100 * val_correct / val_total

print(f"Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_acc:.2f}%")

In [None]:
model.eval()

test_running_loss = 0.0
test_correct = 0
test_total = 0

with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)

        outputs = model(images)
        loss = criterion(outputs, labels)

        test_running_loss += loss.item()

        _, predicted = torch.max(outputs, 1)
        test_total += labels.size(0)
        test_correct += (predicted == labels).sum().item()

test_loss = test_running_loss / len(test_loader)
test_acc = 100 * test_correct / test_total

print(f"Test Loss: {test_loss:.4f}, Test Accuracy: {test_acc:.2f}%")

In [None]:
image_path = './data/test/diving_606.jpg'
target_layer_name = 'layer4'
generate_and_visualize_heatmap(image_path, model, target_layer_name)