In [36]:
import torch
import os
from torchvision import datasets, transforms
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.0,), (1,))
])
data="./Fashion Images"
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
trainset = datasets.ImageFolder(root=data+"/train", transform=transform)
testset = datasets.ImageFolder(root=data+"/test", transform=transform)
validset = datasets.ImageFolder(root=data+"/valid", transform=transform)

trainloader = torch.utils.data.DataLoader(dataset=trainset, batch_size=64, shuffle=True)
testloader = torch.utils.data.DataLoader(dataset=testset, batch_size=64, shuffle=False)
validloader = torch.utils.data.DataLoader(dataset=validset, batch_size=64, shuffle=False)

In [37]:
print(trainset.class_to_idx)

{'Female': 0, 'Male': 1}


In [38]:
for images, labels in trainloader:
    print(images.shape)
    print(labels)
    break


torch.Size([64, 3, 640, 640])
tensor([1, 0, 0, 1, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1,
        0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0])


In [39]:
print(len(trainset),len(testset),len(validset))
print(len(trainloader),len(testloader),len(validloader))

850 50 100
14 1 2


In [40]:
class MLP(torch.nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.fc1 = torch.nn.Linear(3*640*640, 16)
        self.fc2 = torch.nn.Linear(16, 16)
        self.fc3 = torch.nn.Linear(16, 2)

    def forward(self, x):
        x = x.view(-1, 3*640*640)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x
model = MLP()

In [41]:
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [42]:
checkpoint_dir = 'checkpoints_mlp_fashion_images'
os.makedirs(checkpoint_dir, exist_ok=True)

def find_last_checkpoint():
    checkpoints = [f for f in os.listdir(checkpoint_dir) if f.startswith('model_epoch_')]
    if not checkpoints:
        return None, 0
    checkpoints.sort(key=lambda x: int(x.split('_')[-1].split('.')[0]), reverse=True)
    last_checkpoint = checkpoints[0]
    last_epoch = int(last_checkpoint.split('_')[-1].split('.')[0])
    return os.path.join(checkpoint_dir, last_checkpoint), last_epoch

In [43]:
def evaluate_model(testloader):
    checkpoint_path, start_epoch = find_last_checkpoint()
    if checkpoint_path:
        print(f"Loading latest checkpoint from {checkpoint_path}")
        model.load_state_dict(torch.load(checkpoint_path))
    else:
        print("No checkpoint found. Evaluating with the current model state.")
        return "No checkpoints"
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in testloader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = 100 * correct / total
    print(f"Accuracy: {accuracy}%")
    return accuracy

In [44]:
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
model.to(device)
def train_model(model, trainloader, criterion, optimizer, num_epochs=10):
    checkpoint_path, start_epoch = find_last_checkpoint()
    if checkpoint_path:
        print(f"Resuming training from {checkpoint_path}")
        model.load_state_dict(torch.load(checkpoint_path))
    else:
        print("No checkpoint found. Starting from scratch.")
        start_epoch = 0
    for epoch in range(start_epoch, num_epochs):
        running_loss = 0.0
        for inputs, labels in trainloader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        checkpoint_path = os.path.join(checkpoint_dir, f'model_epoch_{epoch + 1}.pth')
        torch.save(model.state_dict(), checkpoint_path)
        print(f"Model saved at {checkpoint_path}")
        val_accuracy = evaluate_model(validloader)
        if not isinstance(val_accuracy, float):
            val_accuracy = float(val_accuracy)
        if not isinstance(running_loss, float):
            running_loss = float(running_loss)
        print(f"Epoch {epoch+1}, Validation Accuracy: {val_accuracy:.2f}%, Loss: {running_loss / len(trainloader):.4f}")
train_model(model, trainloader, criterion, optimizer, num_epochs=30)

Resuming training from checkpoints_mlp_fashion_images/model_epoch_15.pth


  model.load_state_dict(torch.load(checkpoint_path))


Model saved at checkpoints_mlp_fashion_images/model_epoch_16.pth
Loading latest checkpoint from checkpoints_mlp_fashion_images/model_epoch_16.pth


  model.load_state_dict(torch.load(checkpoint_path))


Accuracy: 23.0%
Epoch 16, Validation Accuracy: 23.00%, Loss: 61.7115
Model saved at checkpoints_mlp_fashion_images/model_epoch_17.pth
Loading latest checkpoint from checkpoints_mlp_fashion_images/model_epoch_17.pth
Accuracy: 80.0%
Epoch 17, Validation Accuracy: 80.00%, Loss: 10.0451
Model saved at checkpoints_mlp_fashion_images/model_epoch_18.pth
Loading latest checkpoint from checkpoints_mlp_fashion_images/model_epoch_18.pth
Accuracy: 58.0%
Epoch 18, Validation Accuracy: 58.00%, Loss: 4.1673
Model saved at checkpoints_mlp_fashion_images/model_epoch_19.pth
Loading latest checkpoint from checkpoints_mlp_fashion_images/model_epoch_19.pth
Accuracy: 80.0%
Epoch 19, Validation Accuracy: 80.00%, Loss: 1.2012
Model saved at checkpoints_mlp_fashion_images/model_epoch_20.pth
Loading latest checkpoint from checkpoints_mlp_fashion_images/model_epoch_20.pth
Accuracy: 80.0%
Epoch 20, Validation Accuracy: 80.00%, Loss: 0.4700
Model saved at checkpoints_mlp_fashion_images/model_epoch_21.pth
Loading l

In [45]:
evaluate_model(testloader)
print(model)

Loading latest checkpoint from checkpoints_mlp_fashion_images/model_epoch_30.pth


  model.load_state_dict(torch.load(checkpoint_path))


Accuracy: 88.0%
MLP(
  (fc1): Linear(in_features=1228800, out_features=16, bias=True)
  (fc2): Linear(in_features=16, out_features=16, bias=True)
  (fc3): Linear(in_features=16, out_features=2, bias=True)
)
