In [1]:
random_n=42
import torch
from torchvision.models import regnet_y_8gf
from torch.utils.data import Dataset, DataLoader, random_split
from PIL import Image
import os
import numpy as np
from sklearn.model_selection import train_test_split
from torch.utils.data import Subset
path="/content/drive/MyDrive/strawberry/berry_jpg_cropped"
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"device={device}")
import torchvision.transforms.v2 as v2
class MyDataset(Dataset):
    def __init__(self, subset, transform=None):
        self.subset = subset
        self.transform = transform
    def __getitem__(self, index):
        x, y,z = self.subset[index]
        if self.transform:
            x = self.transform(x)
        return x, y,z
    def __len__(self):
        return len(self.subset)
class StrawberryDataset(Dataset):
    def __init__(self, img_dir, transform=None):
        self.img_dir = img_dir
        self.transform = transform
        self.img_labels = self._generate_labels()
    def _generate_labels(self):
        labels = []
        for i in range(1, 271):
            if i <= 90:
                labels.append(0)  # C
            elif i <= 180:
                labels.append(1)  # B
            else:
                labels.append(2)  # A
        return labels
    def __len__(self):
        return len(os.listdir(self.img_dir))
    def __getitem__(self, idx):
        img_name = os.path.join(self.img_dir, f'strawberry{idx+1:03d}.jpg')
        image = Image.open(img_name)
        label = self.img_labels[idx]
        if self.transform:
            image = self.transform(image)
        return image, label, img_name
transform_train = v2.Compose([
    v2.Resize((256, 256)),
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True),
    v2.RandomHorizontalFlip(p=0.5),
    v2.RandomVerticalFlip(p=0.5),
    v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
transform_test = v2.Compose([
    v2.Resize((256, 256)),
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
whole_ds = StrawberryDataset(img_dir=path)
l1=np.repeat(np.arange(3),30)
print(l1)
train1,test1=train_test_split(range(90), test_size=30, random_state=random_n, stratify=l1)
print(test1)
train_idx=[]
for i in train1:
  train_idx.append(i*3)
  train_idx.append(i*3+1)
  train_idx.append(i*3+2)
test_idx=[]
for i in test1:
  test_idx.append(i*3)
  test_idx.append(i*3+1)
  test_idx.append(i*3+2)
print(test_idx)
train_split = Subset(whole_ds, train_idx)
test_split = Subset(whole_ds, test_idx)
train_dataset = MyDataset(train_split, transform=transform_train)
test_dataset = MyDataset(test_split, transform=transform_test)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
cutmix = v2.CutMix(num_classes=3)
mixup = v2.MixUp(num_classes=3)
cutmix_or_mixup = v2.RandomChoice([cutmix, mixup])
model = regnet_y_8gf(weights="IMAGENET1K_V2")
model.fc = torch.nn.Linear(model.fc.in_features, 3)
T_max=20
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=0.0001,weight_decay=0.005)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=T_max)
model.to(device)
num_epochs = T_max
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for images, labels, img_names in train_loader:
        images, labels = images.to(device), labels.to(device)
        images, labels = cutmix_or_mixup(images, labels)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    scheduler.step()
    model.eval()
    correct_test = 0
    total_test = 0
    misclassified = []
    with torch.no_grad():
        for images, labels, img_names in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total_test += labels.size(0)
            correct_test += (predicted == labels).sum().item()
            for img_name, label, pred in zip(img_names, labels, predicted):
                if label != pred:
                    misclassified.append((img_name, label.item(), pred.item()))
    test_accuracy = 100 * correct_test / total_test
    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}, '
          f'Testing Accuracy: {test_accuracy:.2f}%')
print("\nMisclassified images:")
for img_name, true_label, pred_label in misclassified:
    print(f'{img_name}: True Label: {["C", "B", "A"][true_label]}, Predicted Label: {["C", "B", "A"][pred_label]}')

device=cuda
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1
 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 2 2 2 2
 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2]
[58, 49, 14, 48, 21, 71, 87, 19, 26, 7, 18, 76, 63, 44, 77, 66, 32, 55, 74, 73, 20, 29, 39, 45, 10, 57, 41, 67, 6, 62]
[174, 175, 176, 147, 148, 149, 42, 43, 44, 144, 145, 146, 63, 64, 65, 213, 214, 215, 261, 262, 263, 57, 58, 59, 78, 79, 80, 21, 22, 23, 54, 55, 56, 228, 229, 230, 189, 190, 191, 132, 133, 134, 231, 232, 233, 198, 199, 200, 96, 97, 98, 165, 166, 167, 222, 223, 224, 219, 220, 221, 60, 61, 62, 87, 88, 89, 117, 118, 119, 135, 136, 137, 30, 31, 32, 171, 172, 173, 123, 124, 125, 201, 202, 203, 18, 19, 20, 186, 187, 188]


Downloading: "https://download.pytorch.org/models/regnet_y_8gf-dc2b1b54.pth" to /root/.cache/torch/hub/checkpoints/regnet_y_8gf-dc2b1b54.pth
100%|██████████| 151M/151M [00:00<00:00, 199MB/s]


Epoch [1/20], Loss: 1.0372, Testing Accuracy: 85.56%
Epoch [2/20], Loss: 0.8505, Testing Accuracy: 93.33%
Epoch [3/20], Loss: 0.6096, Testing Accuracy: 95.56%
Epoch [4/20], Loss: 0.4846, Testing Accuracy: 97.78%
Epoch [5/20], Loss: 0.5475, Testing Accuracy: 96.67%
Epoch [6/20], Loss: 0.5924, Testing Accuracy: 95.56%
Epoch [7/20], Loss: 0.5791, Testing Accuracy: 97.78%
Epoch [8/20], Loss: 0.5388, Testing Accuracy: 97.78%
Epoch [9/20], Loss: 0.5200, Testing Accuracy: 97.78%
Epoch [10/20], Loss: 0.4969, Testing Accuracy: 97.78%
Epoch [11/20], Loss: 0.5178, Testing Accuracy: 97.78%
Epoch [12/20], Loss: 0.4369, Testing Accuracy: 97.78%
Epoch [13/20], Loss: 0.4583, Testing Accuracy: 97.78%
Epoch [14/20], Loss: 0.4262, Testing Accuracy: 97.78%
Epoch [15/20], Loss: 0.3998, Testing Accuracy: 97.78%
Epoch [16/20], Loss: 0.5327, Testing Accuracy: 97.78%
Epoch [17/20], Loss: 0.6169, Testing Accuracy: 97.78%
Epoch [18/20], Loss: 0.4413, Testing Accuracy: 97.78%
Epoch [19/20], Loss: 0.6237, Testing 