<a href="https://colab.research.google.com/github/wilberquito/AMLProject/blob/main/AMLProject.wil.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Deep Learning Project: Image Classification
## Advanced Machine Learning


> Wilber E. Bermeo Quito 
>
> Judit Quintana Massana
>
> April 2023

In [1]:
import zipfile
from pathlib import Path
import torch
import matplotlib.pyplot as plt

In [2]:
try:
  import google.colab
  IN_COLAB = True
except:
  IN_COLAB = False

In [3]:
if IN_COLAB:
    ! pip install torchvision
    ! pip install torchinfo

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [4]:
if IN_COLAB:
    from google.colab import drive
    drive.mount('/content/drive')

    !rm -rf data
    data_path = Path('/content/drive/MyDrive/AML/dataset_CIFAR10.zip')
    with zipfile.ZipFile(data_path,"r") as zip_ref:
        zip_ref.extractall("data")

    !rm -rf modular
    data_path = Path('/content/drive/MyDrive/AML/modular.zip')
    with zipfile.ZipFile(data_path,"r") as zip_ref:
        zip_ref.extractall(".")

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


## AML Resnet50 v0

In [7]:
import modular.models as models

amlresnet50_v0 = models.AMLResnet50_V0(out_dim=10)

from prettytable import PrettyTable

def count_parameters(model):
    table = PrettyTable(["Modules", "Parameters"])
    total_params = 0
    for name, parameter in model.named_parameters():
        if not parameter.requires_grad: 
          continue
        params = parameter.numel()
        table.add_row([name, params])
        total_params+=params
    print(table)
    print(f"Total Trainable Params: {total_params}")
    return total_params
    
count_parameters(amlresnet50_v0)

+-------------+------------+
|   Modules   | Parameters |
+-------------+------------+
| fc.0.weight |   262144   |
|  fc.0.bias  |    128     |
| fc.2.weight |    1280    |
|  fc.2.bias  |     10     |
+-------------+------------+
Total Trainable Params: 263562


263562

In [11]:
import torch
import modular.datasets as datasets
import modular.models as models 
from pathlib import Path
from modular.engine import train
import torchvision.transforms as transforms
from modular.utils import set_seeds

set_seeds(seed=42)

# Model
amlresnet50_v0 = models.AMLResnet50_V0(out_dim=10)

train_transforms, validate_transforms = amlresnet50_v0.transforms, amlresnet50_v0.transforms

augmentation = transforms.Compose([
    transforms.RandomRotation(10),
    transforms.ColorJitter(),
    transforms.RandomGrayscale(),
    transforms.RandomHorizontalFlip(),
])

train_transforms = transforms.Compose([augmentation, train_transforms])

# Optimizer
optimizer = torch.optim.Adam(params=amlresnet50_v0.fc.parameters())

# Criterion
criterion = torch.nn.CrossEntropyLoss()

# Number of epochs to train the model
epochs = 20

# Default device to train model
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device = torch.device(device)

# Where the model is saved
save_as = Path('trained/amlresnet50_v0.pth')

# Mini batch
batch_size=800

train_dataloader = datasets.get_dataloader(folder_root='data/train',
                                           transformer=train_transforms,
                                           batch_size=batch_size,
                                           suffle=True)
validate_dataloader = datasets.get_dataloader(folder_root='data/validation',
                                           transformer=validate_transforms,
                                           batch_size=batch_size,
                                           suffle=False)

In [12]:
results = train(model=amlresnet50_v0,
                train_dataloader=train_dataloader,
                test_dataloader=validate_dataloader,
                optimizer=optimizer,
                criterion=criterion,
                epochs=epochs,
                device=device,
                save_as=save_as)

  0%|          | 0/20 [00:00<?, ?it/s]

0it [00:00, ?it/s]

KeyboardInterrupt: ignored

In [None]:
results