In [None]:
import torch
from torchvision import datasets
from torchvision import transforms
from torch.utils.data import DataLoader, ConcatDataset
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import matplotlib.pyplot as plt

Data Augmentation

In [None]:
mean = [0.4319, 0.3926, 0.3274]
std = [0.3181, 0.2624, 0.3108]

brightTransform= transforms.Compose([
    transforms.CenterCrop(224),
    transforms.ColorJitter(0.15, 0, 0, 0),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(torch.Tensor(mean), torch.Tensor(std)),
    ])

contrastTransform = transforms.Compose([
    transforms.CenterCrop(224),
    transforms.ColorJitter(0, 0.5, 0, 0),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(torch.Tensor(mean), torch.Tensor(std)),
    ])

saturateTransform = transforms.Compose([
    transforms.CenterCrop(224),
    transforms.ColorJitter(0, 0, 0.5, 0),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(torch.Tensor(mean), torch.Tensor(std)),
    ])

bcTransform = transforms.Compose([
    transforms.CenterCrop(224),
    transforms.ColorJitter(0.15, 0.5, 0, 0),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(torch.Tensor(mean), torch.Tensor(std)),
    ])

csTransform = transforms.Compose([
    transforms.CenterCrop(224),
    transforms.ColorJitter(0, 0.5, 0.5, 0),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(torch.Tensor(mean), torch.Tensor(std)),
    ])

bsTransform = transforms.Compose([
    transforms.CenterCrop(224),
    transforms.ColorJitter(0.15, 0, 0.5, 0),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(torch.Tensor(mean), torch.Tensor(std)),
    ])

cropTransform = transforms.Compose([
    transforms.RandomCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(torch.Tensor(mean), torch.Tensor(std)),
    ])

comboTransform1 = transforms.Compose([
    transforms.RandomCrop(224),
    transforms.ColorJitter(0.15, 0.5, 0.5, 0),
    transforms.RandomRotation(25),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(torch.Tensor(mean), torch.Tensor(std)),
    ])

comboTransform2 = transforms.Compose([
    transforms.RandomCrop(224),
    transforms.ColorJitter(0.15, 0, 0, 0),
    transforms.RandomRotation(25),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(torch.Tensor(mean), torch.Tensor(std)),
    ])

comboTransform3 = transforms.Compose([
    transforms.RandomCrop(224),
    transforms.ColorJitter(0, 0.5, 0, 0),
    transforms.RandomRotation(25),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(torch.Tensor(mean), torch.Tensor(std)),
    ])

comboTransform4 = transforms.Compose([
    transforms.RandomCrop(224),
    transforms.ColorJitter(0, 0, 0.5, 0),
    transforms.RandomRotation(25),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(torch.Tensor(mean), torch.Tensor(std)),
    ])

comboTransform5 = transforms.Compose([
    transforms.RandomCrop(224),
    transforms.ColorJitter(0.15, 0, 0.5, 0),
    transforms.RandomRotation(25),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(torch.Tensor(mean), torch.Tensor(std)),
    ])

comboTransform6 = transforms.Compose([
    transforms.RandomCrop(224),
    transforms.ColorJitter(0.15, 0.5, 0, 0),
    transforms.RandomRotation(25),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(torch.Tensor(mean), torch.Tensor(std)),
    ])

comboTransform7 = transforms.Compose([
    transforms.RandomCrop(224),
    transforms.ColorJitter(0, 0.5, 0.5, 0),
    transforms.RandomRotation(25),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(torch.Tensor(mean), torch.Tensor(std)),
    ])





In [None]:
brightData1 = datasets.Flowers102(root='data', split = 'train', download=True, transform=brightTransform)

contrastData1 = datasets.Flowers102(root='data', split = 'train', download=True, transform=contrastTransform)

saturateData1 = datasets.Flowers102(root='data', split = 'train', download=True, transform=saturateTransform)

bcData1 = datasets.Flowers102(root='data', split = 'train', download=True, transform=bcTransform)
bcData2 = datasets.Flowers102(root='data', split = 'train', download=True, transform=bcTransform)

csData1 = datasets.Flowers102(root='data', split = 'train', download=True, transform=csTransform)
csData2 = datasets.Flowers102(root='data', split = 'train', download=True, transform=csTransform)

bsData1 = datasets.Flowers102(root='data', split = 'train', download=True, transform=bsTransform)
bsData2 = datasets.Flowers102(root='data', split = 'train', download=True, transform=bsTransform)

cropData1 = datasets.Flowers102(root='data', split = 'train', download=True, transform=cropTransform)
cropData2 = datasets.Flowers102(root='data', split = 'train', download=True, transform=cropTransform)

combo1Data1 = datasets.Flowers102(root='data', split = 'train', download=True, transform=comboTransform1)
combo1Data2 = datasets.Flowers102(root='data', split = 'train', download=True, transform=comboTransform1)
combo1Data3 = datasets.Flowers102(root='data', split = 'train', download=True, transform=comboTransform1)

combo2Data1 = datasets.Flowers102(root='data', split = 'train', download=True, transform=comboTransform2)
combo2Data2 = datasets.Flowers102(root='data', split = 'train', download=True, transform=comboTransform2)
combo2Data3 = datasets.Flowers102(root='data', split = 'train', download=True, transform=comboTransform2)

combo3Data1 = datasets.Flowers102(root='data', split = 'train', download=True, transform=comboTransform3)
combo3Data2 = datasets.Flowers102(root='data', split = 'train', download=True, transform=comboTransform3)
combo3Data3 = datasets.Flowers102(root='data', split = 'train', download=True, transform=comboTransform3)

combo4Data1 = datasets.Flowers102(root='data', split = 'train', download=True, transform=comboTransform4)
combo4Data2 = datasets.Flowers102(root='data', split = 'train', download=True, transform=comboTransform4)
combo4Data3 = datasets.Flowers102(root='data', split = 'train', download=True, transform=comboTransform4)

combo5Data1 = datasets.Flowers102(root='data', split = 'train', download=True, transform=comboTransform5)
combo5Data2 = datasets.Flowers102(root='data', split = 'train', download=True, transform=comboTransform5)
combo5Data3 = datasets.Flowers102(root='data', split = 'train', download=True, transform=comboTransform5)

combo6Data1 = datasets.Flowers102(root='data', split = 'train', download=True, transform=comboTransform6)
combo6Data2 = datasets.Flowers102(root='data', split = 'train', download=True, transform=comboTransform6)
combo6Data3 = datasets.Flowers102(root='data', split = 'train', download=True, transform=comboTransform6)

combo7Data1 = datasets.Flowers102(root='data', split = 'train', download=True, transform=comboTransform7)
combo7Data2 = datasets.Flowers102(root='data', split = 'train', download=True, transform=comboTransform7)
combo7Data3 = datasets.Flowers102(root='data', split = 'train', download=True, transform=comboTransform7)



In [None]:
trainData = ConcatDataset([brightData1,
                          contrastData1,
                          saturateData1,
                          
                          bcData1,
                          bcData2,
                          
                          csData1,
                          csData2, 
                          
                          bsData1, 
                          bsData2, 
                          
                          cropData1, 
                          cropData2,
                          
                          combo1Data1,
                          combo1Data2,
                          combo1Data3,

                          combo2Data1,
                          combo2Data2,
                          combo2Data3,

                          combo3Data1,
                          combo3Data2,
                          combo3Data3,
                          
                          combo4Data1,
                          combo4Data2,
                          combo4Data3,

                          combo5Data1,
                          combo5Data2,
                          combo5Data3,

                          combo6Data1,
                          combo6Data2,
                          combo6Data3,

                          combo7Data1,
                          combo7Data2,
                          combo7Data3])

In [None]:
valTransform = transforms.Compose([
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(torch.Tensor(mean), torch.Tensor(std)),
    ])

testTransform = transforms.Compose([
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(torch.Tensor(mean), torch.Tensor(std)),
    ])

valData = datasets.Flowers102(root='data', split = 'val', download=True, transform=valTransform)
testData = datasets.Flowers102(root='data', split = 'test', download=True, transform=testTransform)

In [None]:
figure = plt.figure(figsize=(10, 10))
cols, rows = 3, 3
for i in range(1, cols * rows + 1):
    sample_idx = torch.randint(len(trainData), size=(1,)).item()
    img, label = trainData[sample_idx]
    figure.add_subplot(rows, cols, i)
    plt.axis("off")
    plt.imshow(img.permute(1,2,0), cmap="gray")
plt.show()

In [None]:
class MyCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.convStack = nn.Sequential(
            nn.Conv2d(3, 32, 3, 1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),

            nn.Conv2d(32, 64, 3, 1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),

            nn.Conv2d(64, 128, 3, 1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),

            nn.Conv2d(128, 256, 3, 1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),

            nn.Conv2d(256, 256, 3, 1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),

            nn.Conv2d(256, 512, 3, 1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),


            
        )

        self.Flatten = nn.Flatten()
        self.LinearStack = nn.Sequential(
            nn.Dropout(p=0.5),
            nn.Linear(512, 1024),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(1024, 102), 
        )

    def forward(self, x):
        x = self.convStack(x)
        x = self.Flatten(x)
        x = self.LinearStack(x)
        return x
    
MyCNN()
    


In [None]:
# Hyperparameters
BATCH_SIZE = 32
NUM_EPOCHS = 200
LR = 0.0001
GAMMA = 0.9

# Device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

# Dataloaders 
trainLoader = DataLoader(trainData, batch_size = 32, shuffle = True, num_workers = 4)
valLoader = DataLoader(valData, batch_size = 32, shuffle = False, num_workers = 4)
testLoader = DataLoader(testData, batch_size = 32, shuffle = False, num_workers = 4)

# Model
model = MyCNN()
model = model.to(device)

# Optimizer
optimizer = optim.Adam(model.parameters(), lr=LR)
scheduler = lr_scheduler.ExponentialLR(optimizer, gamma=GAMMA, last_epoch=-1)

# Loss Function
loss = nn.CrossEntropyLoss()

In [None]:
losses = []
accuracy = []
def trainModel(dataloader, model, lossFunction, optimizer):
    model.train()
    currentLoss = 0.0
    correct = 0
    total = 0
    epochLoss = 0

    for X, y in dataloader:
        X, y = X.to(device), y.to(device)
        total += y.size(0)

        optimizer.zero_grad()
        pred = model(X)
        loss = lossFunction(pred, y)
        loss.backward()
        optimizer.step()

        currentLoss += loss.item()
        epochLoss += lossFunction(pred, y).item()
        correct += (pred.argmax(1) == y).type(torch.float).sum().item()

    epochLoss = epochLoss/total
    correct = correct/total

    losses.append(epochLoss)
    accuracy.append(correct * 100)
    print(f'Training: Accuracy {correct * 100:>0.1f}%, Loss: {currentLoss / len(dataloader):.5f}, Epoch Loss: {epochLoss:.5f}')

In [None]:
def validateModel(dataloader, model, lossFunction):
    model.eval()
    total = 0
    correct = 0
    epochLoss = 0

    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            total += y.size(0)
            epochLoss += lossFunction(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()

    epochLoss = epochLoss/total
    correct = (correct/total) * 100

    print(f"Validation: Accuracy: {(correct):>0.1f}%, Avg loss: {epochLoss:>8f} \n")
    return epochLoss, correct

In [None]:
bestValAccuracy = 0.0
bestValLoss = float('inf')
bestEpoch = 0

valLosses = []
valAccuracies = []

for epoch in range(NUM_EPOCHS):
    print(f'Epoch {epoch+1}:')
    trainModel(trainLoader, model, loss, optimizer)
    valLoss, valAccuracy = validateModel(valLoader, model, loss)
    valLosses.append(valLoss)
    valAccuracies.append(valAccuracy)
    scheduler.step()

    if (valLoss < bestValLoss):
        bestValLoss = valLoss
        bestValAccuracy = valAccuracy
        bestEpoch = epoch+1

print(f'Best Accuracy: {bestValAccuracy}. Best Loss: {bestValLoss} Best Epoch: {bestEpoch}')

Changed transformation to remove duplicates 
* lr 0.0001- 58%