In [9]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader, random_split
from tqdm.notebook import tqdm
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
from torchvision.models.resnet import BasicBlock


In [10]:
class MACNNLoss(nn.Module):
    def __init__(self):
        super(MACNNLoss, self).__init__()
        self.cross_entropy_loss = nn.CrossEntropyLoss()

    def forward(self, predicitions, labels, spatial_attention):
        entropy_loss = self.cross_entropy_loss(predicitions, labels)
        attention_map_1 = spatial_attention[0].flatten(1)
        attention_map_2 = spatial_attention[1].flatten(1)
        attention_map_3 = spatial_attention[2].flatten(1)

        diff_1_2 = torch.mean(F.cosine_similarity(attention_map_1, attention_map_2))
        diff_1_3 = torch.mean(F.cosine_similarity(attention_map_1, attention_map_3))
        diff_2_3 = torch.mean(F.cosine_similarity(attention_map_2, attention_map_3))

        diff_1_2_dis = torch.abs(torch.mean(F.cosine_similarity(attention_map_1, attention_map_2)))
        diff_1_3_dis = torch.abs(torch.mean(F.cosine_similarity(attention_map_1, attention_map_3)))
        diff_2_3_dis = torch.abs(torch.mean(F.cosine_similarity(attention_map_2, attention_map_3)))

        sum_diff = diff_1_2 + diff_1_3 + diff_2_3
        sum_dis = diff_1_2_dis + diff_1_3_dis + diff_2_3_dis

        div_loss = sum_diff
        total_loss = entropy_loss + 0.33*div_loss + 0.15*sum_dis
        return total_loss

class ChannelAttention(nn.Module):
    def __init__(self, channel_amount):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        self.mlp = nn.Sequential(
            nn.Conv2d(channel_amount, channel_amount // 3, kernel_size=1),
            nn.ReLU(),
            nn.Conv2d(channel_amount // 3, channel_amount, kernel_size=1)
                                )
        self.sig = nn.Sigmoid()

    def forward(self, x):
        x_avg = self.avg_pool(x)
        x_max = self.max_pool(x)
        mlp_avg = self.mlp(x_avg)
        mlp_max = self.mlp(x_max)
        combine_mlp = mlp_avg + mlp_max
        x_return = self.sig(combine_mlp)
        return x_return * x


class SpatialAttention(nn.Module):
    def __init__(self, channel_amount):
        super(SpatialAttention, self).__init__()
        self.clayer = nn.Conv2d(channel_amount * 2, 1, kernel_size=7, padding=3)
        self.sig = nn.Sigmoid()

    def forward(self, x):
        max_pool = F.max_pool2d(x, kernel_size=(x.size(2), x.size(3)))
        avg_pool = F.avg_pool2d(x, kernel_size=(x.size(2), x.size(3)))
        combine = torch.cat([max_pool, avg_pool], 1)
        return_x = self.sig(self.clayer(combine))
        return x * return_x

class MyCBAM(nn.Module):
    def __init__(self, channel_amount):
        super(MyCBAM, self).__init__()
        self.channel_attention = ChannelAttention(channel_amount)
        self.spatial_attention = SpatialAttention(channel_amount)

    def forward(self, x):
        x = self.channel_attention(x)
        x = self.spatial_attention(x)
        return x

In [None]:
class MyNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2)
        self.bn1 = nn.BatchNorm2d(64)
        self.pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.block_64 = self.make_big_block(64, count=3)
        self.block_128 = self.make_big_block(128, count=3)
        self.block_256 = self.make_non_trans_big_block(256, count=3)

        self.num_attention_mechanisms = 4
        self.adaptive_pool = nn.AdaptiveAvgPool2d(1)
        # MACNN style attention mechanisms
        self.spatial_attention_module_list = nn.ModuleList([])
        self.fc1_list = nn.ModuleList([])
        self.adaptive_pool_list = nn.ModuleList([])
        for i in range(self.num_attention_mechanisms):
            self.spatial_attention_module_list.append(SpatialAttention(256))
        for i in range(self.num_attention_mechanisms):
            self.adaptive_pool_list.append(nn.AdaptiveAvgPool2d(1))
        for i in range(self.num_attention_mechanisms):
            self.fc1_list.append(nn.Linear(256, 200))

        self.dropout = nn.Dropout(0.5)

    def make_big_block(self, channel_count, count = 3, kernel_size = 3):
        ordering = []
        downsample = nn.Sequential(
            nn.Conv2d(channel_count, channel_count*2, kernel_size=1, stride=2),
            nn.BatchNorm2d(channel_count*2),
        )
        for i in range(count):
            if i != count - 1:
                ordering.append(BasicBlock(channel_count, channel_count))
                ordering.append(MyCBAM(channel_count))
            else:
                ordering.append(BasicBlock(channel_count, channel_count*2, stride=2, downsample=downsample))
                ordering.append(MyCBAM(channel_count*2))

        return nn.Sequential(*ordering)

    def make_non_trans_big_block(self, channel_count, count = 3, kernel_size = 3):
        ordering = []

        for i in range(count):
            ordering.append(BasicBlock(channel_count, channel_count))

        return nn.Sequential(*ordering)


    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.pool(x)
        x = self.block_64(x)
        x = self.block_128(x)
        x = self.block_256(x)
        x = self.adaptive_pool(x)
        attention_mechs = []
        for i in range(self.num_attention_mechanisms):
            attention_mechs.append(self.spatial_attention_module_list[i](x))
        adaptive_pools = []
        for i in range(self.num_attention_mechanisms):
            adaptive_pools.append(self.adaptive_pool_list[i](attention_mechs[i]))
        flattens = []
        for i in range(self.num_attention_mechanisms):
            flattens.append(torch.flatten(adaptive_pools[i], 1))

        outs = []
        for i in range(self.num_attention_mechanisms):
            outs.append(self.fc1_list[i](flattens[i]))
        x = sum(outs)/len(outs)
        return x, attention_mechs


In [None]:
# I looked up how to bold words so I could make test accuracy more visible in output
bold_start = "\033[1m"
bold_end = "\033[0m"

# Looked up this - to enforce using GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Looked up normalization values - ResNet uses these and other models as well - https://discuss.pytorch.org/t/what-does-it-mean-to-normalize-images-for-resnet/96160
train_transform = transforms.Compose([
    transforms.Resize((448, 448)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                     std=[0.229, 0.224, 0.225])
])

transform = transforms.Compose(
    [transforms.Resize((448, 448)),
     transforms.ToTensor(),
     transforms.Normalize(mean=[0.485, 0.456, 0.406],
                     std=[0.229, 0.224, 0.225])])


batch_size = 32
epochs = 100
num_classes = 200

def create_list_of_epochs(epochs):
    list_of_epochs = []
    for epoch in range(epochs):
        list_of_epochs.append(epoch)
    return list_of_epochs


train_dataset = datasets.ImageFolder("CUB_200_2011_reorganized/train",
                                     transform=train_transform)
length_of_inital_train_dataset = int(0.95 * len(train_dataset))
length_of_validation_dataset = len(train_dataset) - length_of_inital_train_dataset
train_dataset, validation_dataset = random_split(
    train_dataset, [length_of_inital_train_dataset, length_of_validation_dataset])
validation_dataset.transform = transform
test_dataset = datasets.ImageFolder("CUB_200_2011_reorganized/test",
                                    transform=transform)

print(f"size of validation dataset {len(validation_dataset)}")

train_loader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size)
validation_loader = DataLoader(validation_dataset, shuffle=False, batch_size=batch_size)
test_loader = DataLoader(test_dataset, shuffle=False, batch_size=batch_size)

model = MyNet()
model = nn.DataParallel(model)
model = model.to(device)


loss_function = MACNNLoss()
# Looked up specific learning rate and weight decay for suggestions
optimizer = optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-4)
validation_loss_tracker = []

list_of_training_loss = []

for epoch in range(epochs):
    model.train()
    running_loss = 0.0
    for images, labels in tqdm(train_loader):
        images = images.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        prediction, maps = model(images)
        loss = loss_function(prediction, labels, maps)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * images.size(0)
    training_loss_for_epoch = running_loss/len(train_dataset)
    list_of_training_loss.append(training_loss_for_epoch)

    model.eval()
    validation_loss = 0
    correct_for_validation = 0
    total_for_validation = 0
    with torch.no_grad():
        for images, labels in validation_loader:
            images, labels = images.to(device), labels.to(device)
            outputs, maps = model(images)
            loss = loss_function(outputs, labels, maps)
            validation_loss += loss.item() * images.size(0)
            nothing, predicted = torch.max(outputs.data, 1)
            total_for_validation += labels.size(0)
            correct_for_validation += (predicted == labels).sum().item()
    avg_val_loss = validation_loss / len(validation_dataset)
    validation_loss_tracker.append(avg_val_loss)
    accuracy = 100 * correct_for_validation / total_for_validation
    print(f"Epoch {epoch} Training Loss {training_loss_for_epoch}. Validation Loss {avg_val_loss} Accuracy {accuracy} correct {correct_for_validation} total {total_for_validation}")


correct = 0
total = 0
model.eval()
with torch.no_grad():
    for images, labels in test_loader:
        model.eval()
        images, labels = images.to(device), labels.to(device)
        outputs, maps = model(images)
        nothing, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'{bold_start}Test Accuracy: {100 * correct // total} % {bold_end}')


list_of_epochs = create_list_of_epochs(epochs)

plt.xlabel("Epochs")
plt.ylabel("Training/Validation loss")
plt.title("Training loss epochs")
plt.plot(list_of_epochs, list_of_training_loss, label="Training Loss")
plt.plot(list_of_epochs, validation_loss_tracker, label="Validation loss")
plt.grid(True)
plt.show()


size of validation dataset 300


  0%|          | 0/178 [00:00<?, ?it/s]

KeyboardInterrupt: 