In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from Data import load_data # Data.py
from cyclegan import *CycleGan* # model.py
from cyclegan import adversarial_loss # model.py
from cyclegan import cycle_loss # model.py
from cyclegan import identity_loss # model.py
import multiprocessing
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

# Load the CIFAR-10 dataset
batch_size = 128
train_loader, validation_loader, test_loader, classes, batch_size = load_data(batch_size)

# Hyperparameters
epochs = 50
learning_rate = 0.001
#device = 'mps:0' if torch.backends.mps.is_available() else 'cpu' # change this part if using 'cuda'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Initialize the model, loss function, and optimizer
model = EfficientNet(version=version, num_classes=10).to(device)
criterion = nn.CrossEntropyLoss() # Loss function
optimizer = optim.Adam(model.parameters(), lr=learning_rate) # Optimizer

# Training
def train(model, trainloader, criterion1, criterion2, criterion3, optimizer, device):
    model.train() # Model in training mode
    running_loss1 = 0.0 #adversariel (main)
    running_loss2 = 0.0 #cycle consistency
    running_loss3 = 0.0 #identity
    true_labels = []
    pred_labels = []
    train_losses1 = []
    train_losses2 = []
    train_losses3 = []


    for batch_idx, (inputs, targets) in enumerate(tqdm(trainloader, desc="Training")):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()  # Clear gradients

        # Forward pass
        outputs = model(inputs)
        loss1 = criterion1(outputs, targets)
        loss2 = criterion2(outputs, targets)
        loss3 = criterion3(outputs, targets)

        # Backward pass
        loss1.backward()
        optimizer.step()

        # Update loss
        running_loss1 += loss1.item()
        running_loss2 += loss2.item()
        running_loss3 += loss3.item()

        # Collect predictions and true labels for accuracy calculation
        _, predicted = outputs.max(1)
        pred_labels.extend(predicted.cpu().numpy())
        true_labels.extend(targets.cpu().numpy())

    avg_loss1 = running_loss1 / len(trainloader)
    avg_loss2 = running_loss2 / len(trainloader)
    avg_loss3 = running_loss3 / len(trainloader)

    train_losses1.append(avg_loss1)
    train_losses2.append(avg_loss2)
    train_losses3.append(avg_loss3)

    return train_losses1, train_losses2, train_losses3

# Validation
def validation(model, testloader, criterion, device):
    model.eval()  # Evaluation mode
    val_loss = 0.0
    true_labels = []
    pred_labels = []

    # Disable gradient calculation for validation
    with torch.no_grad():
        for inputs, targets in tqdm(testloader, desc="Validation"):
            inputs, targets = inputs.to(device), targets.to(device)

            # Forward pass
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            # Update loss
            val_loss += loss.item()

            # Collect predictions and true labels for accuracy calculation
            _, predicted = outputs.max(1)
            pred_labels.extend(predicted.cpu().numpy())
            true_labels.extend(targets.cpu().numpy())

    # Calculate accuracy using scikit-learn's accuracy_score
    accuracy = accuracy_score(true_labels, pred_labels)
    precision = precision_score(true_labels, pred_labels, average="weighted")
    recall = recall_score(true_labels, pred_labels, average="weighted")
    f1 = f1_score(true_labels, pred_labels, average="weighted")

    avg_loss = val_loss / len(testloader)

    return avg_loss, accuracy, precision, recall, f1


def visualize_metrics(train_losses, val_losses, train_accuracies, val_accuracies,
                      train_precisions, val_precisions, train_recalls, val_recalls,
                      train_f1s, val_f1s):
    
    # Plot for Loss
    plt.figure()
    plt.plot(train_losses, label='Train Loss')
    plt.plot(val_losses, label='Validation Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.title('Loss over Epochs')
    plt.savefig('loss.png')
    plt.show()
    



def main():
    best_val_loss = float('inf')
    best_model_state = None
    best_epoch = 0


    with open("training_log.txt", "w") as f:
        f.write("Epoch, Train Loss, Val Loss, Train Acc, Val Acc, Train Prec, Val Prec, Train Recall, Val Recall, Train F1, Val F1\n")

        # Main training loop
        for epoch in range(epochs):
            print(f'Epoch {epoch + 1}/{epochs}')
            
            train_loss, train_acc, train_prec, train_recall, train_f1 = train(model, train_loader, criterion, optimizer, device)
            val_loss, val_acc, val_prec, val_recall, val_f1 = validation(model, validation_loader, criterion, device)

            train_losses.append(train_loss)
            val_losses.append(val_loss)
            
            # Log metrics to file
            f.write(f"{epoch + 1}, {train_loss:.4f}, {val_loss:.4f}, {train_acc:.4f}, {val_acc:.4f}, "
                    f"{train_prec:.4f}, {val_prec:.4f}, {train_recall:.4f}, {val_recall:.4f}, "
                    f"{train_f1:.4f}, {val_f1:.4f}\n")
            
            # Display metrics
            print(f'Epoch {epoch + 1}')
            print(f'Training Loss: {train_loss:.4f}, Training Accuracy: {train_acc:.4f}, Training Precision: {train_prec:.4f}, Training Recall: {train_recall:.4f}, Training F1: {train_f1:.4f}')
            print(f'Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_acc:.4f}, Validation Precision: {val_prec:.4f}, Validation Recall: {val_recall:.4f}, Validation F1: {val_f1:.4f}')

            # Save the model's state when the validation loss is the lowest
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                best_model_state = model.state_dict()  # Save the model's state
                best_epoch = epoch + 1

    if best_model_state is not None:
        torch.save(best_model_state, f"best_efficientnet_epoch{best_epoch}.pth")
        print("Best model saved with validation loss:", best_val_loss)

    visualize_metrics(train_losses, val_losses, train_accuracies, val_accuracies, 
                      train_precisions, val_precisions, train_recalls, val_recalls, train_f1s, val_f1s)

if __name__ == '__main__':
    multiprocessing.freeze_support()
    # Main training loop
    main()

In [2]:
print("ghgg")

ghgg
