In [12]:
import os
import pandas as pd
import json
from PIL import Image
from sklearn.preprocessing import OneHotEncoder
import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torch.utils.data import random_split
from tqdm import tqdm
from sklearn.metrics import accuracy_score

In [2]:
# 定义基础路径
basic_root = 'cassava-leaf-disease-classification'
image_path = os.path.join(basic_root, 'train_images')
csv_path = os.path.join(basic_root, 'train.csv')
json_path = os.path.join(basic_root, 'label_num_to_disease_map.json')

In [3]:
# 读取CSV和JSON文件
image_csv = pd.read_csv(csv_path)
image_json = json.load(open(json_path))
image_json_tolabel = {int(i): j for i, j in image_json.items()}
image_csv['label_name'] = image_csv['label'].map(image_json_tolabel)

# 只处理前3000张图像
image_csv = image_csv.iloc[:3000]

In [4]:
# 数据增强和归一化
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(20),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [5]:
class CassavaDataset(Dataset):
    def __init__(self, csv_file, img_dir, transform=None):
        self.annotations = csv_file
        self.img_dir = img_dir
        self.transform = transform
        self.one_hot_encoder = OneHotEncoder(sparse=False)
        self.labels = self.one_hot_encoder.fit_transform(self.annotations[['label']].values)

    def __len__(self):
        return len(self.annotations)

    def __getitem__(self, index):
        img_path = os.path.join(self.img_dir, self.annotations.iloc[index, 0])
        image = Image.open(img_path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        label = torch.tensor(self.labels[index], dtype=torch.float)
        return image, label

dataset = CassavaDataset(image_csv, image_path, transform=transform)

In [13]:
dataset

<__main__.CassavaDataset at 0x12eff4cee80>

In [6]:
# 拆分训练集和验证集
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=32, shuffle=False)

In [7]:
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv_layers = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),

            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),

            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        
        self.fc_layers = nn.Sequential(
            nn.Linear(128 * 32 * 32, 512),
            nn.ReLU(),
            nn.Dropout(0.5),

            nn.Linear(512, 128),
            nn.ReLU(),

            nn.Linear(128, 5)
        )

    def forward(self, x):
        x = self.conv_layers(x)
        x = x.view(-1, 128 * 32 * 32)
        x = self.fc_layers(x)
        return x

model = CNN()

In [8]:


# 检查CUDA是否可用并选择设备
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# 定义模型、损失函数和优化器
model = model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

best_accuracy = 0.0
epochs = 10

In [9]:
for epoch in range(epochs):
    model.train()
    running_loss = 0.0
    for images, labels in tqdm(train_dataloader, desc=f'Epoch {epoch+1}/{epochs}', unit='batch'):
        # 将数据移到GPU
        images = images.to(device)
        labels = labels.to(device)
        
        outputs = model(images)
        loss = criterion(outputs, torch.max(labels, 1)[1])
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        
    print(f'Epoch [{epoch+1}/{epochs}], Loss: {running_loss/len(train_dataloader):.4f}')
    
    # 在验证集上评估模型
    model.eval()
    val_predictions = []
    val_labels = []
    with torch.no_grad():
        for images, labels in tqdm(val_dataloader, desc='Validating', unit='batch'):
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            val_predictions.extend(predicted.cpu().numpy())
            val_labels.extend(torch.max(labels, 1)[1].cpu().numpy())
    
    # 计算验证集上的准确率
    accuracy = accuracy_score(val_labels, val_predictions)
    print(f'验证集上的准确率: {accuracy:.4f}')
    
    # 保存具有最高验证准确率的模型
    if accuracy > best_accuracy:
        best_accuracy = accuracy
        torch.save(model.state_dict(), 'best_model.pth')

Epoch 1/10: 100%|██████████| 75/75 [00:46<00:00,  1.60batch/s]


Epoch [1/10], Loss: 1.3208


Validating: 100%|██████████| 19/19 [00:09<00:00,  2.08batch/s]


验证集上的准确率: 0.6183


Epoch 2/10: 100%|██████████| 75/75 [00:38<00:00,  1.95batch/s]


Epoch [2/10], Loss: 1.1664


Validating: 100%|██████████| 19/19 [00:08<00:00,  2.13batch/s]


验证集上的准确率: 0.6183


Epoch 3/10: 100%|██████████| 75/75 [00:39<00:00,  1.92batch/s]


Epoch [3/10], Loss: 1.1476


Validating: 100%|██████████| 19/19 [00:09<00:00,  2.10batch/s]


验证集上的准确率: 0.6183


Epoch 4/10: 100%|██████████| 75/75 [00:38<00:00,  1.95batch/s]


Epoch [4/10], Loss: 1.1240


Validating: 100%|██████████| 19/19 [00:08<00:00,  2.13batch/s]


验证集上的准确率: 0.6150


Epoch 5/10: 100%|██████████| 75/75 [00:38<00:00,  1.94batch/s]


Epoch [5/10], Loss: 1.0730


Validating: 100%|██████████| 19/19 [00:08<00:00,  2.12batch/s]


验证集上的准确率: 0.6167


Epoch 6/10: 100%|██████████| 75/75 [00:38<00:00,  1.93batch/s]


Epoch [6/10], Loss: 1.0662


Validating: 100%|██████████| 19/19 [00:08<00:00,  2.12batch/s]


验证集上的准确率: 0.6383


Epoch 7/10: 100%|██████████| 75/75 [00:39<00:00,  1.89batch/s]


Epoch [7/10], Loss: 1.0403


Validating: 100%|██████████| 19/19 [00:08<00:00,  2.11batch/s]


验证集上的准确率: 0.6400


Epoch 8/10: 100%|██████████| 75/75 [00:39<00:00,  1.92batch/s]


Epoch [8/10], Loss: 1.0260


Validating: 100%|██████████| 19/19 [00:09<00:00,  2.08batch/s]


验证集上的准确率: 0.6417


Epoch 9/10: 100%|██████████| 75/75 [00:41<00:00,  1.79batch/s]


Epoch [9/10], Loss: 0.9925


Validating: 100%|██████████| 19/19 [00:09<00:00,  1.94batch/s]


验证集上的准确率: 0.6417


Epoch 10/10: 100%|██████████| 75/75 [01:02<00:00,  1.19batch/s]


Epoch [10/10], Loss: 1.0114


Validating: 100%|██████████| 19/19 [00:08<00:00,  2.12batch/s]

验证集上的准确率: 0.6317





In [10]:
# 加载测试集数据
test_csv_path = 'cassava-leaf-disease-classification/sample_submission.csv'
test_csv = pd.read_csv(test_csv_path)
test_image_path = os.path.join(basic_root, 'test_images')

class CassavaTestDataset(Dataset):
    def __init__(self, csv_file, img_dir, transform=None):
        self.annotations = csv_file
        self.img_dir = img_dir
        self.transform = transform

    def __len__(self):
        return len(self.annotations)

    def __getitem__(self, index):
        img_path = os.path.join(self.img_dir, self.annotations.iloc[index, 0])
        image = Image.open(img_path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        return image

test_dataset = CassavaTestDataset(test_csv, test_image_path, transform=transform)
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False)

# 加载最佳模型
model.load_state_dict(torch.load('best_model.pth'))

<All keys matched successfully>

In [1]:
# 预测并保存结果
model.eval()
predictions = []
with torch.no_grad():
    for images in tqdm(test_dataloader, desc='Predicting', unit='batch'):
        images = images.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs, 1)
        predictions.extend(predicted.cpu().numpy())

test_csv['label'] = predictions
test_csv['label_name'] = test_csv['label'].map(image_json_tolabel)
print(test_csv)

NameError: name 'model' is not defined