In [2]:
import torch
from pathlib import Path
import torch


In [3]:
# Setup device-agnostic code
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.set_default_device(device)
device

'cuda'

In [4]:
from modular.data_setup import  setup_fsl_train_test_dataloaders
data_path = Path("data/")

train_loader, test_loader, fsl_dataset,(train_n_way, test_n_way) = setup_fsl_train_test_dataloaders(data_path,train_episodes=400,test_episodes=300,n_way=5,k_shot=3,q_query=9)

len(fsl_dataset), len(fsl_dataset.classes), fsl_dataset.class_to_idx, fsl_dataset[0][0].shape, fsl_dataset[0][1] , train_n_way,test_n_way

Dataset class distribution: {0: 38, 1: 31, 2: 42, 3: 31, 4: 38, 5: 34, 6: 37, 7: 35, 8: 28, 9: 25, 10: 30}
Found 11 valid classes with at least 12 samples each
Automatically split classes:
  Training classes (6): Rotor-0, A&C&B10, A&B50, A&C30, Noload, A30
  Testing classes (5): A10, Fan, A&C&B30, A&C10, A50
Training class distribution: {0: 38, 1: 31, 4: 38, 6: 37, 9: 25, 10: 30}
Testing class distribution: {2: 42, 3: 31, 5: 34, 7: 35, 8: 28}
Using n_way=5 for training, n_way=5 for testing
Valid classes for sampling: 6
Classes with counts: [(0, 38), (1, 31), (4, 38), (6, 37), (9, 25), (10, 30)]
Valid classes for sampling: 5
Classes with counts: [(2, 42), (3, 31), (5, 34), (7, 35), (8, 28)]


(369,
 11,
 {'A&B50': 0,
  'A&C&B10': 1,
  'A&C&B30': 2,
  'A&C10': 3,
  'A&C30': 4,
  'A10': 5,
  'A30': 6,
  'A50': 7,
  'Fan': 8,
  'Noload': 9,
  'Rotor-0': 10},
 torch.Size([3, 128, 128]),
 0,
 5,
 5)

In [5]:
# from modular.utils import display_random_images
# display_random_images(fsl_dataset, classes=fsl_dataset.classes, n=10, display_shape=True, seed=42)
# #             

In [6]:
from modular.model_builder  import PrototypicalNetwork
model = PrototypicalNetwork(embedding_dim=128)

In [7]:
from torchinfo import summary
# Print a summary using torchinfo (uncomment for actual output)
summary(model=model, 
        input_size=(32, 3, 128, 128), # make sure this is "input_size", not "input_shape"
        # col_names=["input_size"], # uncomment for smaller output
        col_names=["input_size", "output_size", "num_params", "trainable"],
        col_width=20,
        row_settings=["var_names"]
) 

Layer (type (var_name))                       Input Shape          Output Shape         Param #              Trainable
PrototypicalNetwork (PrototypicalNetwork)     [32, 3, 128, 128]    [32, 128]            --                   Partial
├─Sequential (encoder)                        [32, 3, 128, 128]    [32, 512, 1, 1]      --                   False
│    └─Conv2d (0)                             [32, 3, 128, 128]    [32, 64, 64, 64]     (9,408)              False
│    └─BatchNorm2d (1)                        [32, 64, 64, 64]     [32, 64, 64, 64]     (128)                False
│    └─ReLU (2)                               [32, 64, 64, 64]     [32, 64, 64, 64]     --                   --
│    └─MaxPool2d (3)                          [32, 64, 64, 64]     [32, 64, 32, 32]     --                   --
│    └─Sequential (4)                         [32, 64, 32, 32]     [32, 64, 32, 32]     --                   False
│    │    └─BasicBlock (0)                    [32, 64, 32, 32]     [32, 64, 32, 

In [8]:
# Validate dataset and dataloaders
for batch in train_loader:
    support_images, support_labels, query_images, query_labels = batch
    print(f"Support set: {support_images.shape}, {support_labels.shape}")
    print(f"Query set: {query_images.shape}, {query_labels.shape}")
    break



Support set: torch.Size([15, 3, 128, 128]), torch.Size([15])
Query set: torch.Size([45, 3, 128, 128]), torch.Size([45])


In [9]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)


In [10]:
# from modular.engine import prototypical_loss
# # Fetch a single batch from the train_loader
# for batch in train_loader:
#     support_images, support_labels, query_images, query_labels = batch
#     break  # Only take the first batch

# # Move tensors to the same device as the model
# support_images = support_images.to(device)
# support_labels = support_labels.to(device)
# query_images = query_images.to(device)
# query_labels = query_labels.to(device)

# # Pass the support and query images through the model
# support_embeddings = model(support_images)
# query_embeddings = model(query_images)

# # Calculate the loss and accuracy
# loss, accuracy = prototypical_loss(
#     query_embeddings=query_embeddings,
#     support_embeddings=support_embeddings,
#     query_labels=query_labels,
#     support_labels=support_labels,
#     n_way=5  # Adjust based on your setup
# )

# # Print the results
# print(f"Loss: {loss.item():.4f}")
# print(f"Accuracy: {accuracy:.4f}")

In [11]:
from collections import Counter

class_counts = Counter(fsl_dataset.labels)
print("Class distribution:", class_counts)

Class distribution: Counter({2: 42, 0: 38, 4: 38, 6: 37, 7: 35, 5: 34, 1: 31, 3: 31, 10: 30, 8: 28, 9: 25})


In [12]:
from collections import Counter

train_class_counts = Counter([fsl_dataset.labels[i] for i in train_loader.batch_sampler.labels])
print("Training class distribution:", train_class_counts)

Training class distribution: Counter({0: 199})


In [13]:
from modular.engine import train_prototype_network

results = train_prototype_network(
    model=model,
    train_dataloader=train_loader,
    test_dataloader=test_loader,
    optimizer=optimizer,
    n_way=5,  # Adjust based on your setup
    epochs=10,  # Adjust as needed
    device=device,
)


Epoch 1/10


Training:   0%|          | 0/400 [00:00<?, ?it/s]

Testing:   0%|          | 0/300 [00:00<?, ?it/s]

Train Loss: 0.6093 | Train Acc: 0.9132 | Test Loss: 0.7694 | Test Acc: 0.9327

Epoch 2/10


Training:   0%|          | 0/400 [00:00<?, ?it/s]

Testing:   0%|          | 0/300 [00:00<?, ?it/s]

Train Loss: 0.5670 | Train Acc: 0.9355 | Test Loss: 0.7335 | Test Acc: 0.9353

Epoch 3/10


Training:   0%|          | 0/400 [00:00<?, ?it/s]

Testing:   0%|          | 0/300 [00:00<?, ?it/s]

Train Loss: 0.5594 | Train Acc: 0.9401 | Test Loss: 0.7369 | Test Acc: 0.9286

Epoch 4/10


Training:   0%|          | 0/400 [00:00<?, ?it/s]

Testing:   0%|          | 0/300 [00:00<?, ?it/s]

Train Loss: 0.5465 | Train Acc: 0.9444 | Test Loss: 0.7355 | Test Acc: 0.9365

Epoch 5/10


Training:   0%|          | 0/400 [00:00<?, ?it/s]

Testing:   0%|          | 0/300 [00:00<?, ?it/s]

Train Loss: 0.5483 | Train Acc: 0.9465 | Test Loss: 0.7304 | Test Acc: 0.9387

Epoch 6/10


Training:   0%|          | 0/400 [00:00<?, ?it/s]

Testing:   0%|          | 0/300 [00:00<?, ?it/s]

Train Loss: 0.5376 | Train Acc: 0.9511 | Test Loss: 0.7381 | Test Acc: 0.9380

Epoch 7/10


Training:   0%|          | 0/400 [00:00<?, ?it/s]

Testing:   0%|          | 0/300 [00:00<?, ?it/s]

Train Loss: 0.5350 | Train Acc: 0.9517 | Test Loss: 0.7164 | Test Acc: 0.9355

Epoch 8/10


Training:   0%|          | 0/400 [00:00<?, ?it/s]

Testing:   0%|          | 0/300 [00:00<?, ?it/s]

Train Loss: 0.5349 | Train Acc: 0.9506 | Test Loss: 0.7212 | Test Acc: 0.9330

Epoch 9/10


Training:   0%|          | 0/400 [00:00<?, ?it/s]

Testing:   0%|          | 0/300 [00:00<?, ?it/s]

Train Loss: 0.5301 | Train Acc: 0.9518 | Test Loss: 0.7276 | Test Acc: 0.9274

Epoch 10/10


Training:   0%|          | 0/400 [00:00<?, ?it/s]

Testing:   0%|          | 0/300 [00:00<?, ?it/s]

Train Loss: 0.5307 | Train Acc: 0.9507 | Test Loss: 0.7032 | Test Acc: 0.9365
New best model saved with test accuracy: 0.9365


In [14]:
results

{'train_loss': [0.6093176437914372,
  0.5670165229588747,
  0.5594360480457544,
  0.5465258422493935,
  0.5483134586364031,
  0.5375832176953554,
  0.5349967773258686,
  0.5348615871369838,
  0.5301328288763761,
  0.5306673165410757],
 'train_acc': [0.9131666718423367,
  0.9355000038444996,
  0.940111114680767,
  0.9444444477558136,
  0.9465000031888485,
  0.9510555584728718,
  0.951722225099802,
  0.9505555585026741,
  0.9518333362042903,
  0.9506666696071625],
 'test_loss': [0.7694217917323113,
  0.7335426843166352,
  0.7369341949621836,
  0.7354630064964295,
  0.7304144008954366,
  0.7381244313716888,
  0.71643303891023,
  0.721181761821111,
  0.727556278804938,
  0.7032174928983053],
 'test_acc': [0.932740744749705,
  0.9352592631181081,
  0.9285925968488058,
  0.9365185223023097,
  0.9386666703224182,
  0.938000003695488,
  0.9354814853270849,
  0.9330370410283406,
  0.9274074117342631,
  0.9365185223023097]}