Doing some basic data exploration on the dataset.

In [8]:
# get number of files in each directory of yoga_poses
import os

def count_files(directory):
  file_counts = {}
  for root, _, files in os.walk(directory):
    file_counts[root] = len(files)
  return file_counts

# Usage
yoga_poses_path = "yoga_poses"
file_counts = count_files(yoga_poses_path)

# Print the number of files in each directory
for directory, count in file_counts.items():
    name = directory.split("/")[-1]
    print(f"{name}:{count} files")


yoga_poses:0 files
test:0 files
shoudler_stand:9 files
traingle:10 files
cobra:232 files
dog:180 files
chair:168 files
no_pose:2 files
tree:192 files
warrior:218 files
train:0 files
shoudler_stand:50 files
traingle:45 files
cobra:400 files
dog:400 files
chair:400 files
no_pose:26 files
tree:418 files
warrior:400 files


We see that the dataset is imbalanced. We will have to use a script to scrape images and also use data augmentation to balance the dataset.

In [27]:
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split, Dataset

In [64]:
class CustomDataset(Dataset):
    def __init__(self, root, transform=None,alt_transform=None, augment_classes=None):
        self.dataset = datasets.ImageFolder(root=root)
        self.augment_classes = augment_classes

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, index):
        x, y = self.dataset[index]

        if self.augment_classes is not None and y in self.augment_classes:
            x = self.transform(x)
        else:
            x = self.alt_transform(x)
        return x, y

In [65]:
# train_dataset = datasets.ImageFolder(root="yoga_poses/train", transform=transform)
# test_dataset = datasets.ImageFolder(root="yoga_poses/test", transform=transform)
# train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
# test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=True)
transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
    transforms.RandomResizedCrop(32, scale=(0.8, 1.0)),
    transforms.ToTensor(),
])
alt_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
])



In [66]:
from torch.utils.data import DataLoader
augment_classes = [0,1,5]
train_dataset = CustomDataset(root="yoga_poses/train", transform=transform, alt_transform=alt_transform, augment_classes=augment_classes)
test_dataset = CustomDataset(root="yoga_poses/test", transform=transform, alt_transform=alt_transform, augment_classes=augment_classes)

In [67]:
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=True)

In [68]:
len(train_dataset), len(test_dataset)

(2188, 1011)