In [1]:
import os
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader, Sampler
import torch.nn.functional as F
from collections import defaultdict
from PIL import Image
from torchvision import transforms
import torch.nn as nn
import torch.optim as optim
from torchvision import models
from tqdm import tqdm
from dataset import CustomImageDataset

In [2]:
os.cpu_count()

8

In [3]:
root_folder = "Indian_bovine_breeds"  # change this to your dataset path

class_counts = {}
for class_folder in os.listdir(root_folder):
    class_path = os.path.join(root_folder, class_folder)
    if os.path.isdir(class_path):
        num_files = len([
            f for f in os.listdir(class_path)
            if os.path.isfile(os.path.join(class_path, f))
        ])
        class_counts[class_folder] = num_files
print(class_counts)

{'Vechur': 140, 'Mehsana': 95, 'Hallikar': 186, 'Amritmahal': 94, 'Kankrej': 179, 'Sahiwal': 439, 'Surti': 64, 'Jersey': 203, 'Pulikulam': 125, 'Nagpuri': 187, 'Nagori': 89, 'Malnad_gidda': 107, 'Dangi': 82, 'Murrah': 173, 'Jaffrabadi': 102, 'Red_Dane': 167, 'Krishna_Valley': 136, 'Guernsey': 119, 'Kherigarh': 36, 'Rathi': 149, 'Khillari': 113, 'Bargur': 94, 'Banni': 109, 'Holstein_Friesian': 328, 'Toda': 124, 'Alambadi': 99, 'Deoni': 99, 'Kangayam': 91, 'Kenkatha': 55, 'Kasargod': 95, 'Nimari': 84, 'Tharparkar': 217, 'Bhadawari': 86, 'Ongole': 191, 'Red_Sindhi': 166, 'Hariana': 130, 'Umblachery': 76, 'Gir': 372, 'Ayrshire': 234, 'Brown_Swiss': 225, 'Nili_Ravi': 89}


In [4]:
sorted_class_counts = [[k,v] for k,v in class_counts.items()]
sorted_class_counts = sorted(sorted_class_counts, key=lambda x: x[1], reverse=True)
sorted_class_counts

[['Sahiwal', 439],
 ['Gir', 372],
 ['Holstein_Friesian', 328],
 ['Ayrshire', 234],
 ['Brown_Swiss', 225],
 ['Tharparkar', 217],
 ['Jersey', 203],
 ['Ongole', 191],
 ['Nagpuri', 187],
 ['Hallikar', 186],
 ['Kankrej', 179],
 ['Murrah', 173],
 ['Red_Dane', 167],
 ['Red_Sindhi', 166],
 ['Rathi', 149],
 ['Vechur', 140],
 ['Krishna_Valley', 136],
 ['Hariana', 130],
 ['Pulikulam', 125],
 ['Toda', 124],
 ['Guernsey', 119],
 ['Khillari', 113],
 ['Banni', 109],
 ['Malnad_gidda', 107],
 ['Jaffrabadi', 102],
 ['Alambadi', 99],
 ['Deoni', 99],
 ['Mehsana', 95],
 ['Kasargod', 95],
 ['Amritmahal', 94],
 ['Bargur', 94],
 ['Kangayam', 91],
 ['Nagori', 89],
 ['Nili_Ravi', 89],
 ['Bhadawari', 86],
 ['Nimari', 84],
 ['Dangi', 82],
 ['Umblachery', 76],
 ['Surti', 64],
 ['Kenkatha', 55],
 ['Kherigarh', 36]]

In [5]:
def split_dataset_paths(root_folder, train_ratio=0.8, val_ratio=0.1, test_ratio=0.1, seed=1337):
    np.random.seed(seed)
    dataset_splits = {}

    for cls in sorted(os.listdir(root_folder)):
        cls_path = os.path.join(root_folder, cls)
        if not os.path.isdir(cls_path):
            continue
        images = [os.path.join(cls_path, f) for f in os.listdir(cls_path) if os.path.isfile(os.path.join(cls_path, f))]
        np.random.shuffle(images)

        n_total = len(images)
        # n_train = int(n_total * train_ratio)
        n_val = max(int(n_total * val_ratio), 5)
        n_test = max(int(n_total * test_ratio), 5)
        n_train = n_total - n_val - n_test

        train_files = images[:n_train]
        val_files = images[n_train:n_train + n_val]
        test_files = images[n_train + n_val:]

        dataset_splits[cls] = {
            'train': train_files,
            'val': val_files,
            'test': test_files
        }

    return dataset_splits

In [6]:
dataset = split_dataset_paths("Indian_bovine_breeds")

In [7]:
len(dataset['Kherigarh']['train']),len(dataset['Kherigarh']['val']),len(dataset['Kherigarh']['test'])

(26, 5, 5)

In [8]:
len(dataset['Sahiwal']['train']),len(dataset['Sahiwal']['val']),len(dataset['Sahiwal']['test'])

(353, 43, 43)

In [9]:
def create_file_list(splits_dict, split_name):
    file_list = []
    for cls, splits in splits_dict.items():
        label = list(splits_dict.keys()).index(cls)  # index of class
        paths = splits[split_name]
        file_list.extend([(p, label) for p in paths])
    return file_list

# Define torchvision transforms
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

class BalancedBatchSampler(Sampler):
    def __init__(self, labels, samples_per_class=5):
        """
        labels: list or array of class labels aligned with dataset indices
        samples_per_class: number of images per class per batch
        """
        self.labels = np.array(labels)
        self.samples_per_class = samples_per_class
        
        # Group indices by class
        self.class_indices = defaultdict(list)
        for idx, label in enumerate(labels):
            self.class_indices[label].append(idx)
        
        # Calculate number of batches
        self.batches_per_epoch = min(len(idxs) // samples_per_class for idxs in self.class_indices.values())
        self.classes = list(self.class_indices.keys())

    def __iter__(self):
        batch = []
        # Shuffle indices within each class
        shuffled_indices = {}
        for c in self.classes:
            idxs = self.class_indices[c]
            np.random.shuffle(idxs)
            shuffled_indices[c] = idxs
        
        for batch_idx in range(self.batches_per_epoch):
            batch.clear()
            for c in self.classes:
                start = batch_idx * self.samples_per_class
                end = start + self.samples_per_class
                batch.extend(shuffled_indices[c][start:end])
            np.random.shuffle(batch)  # Optional: shuffle final batch indices
            yield batch

    def __len__(self):
        return self.batches_per_epoch

In [10]:
train_files = create_file_list(dataset, 'train')
val_files = create_file_list(dataset, 'val')
test_files = create_file_list(dataset, 'test')

train_dataset = CustomImageDataset(train_files, transform=transform)
val_dataset = CustomImageDataset(val_files, transform=transform)
test_dataset = CustomImageDataset(test_files, transform=transform)

train_loader = DataLoader(
    train_dataset, 
    batch_size=32, 
    shuffle=True, 
    prefetch_factor=2,
    pin_memory=True,
    num_workers=2
)
val_loader = DataLoader(
    val_dataset, 
    batch_size=32, 
    prefetch_factor=2,
    pin_memory=True,
    num_workers=2
)
test_loader = DataLoader(
    test_dataset, 
    batch_size=32, 
    prefetch_factor=2,
    pin_memory=True,
    num_workers=2
)

In [11]:
train_files = create_file_list(dataset, 'train')  # as before
train_labels = [label for _, label in train_files]

train_dataset = CustomImageDataset(train_files, transform=transform)
train_sampler = BalancedBatchSampler(train_labels, samples_per_class=1)
train_loader = DataLoader(train_dataset, batch_sampler=train_sampler)

val_files = create_file_list(dataset, 'val')  # as before
val_labels = [label for _, label in val_files]

val_dataset = CustomImageDataset(val_files, transform=transform)
val_sampler = BalancedBatchSampler(val_labels, samples_per_class=3)
val_loader = DataLoader(val_dataset, batch_sampler=val_sampler)

test_files = create_file_list(dataset, 'test')  # as before
test_labels = [label for _, label in test_files]

test_dataset = CustomImageDataset(test_files, transform=transform)
test_sampler = BalancedBatchSampler(test_labels, samples_per_class=3)
test_loader = DataLoader(test_dataset, batch_sampler=test_sampler)

In [12]:
len(train_loader), len(val_loader), len(test_loader)

(26, 1, 1)

In [13]:
class SimpleCNN(nn.Module):
    def __init__(self, num_classes):
        super(SimpleCNN, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),

            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            # nn.Conv2d(64, 64, kernel_size=3, padding=1),
            # nn.Conv2d(64, 64, kernel_size=3, padding=1),            
            nn.ReLU(),
            nn.MaxPool2d(2),

            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            # nn.Conv2d(128, 128, kernel_size=3, padding=1),
            # nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
        )
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(128 * 28 * 28, 256),  # assuming input size 224x224
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(256, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x

In [14]:
# def train(model, train_loader, val_loader, epochs=10, lr=0.001, device='cuda'):
#     model = model.to(device)
#     criterion = nn.CrossEntropyLoss()
#     optimizer = optim.Adam(model.parameters(), lr=lr)

#     for epoch in range(epochs):
#         model.train()
#         running_loss = 0.0
#         correct = 0
#         total = 0

#         for images, labels in tqdm(train_loader):
#             images, labels = images.to(device), labels.to(device)

#             optimizer.zero_grad()
#             outputs = model(images)
#             loss = criterion(outputs, labels)
#             loss.backward()
#             optimizer.step()

#             running_loss += loss.item() * images.size(0)
#             _, predicted = torch.max(outputs, 1)
#             total += labels.size(0)
#             correct += (predicted == labels).sum().item()

#         train_loss = running_loss / total
#         train_acc = correct / total

#         # Validation
#         model.eval()
#         val_loss = 0.0
#         val_correct = 0
#         val_total = 0
#         with torch.no_grad():
#             for images, labels in tqdm(val_loader):
#                 images, labels = images.to(device), labels.to(device)
#                 outputs = model(images)
#                 loss = criterion(outputs, labels)
#                 val_loss += loss.item() * images.size(0)
#                 _, predicted = torch.max(outputs, 1)
#                 val_total += labels.size(0)
#                 val_correct += (predicted == labels).sum().item()

#         val_loss /= val_total
#         val_acc = val_correct / val_total

#         print(f"Epoch [{epoch + 1}/{epochs}] Train Loss: {train_loss:.4f} Train Acc: {train_acc:.4f} | Val Loss: {val_loss:.4f} Val Acc: {val_acc:.4f}")

In [15]:
# num_classes = 41  # number of classes
# model = SimpleCNN(num_classes)
# train(model, train_loader, val_loader, epochs=10, lr=0.001, device='cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
# Episodic Sampler for N-way K-shot + Q query setup
class EpisodicBatchSampler(Sampler):
    def __init__(self, labels, n_way, k_shot, q_query, episodes_per_epoch):
        self.labels = np.array(labels)
        self.n_way = n_way
        self.k_shot = k_shot
        self.q_query = q_query
        self.episodes_per_epoch = episodes_per_epoch
        self.class_to_indices = defaultdict(list)
        for idx, label in enumerate(labels):
            self.class_to_indices[label].append(idx)
        self.classes = [c for c in self.class_to_indices if len(self.class_to_indices[c]) >= (k_shot + q_query)]

    def __len__(self):
        return self.episodes_per_epoch

    def __iter__(self):
        for _ in range(self.episodes_per_epoch):
            batch_indices = []
            selected_classes = np.random.choice(self.classes, self.n_way, replace=False)
            for cls in selected_classes:
                indices = np.random.choice(self.class_to_indices[cls], self.k_shot + self.q_query, replace=False)
                batch_indices.extend(indices)
            yield batch_indices

# Simple CNN embedding model
class ProtoNet(nn.Module):
    def __init__(self, embedding_dim=64):
        super(ProtoNet, self).__init__()

        def conv_block(in_channels, out_channels):
            return nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False),
                nn.BatchNorm2d(out_channels),
                nn.ReLU(inplace=True),
                nn.MaxPool2d(2)  # halve spatial resolution
            )

        # 224 → 112 → 56 → 28 → 14 → 7
        self.encoder = nn.Sequential(
            conv_block(3, 8),      # 224 → 112
            conv_block(8, 16),    # 112 → 56
            conv_block(16, 32),   # 56 → 28
            conv_block(32, 64),   # 28 → 14
            conv_block(64, embedding_dim),  # 14 → 7
            nn.AdaptiveAvgPool2d(1) # always → (1,1)
        )

    def forward(self, x):
        x = self.encoder(x)              # (B, embedding_dim, 1, 1)
        return x.view(x.size(0), -1)     # flatten → (B, embedding_dim)

        
# Prototypical loss function
def prototypical_loss(embeddings, targets, n_way, k_shot, q_query):
    """
    embeddings: (batch_size, embedding_dim) for one episode
    targets: (batch_size) class labels, mapped 0 to n_way-1 for sampled classes
    """
    support_indices = []
    query_indices = []
    for i in range(n_way):
        support_indices.extend(range(i*(k_shot+q_query), i*(k_shot+q_query) + k_shot))
        query_indices.extend(range(i*(k_shot+q_query) + k_shot, (i+1)*(k_shot+q_query)))
    support_embeddings = embeddings[support_indices]
    query_embeddings = embeddings[query_indices]

    # compute prototypes
    prototypes = support_embeddings.view(n_way, k_shot, -1).mean(dim=1)  # (n_way, embedding_dim)

    # dist between queries and prototypes (batch matrix)
    dists = torch.cdist(query_embeddings, prototypes)  # (n_way*q_query, n_way)

    # Compute log-probabilities
    log_p_y = F.log_softmax(-dists, dim=1)  # negative distances as logits

    # Construct query labels
    query_labels = torch.arange(n_way).unsqueeze(1).repeat(1, q_query).view(-1).to(embeddings.device)

    # Loss is negative log-likelihood of true classes
    loss = F.nll_loss(log_p_y, query_labels)

    # Accuracy
    _, y_hat = log_p_y.max(1)
    acc = torch.eq(y_hat, query_labels).float().mean()

    return loss, acc

# Transforms
transform = transforms.Compose([
    transforms.Resize((112, 112)),  # typical size for prototypical nets
    transforms.ToTensor()
])

# Assume splits_dict is defined as before
train_files = create_file_list(dataset, 'train')  # (image_path, label)
train_labels = [label for _, label in train_files]

# Parameters for prototypical training
n_way = 5
k_shot = 20
q_query = 2
episodes_per_epoch = 100

train_dataset = CustomImageDataset(train_files, transform=transform)
train_sampler = EpisodicBatchSampler(train_labels, n_way, k_shot, q_query, episodes_per_epoch)
train_loader = DataLoader(
    train_dataset,
    batch_sampler=train_sampler,
    num_workers=4,   # use all available cores
    prefetch_factor=2,            # workers prefetch batches
    persistent_workers=True       # keep workers alive
)

In [17]:
if torch.cuda.is_available():
    device = "cuda"
# elif torch.backends.mps.is_available() and torch.backends.mps.is_built():
#     device = "mps"
else:
    device = "cpu"

print("Using device:", device)
model = ProtoNet().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# # Training loop
# for epoch in range(10):
#     model.train()
#     total_loss, total_acc = 0.0, 0.0
#     for batch in tqdm(train_loader):
#         images, labels = batch
#         images, labels = images.to(device), labels.to(device)

#         optimizer.zero_grad()
#         embeddings = model(images)

#         # Remap labels to 0..n_way-1 for the sampled classes within the episode
#         # This is necessary because batch includes only n_way classes but labels have global indices
#         unique_labels = torch.unique(labels)
#         label_map = {l.item(): i for i, l in enumerate(unique_labels)}
#         mapped_labels = torch.tensor([label_map[l.item()] for l in labels]).to(device)

#         loss, acc = prototypical_loss(embeddings, mapped_labels, n_way, k_shot, q_query)
#         loss.backward()
#         optimizer.step()

#         total_loss += loss.item()
#         total_acc += acc.item()

#     print(f"Epoch {epoch + 1}, Loss: {total_loss / len(train_loader):.4f}, Acc: {total_acc / len(train_loader):.4f}")

Using device: cpu


In [18]:
# # Before training loop
# class_prototypes = torch.zeros(41, 64, device=device)
# class_counts = torch.zeros(41, device=device)

# for epoch in range(20):
#     model.train()
#     total_loss, total_acc = 0.0, 0.0
#     class_correct = defaultdict(int)
#     class_total = defaultdict(int)

#     for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}"):
#         images, labels = batch
#         images, labels = images.to(device), labels.to(device)

#         optimizer.zero_grad()
#         embeddings = model(images)  # (B, embedding_dim)

#         # --- Compute distances to ALL prototypes ---
#         dists = torch.cdist(embeddings, class_prototypes)  # (B, num_classes)
#         logits = -dists

#         # --- Loss & optimization ---
#         loss = F.cross_entropy(logits, labels)
#         loss.backward()
#         optimizer.step()

#         total_loss += loss.item()
#         preds = logits.argmax(dim=1)
#         total_acc += (preds == labels).float().mean().item()

#         # --- Per-class accuracy tracking ---
#         for t, p in zip(labels, preds):
#             class_total[t.item()] += 1
#             if t == p:
#                 class_correct[t.item()] += 1

#         # --- Update prototypes safely ---
#         with torch.no_grad():
#             for c in labels.unique():
#                 emb = embeddings[labels == c].mean(dim=0).detach()
#                 count = (labels == c).sum()
#                 class_prototypes[c] = (class_prototypes[c] * class_counts[c] + emb * count) / (class_counts[c] + count)
#                 class_counts[c] += count

#     # --- End of epoch summary ---
#     avg_loss = total_loss / len(train_loader)
#     avg_acc = total_acc / len(train_loader)
#     print(f"\nEpoch {epoch + 1} Summary:")
#     print(f"  Loss: {avg_loss:.4f}, Accuracy: {avg_acc:.4f}")

#     print("  Per-class accuracy:")
#     for cls in range(41):
#         if class_total[cls] > 0:
#             acc_cls = class_correct[cls] / class_total[cls]
#             print(f"    Class {cls:02d}: {acc_cls:.4f} ({class_correct[cls]}/{class_total[cls]})")
#         else:
#             print(f"    Class {cls:02d}: N/A (not sampled this epoch)")


In [19]:
# --- Define model with learnable prototypes ---
class ProtoClassifier(nn.Module):
    def __init__(self, backbone, num_classes=41, embedding_dim=64):
        super().__init__()
        self.backbone = backbone
        # prototypes are learnable parameters
        self.class_prototypes = nn.Parameter(
            torch.randn(num_classes, embedding_dim)
        )

    def forward(self, x):
        embeddings = self.backbone(x)  # (B, embedding_dim)
        dists = torch.cdist(embeddings, self.class_prototypes)  # (B, num_classes)
        logits = -dists
        return logits, embeddings


# --- Setup ---
num_classes = 41
embedding_dim = 64
model = ProtoClassifier(ProtoNet(embedding_dim), num_classes, embedding_dim).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

# --- Training loop ---
for epoch in range(20):
    model.train()
    total_loss, total_acc = 0.0, 0.0
    class_correct = defaultdict(int)
    class_total = defaultdict(int)

    for images, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}"):
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        logits, embeddings = model(images)

        # --- Loss & optimization ---
        loss = F.cross_entropy(logits, labels)
        loss.backward()
        optimizer.step()

        # --- Batch accuracy ---
        preds = logits.argmax(dim=1)
        total_loss += loss.item()
        total_acc += (preds == labels).float().mean().item()

        # --- Per-class accuracy tracking ---
        for t, p in zip(labels, preds):
            class_total[t.item()] += 1
            if t == p:
                class_correct[t.item()] += 1

    # --- End of epoch summary ---
    avg_loss = total_loss / len(train_loader)
    avg_acc = total_acc / len(train_loader)
    print(f"\nEpoch {epoch + 1} Summary:")
    print(f"  Loss: {avg_loss:.4f}, Accuracy: {avg_acc:.4f}")

    print("  Per-class accuracy:")
    for cls in range(num_classes):
        if class_total[cls] > 0:
            acc_cls = class_correct[cls] / class_total[cls]
            print(f"    Class {cls:02d}: {acc_cls:.4f} "
                  f"({class_correct[cls]}/{class_total[cls]})")
        else:
            print(f"    Class {cls:02d}: N/A (not sampled this epoch)")


Epoch 1: 100%|██████████| 100/100 [00:43<00:00,  2.29it/s]



Epoch 1 Summary:
  Loss: 3.5277, Accuracy: 0.0883
  Per-class accuracy:
    Class 00: 0.0000 (0/300)
    Class 01: 0.0000 (0/264)
    Class 02: 0.0461 (13/282)
    Class 03: 0.0867 (26/300)
    Class 04: 0.0111 (3/270)
    Class 05: 0.0062 (2/324)
    Class 06: 0.1361 (49/360)
    Class 07: 0.3276 (114/348)
    Class 08: 0.0000 (0/324)
    Class 09: 0.0000 (0/252)
    Class 10: 0.0833 (24/288)
    Class 11: 0.0000 (0/282)
    Class 12: 0.2660 (83/312)
    Class 13: 0.2092 (59/282)
    Class 14: 0.1667 (56/336)
    Class 15: 0.0000 (0/300)
    Class 16: 0.0000 (0/306)
    Class 17: 0.0000 (0/318)
    Class 18: 0.0061 (2/330)
    Class 19: 0.6369 (214/336)
    Class 20: 0.1727 (57/330)
    Class 21: 0.0000 (0/318)
    Class 22: 0.0000 (0/288)
    Class 23: 0.0621 (22/354)
    Class 24: 0.0148 (4/270)
    Class 25: 0.1091 (36/330)
    Class 26: 0.0595 (15/252)
    Class 27: 0.0000 (0/366)
    Class 28: 0.0000 (0/288)
    Class 29: 0.0068 (2/294)
    Class 30: 0.2654 (86/324)
    Class 31

Epoch 2: 100%|██████████| 100/100 [00:41<00:00,  2.40it/s]



Epoch 2 Summary:
  Loss: 3.2174, Accuracy: 0.1774
  Per-class accuracy:
    Class 00: 0.0387 (13/336)
    Class 01: 0.0648 (21/324)
    Class 02: 0.2667 (80/300)
    Class 03: 0.2901 (94/324)
    Class 04: 0.2400 (72/300)
    Class 05: 0.0152 (4/264)
    Class 06: 0.2133 (64/300)
    Class 07: 0.5614 (192/342)
    Class 08: 0.0063 (2/318)
    Class 09: 0.0128 (4/312)
    Class 10: 0.1950 (55/282)
    Class 11: 0.0727 (24/330)
    Class 12: 0.1932 (51/264)
    Class 13: 0.5452 (193/354)
    Class 14: 0.1512 (49/324)
    Class 15: 0.0000 (0/312)
    Class 16: 0.0636 (21/330)
    Class 17: 0.0096 (3/312)
    Class 18: 0.0316 (11/348)
    Class 19: 0.4479 (129/288)
    Class 20: 0.6701 (197/294)
    Class 21: 0.0035 (1/288)
    Class 22: 0.0200 (6/300)
    Class 23: 0.1967 (59/300)
    Class 24: 0.3529 (108/306)
    Class 25: 0.1277 (36/282)
    Class 26: 0.1954 (68/348)
    Class 27: 0.0000 (0/288)
    Class 28: 0.0654 (20/306)
    Class 29: 0.1700 (51/300)
    Class 30: 0.4874 (155/318)

Epoch 3: 100%|██████████| 100/100 [00:40<00:00,  2.47it/s]



Epoch 3 Summary:
  Loss: 3.0067, Accuracy: 0.2518
  Per-class accuracy:
    Class 00: 0.1293 (38/294)
    Class 01: 0.3148 (102/324)
    Class 02: 0.4138 (144/348)
    Class 03: 0.2788 (87/312)
    Class 04: 0.6242 (191/306)
    Class 05: 0.0544 (16/294)
    Class 06: 0.1961 (60/306)
    Class 07: 0.5667 (170/300)
    Class 08: 0.0278 (8/288)
    Class 09: 0.0000 (0/300)
    Class 10: 0.2540 (64/252)
    Class 11: 0.1451 (47/324)
    Class 12: 0.3459 (110/318)
    Class 13: 0.5986 (176/294)
    Class 14: 0.2138 (59/276)
    Class 15: 0.0111 (3/270)
    Class 16: 0.1250 (30/240)
    Class 17: 0.0494 (16/324)
    Class 18: 0.0586 (19/324)
    Class 19: 0.4236 (122/288)
    Class 20: 0.7787 (271/348)
    Class 21: 0.0145 (4/276)
    Class 22: 0.0632 (22/348)
    Class 23: 0.2853 (89/312)
    Class 24: 0.3050 (86/282)
    Class 25: 0.4248 (130/306)
    Class 26: 0.1803 (53/294)
    Class 27: 0.0163 (5/306)
    Class 28: 0.2186 (80/366)
    Class 29: 0.4375 (147/336)
    Class 30: 0.4746 (

Epoch 4: 100%|██████████| 100/100 [00:40<00:00,  2.48it/s]



Epoch 4 Summary:
  Loss: 2.8501, Accuracy: 0.3017
  Per-class accuracy:
    Class 00: 0.1489 (42/282)
    Class 01: 0.3105 (95/306)
    Class 02: 0.4048 (119/294)
    Class 03: 0.2848 (94/330)
    Class 04: 0.6789 (167/246)
    Class 05: 0.1321 (42/318)
    Class 06: 0.2424 (80/330)
    Class 07: 0.6384 (203/318)
    Class 08: 0.0426 (12/282)
    Class 09: 0.0248 (7/282)
    Class 10: 0.2364 (61/258)
    Class 11: 0.2530 (85/336)
    Class 12: 0.3571 (105/294)
    Class 13: 0.5667 (153/270)
    Class 14: 0.2340 (66/282)
    Class 15: 0.0233 (6/258)
    Class 16: 0.2327 (74/318)
    Class 17: 0.1019 (33/324)
    Class 18: 0.0852 (23/270)
    Class 19: 0.3129 (92/294)
    Class 20: 0.9167 (341/372)
    Class 21: 0.0621 (19/306)
    Class 22: 0.0723 (23/318)
    Class 23: 0.3500 (105/300)
    Class 24: 0.4599 (149/324)
    Class 25: 0.5345 (186/348)
    Class 26: 0.3278 (118/360)
    Class 27: 0.0243 (7/288)
    Class 28: 0.3366 (103/306)
    Class 29: 0.5247 (170/324)
    Class 30: 0.47

Epoch 5: 100%|██████████| 100/100 [00:40<00:00,  2.48it/s]



Epoch 5 Summary:
  Loss: 2.7207, Accuracy: 0.3509
  Per-class accuracy:
    Class 00: 0.2619 (77/294)
    Class 01: 0.4712 (147/312)
    Class 02: 0.4333 (117/270)
    Class 03: 0.3482 (117/336)
    Class 04: 0.7305 (206/282)
    Class 05: 0.2340 (73/312)
    Class 06: 0.2345 (83/354)
    Class 07: 0.7049 (203/288)
    Class 08: 0.1190 (35/294)
    Class 09: 0.0340 (11/324)
    Class 10: 0.3878 (114/294)
    Class 11: 0.3302 (105/318)
    Class 12: 0.4259 (138/324)
    Class 13: 0.6122 (180/294)
    Class 14: 0.3814 (119/312)
    Class 15: 0.0747 (26/348)
    Class 16: 0.3523 (93/264)
    Class 17: 0.1964 (66/336)
    Class 18: 0.1273 (42/330)
    Class 19: 0.4253 (148/348)
    Class 20: 0.9113 (257/282)
    Class 21: 0.0850 (25/294)
    Class 22: 0.0988 (32/324)
    Class 23: 0.5278 (171/324)
    Class 24: 0.5804 (195/336)
    Class 25: 0.3962 (126/318)
    Class 26: 0.3227 (91/282)
    Class 27: 0.0341 (9/264)
    Class 28: 0.4434 (141/318)
    Class 29: 0.6069 (193/318)
    Class 3

Epoch 6:  53%|█████▎    | 53/100 [00:23<00:20,  2.24it/s]


KeyboardInterrupt: 