In [1]:
import os
import glob
import random
import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms as transforms
import scipy.io as scp
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import skimage.io as skio
from torch.utils.data import Dataset, DataLoader 
from torch.optim.lr_scheduler import StepLR
import torchvision
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts


In [2]:
train_set = torchvision.datasets.Flowers102(root="data", download=True)
test_set = torchvision.datasets.Flowers102(root="data", download=True, split="test")
val_set = torchvision.datasets.Flowers102(root="data", download=True, split="val")

Downloading https://thor.robots.ox.ac.uk/datasets/flowers-102/102flowers.tgz to data\flowers-102\102flowers.tgz


100%|██████████| 344862509/344862509 [00:07<00:00, 45634971.67it/s]


Extracting data\flowers-102\102flowers.tgz to data\flowers-102
Downloading https://thor.robots.ox.ac.uk/datasets/flowers-102/imagelabels.mat to data\flowers-102\imagelabels.mat


100%|██████████| 502/502 [00:00<00:00, 501438.58it/s]


Downloading https://thor.robots.ox.ac.uk/datasets/flowers-102/setid.mat to data\flowers-102\setid.mat


100%|██████████| 14989/14989 [00:00<00:00, 7494149.80it/s]


In [9]:
dir(train_set)
from collections import Counter
print(train_set._labels)
Counter(train_set._labels)

[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 27, 27, 27, 27, 27,

Counter({0: 10,
         1: 10,
         2: 10,
         3: 10,
         4: 10,
         5: 10,
         6: 10,
         7: 10,
         8: 10,
         9: 10,
         10: 10,
         11: 10,
         12: 10,
         13: 10,
         14: 10,
         15: 10,
         20: 10,
         16: 10,
         17: 10,
         18: 10,
         19: 10,
         21: 10,
         22: 10,
         23: 10,
         24: 10,
         25: 10,
         26: 10,
         27: 10,
         28: 10,
         29: 10,
         30: 10,
         31: 10,
         32: 10,
         33: 10,
         34: 10,
         35: 10,
         36: 10,
         37: 10,
         38: 10,
         39: 10,
         40: 10,
         41: 10,
         42: 10,
         43: 10,
         44: 10,
         45: 10,
         46: 10,
         47: 10,
         48: 10,
         49: 10,
         50: 10,
         51: 10,
         52: 10,
         53: 10,
         54: 10,
         55: 10,
         56: 10,
         57: 10,
         58: 10,
       