In [1]:
import torch
import argparse
from torch.utils.data import DataLoader
from litdata import litdata
import torchvision . transforms as T
from torch import nn
import timm

from useful_functions import seed_everything, train, evaluate, train_epochs
from model import Stacked_ViT, Big_model

import warnings
warnings.filterwarnings('ignore')

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description = "Train model")
    parser.add_argument("--batch_size", default = 256)
    parser.add_argument("--lr", default = 0.000001)
    parser.add_argument("--weight_decay", default = 0.2)
    parser.add_argument("--clip", default = 100)
    parser.add_argument("--epochs", default = 100)
    parser.add_argument("--seed", default = 5310)
    parser.add_argument("--save", default = True)
    parser.add_argument("--fox", default = True)
    parser.add_argument("--train", default = True)
    parser.add_argument("--test_model", default = False)
    parser.add_argument("--imagenet", default = True)
    parser.add_argument("--num_pred", default = 12)
    args = parser.parse_args()


    seed_everything(args.seed)

    print("Loading data ...") 
    print(f"batch size = {args.batch_size}") 
    if args.fox:
        data_path = "/fp/projects01/ec232/data/"

    else:
        #katinka
        #data_path = "C:\Users\laila\Documents\Studium\5.Semester\AdvancedDeepLearning\g05-p3"
        #coco
        data_path = "../../../../../../Desktop/"
        #amir
        #data_path = "/Users/amir/Documents/UiO/IN5310 – Advanced Deep Learning for Image Analysis/project3/"

    in_mean = [0.485, 0.456, 0.406]
    in_std = [0.229, 0.224, 0.225]

    postprocess = (
        T.Compose([
            T.ToTensor(),
            T.Resize((224, 224), antialias=True),
            T.Lambda(lambda x: x.repeat(3, 1, 1) if x.size(0) == 1 else x),
            T.Normalize(in_mean, in_std),
        ]),
        nn.Identity(), 
    )

    if args.imagenet:
        dataset = "IN1k"
        num_classes = 1000
    else:
        dataset = "Caltech256"
        num_classes = 257

    print(f"\ndataset = {dataset}\nnumber of classes = {num_classes}\nnumber of predictions = {args.num_pred}\n")
    traindata = litdata.LITDataset(dataset, data_path, override_extensions = ["jpg", "cls"] ).map_tuple(*postprocess)
    valdata = litdata.LITDataset(dataset, data_path, train=False, override_extensions = ["jpg", "cls"]).map_tuple(*postprocess)

    train_loader = DataLoader(traindata, shuffle=True, batch_size = args.batch_size)
    val_loader = DataLoader(valdata, shuffle=False, batch_size = args.batch_size)
        
    print("Loading data done") 

    if args.train:
        print("Loading model ...")
        args.num_classes = num_classes

        tiny = 'vit_tiny_patch16_224'
        base = "vit_base_patch16_224"

        pretrained_model = tiny

        print(f"\npretrained model = {pretrained_model}\n")

        model = timm.create_model(pretrained_model, pretrained=True, num_classes = num_classes).to(device)
        loss = nn.CrossEntropyLoss()
        optimizer = torch.optim.AdamW(model.parameters(), lr = args.lr, weight_decay = args.weight_decay)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs * len(train_loader))

        stacked_vit = Stacked_ViT(model, num_classes, args.num_pred)
        stacked_vit = torch.nn.DataParallel(stacked_vit)

        print("Loading model done")

        print("Training model ...")
        if args.imagenet:
            evaluate(model, val_loader, args.num_pred)
        else:                                                                                                               
            train_acc, val_acc = train_epochs(args, model, loss, optimizer, scheduler, train_loader, val_loader, args.num_pred)
        print("Training model done")

usage: ipykernel_launcher.py [-h] [--batch_size BATCH_SIZE] [--lr LR]
                             [--weight_decay WEIGHT_DECAY] [--clip CLIP]
                             [--epochs EPOCHS] [--seed SEED] [--save SAVE]
                             [--fox FOX] [--train TRAIN]
                             [--test_model TEST_MODEL] [--imagenet IMAGENET]
                             [--num_pred NUM_PRED]
ipykernel_launcher.py: error: unrecognized arguments: -f /Users/corneliusbencsik/Library/Jupyter/runtime/kernel-ac43cea8-faac-4292-a99e-f854bb1b4e69.json


SystemExit: 2