# Imports

In [None]:
import os
import json
import paths
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
from torchvision import datasets
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from model import SimpleCNN
from torchvision.models import alexnet, AlexNet_Weights

## Setting model parameters

In [None]:
parameters = {
    'lr': 0.001,
    'batch_size': 64,
    'epochs': 40,
    'num_workers': 6
}

## Use GPU for training if available

In [None]:
# Check device
device = torch.device("cuda" if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else "cpu")
device

## Normalizing the images

In [None]:
transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=3),  # Convert grayscale to RGB
    transforms.Resize(256),
    transforms.CenterCrop(224), # Resize to fit input size
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Normalization
])

### Loading the datasets from training and testing folders inside  **/model_inputs_outputs/inputs**

In [None]:
# Load the datasets
train_dataset = datasets.ImageFolder(root=paths.TRAIN_DIR, transform=transform)
test_dataset = datasets.ImageFolder(root=paths.TEST_DIR, transform=transform)

# Dataloaders
train_loader = DataLoader(train_dataset, batch_size=parameters['batch_size'], shuffle=True, num_workers=parameters['num_workers'], persistent_workers=True)
test_loader = DataLoader(test_dataset, batch_size=parameters['batch_size'], shuffle=False, num_workers=parameters['num_workers'], persistent_workers=True)

### We are storing the class names. This is important to adjust the last layer in the ResNet model.<br> The layer should have a dimension of (num_features, num_classes)

In [None]:
class_names = sorted(os.listdir(paths.TRAIN_DIR))
class_names

In [None]:
num_classes = len(class_names)
model = alexnet(weights = AlexNet_Weights.IMAGENET1K_V1)

num_ftrs = model.fc.in_features 

model.fc = nn.Linear(num_ftrs, num_classes)
model = model.to(device)

### Since this is a multiclass classification problem, cross entropy loss is a good choice. We choose Adam optimizer, it is one of the most used optimizers in deep learning.<br> Feel free to try other choices like NllLoss (negative log likelihood loss) or SGD optimizer.

In [None]:
# Loss function and optimizer
criterion = nn.CrossEntropyLoss()
# optimizer = optim.Adam(model.parameters(), lr=parameters['lr'])
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

## Training loop

In [None]:
# Training the model
print('Starting training...')
loss_per_epoch = []
num_epochs = parameters['epochs']
progress_bar = tqdm(range(num_epochs), desc='Total Epochs')
for epoch in progress_bar:
    model.train()
    running_loss = 0.0
    for images, labels in tqdm(train_loader, leave=True, desc='Epoch progress'):
        images, labels = images.to(device), labels.to(device)


        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)

        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        

        running_loss += loss.item()
    epoch_loss = running_loss / len(train_loader)
    scheduler.step()
    progress_bar.set_postfix({'Loss': epoch_loss})
    loss_per_epoch.append(epoch_loss)


## Loss plot

In [None]:
plt.figure(figsize=(10, 5))
plt.title('Loss per epoch')
plt.plot(loss_per_epoch, marker='o', color='orange')
plt.show()

## Save the model

In [None]:
torch.save(model, paths.MODEL_PATH)

with open(paths.MODEL_PARAMS_PATH, 'w') as f:
    json.dump(parameters, f)