In [None]:
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset
from train.train import Trainer

from methods.CMAL.builder_resnet import Network_Wrapper
from torch.utils.model_zoo import load_url as load_state_dict_from_url
from torch.optim.lr_scheduler import StepLR

In [None]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

In [None]:
import torchvision.transforms as transforms
from utility.data.preprocessing import Autoaugment_preprocess
transform = Autoaugment_preprocess(channels=3, resize_dim=(260,260), crop_dim=(224,224)) 

class TestDataset(Dataset):
    def __init__(self, test_dir, transform=None):
        self.test_dir = test_dir
        self.image_files = [f for f in os.listdir(test_dir) if os.path.isfile(os.path.join(test_dir, f))]
        self.transform = transform

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_name = os.path.join(self.test_dir, self.image_files[idx])
        image = Image.open(img_name)
        if self.transform:
            image = self.transform(image)
        return image, self.image_files[idx]

train_dir = "./images/competition_data/train"
test_dir = "./images/competition_data/test"

#train_data = datasets.ImageFolder(root = train_dir, transform = transform.transform)
#num_train = int(len(train_data) * 0.8)
#num_val = len(train_data) - num_train
#trainset, valset = random_split(train_data, [num_train, num_val])

#testset = TestDataset(test_dir, transform = transform.transform)

trainset = torchvision.datasets.FGVCAircraft(root=root, split = 'train',
                                             download=True, transform=transform.transform)

valset = torchvision.datasets.FGVCAircraft(root=root, split = 'val',
                                            download=True, transform=transform.transform)

testset = torchvision.datasets.FGVCAircraft(root=root, split = 'test',
                                            download=True, transform=transform.transform)

trainloader = torch.utils.data.DataLoader(trainset, batch_size=16,
                                          shuffle=True, num_workers=2)
valloader = torch.utils.data.DataLoader(valset, batch_size=16,
                                        shuffle=False, num_workers=2)
testloader = torch.utils.data.DataLoader(testset, batch_size=16,
                                         shuffle=False, num_workers=2)

data_loaders = {
    "train_loader": trainloader,
    "val_loader": valloader,
    "test_loader": valloader
}          


In [4]:
model = torchvision.models.resnet50()
state_dict = load_state_dict_from_url('https://download.pytorch.org/models/resnet50-19c8e357.pth')
model.load_state_dict(state_dict)

net_layers = list(model.children())
net_layers = net_layers[0:8]

model = Network_Wrapper(net_layers, 100)

In [5]:
#model = torchvision.models.resnext50_32x4d(weights='DEFAULT')

#net_layers = list(model.children())
#net_layers = net_layers[0:8]

#model = Network_Wrapper(net_layers, 102)

In [6]:
optimizer = optim.SGD([
    {'params': model.classifier_concat.parameters(), 'lr': 0.002},
    {'params': model.conv_block1.parameters(), 'lr': 0.002},
    {'params': model.classifier1.parameters(), 'lr': 0.002},
    {'params': model.conv_block2.parameters(), 'lr': 0.002},
    {'params': model.classifier2.parameters(), 'lr': 0.002},
    {'params': model.conv_block3.parameters(), 'lr': 0.002},
    {'params': model.classifier3.parameters(), 'lr': 0.002},
    {'params': model.Features.parameters(), 'lr': 0.0002}

],
    momentum=0.9, weight_decay=5e-4)

In [7]:
CELoss = nn.CrossEntropyLoss()

scheduler = StepLR(optimizer, step_size=1, gamma=0.01)

In [None]:
model.to(device)

In [11]:
training = Trainer(
    data_loaders=data_loaders, 
    dataset_name = "Mammalia",
    model=model,
    optimizer=optimizer,
    loss_fn=CELoss,
    device=device,
    seed=42,
    exp_path="/home/zazza/Documents/ML/Fine-Grained-Visual-Classification/data", # change this to you where you want to save the experiment
    exp_name="ResNet50_CMAL_Mammalia", # name of the experiment
    use_early_stopping=True)

In [None]:
training.main(epochs=5, log_interval = 20)