In [1]:
import sys, os 
sys.path.append(os.path.abspath('C:/Users/Yazeed/Desktop/workspace/flexaibuild'))

In [2]:
import shutil
from pathlib import Path

import torch
from torch.utils.data import DataLoader
from torchvision.transforms import v2
from torchvision.datasets import ImageFolder

from flexai import Learner, ActivationStatsManger # type: ignore
from utils import filetools
import config

In [3]:
url_dataset = "https://www.kaggle.com/api/v1/datasets/download/cashbowman/ai-generated-images-vs-real-images?datasetVersionNumber=1"
path_dataset = Path('data')
if not path_dataset.exists():
    path_comp = filetools.download_file(url_dataset, path_dataset)
    filetools.uncompress_and_remove(path_comp)
    folders = ['AiArtData', 'RealArt']
    labels = ['AI', 'Real']
    for folder, label in zip(folders, labels):
        shutil.move(path_dataset / folder / folder, path_dataset / label)
        shutil.rmtree(path_dataset/ folder)

In [4]:
root_ds = ImageFolder(
    path_dataset,
    transform= v2.Compose([
        v2.ToTensor(),
        v2.Resize((224,224), antialias=True),
    ])
)

generator = torch.Generator().manual_seed(42)
train_ds, valid_ds, test_ds = torch.utils.data.random_split(root_ds, [0.8, 0.1, 0.1], generator)

NUM_WORKERS = os.cpu_count()

dataloaders = {
    'train': DataLoader(train_ds, batch_size=32, shuffle=True, num_workers=NUM_WORKERS),
    'valid': DataLoader(valid_ds, batch_size=32, shuffle=False, num_workers=NUM_WORKERS),
    'test': DataLoader(test_ds, batch_size=32, shuffle=False, num_workers=NUM_WORKERS),
}

device = 'cuda' if torch.cuda.is_available() else 'cpu'

print(f'train: {len(train_ds)}, valid: {len(test_ds)}, test: {len(valid_ds)}')

train: 779, valid: 97, test: 97




In [None]:
root_ds.class_to_idx

In [None]:
from flexai.vision.utils import image_grid # type: ignore

images, labels = next(iter(dataloaders['train']))
image_grid(images, labels, {v:k for k,v in root_ds.class_to_idx.items()}, figsize=(14,8))

In [None]:
import flexai.vision.transforms as transforms_fai # type: ignore


tfs = {
    'train': v2.Compose([
        v2.Normalize(config.INPUT_MEAN, config.INPUT_STD),
        v2.AugMix(),
        # v2.RandomChannelPermutation(),
        # v2.RandomChoice([
        #     v2.CenterCrop(200),
        #     v2.CenterCrop(180),
        #     v2.CenterCrop(160),
        # ]
        # ),
        # v2.RandomGrayscale(),
        # v2.RandomChoice([
        #     v2.RandomErasing(),
        #     transforms_fai.RandomNoise(scales=[0.4, 0.4]),
        # ]
        # ),
    ]),
    'valid': v2.Normalize(config.INPUT_MEAN, config.INPUT_STD),
}

In [None]:
from torch.optim import SGD
from torch.nn import CrossEntropyLoss
from torcheval.metrics.functional import multiclass_accuracy
from flexai.callbacks import LRFinderCB, TransformCB, LoggerCB, MetricPlotterCB, ForwardHookCB, LSUVCB # type: ignore

model = config.MODEL
for module in model.modules():
    freeze = False if (type(module).__name__ == 'BatchNorm2d') else True
    module.requires_grad_(freeze)
  

model.to(device)
metrics = {'loss': CrossEntropyLoss(), 'accuracy': multiclass_accuracy}
optimizer = SGD(model.parameters(), lr=1e-3, weight_decay=0.1)

In [None]:
from torch.optim.lr_scheduler import ExponentialLR
lr_scheduler = ExponentialLR(optimizer, gamma=1.33)
Learner(
    model,
    dataloaders,
    optimizer,
    metrics,
    callbacks=[
        LRFinderCB(lr_scheduler, start_lr=1e-7, max_lr=1, break_f=5),
        TransformCB(transform=tfs['train'], phase='train'),
        TransformCB(transform=tfs['valid'], phase='valid'),
    ],
    device=device,
).fit(3)

In [None]:
for g in optimizer.param_groups:
    g['lr'] = 5e-3

In [None]:
manager = ActivationStatsManger(model, ['ResLayer', 'DenseLayer'])
learner = Learner(
    model,
    dataloaders,
    optimizer,
    metrics,
    callbacks=[
        TransformCB(transform=tfs['train'], phase='train'),
        TransformCB(transform=tfs['valid'], phase='valid'),
        LoggerCB(),
        MetricPlotterCB(),
        ForwardHookCB(manager.register_stats),
    ],
    device=device
)

In [None]:
learner.fit(5)

In [None]:
for m in list(model.children())[-3:]:
    m.requires_grad_(True)

In [None]:
learner.fit(5)

In [None]:
manager.mean_std()
manager.color_dim()
manager.dead_chart()

In [None]:
model.eval()
acc = []
for X, y in dataloaders['test']:
    y_pred = model(X)
    a = metrics['accuracy'](y_pred, y)
    acc.append(a.item())
print(sum(acc)/len(acc))

In [None]:
learner.save_checkpoint('trained-model/checkpoint_my.pt')

In [None]:
torch.save(model.state_dict(), config.WEIGTHS_PATH)