In [None]:
import os
import sys
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from transformers import ViTForImageClassification, ViTFeatureExtractor
import matplotlib.pyplot as plt
from tqdm import tqdm
from huggingface_hub import notebook_login
from transformers import ViTFeatureExtractor, ViTModel
from PIL import Image
import requests
import pandas as pd
from sklearn.metrics import recall_score

def train_and_evaluate_model(model_name, train_loader, val_loader, num_classes, device, base_save_path, epochs=40, lr=0.0002, patience=50):
    # 选择模型
    if model_name == 'resnet50':
        model = models.resnet50(pretrained=True)
        model.fc = nn.Linear(model.fc.in_features, num_classes)
    elif model_name == 'resnet18':
        model = models.resnet18(pretrained=True)
        model.fc = nn.Linear(model.fc.in_features, num_classes)
    elif model_name == 'efficientnet_b3':
        model = models.efficientnet_b3(pretrained=True)
        model.classifier[1] = nn.Linear(model.classifier[1].in_features, num_classes)
    elif model_name == 'vit_base':
        model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224')
        model.classifier = nn.Linear(model.config.hidden_size, num_classes)
    elif model_name == 'vit_huge':
        # model = ViTForImageClassification.from_pretrained('google/vit-huge-patch14-224',token="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")
        # model.classifier = nn.Linear(model.config.hidden_size, num_classes)
        feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-huge-patch14-224-in21k')
        model = ViTModel.from_pretrained('google/vit-huge-patch14-224-in21k')
        model.classifier = nn.Linear(model.config.hidden_size, num_classes)

    elif model_name == 'vit_s':
        # model = ViTForImageClassification.from_pretrained('google/vit-small-patch16-224')
        model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224')

        model.classifier = nn.Linear(model.config.hidden_size, num_classes)
    elif model_name == 'vit_swag':
        model = ViTForImageClassification.from_pretrained('google/vit-huge-patch14-224-swag')
        model.classifier = nn.Linear(model.config.hidden_size, num_classes)
    elif model_name == 'vgg16':
        model = models.vgg16(pretrained=True)
        model.classifier[6] = nn.Linear(model.classifier[6].in_features, num_classes)
    elif model_name == 'vgg19':
        model = models.vgg19(pretrained=True)
        model.classifier[6] = nn.Linear(model.classifier[6].in_features, num_classes)
    else:
        raise ValueError("Unknown model name")

    model = model.to(device)
    loss_function = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)

    train_losses = []  # 用于记录每个 epoch 的训练损失
    val_losses = []  # 用于记录每个 epoch 的验证损失
    best_acc = 0.0
    epochs_no_improve = 0  # 追踪没有进步的epoch数
    early_stop = False


    save_path = os.path.join(base_save_path, f'best_model_{model_name}.pth')


     # 初始化記錄
    records = []

    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        train_bar = tqdm(train_loader, file=sys.stdout)
        for step, data in enumerate(train_bar):
            images, labels = data
            optimizer.zero_grad()
            outputs = model(images.to(device))
            logits = outputs.logits if model_name.startswith('vit') else outputs
            loss = loss_function(logits, labels.to(device))
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1, epochs, loss)

        train_loss = running_loss / len(train_loader)

        # Validation process with recall calculation
        model.eval()
        val_running_loss = 0.0
        acc = 0.0
        all_val_labels = []
        all_predictions = []
        with torch.no_grad():
            val_bar = tqdm(val_loader, file=sys.stdout)
            for val_data in val_bar:
                val_images, val_labels = val_data
                outputs = model(val_images.to(device))
                logits = outputs.logits if model_name.startswith('vit') else outputs
                val_loss = loss_function(logits, val_labels.to(device))
                predict_y = torch.max(logits, dim=1)[1]
                acc += torch.eq(predict_y, val_labels.to(device)).sum().item()
                val_running_loss += val_loss.item()
                all_val_labels.extend(val_labels.cpu().numpy())
                all_predictions.extend(predict_y.cpu().numpy())

        val_loss = val_running_loss / len(val_loader)
        val_accuracy = acc / len(val_loader.dataset)
        val_recall = recall_score(all_val_labels, all_predictions, average='macro')  # Calculate recall

        records.append((epoch, train_loss, val_loss, val_accuracy, val_recall))

        print('[epoch %d] train_loss: %.3f  val_loss: %.3f val_accuracy: %.3f val_recall: %.3f' % (epoch + 1, train_loss, val_loss, val_accuracy, val_recall))

        # 早停检测
        if val_accuracy > best_acc:
            best_acc = val_accuracy
            torch.save(model.state_dict(), save_path)
            epochs_no_improve = 0  # 重置计数器
        else:
            epochs_no_improve += 1
            if epochs_no_improve >= patience:
                print('Early stopping triggered after {} epochs with no improvement'.format(patience))
                early_stop = True
                break

    # DataFrame and plots storage
    df = pd.DataFrame(records, columns=['Epoch', 'Train Loss', 'Validation Loss', 'Validation Accuracy', 'Validation Recall'])
    df.to_csv(os.path.join(base_save_path, f'{model_name}_training_records.csv'), index=False)

    plt.figure(figsize=(10, 5))
    plt.plot(df['Epoch'], df['Train Loss'], label='Train Loss')
    plt.plot(df['Epoch'], df['Validation Loss'], label='Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title(f'Training and Validation Loss - {model_name}')
    plt.legend()
    plt.savefig(os.path.join(base_save_path, f'{model_name}_loss_curve.png'))
    plt.show()
    plt.close()

    plt.figure(figsize=(10, 5))
    plt.plot(df['Epoch'], df['Validation Accuracy'], label='Validation Accuracy')
    plt.plot(df['Epoch'], df['Validation Recall'], label='Validation Recall', linestyle='--')
    plt.xlabel('Epoch')
    plt.ylabel('Metrics')
    plt.title(f'Validation Accuracy and Recall - {model_name}')
    plt.legend()
    plt.savefig(os.path.join(base_save_path, f'{model_name}_accuracy_recall_curve.png'))
    plt.show()
    plt.close()

    print(f'Finished Training {model_name}')

# 主程序
def main():
    base_save_path = '/content'  # 基础保存路径
    batch_size = 32
    epochs = 600
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("Using {} device.".format(device))

    # 数据变换
    data_transform = {
        "train": transforms.Compose([transforms.RandomResizedCrop(224),
                                      transforms.RandomHorizontalFlip(),
                                      transforms.ToTensor(),
                                      transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),
        "val": transforms.Compose([transforms.Resize((224, 224)),
                                    transforms.ToTensor(),
                                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    }

    # 加载数据集
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
    val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

    num_classes = len(full_dataset.classes)

    # 需要训练的模型列表
    model_names = [ 'resnet50', 'resnet18', 'efficientnet_b3','vit_base',  'vgg16', 'vgg19']
#'vit_huge','vit_s','vit_swag',
    for model_name in model_names:
        print(f'Training {model_name}...')
        train_and_evaluate_model(model_name, train_loader, val_loader, num_classes, device, base_save_path, epochs)

if __name__ == '__main__':
    main()
