In [1]:
import sys
sys.path.append('../src')

import os
from dotenv import load_dotenv

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models

from data.handler import load_and_transform_data, get_data_loader
from training.train_funcs import train_clean_model
from vizualization.tensors import imshow

In [2]:
load_dotenv()

True

In [3]:
if torch.cuda.is_available():
    device = "cuda"
elif torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cpu"

In [4]:
# Hugging face stores downloads at ~/.cache/huggingface/datasets by default 

dataset_name = 'cifar10'
batch_size = 32
cache_dir = os.getenv("CACHE_DIR")

In [5]:
train_dataset = load_and_transform_data(dataset_name, 'train', augment=False, download_dir=cache_dir)
test_dataset = load_and_transform_data(dataset_name, 'test', augment=False, download_dir=cache_dir)

In [6]:
train_dataloader = get_data_loader(train_dataset, batch_size, shuffle=True)
test_dataloader = get_data_loader(test_dataset, batch_size, shuffle=True)

In [7]:
model = models.resnet50(pretrained=False).eval().to(device)



In [8]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

In [9]:
trained_model = train_clean_model(model, criterion, optimizer, train_dataloader, device, num_epochs=5)

In [None]:
torch.save(trained_model.state_dict(), '../models/trained_model.pth')