In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
import numpy as np
import os

def get_data_loader(data_dir, batch_size=256, shuffle=True, train_split=0.8):
    """
    Define the way we compose the batch dataset including the augmentation for increasing the number of data
    and return the augmented batch-dataset
    :param data_dir: root directory where the dataset is
    :param batch_size: size of the batch
    :param train: true if current phase is training, else false
    :param train_split: percentage of data to be used for training
    :return: augmented batch dataset
    """

    # Define transformations
    transform = transforms.Compose([
        transforms.Resize([224, 224]),  # Resizing the image as the VGG only takes 224 x 224 as input size
        transforms.RandomHorizontalFlip(),  # Flip the data horizontally
        # TODO: Add random crop if needed
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
    ])

    # Load dataset
    full_dataset = torchvision.datasets.ImageFolder(root=data_dir, transform=transform)

    # Calculate sizes of train and test sets
    train_size = int(train_split * len(full_dataset))
    test_size = len(full_dataset) - train_size

    # Split dataset into train and test sets
    train_dataset, test_dataset = torch.utils.data.random_split(full_dataset, [train_size, test_size])


    # Create data loader
    train_dataset_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=4)
    test_dataset_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=4)

    return train_dataset_loader, test_dataset_loader

In [None]:
train_data_loader,test_data_loader = get_data_loader(data_dir="Data/", shuffle=True)

len(train_data_loader)

In [None]:
for i in range(1):
    batch_x, batch_y = next(iter(train_data_loader))
    print(np.shape(batch_x), batch_y)

# Model Training

In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
import numpy as np
import os

def get_data_loader(data_dir, batch_size=256, shuffle=True, train_split=0.8):
    """
    Define the way we compose the batch dataset including the augmentation for increasing the number of data
    and return the augmented batch-dataset
    :param data_dir: root directory where the dataset is
    :param batch_size: size of the batch
    :param train: true if current phase is training, else false
    :param train_split: percentage of data to be used for training
    :return: augmented batch dataset
    """

    # Define transformations
    transform = transforms.Compose([
        transforms.Resize([224, 224]),  # Resizing the image as the VGG only takes 224 x 224 as input size
        transforms.RandomHorizontalFlip(),  # Flip the data horizontally
        # TODO: Add random crop if needed
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
    ])

    # Load dataset
    full_dataset = torchvision.datasets.ImageFolder(root=data_dir, transform=transform)

    # Calculate sizes of train and test sets
    train_size = int(train_split * len(full_dataset))
    test_size = len(full_dataset) - train_size

    # Split dataset into train and test sets
    train_dataset, test_dataset = torch.utils.data.random_split(full_dataset, [train_size, test_size])


    # Create data loader
    train_dataset_loader = torch.utils.data.DataLoader(dataset=train_dataset, 
                                                       batch_size=batch_size, 
                                                       shuffle=shuffle, 
                                                    #    num_workers=4,
                                                    #    pin_memory=True,
                                                       )
    test_dataset_loader = torch.utils.data.DataLoader(dataset=test_dataset, 
                                                      batch_size=batch_size, 
                                                      shuffle=shuffle, 
                                                    #   num_workers=4,
                                                    #   pin_memory=True,
                                                      )

    return train_dataset_loader, test_dataset_loader

In [4]:
from dataloader import get_data_loader
import torchvision

train_split = 0.75
# len(full_dataset) = 
full_dataset = torchvision.datasets.ImageFolder(root="Data/")
train_size = int(train_split * len(full_dataset))
test_size = int((len(full_dataset) - train_size)/2)
val_size = len(full_dataset) - train_size - test_size

train_size, test_size, val_size

(4353, 725, 726)

In [2]:
train_dataloader, test_dataloader, val_dataloader = get_data_loader(batch_size=128,data_dir="Data/", shuffle=True)
for a in train_dataloader:
    print(a[0].shape, a[1].shape)

5804
torch.Size([128, 3, 224, 224]) torch.Size([128])
torch.Size([128, 3, 224, 224]) torch.Size([128])
torch.Size([128, 3, 224, 224]) torch.Size([128])
torch.Size([128, 3, 224, 224]) torch.Size([128])
torch.Size([128, 3, 224, 224]) torch.Size([128])
torch.Size([128, 3, 224, 224]) torch.Size([128])
torch.Size([128, 3, 224, 224]) torch.Size([128])
torch.Size([128, 3, 224, 224]) torch.Size([128])
torch.Size([128, 3, 224, 224]) torch.Size([128])
torch.Size([128, 3, 224, 224]) torch.Size([128])
torch.Size([128, 3, 224, 224]) torch.Size([128])
torch.Size([128, 3, 224, 224]) torch.Size([128])
torch.Size([128, 3, 224, 224]) torch.Size([128])
torch.Size([128, 3, 224, 224]) torch.Size([128])
torch.Size([128, 3, 224, 224]) torch.Size([128])
torch.Size([128, 3, 224, 224]) torch.Size([128])
torch.Size([128, 3, 224, 224]) torch.Size([128])
torch.Size([128, 3, 224, 224]) torch.Size([128])
torch.Size([128, 3, 224, 224]) torch.Size([128])
torch.Size([128, 3, 224, 224]) torch.Size([128])
torch.Size([128

In [None]:
import ssl

ssl._create_default_https_context = ssl._create_stdlib_context

# import packages 
import os 
from tqdm import tqdm

import matplotlib.pyplot as plt
import numpy as np

import torch 
import torch.nn as nn

#import your model here
from log import create_logger
from dataloader import get_data_loader
from models.resnet import resnet18

# Add your models here
models = {'resnet18': resnet18,}

# RUN DETAILS
run_name = "jly_0131_resnet_lr1e-3"
model_base = 'resnet18'
num_epochs = 20
lr = 1e-3
random_seed = 42
save_chks = range(num_epochs) # iterable of epochs for which to save the model

device = 'cuda' if torch.cuda.is_available() else ('mps' if torch.backends.mps.is_available() else 'cpu')
if device == 'mps':
    torch.mps.empty_cache()

# set up run dir 
run_dir = os.path.join('saved_models', run_name)
os.makedirs(run_dir, exist_ok = True)
log, logclose = create_logger(log_filename=os.path.join(run_dir, 'train.log'), display = False)
log(f'using device: {device}')
log(f'saving models to: {run_dir}')
log(f'using base model: {model_base}')
log(f'learning rate: {lr}')
log(f'random seed: {random_seed}')

# seed randoms and make deterministic
torch.manual_seed(random_seed)
torch.cuda.manual_seed(random_seed)
np.random.seed(random_seed)
# random.seed(random_seed)
torch.backends.cudnn.enabled=False
torch.backends.cudnn.deterministic=True

# dataloader
train_dataloader, test,_dataloader, val_dataloader = get_data_loader(data_dir="Data/", shuffle=True)

# define model 
model = models['resnet18']()
model.to(device)

# define optimizer and criterion
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()

# training loop
train_loss = []
val_loss = []
test_loss = []
train_metrics = []
val_metrics = []
for epoch in range(num_epochs):
    print(f"epoch: {epoch}")
    log(f'epoch {epoch}')
    #training
    model.train()
    batch_loss = []
    batch_metric = []
    for i, (_data, _target) in tqdm(enumerate(train_dataloader)): 
        data = _data.to(device)
        target = _target.to(device)
        optimizer.zero_grad()
        pred = model(data)
        loss = criterion(pred, target)
        optimizer.step()
        batch_loss.append(loss.item())
        batch_metric.append(sum(torch.argmax(pred, dim=1)==target).item()/len(target))
    train_loss.append(sum(np.array(batch_loss)/len(train_dataloader)))
    log(f'\ttrain loss: {train_loss[-1]}')
    train_metrics.append(np.mean(batch_metric)) #TODO: add metrics
    del data 
    del target
    del pred
    del loss

    # validation
    with torch.no_grad():
        model.eval()
        batch_loss = []
        for i, (_data, _target) in tqdm(enumerate(val_dataloader)): 
            data = _data.to(device)
            target = _target.to(device)
            pred = model(data)
            loss = criterion(pred, target)
            batch_loss.append(loss.item())
            batch_metric.append(sum(torch.argmax(pred, dim=1)==target).item()/len(target))
        val_loss.append(sum(np.array(batch_loss)/len(val_dataloader)))
        log(f'\tval loss: {val_loss[-1]}')
        val_metrics.append(np.mean(batch_metric)) #TODO: add metrics

    if epoch in save_chks: 
        torch.save(model.state_dict(), os.path.join(run_dir, f'{epoch}.chkpt'))

    plt.plot(train_loss, label='train loss')
    plt.plot(val_loss, label='val loss')
    plt.plot(train_metrics, label='train accuracy')
    plt.plot(val_metrics, label='val accuracy')
    plt.xlabel('epoch')
    plt.ylabel('loss, accuracy')
    plt.legend()
    plt.savefig(os.path.join(run_dir, 'loss'))
    plt.close()
    del data 
    del target
    del pred
    del loss

    if device == 'mps':
        torch.mps.empty_cache()


# testing
with torch.no_grad():
    model.eval()

In [None]:
with torch.no_grad():
    model.eval()
    batch_loss = []
    for i, (_data, _target) in tqdm(enumerate(test_dataloader)): 
        data = _data.to(device)
        target = _target.to(device)
        pred = model(data)
        loss = criterion(pred, target)
        batch_loss.append(loss.item())
        batch_metric.append(sum(torch.argmax(pred, dim=1)==target).item()/len(target))
    test_loss.append(sum(np.array(batch_loss)/len(test_dataloader)))
    log(f'\tval loss: {val_loss[-1]}')
    val_metrics.append(np.mean(batch_metric)) #TODO: add metrics