## Module install
Let's first install required python modules.

In [7]:
###########################
##   IMGPROC MODULES
###########################
%pip install pillow

###########################
##   ML MODULES
###########################
%pip install torch torchvision scikit-learn

###########################
##   OTHER MODULES
###########################
%pip install gdown # gDrive
%pip install matplotlib

###########################
##   LOGGING
###########################
import logging
import logging.config

# Create the Logger
logging.basicConfig(format="%(asctime)-15s %(levelname)s - %(message)s",datefmt='%Y-%m-%d %H:%M:%S')
logger= logging.getLogger("gmnist-classifier")
logger.setLevel(logging.INFO)

Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.


Note: you may need to restart the kernel to use updated packages.


### Project folders

In [9]:
import os
from pathlib import Path

def create_dir(dirname):
    logger.info("Creating directory %s ..." % (dirname))
    path = Path(dirname)
    path.mkdir(parents=True, exist_ok=True)
    logger.info("Run directory %s created successfully ..." % (dirname))

# - Set project directory
topdir= os.getcwd()
logger.info("topdir=%s" % (topdir))

############################
##   DATASET URL
############################
dataset_name= "galaxy_mnist-dataset"
dataset_dir= os.path.join(topdir, dataset_name)
dataset_filename= 'galaxy_mnist-dataset.tar.gz'
dataset_url= 'https://drive.google.com/uc?export=download&id=1OprJ_NQIFyQSRWqjGLFQsAMumHvJ-tMB'
filename_train= os.path.join(dataset_dir, "train/1chan/datalist_train.json")
filename_test= os.path.join(dataset_dir, "test/1chan/datalist_test.json")
filename_train_3chan= os.path.join(dataset_dir, "train/3chan/datalist_train.json")
filename_test_3chan= os.path.join(dataset_dir, "test/3chan/datalist_test.json")

2025-03-06 12:23:37 INFO - topdir=/home/riggi/Analysis/MLProjects/usc8-ai-workshop


## Dataset
The dataset for this tutorial contains ... ... [ADD DATASET DESCRIPTION]

### Dataset Download
We download the dataset from GoogleDrive and unzip it in the main folder.

In [12]:
import os
import gdown
import tarfile

#################################
##      DOWNLOAD FILES
#################################
def download_files_from_gdrive(url, outfile, force=False):
  """ Download model file from gDrive """

  if force or not os.path.isfile(outfile):
    gdown.download(url, outfile, quiet=False)
   
def untar_file(filename):
  """ Unzip file """
  
  fp= tarfile.open(filename)
  fp.extractall('.')
  fp.close()  

# - Enter top directory
os.chdir(topdir)

# - Download dataset
logger.info("Downloading file from url %s ..." % (dataset_url))
download_files_from_gdrive(dataset_url, dataset_filename, force=True)
logger.info("DONE!")

# - Untar dataset
logger.info("Unzipping dataset file %s ..." % (dataset_filename))
untar_file(dataset_filename)
logger.info("DONE")

2025-03-06 13:01:56 INFO - Downloading file from url https://drive.google.com/uc?export=download&id=1OprJ_NQIFyQSRWqjGLFQsAMumHvJ-tMB ...
Downloading...
From (original): https://drive.google.com/uc?export=download&id=1OprJ_NQIFyQSRWqjGLFQsAMumHvJ-tMB
From (redirected): https://drive.google.com/uc?export=download&id=1OprJ_NQIFyQSRWqjGLFQsAMumHvJ-tMB&confirm=t&uuid=7d64857c-68d1-4d75-b07e-7773dd0b0259
To: /home/riggi/Analysis/MLProjects/usc8-ai-workshop/galaxy_mnist-dataset.tar.gz
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1.33G/1.33G [01:40<00:00, 13.2MB/s]
2025-03-06 13:03:41 INFO - DONE!
2025-03-06 13:03:41 INFO - Unzipping dataset file galaxy_mnist-dataset.tar.gz ...
2025-03-06 13:03:46 INFO - DONE


### Create PyTorch Dataset
Let's create a custom pytorch dataset using base VisionDataset. 

In [22]:
from torch.utils.data import Dataset
from torchvision.datasets.vision import VisionDataset
import json
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Union


class GMNISTDataset(Dataset):
  """ Galaxy MNIST dataset """

  def __init__(
      self, 
      metadata_file: Union[str, Path], 
      transform: Optional[Callable] = None,
      target_transform: Optional[Callable] = None,
  ):
    """
      Arguments:
        metadata_file (string): Path to the json file with annotations.
        transform (callable, optional): Optional transform to be applied on a sample.
    """
    
    logger.info("Reading dataset metadata from file %s ..." % (metadata_file))
    self.__read_metadata(metadata_file)
    self.transform = transform
    self.target_transform = target_transform

  def __read_metadata(self, filename):
    """ Read json metadata """
    
    f= open(filename, "r")
    self.datalist= json.load(f)["data"]
  
  def __len__(self):
    """ Return size of dataset """    
    return len(self.datalist)

  def __getitem__(self, idx):
    """ Return dataset item """
    
    # - Read image path & class id
    img_path= self.datalist[idx]['filepaths'][0]
    target= self.datalist[idx]['id'] # class id
    
    # Read PIL image as RGB
    img = Image.open(img_path, mode='RGB')

    # - Transform image?
    if self.transform is not None:
      img = self.transform(img)

    # - Transform target?
    if self.target_transform is not None:
      target = self.target_transform(target)

    return img, target


###############################
##     LOAD DATASETS
###############################
# - Read traincv dataset
logger.info("Read train-cv dataset from file %s ..." % (filename_train_3chan))
dataset_traincv= GMNISTDataset(
  metadata_file=filename_train_3chan
)
logger.info("DONE!")

# - Read test dataset
logger.info("Read test dataset from file %s ..." % (filename_test_3chan))
dataset_test= GMNISTDataset(
  metadata_file=filename_test_3chan
)
logger.info("DONE!")


2025-03-06 13:28:15 INFO - Read train-cv dataset from file /home/riggi/Analysis/MLProjects/usc8-ai-workshop/galaxy_mnist-dataset/train/3chan/datalist_train.json ...
2025-03-06 13:28:15 INFO - Reading dataset metadata from file /home/riggi/Analysis/MLProjects/usc8-ai-workshop/galaxy_mnist-dataset/train/3chan/datalist_train.json ...
2025-03-06 13:28:15 INFO - DONE!
2025-03-06 13:28:15 INFO - Read test dataset from file /home/riggi/Analysis/MLProjects/usc8-ai-workshop/galaxy_mnist-dataset/test/3chan/datalist_test.json ...
2025-03-06 13:28:15 INFO - Reading dataset metadata from file /home/riggi/Analysis/MLProjects/usc8-ai-workshop/galaxy_mnist-dataset/test/3chan/datalist_test.json ...
2025-03-06 13:28:15 INFO - DONE!
