In [10]:
# !git clone https://github.com/soumitrapy/templatecv.git project
# %cd project

In [11]:
# !wget https://storage.googleapis.com/wandb_datasets/nature_12K.zip
# !unzip nature_12K.zip

In [1]:
import yaml
import os
from datetime import datetime
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt

import torch
from torch.utils.data import Dataset,DataLoader, random_split

from torchvision import transforms
#from torchvision.io import read_image

In [2]:
#config = yaml.safe_load(open("config/default.yaml"))
config = yaml.safe_load(open("config/pretrained.yaml"))
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cpu


In [3]:
class CustomDataset(Dataset):
    def __init__(self, path, class_names = None, transform=None, target_transform=None):
        super().__init__()
        self.path = path
        self.class_names = class_names
        self.transform = transform
        self.target_transform = target_transform

        if self.class_names is None:
            self.class_names =[x for x in os.listdir(path) if os.path.isdir(os.path.join(path, x))]
        
        self.images = []
        self.labels = []

        for i, cls in enumerate(self.class_names):
            img_dir = os.path.join(self.path, cls)
            for f in os.listdir(img_dir):
                if f.endswith(('.jpg','.png')):
                    self.images.append(os.path.join(img_dir, f))
                    self.labels.append(i)


    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        image = Image.open(self.images[index]).convert("RGB")
        label = self.labels[index]
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image, label


### DataLoader Creation

In [4]:
# transform = transforms.Compose([
#         transforms.Resize(tuple(config['dataset']['img_size'])),
#         transforms.ToTensor()
#     ])
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}
target_transform = None
train_path = config['dataset']['path']+"train"
test_path = config['dataset']['path']+"val"

trainds = CustomDataset(path=train_path,
                        class_names=config['dataset']['class_names'],
                        transform=data_transforms['train'],
                        target_transform=target_transform
                        )

val_split = int(0.2 * len(trainds))
trainds, valds = random_split(trainds, [len(trainds) - val_split, val_split])
testds = CustomDataset(path=test_path,
                        class_names=config['dataset']['class_names'],
                        transform=data_transforms['val'],
                        target_transform=target_transform
                        )
traindl = DataLoader(trainds, batch_size=config['dataset']['batch_size'])
valdl = DataLoader(valds, batch_size=config['dataset']['batch_size'])
testdl = DataLoader(testds, batch_size=config['dataset']['batch_size'])
print(len(trainds), len(valds), len(testds))

8000 1999 2000


In [5]:
# # Data augmentation and normalization for training
# # Just normalization for validation
# data_transforms = {
#     'train': transforms.Compose([
#         transforms.RandomResizedCrop(224),
#         transforms.RandomHorizontalFlip(),
#         transforms.ToTensor(),
#         transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
#     ]),
#     'val': transforms.Compose([
#         transforms.Resize(256),
#         transforms.CenterCrop(224),
#         transforms.ToTensor(),
#         transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
#     ]),
# }

# data_dir = 'data/hymenoptera_data'
# image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
#                                           data_transforms[x])
#                   for x in ['train', 'val']}
# dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4,
#                                              shuffle=True, num_workers=4)
#               for x in ['train', 'val']}
# dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
# class_names = image_datasets['train'].classes

# # We want to be able to train our model on an `accelerator <https://pytorch.org/docs/stable/torch.html#accelerators>`__
# # such as CUDA, MPS, MTIA, or XPU. If the current accelerator is available, we will use it. Otherwise, we use the CPU.

# device = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else "cpu"
# print(f"Using {device} device")

### Model

In [None]:
from models.simplemodel import SmallCNN
model = SmallCNN(config['model'])

### Pretrained Model

In [6]:
from models.pretrained import get_model
model = get_model(config['model'])
for param in model.parameters():
    param.requires_grad = False
for param in model.fc.parameters():
    param.requires_grad = True

In [7]:
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr = 0.001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.01)

### Training Loop

In [None]:
from train import train, train_one_epoch, val_one_epoch
# train(model=model, traindl=traindl, optimizer=optimizer, loss_fn=loss_fn, config=config['train'], scheduler=None, valdl=valdl, device=device)
cfg = config['train']
model.to(device)
best_loss = float('inf')
os.makedirs('checkpoints', exist_ok=True)
for epoch in range(cfg['epochs']):
    model.train()
    train_loss, train_acc = train_one_epoch(model, traindl, optimizer, loss_fn, config['train'], epoch=epoch, device=device)
    if valdl and (epoch+1)%cfg['val_interval']==0:
        val_loss, val_acc = val_one_epoch(model, valdl, loss_fn, cfg, epoch=epoch, device=device)
        if val_loss<best_loss:
            best_loss = val_loss
            model_name = type(model).__name__+'_'+device.type+str(datetime.now())[:15]
            model_path = os.path.join('checkpoints', model_name)
            torch.save(model.state_dict(), model_path)
            
    if scheduler:
        scheduler.step()


### Prediction

In [None]:
from train import visualize_model
visualize_model(model, )

AttributeError: 'Subset' object has no attribute 'class_names'

torch.Size([5, 10])