In [1]:
import sys
sys.path.append("./attention-module/MODELS")

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader, random_split
from tqdm.notebook import tqdm
import torch.nn.functional as F
import matplotlib.pyplot as plt
from cbam import CBAM
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
from torchvision.models.resnet import BasicBlock

In [3]:
# BasicBlock taken from Resnet Paper https://arxiv.org/pdf/1512.03385

class MyNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 64, 3)
        self.bn1 = nn.BatchNorm2d(64)
        self.cbam1 = CBAM(64)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(64, 128, 3)
        self.bn2 = nn.BatchNorm2d(128)
        self.cbam2 = CBAM(128)

        self.resblock1 = BasicBlock(128, 128)
        self.cbamres1 = CBAM(128)
        self.resblock1a = BasicBlock(128,128)
        self.resblock1b = BasicBlock(128,128)
        self.resblock1c = BasicBlock(128,128)
        downsample = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=1, stride=2, bias=False),
            nn.BatchNorm2d(256),
        )

        self.resblock2 = BasicBlock(128, 256, stride=2, downsample=downsample)
        self.cbamres2 = CBAM(256)

      
        self.adaptive_pool = nn.AdaptiveAvgPool2d((1,1))
        self.fc1 = nn.Linear(256, 200)
        self.dropout = nn.Dropout(0.5)


    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        #x = self.cbam1(x)
        x = self.pool(x)
        x = F.relu(self.bn2(self.conv2(x)))
        #x = self.cbam2(x)
        x = self.pool(x)
        x = self.resblock1(x)
        x = self.resblock1a(x)
        x = self.pool(x)
        x = self.resblock1b(x)
        x = self.pool(x)
        x = self.resblock1c(x)
        x = self.cbamres1(x)
        x = self.pool(x)
        x = self.resblock2(x)
        x = self.cbamres2(x)
        x = self.dropout(x)
        x = self.adaptive_pool(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        return x

In [4]:
# I looked up how to bold words so I could make test accuracy more visible in output
bold_start = "\033[1m"
bold_end = "\033[0m"


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

transform = transforms.Compose(
    [transforms.Resize((224, 224)),
     transforms.ToTensor(),
     transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])])


batch_size = 32
epochs = 50
num_classes = 200

def create_list_of_epochs(epochs):
    list_of_epochs = []
    for epoch in range(epochs):
        list_of_epochs.append(epoch)
    return list_of_epochs


train_dataset = datasets.ImageFolder("CUB_200_2011_reorganized/train",
                                     transform=transform)
length_of_inital_train_dataset = int(0.95 * len(train_dataset))
length_of_validation_dataset = len(train_dataset) - length_of_inital_train_dataset
train_dataset, validation_dataset = random_split(
    train_dataset, [length_of_inital_train_dataset, length_of_validation_dataset])
test_dataset = datasets.ImageFolder("CUB_200_2011_reorganized/test",
                                    transform=transform)

print(f"size of validation dataset {len(validation_dataset)}")

train_loader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size)
validation_loader = DataLoader(validation_dataset, shuffle=False, batch_size=batch_size)
test_loader = DataLoader(test_dataset, shuffle=False, batch_size=batch_size)

model = MyNet()
model = nn.DataParallel(model)
model = model.to(device)

loss_function = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters())
#scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2, eta_min=1e-6)
#scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)

validation_loss_tracker = []

list_of_training_loss = []

#model.summary()
for epoch in range(epochs):
    model.train()
    running_loss = 0.0
    for images, labels in tqdm(train_loader):
        images = images.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        prediction = model(images)
        loss = loss_function(prediction, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * images.size(0)
    training_loss_for_epoch = running_loss/len(train_dataset)
    list_of_training_loss.append(training_loss_for_epoch)

    model.eval()
    validation_loss = 0
    correct_for_validation = 0
    total_for_validation = 0
    with torch.no_grad():
        for images, labels in validation_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = loss_function(outputs, labels)
            validation_loss += loss.item() * images.size(0)
            nothing, predicted = torch.max(outputs.data, 1)
            total_for_validation += labels.size(0)
            correct_for_validation += (predicted == labels).sum().item()
    avg_val_loss = validation_loss / len(validation_dataset)
    validation_loss_tracker.append(avg_val_loss)
    accuracy = 100 * correct_for_validation / total_for_validation
    print(f"Epoch {epoch} Training Loss {training_loss_for_epoch}. Validation Loss {avg_val_loss} Accuracy {accuracy}")
    #scheduler.step()


correct = 0
total = 0
model.eval()
with torch.no_grad():
    for images, labels in test_loader:
        model.eval()
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        nothing, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'{bold_start}Test Accuracy: {100 * correct // total} % {bold_end}')


list_of_epochs = create_list_of_epochs(epochs)

plt.xlabel("Epochs")
plt.ylabel("Training loss")
plt.title("Training loss epochs")
plt.plot(list_of_epochs, list_of_training_loss, label="Training Loss")
plt.plot(list_of_epochs, validation_loss_tracker, label="Validation loss")
plt.grid(True)
plt.show()


size of validation dataset 300


  0%|          | 0/178 [00:00<?, ?it/s]

Epoch 0 Training Loss 5.186112305051953. Validation Loss 5.051032962799073 Accuracy 1.6666666666666667


  0%|          | 0/178 [00:00<?, ?it/s]

Epoch 1 Training Loss 4.807513575491924. Validation Loss 4.492236652374268 Accuracy 3.0


  0%|          | 0/178 [00:00<?, ?it/s]

Epoch 2 Training Loss 4.448487560163684. Validation Loss 4.363850905100505 Accuracy 3.6666666666666665


  0%|          | 0/178 [00:00<?, ?it/s]

Epoch 3 Training Loss 4.180001729992678. Validation Loss 4.067730433146159 Accuracy 8.333333333333334


  0%|          | 0/178 [00:00<?, ?it/s]

Epoch 4 Training Loss 3.915256171633496. Validation Loss 3.737173341115316 Accuracy 12.333333333333334


  0%|          | 0/178 [00:00<?, ?it/s]

Epoch 5 Training Loss 3.6691769616495824. Validation Loss 3.580542573928833 Accuracy 13.0


  0%|          | 0/178 [00:00<?, ?it/s]

Epoch 6 Training Loss 3.4050442548563575. Validation Loss 3.6652168305714925 Accuracy 15.333333333333334


  0%|          | 0/178 [00:00<?, ?it/s]

Epoch 7 Training Loss 3.1762597550833314. Validation Loss 3.325432119369507 Accuracy 19.333333333333332


  0%|          | 0/178 [00:00<?, ?it/s]

Epoch 8 Training Loss 2.9646182162409125. Validation Loss 3.107712694803874 Accuracy 22.333333333333332


  0%|          | 0/178 [00:00<?, ?it/s]

Epoch 9 Training Loss 2.769148322963949. Validation Loss 3.239995126724243 Accuracy 22.0


  0%|          | 0/178 [00:00<?, ?it/s]

Epoch 10 Training Loss 2.536209732262679. Validation Loss 3.138105853398641 Accuracy 25.0


  0%|          | 0/178 [00:00<?, ?it/s]

Epoch 11 Training Loss 2.3203180322824632. Validation Loss 3.1017212931315106 Accuracy 26.666666666666668


  0%|          | 0/178 [00:00<?, ?it/s]

Epoch 12 Training Loss 2.096459613836226. Validation Loss 3.118009703954061 Accuracy 28.0


  0%|          | 0/178 [00:00<?, ?it/s]

Epoch 13 Training Loss 1.9186658805656902. Validation Loss 2.95381454149882 Accuracy 26.333333333333332


  0%|          | 0/178 [00:00<?, ?it/s]

Epoch 14 Training Loss 1.7164296691862373. Validation Loss 3.5202336343129477 Accuracy 23.333333333333332


  0%|          | 0/178 [00:00<?, ?it/s]

Epoch 15 Training Loss 1.495942937017199. Validation Loss 3.083222942352295 Accuracy 28.0


  0%|          | 0/178 [00:00<?, ?it/s]

Epoch 16 Training Loss 1.3324064606317185. Validation Loss 3.1985652923583983 Accuracy 29.666666666666668


  0%|          | 0/178 [00:00<?, ?it/s]

Epoch 17 Training Loss 1.1647544509074041. Validation Loss 3.4826162560780842 Accuracy 26.666666666666668


  0%|          | 0/178 [00:00<?, ?it/s]

Epoch 18 Training Loss 0.9832948446483163. Validation Loss 3.3762538274129232 Accuracy 26.333333333333332


  0%|          | 0/178 [00:00<?, ?it/s]

Epoch 19 Training Loss 0.8310496207074111. Validation Loss 3.4600459639231365 Accuracy 26.666666666666668


  0%|          | 0/178 [00:00<?, ?it/s]

Epoch 20 Training Loss 0.7280132149578104. Validation Loss 3.7588232549031577 Accuracy 25.333333333333332


  0%|          | 0/178 [00:00<?, ?it/s]

KeyboardInterrupt: 