In [279]:
import pandas as pd
import torch
import os
from torch import nn as nn
import torchvision
from torch.utils.data import DataLoader, dataset, Dataset
from torchvision.io import read_image
from torchvision.datasets import ImageFolder    
from torchvision import transforms
from PIL import Image
import numpy as np
from transformers import MobileViTFeatureExtractor, MobileViTForImageClassification
from tqdm import tqdm

In [280]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [156]:
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]

In [311]:
transforms_train = transforms.Compose([
        transforms.Resize((288,288)),
        transforms.RandomCrop(256),
        transforms.AutoAugment(),
    transforms.ToTensor(),
    transforms.Normalize(mean, std)
    ])

transforms_test = transforms.Compose([
        transforms.Resize((288,288)),
        transforms.RandomCrop(256),
    transforms.Normalize(mean, std)
    ])

In [312]:
train_data_ = ImageFolder("/home/sahmaran/data/ImageNet/ILSVRC/Data/CLS-LOC/train", transform=  transforms_train);

In [313]:
train_data = DataLoader(train_data_, batch_size = 32, num_workers = 0, shuffle = True);
classes_dict = train_data_.class_to_idx

In [314]:
class test_data(Dataset):
    def __init__(self, csv_file = "/home/sahmaran/data/ImageNet/LOC_val_solution.csv",
                 root_dir = "/home/sahmaran/data/ImageNet/ILSVRC/Data/CLS-LOC/val",
                 classes_dict = classes_dict, transformations = None):
        super().__init__()

        self.root_dir = root_dir
        self.classes_dict = classes_dict
        ###
        file = pd.read_csv(csv_file)
        self.file_names = file.iloc[:,0]
        self.anotations = file.iloc[:,-1].apply(self.__split__)
        ###
        self.transformations = transformations
        ###
        
    def __len__(self):
        return len(self.anotations)
    def __getitem__(self, index):
        image_path = os.path.join(self.root_dir, self.file_names[index]+ ".JPEG")
        image = read_image(image_path, )
        image = image/255.0
        ###
        if image.shape[0] == 1:
            image = torch.cat((image, image, image), dim = 0)
        if self.transformations:
            image = self.transformations(image)
        ###
        return image, self.classes_dict[self.anotations[index]]
        
    def __split__(self, n):
        return n.split()[0]

In [315]:
test_data_ = DataLoader(test_data(transformations = transforms_test), batch_size = 32, num_workers = 0, );

In [303]:
feature_extractor = MobileViTFeatureExtractor.from_pretrained("apple/mobilevit-xx-small")
model = MobileViTForImageClassification.from_pretrained("apple/mobilevit-xx-small")

In [304]:
model.classifier = nn.Sequential(nn.Dropout(0.2), 
                                 nn.Linear(320, 1000, bias = True))
if torch.cuda.device_count() > 1:
  print("Let's use", torch.cuda.device_count(), "GPUs!")
  # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs
  model = nn.DataParallel(model)

model.to(device)


Let's use 2 GPUs!


DataParallel(
  (module): MobileViTForImageClassification(
    (mobilevit): MobileViTModel(
      (conv_stem): MobileViTConvLayer(
        (convolution): Conv2d(3, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (normalization): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (activation): SiLUActivation()
      )
      (encoder): MobileViTEncoder(
        (layer): ModuleList(
          (0): MobileViTMobileNetLayer(
            (layer): ModuleList(
              (0): MobileViTInvertedResidual(
                (expand_1x1): MobileViTConvLayer(
                  (convolution): Conv2d(16, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
                  (normalization): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
                  (activation): SiLUActivation()
                )
                (conv_3x3): MobileViTConvLayer(
                  (convolution): Conv2d(32, 32, kernel_size

In [316]:
optimizer = torch.optim.Adam(model.parameters(), lr = 0.0001)

In [317]:
"""Train Loop starts here"""
for i in range(25):
    K = []
    loss = []
    Loss = []
    bar = tqdm(train_data)
    for x,y in bar:
        x = x.to(device)
        y = y.to(device)
        optimizer.zero_grad()
        model.train()
        y_pred = model(x)["logits"]
        
        loss = nn.CrossEntropyLoss()(y_pred, y)
        loss.backward()
        optimizer.step()
        Loss.append(loss.item())
        bar.set_description(f"{loss.item()}")
    model.eval()
    with torch.no_grad():
        for x,y in test_data_:
            x = x.to(device)
            y = y.to(device)
            K.append(sum(torch.argmax(model(x)["logits"],-1) == y).item()/len(y))
    print(f"{i}th epoch passed, validation acc is {np.mean(K)}, train loss is {np.mean(Loss)}")

3.9134633541107178: 100%|███████████████| 40037/40037 [3:12:35<00:00,  3.46it/s]


0th epoch passed, validation acc is 0.3525271912987844, train loss is 4.52855388319099


3.4410736560821533: 100%|███████████████| 40037/40037 [3:10:48<00:00,  3.50it/s]


1th epoch passed, validation acc is 0.4599728087012156, train loss is 3.0142564069296682


2.2609126567840576: 100%|███████████████| 40037/40037 [3:10:42<00:00,  3.50it/s]


2th epoch passed, validation acc is 0.5080974088291746, train loss is 2.6269749096081436


2.0236284732818604: 100%|███████████████| 40037/40037 [3:10:42<00:00,  3.50it/s]


3th epoch passed, validation acc is 0.5235724568138196, train loss is 2.4217881180414307


2.8066599369049072: 100%|███████████████| 40037/40037 [3:10:34<00:00,  3.50it/s]


4th epoch passed, validation acc is 0.552643154190659, train loss is 2.286628150563588


3.6710119247436523: 100%|███████████████| 40037/40037 [3:12:29<00:00,  3.47it/s]


5th epoch passed, validation acc is 0.5634396992962252, train loss is 2.1867813916685375


1.6322369575500488: 100%|███████████████| 40037/40037 [3:10:29<00:00,  3.50it/s]


6th epoch passed, validation acc is 0.5750359884836852, train loss is 2.112310324328987


1.8335164785385132:  87%|█████████████  | 34702/40037 [2:45:04<25:22,  3.50it/s]


KeyboardInterrupt: 

In [322]:
torch.save(model.state_dict(), "/home/sahmaran/Dropbox/Machines_learning/vit/model_train/model.mod")

In [320]:
import os
os.getcwd()

'/home/sahmaran/Dropbox/Machines_learning/vit/model_train'

In [264]:
with torch.no_grad():
    for x,y in tqdm(test_data_):
        x = x.cuda(1)
        y = y.cuda(1)
        K.append(sum(torch.argmax(model(x)["logits"],-1) == y).item()/len(y))
print(f"{i}th epoch passed, validation acc is {np.mean(K)}, train loss is {np.mean(Loss)}")

100%|███████████████████████████████████████| 2084/2084 [06:07<00:00,  5.67it/s]

0th epoch passed, validation acc is 0.0009596928982725527, train loss is 6.8178387895415105



