### Imports

In [5]:
import numpy as np
import torch
import torch.nn as nn
import torchvision.models as models
from torchvision import transforms
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torch.utils.data import Subset
import os
import json
import matplotlib.pyplot as plt

### EarlyStopping

In [None]:
class EarlyStopping:
    def __init__(self, patience=5, verbose=False, delta=0):
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta

    def __call__(self, val_loss, model, path):
        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model, path)
        elif score < self.best_score + self.delta:
            self.counter += 1
            if self.verbose:
                print(f"EarlyStopping counter: {self.counter} out of {self.patience}")
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model, path)
            self.counter = 0

    def save_checkpoint(self, val_loss, model, path):
        if self.verbose:
            print(
                f"Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ..."
            )
        torch.save(model.state_dict(), path)
        self.val_loss_min = val_loss

### Model Preparation

In [2]:
resnet50 = models.resnet50(pretrained=True)

for param in resnet50.parameters():
    param.requires_grad = False

num_features = resnet50.fc.in_features
resnet50.fc = nn.Linear(num_features, 2)

Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /home/sit3kk/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
100%|██████████| 97.8M/97.8M [00:02<00:00, 34.6MB/s]


### Transformation and Data Preparation

In [5]:
transform = transforms.Compose(
    [
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)


train_dataset_full = ImageFolder(root="../data/processed/train", transform=transform)
val_dataset_full = ImageFolder(root="../data/processed/val", transform=transform)


def limit_dataset(dataset, limit=0.1):
    indices = np.random.choice(len(dataset), int(len(dataset) * limit), replace=False)
    return Subset(dataset, indices)


train_dataset = limit_dataset(train_dataset_full, limit=0.1)
val_dataset = limit_dataset(val_dataset_full, limit=0.1)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

### Preparation of the device, loss criterion and optimizer

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
resnet50 = resnet50.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(resnet50.fc.parameters(), lr=0.001)

In [4]:
checkpoint_dir = "checkpoints/ResNet/"
saved_model_dir = "saved_models/ResNet/"

if not os.path.exists(checkpoint_dir):
    os.makedirs(checkpoint_dir)
if not os.path.exists(saved_model_dir):
    os.makedirs(saved_model_dir)

checkpoint_path = os.path.join(checkpoint_dir, "checkpoint.pth")
saved_model_path = os.path.join(saved_model_dir, "model.pth")

In [None]:
def train_model(
    model,
    dataloaders,
    criterion,
    optimizer,
    num_epochs=25,
    patience=5,
    checkpoint_path="checkpoint.pth",
):
    early_stopping = EarlyStopping(patience=patience, verbose=True)
    history = {"train_loss": [], "val_loss": [], "val_acc": []}

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        running_corrects = 0

        for inputs, labels in dataloaders["train"]:
            inputs = inputs.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            _, preds = torch.max(outputs, 1)
            running_loss += loss.item() * inputs.size(0)
            running_corrects += torch.sum(preds == labels.data)

        epoch_loss = running_loss / len(dataloaders["train"].dataset)
        epoch_acc = running_corrects.double() / len(dataloaders["train"].dataset)

        print(
            f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}, Acc: {epoch_acc:.4f}"
        )
        history["train_loss"].append(epoch_loss)

        model.eval()
        val_running_loss = 0.0
        val_running_corrects = 0

        with torch.no_grad():
            for inputs, labels in dataloaders["val"]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                outputs = model(inputs)
                loss = criterion(outputs, labels)

                _, preds = torch.max(outputs, 1)
                val_running_loss += loss.item() * inputs.size(0)
                val_running_corrects += torch.sum(preds == labels.data)

        val_epoch_loss = val_running_loss / len(dataloaders["val"].dataset)
        val_epoch_acc = val_running_corrects.double() / len(dataloaders["val"].dataset)

        print(
            f"Validation Loss: {val_epoch_loss:.4f}, Validation Acc: {val_epoch_acc:.4f}"
        )
        history["val_loss"].append(val_epoch_loss)
        history["val_acc"].append(val_epoch_acc)

       
        early_stopping(val_epoch_loss, model, checkpoint_path)
        if early_stopping.early_stop:
            print("Early stopping")
            break

    
    model.load_state_dict(torch.load(checkpoint_path))
    return model, history


dataloaders = {"train": train_loader, "val": val_loader}

### Model training

In [None]:
model, history = train_model(
    resnet50,
    dataloaders,
    criterion,
    optimizer,
    num_epochs=10,
    patience=5,
    checkpoint_path=checkpoint_path,
)


torch.save(model.state_dict(), saved_model_path)

In [None]:
def load_history(history_path):
    import json

    with open(history_path, "r") as f:
        history = json.load(f)
    return history


def save_history(history, history_path):
   

    with open(history_path, "w") as f:
        json.dump(history, f)


save_history(history, "training_history.json")

history = load_history("training_history.json")


plt.figure()
plt.plot(history["train_loss"], label="Train Loss")
plt.plot(history["val_loss"], label="Validation Loss")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.legend()
plt.show()

plt.figure()
plt.plot(history["val_acc"], label="Validation Accuracy")
plt.xlabel("Epochs")
plt.ylabel("Accuracy")
plt.legend()
plt.show()